From 8eddcd77bfa2cb8a03027ba5bc0dc6cc062f2608 Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Fri, 24 Apr 2020 23:48:55 +0800 Subject: [PATCH] Chore: dialer hook should return a error --- component/dialer/dialer.go | 48 +++++++++++++++++++++++++++----------- component/dialer/hook.go | 26 +++++++++++++-------- dns/client.go | 10 ++++++-- 3 files changed, 58 insertions(+), 26 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 6dcecdc1..5cb9badb 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -8,22 +8,26 @@ import ( "github.com/Dreamacro/clash/component/resolver" ) -func Dialer() *net.Dialer { +func Dialer() (*net.Dialer, error) { dialer := &net.Dialer{} if DialerHook != nil { - DialerHook(dialer) + if err := DialerHook(dialer); err != nil { + return nil, err + } } - return dialer + return dialer, nil } -func ListenConfig() *net.ListenConfig { +func ListenConfig() (*net.ListenConfig, error) { cfg := &net.ListenConfig{} if ListenConfigHook != nil { - ListenConfigHook(cfg) + if err := ListenConfigHook(cfg); err != nil { + return nil, err + } } - return cfg + return cfg, nil } func Dial(network, address string) (net.Conn, error) { @@ -38,7 +42,10 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error) return nil, err } - dialer := Dialer() + dialer, err := Dialer() + if err != nil { + return nil, err + } var ip net.IP switch network { @@ -53,7 +60,9 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error) } if DialHook != nil { - DialHook(dialer, network, ip) + if err := DialHook(dialer, network, ip); err != nil { + return nil, err + } } return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) case "tcp", "udp": @@ -64,13 +73,17 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error) } func ListenPacket(network, address string) (net.PacketConn, error) { - lc := ListenConfig() + lc, err := ListenConfig() + if err != nil { + return nil, err + } if ListenPacketHook != nil && address == "" { - ip := ListenPacketHook() - if ip != nil { - address = net.JoinHostPort(ip.String(), "0") + ip, err := ListenPacketHook() + if err != nil { + return nil, err } + address = net.JoinHostPort(ip.String(), "0") } return lc.ListenPacket(context.Background(), network, address) } @@ -95,7 +108,6 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con var primary, fallback dialResult startRacer := func(ctx context.Context, network, host string, ipv6 bool) { - dialer := Dialer() result := dialResult{ipv6: ipv6, done: true} defer func() { select { @@ -107,6 +119,12 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con } }() + dialer, err := Dialer() + if err != nil { + result.error = err + return + } + var ip net.IP if ipv6 { ip, result.error = resolver.ResolveIPv6(host) @@ -119,7 +137,9 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con result.resolved = true if DialHook != nil { - DialHook(dialer, network, ip) + if result.error = DialHook(dialer, network, ip); result.error != nil { + return + } } result.Conn, result.error = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) } diff --git a/component/dialer/hook.go b/component/dialer/hook.go index 3546a393..d4c955ab 100644 --- a/component/dialer/hook.go +++ b/component/dialer/hook.go @@ -8,10 +8,10 @@ import ( "github.com/Dreamacro/clash/common/singledo" ) -type DialerHookFunc = func(dialer *net.Dialer) -type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) -type ListenConfigHookFunc = func(*net.ListenConfig) -type ListenPacketHookFunc = func() net.IP +type DialerHookFunc = func(dialer *net.Dialer) error +type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) error +type ListenConfigHookFunc = func(*net.ListenConfig) error +type ListenPacketHookFunc = func() (net.IP, error) var ( DialerHook DialerHookFunc @@ -70,7 +70,7 @@ func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) { func ListenPacketWithInterface(name string) ListenPacketHookFunc { single := singledo.NewSingle(5 * time.Second) - return func() net.IP { + return func() (net.IP, error) { elm, err, _ := single.Do(func() (interface{}, error) { iface, err := net.InterfaceByName(name) if err != nil { @@ -86,7 +86,7 @@ func ListenPacketWithInterface(name string) ListenPacketHookFunc { }) if err != nil { - return nil + return nil, err } addrs := elm.([]net.Addr) @@ -97,17 +97,17 @@ func ListenPacketWithInterface(name string) ListenPacketHookFunc { continue } - return addr.IP + return addr.IP, nil } - return nil + return nil, ErrAddrNotFound } } func DialerWithInterface(name string) DialHookFunc { single := singledo.NewSingle(5 * time.Second) - return func(dialer *net.Dialer, network string, ip net.IP) { + return func(dialer *net.Dialer, network string, ip net.IP) error { elm, err, _ := single.Do(func() (interface{}, error) { iface, err := net.InterfaceByName(name) if err != nil { @@ -123,7 +123,7 @@ func DialerWithInterface(name string) DialHookFunc { }) if err != nil { - return + return err } addrs := elm.([]net.Addr) @@ -132,11 +132,17 @@ func DialerWithInterface(name string) DialHookFunc { case "tcp", "tcp4", "tcp6": if addr, err := lookupTCPAddr(ip, addrs); err == nil { dialer.LocalAddr = addr + } else { + return err } case "udp", "udp4", "udp6": if addr, err := lookupUDPAddr(ip, addrs); err == nil { dialer.LocalAddr = addr + } else { + return err } } + + return nil } } diff --git a/dns/client.go b/dns/client.go index f12b0e01..a40888dd 100644 --- a/dns/client.go +++ b/dns/client.go @@ -34,13 +34,19 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err } } - d := dialer.Dialer() + d, err := dialer.Dialer() + if err != nil { + return nil, err + } + if dialer.DialHook != nil { network := "udp" if strings.HasPrefix(c.Client.Net, "tcp") { network = "tcp" } - dialer.DialHook(d, network, ip) + if err := dialer.DialHook(d, network, ip); err != nil { + return nil, err + } } c.Client.Dialer = d