Chore: dialer hook should return a error

This commit is contained in:
Dreamacro 2020-04-24 23:48:55 +08:00
parent 27dd1d7944
commit 8eddcd77bf
3 changed files with 58 additions and 26 deletions

View file

@ -8,22 +8,26 @@ import (
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
) )
func Dialer() *net.Dialer { func Dialer() (*net.Dialer, error) {
dialer := &net.Dialer{} dialer := &net.Dialer{}
if DialerHook != nil { if DialerHook != nil {
DialerHook(dialer) if err := DialerHook(dialer); err != nil {
return nil, err
}
} }
return dialer return dialer, nil
} }
func ListenConfig() *net.ListenConfig { func ListenConfig() (*net.ListenConfig, error) {
cfg := &net.ListenConfig{} cfg := &net.ListenConfig{}
if ListenConfigHook != nil { if ListenConfigHook != nil {
ListenConfigHook(cfg) if err := ListenConfigHook(cfg); err != nil {
return nil, err
}
} }
return cfg return cfg, nil
} }
func Dial(network, address string) (net.Conn, error) { func Dial(network, address string) (net.Conn, error) {
@ -38,7 +42,10 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error)
return nil, err return nil, err
} }
dialer := Dialer() dialer, err := Dialer()
if err != nil {
return nil, err
}
var ip net.IP var ip net.IP
switch network { switch network {
@ -53,7 +60,9 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
if DialHook != nil { if DialHook != nil {
DialHook(dialer, network, ip) if err := DialHook(dialer, network, ip); err != nil {
return nil, err
}
} }
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
case "tcp", "udp": case "tcp", "udp":
@ -64,13 +73,17 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
func ListenPacket(network, address string) (net.PacketConn, error) { func ListenPacket(network, address string) (net.PacketConn, error) {
lc := ListenConfig() lc, err := ListenConfig()
if err != nil {
return nil, err
}
if ListenPacketHook != nil && address == "" { if ListenPacketHook != nil && address == "" {
ip := ListenPacketHook() ip, err := ListenPacketHook()
if ip != nil { if err != nil {
address = net.JoinHostPort(ip.String(), "0") return nil, err
} }
address = net.JoinHostPort(ip.String(), "0")
} }
return lc.ListenPacket(context.Background(), network, address) return lc.ListenPacket(context.Background(), network, address)
} }
@ -95,7 +108,6 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con
var primary, fallback dialResult var primary, fallback dialResult
startRacer := func(ctx context.Context, network, host string, ipv6 bool) { startRacer := func(ctx context.Context, network, host string, ipv6 bool) {
dialer := Dialer()
result := dialResult{ipv6: ipv6, done: true} result := dialResult{ipv6: ipv6, done: true}
defer func() { defer func() {
select { select {
@ -107,6 +119,12 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con
} }
}() }()
dialer, err := Dialer()
if err != nil {
result.error = err
return
}
var ip net.IP var ip net.IP
if ipv6 { if ipv6 {
ip, result.error = resolver.ResolveIPv6(host) ip, result.error = resolver.ResolveIPv6(host)
@ -119,7 +137,9 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con
result.resolved = true result.resolved = true
if DialHook != nil { if DialHook != nil {
DialHook(dialer, network, ip) if result.error = DialHook(dialer, network, ip); result.error != nil {
return
}
} }
result.Conn, result.error = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) result.Conn, result.error = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
} }

View file

@ -8,10 +8,10 @@ import (
"github.com/Dreamacro/clash/common/singledo" "github.com/Dreamacro/clash/common/singledo"
) )
type DialerHookFunc = func(dialer *net.Dialer) type DialerHookFunc = func(dialer *net.Dialer) error
type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) error
type ListenConfigHookFunc = func(*net.ListenConfig) type ListenConfigHookFunc = func(*net.ListenConfig) error
type ListenPacketHookFunc = func() net.IP type ListenPacketHookFunc = func() (net.IP, error)
var ( var (
DialerHook DialerHookFunc DialerHook DialerHookFunc
@ -70,7 +70,7 @@ func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) {
func ListenPacketWithInterface(name string) ListenPacketHookFunc { func ListenPacketWithInterface(name string) ListenPacketHookFunc {
single := singledo.NewSingle(5 * time.Second) single := singledo.NewSingle(5 * time.Second)
return func() net.IP { return func() (net.IP, error) {
elm, err, _ := single.Do(func() (interface{}, error) { elm, err, _ := single.Do(func() (interface{}, error) {
iface, err := net.InterfaceByName(name) iface, err := net.InterfaceByName(name)
if err != nil { if err != nil {
@ -86,7 +86,7 @@ func ListenPacketWithInterface(name string) ListenPacketHookFunc {
}) })
if err != nil { if err != nil {
return nil return nil, err
} }
addrs := elm.([]net.Addr) addrs := elm.([]net.Addr)
@ -97,17 +97,17 @@ func ListenPacketWithInterface(name string) ListenPacketHookFunc {
continue continue
} }
return addr.IP return addr.IP, nil
} }
return nil return nil, ErrAddrNotFound
} }
} }
func DialerWithInterface(name string) DialHookFunc { func DialerWithInterface(name string) DialHookFunc {
single := singledo.NewSingle(5 * time.Second) single := singledo.NewSingle(5 * time.Second)
return func(dialer *net.Dialer, network string, ip net.IP) { return func(dialer *net.Dialer, network string, ip net.IP) error {
elm, err, _ := single.Do(func() (interface{}, error) { elm, err, _ := single.Do(func() (interface{}, error) {
iface, err := net.InterfaceByName(name) iface, err := net.InterfaceByName(name)
if err != nil { if err != nil {
@ -123,7 +123,7 @@ func DialerWithInterface(name string) DialHookFunc {
}) })
if err != nil { if err != nil {
return return err
} }
addrs := elm.([]net.Addr) addrs := elm.([]net.Addr)
@ -132,11 +132,17 @@ func DialerWithInterface(name string) DialHookFunc {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
if addr, err := lookupTCPAddr(ip, addrs); err == nil { if addr, err := lookupTCPAddr(ip, addrs); err == nil {
dialer.LocalAddr = addr dialer.LocalAddr = addr
} else {
return err
} }
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
if addr, err := lookupUDPAddr(ip, addrs); err == nil { if addr, err := lookupUDPAddr(ip, addrs); err == nil {
dialer.LocalAddr = addr dialer.LocalAddr = addr
} else {
return err
} }
} }
return nil
} }
} }

View file

@ -34,13 +34,19 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err
} }
} }
d := dialer.Dialer() d, err := dialer.Dialer()
if err != nil {
return nil, err
}
if dialer.DialHook != nil { if dialer.DialHook != nil {
network := "udp" network := "udp"
if strings.HasPrefix(c.Client.Net, "tcp") { if strings.HasPrefix(c.Client.Net, "tcp") {
network = "tcp" network = "tcp"
} }
dialer.DialHook(d, network, ip) if err := dialer.DialHook(d, network, ip); err != nil {
return nil, err
}
} }
c.Client.Dialer = d c.Client.Dialer = d