chore: support golang1.20's dialer.ControlContext

This commit is contained in:
wwqgtxx 2023-02-13 11:14:19 +08:00
parent ce8929d153
commit ae42d35184
9 changed files with 96 additions and 65 deletions

View file

@ -1,6 +1,7 @@
package dialer
import (
"context"
"net"
"net/netip"
"syscall"
@ -10,16 +11,8 @@ import (
"golang.org/x/sys/unix"
)
type controlFn = func(network, address string, c syscall.RawConn) error
func bindControl(ifaceIdx int, chain controlFn) controlFn {
return func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
func bindControl(ifaceIdx int) controlFn {
return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
addrPort, err := netip.ParseAddrPort(address)
if err == nil && !addrPort.Addr().IsGlobalUnicast() {
return
@ -49,7 +42,7 @@ func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.A
return err
}
dialer.Control = bindControl(ifaceObj.Index, dialer.Control)
addControlToDialer(dialer, bindControl(ifaceObj.Index))
return nil
}
@ -59,7 +52,7 @@ func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address
return "", err
}
lc.Control = bindControl(ifaceObj.Index, lc.Control)
addControlToListenConfig(lc, bindControl(ifaceObj.Index))
return address, nil
}

View file

@ -1,6 +1,7 @@
package dialer
import (
"context"
"net"
"net/netip"
"syscall"
@ -8,16 +9,8 @@ import (
"golang.org/x/sys/unix"
)
type controlFn = func(network, address string, c syscall.RawConn) error
func bindControl(ifaceName string, chain controlFn) controlFn {
return func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
func bindControl(ifaceName string) controlFn {
return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
addrPort, err := netip.ParseAddrPort(address)
if err == nil && !addrPort.Addr().IsGlobalUnicast() {
return
@ -37,13 +30,13 @@ func bindControl(ifaceName string, chain controlFn) controlFn {
}
func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.Addr) error {
dialer.Control = bindControl(ifaceName, dialer.Control)
addControlToDialer(dialer, bindControl(ifaceName))
return nil
}
func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address string) (string, error) {
lc.Control = bindControl(ifaceName, lc.Control)
addControlToListenConfig(lc, bindControl(ifaceName))
return address, nil
}

View file

@ -1,6 +1,7 @@
package dialer
import (
"context"
"encoding/binary"
"net"
"net/netip"
@ -26,16 +27,8 @@ func bind6(handle syscall.Handle, ifaceIdx int) error {
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx)
}
type controlFn = func(network, address string, c syscall.RawConn) error
func bindControl(ifaceIdx int, chain controlFn) controlFn {
return func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
func bindControl(ifaceIdx int) controlFn {
return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
addrPort, err := netip.ParseAddrPort(address)
if err == nil && !addrPort.Addr().IsGlobalUnicast() {
return
@ -69,7 +62,7 @@ func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.A
return err
}
dialer.Control = bindControl(ifaceObj.Index, dialer.Control)
addControlToDialer(dialer, bindControl(ifaceObj.Index))
return nil
}
@ -79,7 +72,7 @@ func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address
return "", err
}
lc.Control = bindControl(ifaceObj.Index, lc.Control)
addControlToListenConfig(lc, bindControl(ifaceObj.Index))
return address, nil
}

View file

@ -0,0 +1,22 @@
package dialer
import (
"context"
"net"
"syscall"
)
type controlFn = func(ctx context.Context, network, address string, c syscall.RawConn) error
func addControlToListenConfig(lc *net.ListenConfig, fn controlFn) {
llc := *lc
lc.Control = func(network, address string, c syscall.RawConn) (err error) {
switch {
case llc.Control != nil:
if err = llc.Control(network, address, c); err != nil {
return
}
}
return fn(context.Background(), network, address, c)
}
}

View file

@ -0,0 +1,22 @@
//go:build !go1.20
package dialer
import (
"context"
"net"
"syscall"
)
func addControlToDialer(d *net.Dialer, fn controlFn) {
ld := *d
d.Control = func(network, address string, c syscall.RawConn) (err error) {
switch {
case ld.Control != nil:
if err = ld.Control(network, address, c); err != nil {
return
}
}
return fn(context.Background(), network, address, c)
}
}

View file

@ -0,0 +1,26 @@
//go:build go1.20
package dialer
import (
"context"
"net"
"syscall"
)
func addControlToDialer(d *net.Dialer, fn controlFn) {
ld := *d
d.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
switch {
case ld.ControlContext != nil:
if err = ld.ControlContext(ctx, network, address, c); err != nil {
return
}
case ld.Control != nil:
if err = ld.Control(network, address, c); err != nil {
return
}
}
return fn(ctx, network, address, c)
}
}

View file

@ -3,26 +3,22 @@
package dialer
import (
"context"
"net"
"net/netip"
"syscall"
)
func bindMarkToDialer(mark int, dialer *net.Dialer, _ string, _ netip.Addr) {
dialer.Control = bindMarkToControl(mark, dialer.Control)
addControlToDialer(dialer, bindMarkToControl(mark))
}
func bindMarkToListenConfig(mark int, lc *net.ListenConfig, _, _ string) {
lc.Control = bindMarkToControl(mark, lc.Control)
addControlToListenConfig(lc, bindMarkToControl(mark))
}
func bindMarkToControl(mark int, chain controlFn) controlFn {
return func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
func bindMarkToControl(mark int) controlFn {
return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
addrPort, err := netip.ParseAddrPort(address)
if err == nil && !addrPort.Addr().IsGlobalUnicast() {

View file

@ -3,6 +3,7 @@
package dialer
import (
"context"
"net"
"syscall"
@ -10,18 +11,10 @@ import (
)
func addrReuseToListenConfig(lc *net.ListenConfig) {
chain := lc.Control
lc.Control = func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
}
})
}

View file

@ -1,6 +1,7 @@
package dialer
import (
"context"
"net"
"syscall"
@ -8,17 +9,9 @@ import (
)
func addrReuseToListenConfig(lc *net.ListenConfig) {
chain := lc.Control
lc.Control = func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1)
})
}
})
}