chore: wireguard using internal dialer

This commit is contained in:
wwqgtxx 2023-03-07 09:30:51 +08:00
parent 545a79d406
commit 9cc7fdaca9
3 changed files with 57 additions and 23 deletions

View file

@ -34,7 +34,7 @@ type WireGuard struct {
bind *wireguard.ClientBind bind *wireguard.ClientBind
device *device.Device device *device.Device
tunDevice wireguard.Device tunDevice wireguard.Device
dialer *wgDialer dialer *wgSingDialer
startOnce sync.Once startOnce sync.Once
startErr error startErr error
} }
@ -56,16 +56,28 @@ type WireGuardOption struct {
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"` PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
} }
type wgDialer struct { type wgSingDialer struct {
options []dialer.Option dialer dialer.Dialer
} }
func (d *wgDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { var _ N.Dialer = &wgSingDialer{}
return dialer.DialContext(ctx, network, destination.String(), d.options...)
func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return d.dialer.DialContext(ctx, network, destination.String())
} }
func (d *wgDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return dialer.ListenPacket(ctx, dialer.ParseNetwork("udp", destination.Addr), "", d.options...) return d.dialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
}
type wgNetDialer struct {
tunDevice wireguard.Device
}
var _ dialer.NetDialer = &wgNetDialer{}
func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address))
} }
func NewWireGuard(option WireGuardOption) (*WireGuard, error) { func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
@ -79,7 +91,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: C.NewDNSPrefer(option.IPVersion),
}, },
dialer: &wgDialer{}, dialer: &wgSingDialer{dialer: dialer.NewDialer()},
} }
runtime.SetFinalizer(outbound, closeWireGuard) runtime.SetFinalizer(outbound, closeWireGuard)
@ -199,7 +211,8 @@ func closeWireGuard(w *WireGuard) {
} }
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
w.dialer.options = opts options := w.Base.DialOptions(opts...)
w.dialer.dialer = dialer.NewDialer(options...)
var conn net.Conn var conn net.Conn
w.startOnce.Do(func() { w.startOnce.Do(func() {
w.startErr = w.tunDevice.Start() w.startErr = w.tunDevice.Start()
@ -208,12 +221,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
return nil, w.startErr return nil, w.startErr
} }
if !metadata.Resolved() { if !metadata.Resolved() {
var addrs []netip.Addr options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
addrs, err = resolver.LookupIP(ctx, metadata.Host) conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
if err != nil {
return nil, err
}
conn, err = N.DialSerial(ctx, w.tunDevice, "tcp", M.ParseSocksaddr(metadata.RemoteAddress()), addrs)
} else { } else {
port, _ := strconv.Atoi(metadata.DstPort) port, _ := strconv.Atoi(metadata.DstPort)
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, uint16(port))) conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, uint16(port)))
@ -228,7 +237,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
} }
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
w.dialer.options = opts options := w.Base.DialOptions(opts...)
w.dialer.dialer = dialer.NewDialer(options...)
var pc net.PacketConn var pc net.PacketConn
w.startOnce.Do(func() { w.startOnce.Do(func() {
w.startErr = w.tunDevice.Start() w.startErr = w.tunDevice.Start()

View file

@ -109,7 +109,19 @@ func GetTcpConcurrent() bool {
} }
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) { func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
dialer := &net.Dialer{} address := net.JoinHostPort(destination.String(), port)
netDialer := opt.netDialer
switch netDialer.(type) {
case nil:
netDialer = &net.Dialer{}
case *net.Dialer:
netDialer = &*netDialer.(*net.Dialer) // make a copy
default:
return netDialer.DialContext(ctx, network, address)
}
dialer := netDialer.(*net.Dialer)
if opt.interfaceName != "" { if opt.interfaceName != "" {
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil { if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err return nil, err
@ -118,8 +130,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
if opt.routingMark != 0 { if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination) bindMarkToDialer(opt.routingMark, dialer, network, destination)
} }
address := net.JoinHostPort(destination.String(), port)
if opt.tfo { if opt.tfo {
return dialTFO(ctx, *dialer, network, address) return dialTFO(ctx, *dialer, network, address)
} }
@ -307,15 +317,15 @@ func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) {
} }
type Dialer struct { type Dialer struct {
Opt option opt option
} }
func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return DialContext(ctx, network, address, WithOption(d.Opt)) return DialContext(ctx, network, address, WithOption(d.opt))
} }
func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
opt := WithOption(d.Opt) opt := WithOption(d.opt)
if rAddrPort.Addr().Unmap().IsLoopback() { if rAddrPort.Addr().Unmap().IsLoopback() {
// avoid "The requested address is not valid in its context." // avoid "The requested address is not valid in its context."
opt = WithInterface("") opt = WithInterface("")
@ -325,5 +335,5 @@ func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddr
func NewDialer(options ...Option) Dialer { func NewDialer(options ...Option) Dialer {
opt := applyOptions(options...) opt := applyOptions(options...)
return Dialer{Opt: *opt} return Dialer{opt: *opt}
} }

View file

@ -1,6 +1,9 @@
package dialer package dialer
import ( import (
"context"
"net"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic" "go.uber.org/atomic"
@ -12,6 +15,10 @@ var (
DefaultRoutingMark = atomic.NewInt32(0) DefaultRoutingMark = atomic.NewInt32(0)
) )
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type option struct { type option struct {
interfaceName string interfaceName string
addrReuse bool addrReuse bool
@ -20,6 +27,7 @@ type option struct {
prefer int prefer int
tfo bool tfo bool
resolver resolver.Resolver resolver resolver.Resolver
netDialer NetDialer
} }
type Option func(opt *option) type Option func(opt *option)
@ -76,6 +84,12 @@ func WithTFO(tfo bool) Option {
} }
} }
func WithNetDialer(netDialer NetDialer) Option {
return func(opt *option) {
opt.netDialer = netDialer
}
}
func WithOption(o option) Option { func WithOption(o option) Option {
return func(opt *option) { return func(opt *option) {
*opt = o *opt = o