Chore: dialer hook should return a error
This commit is contained in:
parent
27dd1d7944
commit
8eddcd77bf
3 changed files with 58 additions and 26 deletions
|
@ -8,22 +8,26 @@ import (
|
||||||
"github.com/Dreamacro/clash/component/resolver"
|
"github.com/Dreamacro/clash/component/resolver"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Dialer() *net.Dialer {
|
func Dialer() (*net.Dialer, error) {
|
||||||
dialer := &net.Dialer{}
|
dialer := &net.Dialer{}
|
||||||
if DialerHook != nil {
|
if DialerHook != nil {
|
||||||
DialerHook(dialer)
|
if err := DialerHook(dialer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return dialer
|
return dialer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenConfig() *net.ListenConfig {
|
func ListenConfig() (*net.ListenConfig, error) {
|
||||||
cfg := &net.ListenConfig{}
|
cfg := &net.ListenConfig{}
|
||||||
if ListenConfigHook != nil {
|
if ListenConfigHook != nil {
|
||||||
ListenConfigHook(cfg)
|
if err := ListenConfigHook(cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return cfg
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Dial(network, address string) (net.Conn, error) {
|
func Dial(network, address string) (net.Conn, error) {
|
||||||
|
@ -38,7 +42,10 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer := Dialer()
|
dialer, err := Dialer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
switch network {
|
switch network {
|
||||||
|
@ -53,7 +60,9 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if DialHook != nil {
|
if DialHook != nil {
|
||||||
DialHook(dialer, network, ip)
|
if err := DialHook(dialer, network, ip); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||||
case "tcp", "udp":
|
case "tcp", "udp":
|
||||||
|
@ -64,13 +73,17 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenPacket(network, address string) (net.PacketConn, error) {
|
func ListenPacket(network, address string) (net.PacketConn, error) {
|
||||||
lc := ListenConfig()
|
lc, err := ListenConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if ListenPacketHook != nil && address == "" {
|
if ListenPacketHook != nil && address == "" {
|
||||||
ip := ListenPacketHook()
|
ip, err := ListenPacketHook()
|
||||||
if ip != nil {
|
if err != nil {
|
||||||
address = net.JoinHostPort(ip.String(), "0")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
address = net.JoinHostPort(ip.String(), "0")
|
||||||
}
|
}
|
||||||
return lc.ListenPacket(context.Background(), network, address)
|
return lc.ListenPacket(context.Background(), network, address)
|
||||||
}
|
}
|
||||||
|
@ -95,7 +108,6 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con
|
||||||
var primary, fallback dialResult
|
var primary, fallback dialResult
|
||||||
|
|
||||||
startRacer := func(ctx context.Context, network, host string, ipv6 bool) {
|
startRacer := func(ctx context.Context, network, host string, ipv6 bool) {
|
||||||
dialer := Dialer()
|
|
||||||
result := dialResult{ipv6: ipv6, done: true}
|
result := dialResult{ipv6: ipv6, done: true}
|
||||||
defer func() {
|
defer func() {
|
||||||
select {
|
select {
|
||||||
|
@ -107,6 +119,12 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
dialer, err := Dialer()
|
||||||
|
if err != nil {
|
||||||
|
result.error = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
if ipv6 {
|
if ipv6 {
|
||||||
ip, result.error = resolver.ResolveIPv6(host)
|
ip, result.error = resolver.ResolveIPv6(host)
|
||||||
|
@ -119,7 +137,9 @@ func dualStackDailContext(ctx context.Context, network, address string) (net.Con
|
||||||
result.resolved = true
|
result.resolved = true
|
||||||
|
|
||||||
if DialHook != nil {
|
if DialHook != nil {
|
||||||
DialHook(dialer, network, ip)
|
if result.error = DialHook(dialer, network, ip); result.error != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
result.Conn, result.error = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
result.Conn, result.error = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,10 +8,10 @@ import (
|
||||||
"github.com/Dreamacro/clash/common/singledo"
|
"github.com/Dreamacro/clash/common/singledo"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DialerHookFunc = func(dialer *net.Dialer)
|
type DialerHookFunc = func(dialer *net.Dialer) error
|
||||||
type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP)
|
type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) error
|
||||||
type ListenConfigHookFunc = func(*net.ListenConfig)
|
type ListenConfigHookFunc = func(*net.ListenConfig) error
|
||||||
type ListenPacketHookFunc = func() net.IP
|
type ListenPacketHookFunc = func() (net.IP, error)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
DialerHook DialerHookFunc
|
DialerHook DialerHookFunc
|
||||||
|
@ -70,7 +70,7 @@ func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) {
|
||||||
func ListenPacketWithInterface(name string) ListenPacketHookFunc {
|
func ListenPacketWithInterface(name string) ListenPacketHookFunc {
|
||||||
single := singledo.NewSingle(5 * time.Second)
|
single := singledo.NewSingle(5 * time.Second)
|
||||||
|
|
||||||
return func() net.IP {
|
return func() (net.IP, error) {
|
||||||
elm, err, _ := single.Do(func() (interface{}, error) {
|
elm, err, _ := single.Do(func() (interface{}, error) {
|
||||||
iface, err := net.InterfaceByName(name)
|
iface, err := net.InterfaceByName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -86,7 +86,7 @@ func ListenPacketWithInterface(name string) ListenPacketHookFunc {
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
addrs := elm.([]net.Addr)
|
addrs := elm.([]net.Addr)
|
||||||
|
@ -97,17 +97,17 @@ func ListenPacketWithInterface(name string) ListenPacketHookFunc {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return addr.IP
|
return addr.IP, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, ErrAddrNotFound
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DialerWithInterface(name string) DialHookFunc {
|
func DialerWithInterface(name string) DialHookFunc {
|
||||||
single := singledo.NewSingle(5 * time.Second)
|
single := singledo.NewSingle(5 * time.Second)
|
||||||
|
|
||||||
return func(dialer *net.Dialer, network string, ip net.IP) {
|
return func(dialer *net.Dialer, network string, ip net.IP) error {
|
||||||
elm, err, _ := single.Do(func() (interface{}, error) {
|
elm, err, _ := single.Do(func() (interface{}, error) {
|
||||||
iface, err := net.InterfaceByName(name)
|
iface, err := net.InterfaceByName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -123,7 +123,7 @@ func DialerWithInterface(name string) DialHookFunc {
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
addrs := elm.([]net.Addr)
|
addrs := elm.([]net.Addr)
|
||||||
|
@ -132,11 +132,17 @@ func DialerWithInterface(name string) DialHookFunc {
|
||||||
case "tcp", "tcp4", "tcp6":
|
case "tcp", "tcp4", "tcp6":
|
||||||
if addr, err := lookupTCPAddr(ip, addrs); err == nil {
|
if addr, err := lookupTCPAddr(ip, addrs); err == nil {
|
||||||
dialer.LocalAddr = addr
|
dialer.LocalAddr = addr
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
case "udp", "udp4", "udp6":
|
case "udp", "udp4", "udp6":
|
||||||
if addr, err := lookupUDPAddr(ip, addrs); err == nil {
|
if addr, err := lookupUDPAddr(ip, addrs); err == nil {
|
||||||
dialer.LocalAddr = addr
|
dialer.LocalAddr = addr
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,13 +34,19 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d := dialer.Dialer()
|
d, err := dialer.Dialer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if dialer.DialHook != nil {
|
if dialer.DialHook != nil {
|
||||||
network := "udp"
|
network := "udp"
|
||||||
if strings.HasPrefix(c.Client.Net, "tcp") {
|
if strings.HasPrefix(c.Client.Net, "tcp") {
|
||||||
network = "tcp"
|
network = "tcp"
|
||||||
}
|
}
|
||||||
dialer.DialHook(d, network, ip)
|
if err := dialer.DialHook(d, network, ip); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Client.Dialer = d
|
c.Client.Dialer = d
|
||||||
|
|
Loading…
Reference in a new issue