From ae42d351844628fa5ab2edc95813de6f9fd3f93f Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 13 Feb 2023 11:14:19 +0800 Subject: [PATCH] chore: support golang1.20's dialer.ControlContext --- component/dialer/bind_darwin.go | 17 +++++------------ component/dialer/bind_linux.go | 17 +++++------------ component/dialer/bind_windows.go | 17 +++++------------ component/dialer/control.go | 22 ++++++++++++++++++++++ component/dialer/control_go119.go | 22 ++++++++++++++++++++++ component/dialer/control_go120.go | 26 ++++++++++++++++++++++++++ component/dialer/mark_linux.go | 14 +++++--------- component/dialer/reuse_unix.go | 13 +++---------- component/dialer/reuse_windows.go | 13 +++---------- 9 files changed, 96 insertions(+), 65 deletions(-) create mode 100644 component/dialer/control.go create mode 100644 component/dialer/control_go119.go create mode 100644 component/dialer/control_go120.go diff --git a/component/dialer/bind_darwin.go b/component/dialer/bind_darwin.go index 8e88b461..8705a708 100644 --- a/component/dialer/bind_darwin.go +++ b/component/dialer/bind_darwin.go @@ -1,6 +1,7 @@ package dialer import ( + "context" "net" "net/netip" "syscall" @@ -10,16 +11,8 @@ import ( "golang.org/x/sys/unix" ) -type controlFn = func(network, address string, c syscall.RawConn) error - -func bindControl(ifaceIdx int, chain controlFn) controlFn { - return func(network, address string, c syscall.RawConn) (err error) { - defer func() { - if err == nil && chain != nil { - err = chain(network, address, c) - } - }() - +func bindControl(ifaceIdx int) controlFn { + return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) { addrPort, err := netip.ParseAddrPort(address) if err == nil && !addrPort.Addr().IsGlobalUnicast() { return @@ -49,7 +42,7 @@ func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.A return err } - dialer.Control = bindControl(ifaceObj.Index, dialer.Control) + addControlToDialer(dialer, bindControl(ifaceObj.Index)) return nil } @@ -59,7 +52,7 @@ func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address return "", err } - lc.Control = bindControl(ifaceObj.Index, lc.Control) + addControlToListenConfig(lc, bindControl(ifaceObj.Index)) return address, nil } diff --git a/component/dialer/bind_linux.go b/component/dialer/bind_linux.go index 57a2e0c1..1ec98f3d 100644 --- a/component/dialer/bind_linux.go +++ b/component/dialer/bind_linux.go @@ -1,6 +1,7 @@ package dialer import ( + "context" "net" "net/netip" "syscall" @@ -8,16 +9,8 @@ import ( "golang.org/x/sys/unix" ) -type controlFn = func(network, address string, c syscall.RawConn) error - -func bindControl(ifaceName string, chain controlFn) controlFn { - return func(network, address string, c syscall.RawConn) (err error) { - defer func() { - if err == nil && chain != nil { - err = chain(network, address, c) - } - }() - +func bindControl(ifaceName string) controlFn { + return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) { addrPort, err := netip.ParseAddrPort(address) if err == nil && !addrPort.Addr().IsGlobalUnicast() { return @@ -37,13 +30,13 @@ func bindControl(ifaceName string, chain controlFn) controlFn { } func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.Addr) error { - dialer.Control = bindControl(ifaceName, dialer.Control) + addControlToDialer(dialer, bindControl(ifaceName)) return nil } func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address string) (string, error) { - lc.Control = bindControl(ifaceName, lc.Control) + addControlToListenConfig(lc, bindControl(ifaceName)) return address, nil } diff --git a/component/dialer/bind_windows.go b/component/dialer/bind_windows.go index 87b39bc2..b680e90f 100644 --- a/component/dialer/bind_windows.go +++ b/component/dialer/bind_windows.go @@ -1,6 +1,7 @@ package dialer import ( + "context" "encoding/binary" "net" "net/netip" @@ -26,16 +27,8 @@ func bind6(handle syscall.Handle, ifaceIdx int) error { return syscall.SetsockoptInt(handle, syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx) } -type controlFn = func(network, address string, c syscall.RawConn) error - -func bindControl(ifaceIdx int, chain controlFn) controlFn { - return func(network, address string, c syscall.RawConn) (err error) { - defer func() { - if err == nil && chain != nil { - err = chain(network, address, c) - } - }() - +func bindControl(ifaceIdx int) controlFn { + return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) { addrPort, err := netip.ParseAddrPort(address) if err == nil && !addrPort.Addr().IsGlobalUnicast() { return @@ -69,7 +62,7 @@ func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.A return err } - dialer.Control = bindControl(ifaceObj.Index, dialer.Control) + addControlToDialer(dialer, bindControl(ifaceObj.Index)) return nil } @@ -79,7 +72,7 @@ func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address return "", err } - lc.Control = bindControl(ifaceObj.Index, lc.Control) + addControlToListenConfig(lc, bindControl(ifaceObj.Index)) return address, nil } diff --git a/component/dialer/control.go b/component/dialer/control.go new file mode 100644 index 00000000..26b1db76 --- /dev/null +++ b/component/dialer/control.go @@ -0,0 +1,22 @@ +package dialer + +import ( + "context" + "net" + "syscall" +) + +type controlFn = func(ctx context.Context, network, address string, c syscall.RawConn) error + +func addControlToListenConfig(lc *net.ListenConfig, fn controlFn) { + llc := *lc + lc.Control = func(network, address string, c syscall.RawConn) (err error) { + switch { + case llc.Control != nil: + if err = llc.Control(network, address, c); err != nil { + return + } + } + return fn(context.Background(), network, address, c) + } +} diff --git a/component/dialer/control_go119.go b/component/dialer/control_go119.go new file mode 100644 index 00000000..ec980586 --- /dev/null +++ b/component/dialer/control_go119.go @@ -0,0 +1,22 @@ +//go:build !go1.20 + +package dialer + +import ( + "context" + "net" + "syscall" +) + +func addControlToDialer(d *net.Dialer, fn controlFn) { + ld := *d + d.Control = func(network, address string, c syscall.RawConn) (err error) { + switch { + case ld.Control != nil: + if err = ld.Control(network, address, c); err != nil { + return + } + } + return fn(context.Background(), network, address, c) + } +} diff --git a/component/dialer/control_go120.go b/component/dialer/control_go120.go new file mode 100644 index 00000000..65e33f9c --- /dev/null +++ b/component/dialer/control_go120.go @@ -0,0 +1,26 @@ +//go:build go1.20 + +package dialer + +import ( + "context" + "net" + "syscall" +) + +func addControlToDialer(d *net.Dialer, fn controlFn) { + ld := *d + d.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) (err error) { + switch { + case ld.ControlContext != nil: + if err = ld.ControlContext(ctx, network, address, c); err != nil { + return + } + case ld.Control != nil: + if err = ld.Control(network, address, c); err != nil { + return + } + } + return fn(ctx, network, address, c) + } +} diff --git a/component/dialer/mark_linux.go b/component/dialer/mark_linux.go index eaba5cf7..996c3dea 100644 --- a/component/dialer/mark_linux.go +++ b/component/dialer/mark_linux.go @@ -3,26 +3,22 @@ package dialer import ( + "context" "net" "net/netip" "syscall" ) func bindMarkToDialer(mark int, dialer *net.Dialer, _ string, _ netip.Addr) { - dialer.Control = bindMarkToControl(mark, dialer.Control) + addControlToDialer(dialer, bindMarkToControl(mark)) } func bindMarkToListenConfig(mark int, lc *net.ListenConfig, _, _ string) { - lc.Control = bindMarkToControl(mark, lc.Control) + addControlToListenConfig(lc, bindMarkToControl(mark)) } -func bindMarkToControl(mark int, chain controlFn) controlFn { - return func(network, address string, c syscall.RawConn) (err error) { - defer func() { - if err == nil && chain != nil { - err = chain(network, address, c) - } - }() +func bindMarkToControl(mark int) controlFn { + return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) { addrPort, err := netip.ParseAddrPort(address) if err == nil && !addrPort.Addr().IsGlobalUnicast() { diff --git a/component/dialer/reuse_unix.go b/component/dialer/reuse_unix.go index 85fe5e5e..a0cf7388 100644 --- a/component/dialer/reuse_unix.go +++ b/component/dialer/reuse_unix.go @@ -3,6 +3,7 @@ package dialer import ( + "context" "net" "syscall" @@ -10,18 +11,10 @@ import ( ) func addrReuseToListenConfig(lc *net.ListenConfig) { - chain := lc.Control - - lc.Control = func(network, address string, c syscall.RawConn) (err error) { - defer func() { - if err == nil && chain != nil { - err = chain(network, address, c) - } - }() - + addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) }) - } + }) } diff --git a/component/dialer/reuse_windows.go b/component/dialer/reuse_windows.go index 77fcf7ab..b8d0d809 100644 --- a/component/dialer/reuse_windows.go +++ b/component/dialer/reuse_windows.go @@ -1,6 +1,7 @@ package dialer import ( + "context" "net" "syscall" @@ -8,17 +9,9 @@ import ( ) func addrReuseToListenConfig(lc *net.ListenConfig) { - chain := lc.Control - - lc.Control = func(network, address string, c syscall.RawConn) (err error) { - defer func() { - if err == nil && chain != nil { - err = chain(network, address, c) - } - }() - + addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1) }) - } + }) }