chore: wireguard using internal dialer
This commit is contained in:
parent
545a79d406
commit
9cc7fdaca9
3 changed files with 57 additions and 23 deletions
|
@ -34,7 +34,7 @@ type WireGuard struct {
|
|||
bind *wireguard.ClientBind
|
||||
device *device.Device
|
||||
tunDevice wireguard.Device
|
||||
dialer *wgDialer
|
||||
dialer *wgSingDialer
|
||||
startOnce sync.Once
|
||||
startErr error
|
||||
}
|
||||
|
@ -56,16 +56,28 @@ type WireGuardOption struct {
|
|||
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
|
||||
}
|
||||
|
||||
type wgDialer struct {
|
||||
options []dialer.Option
|
||||
type wgSingDialer struct {
|
||||
dialer dialer.Dialer
|
||||
}
|
||||
|
||||
func (d *wgDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, network, destination.String(), d.options...)
|
||||
var _ N.Dialer = &wgSingDialer{}
|
||||
|
||||
func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
return d.dialer.DialContext(ctx, network, destination.String())
|
||||
}
|
||||
|
||||
func (d *wgDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return dialer.ListenPacket(ctx, dialer.ParseNetwork("udp", destination.Addr), "", d.options...)
|
||||
func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return d.dialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
|
||||
}
|
||||
|
||||
type wgNetDialer struct {
|
||||
tunDevice wireguard.Device
|
||||
}
|
||||
|
||||
var _ dialer.NetDialer = &wgNetDialer{}
|
||||
|
||||
func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address))
|
||||
}
|
||||
|
||||
func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
|
||||
|
@ -79,7 +91,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
|
|||
rmark: option.RoutingMark,
|
||||
prefer: C.NewDNSPrefer(option.IPVersion),
|
||||
},
|
||||
dialer: &wgDialer{},
|
||||
dialer: &wgSingDialer{dialer: dialer.NewDialer()},
|
||||
}
|
||||
runtime.SetFinalizer(outbound, closeWireGuard)
|
||||
|
||||
|
@ -199,7 +211,8 @@ func closeWireGuard(w *WireGuard) {
|
|||
}
|
||||
|
||||
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
|
||||
w.dialer.options = opts
|
||||
options := w.Base.DialOptions(opts...)
|
||||
w.dialer.dialer = dialer.NewDialer(options...)
|
||||
var conn net.Conn
|
||||
w.startOnce.Do(func() {
|
||||
w.startErr = w.tunDevice.Start()
|
||||
|
@ -208,12 +221,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
|
|||
return nil, w.startErr
|
||||
}
|
||||
if !metadata.Resolved() {
|
||||
var addrs []netip.Addr
|
||||
addrs, err = resolver.LookupIP(ctx, metadata.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err = N.DialSerial(ctx, w.tunDevice, "tcp", M.ParseSocksaddr(metadata.RemoteAddress()), addrs)
|
||||
options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
|
||||
conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
|
||||
} else {
|
||||
port, _ := strconv.Atoi(metadata.DstPort)
|
||||
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, uint16(port)))
|
||||
|
@ -228,7 +237,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
|
|||
}
|
||||
|
||||
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
|
||||
w.dialer.options = opts
|
||||
options := w.Base.DialOptions(opts...)
|
||||
w.dialer.dialer = dialer.NewDialer(options...)
|
||||
var pc net.PacketConn
|
||||
w.startOnce.Do(func() {
|
||||
w.startErr = w.tunDevice.Start()
|
||||
|
|
|
@ -109,7 +109,19 @@ func GetTcpConcurrent() bool {
|
|||
}
|
||||
|
||||
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
|
||||
dialer := &net.Dialer{}
|
||||
address := net.JoinHostPort(destination.String(), port)
|
||||
|
||||
netDialer := opt.netDialer
|
||||
switch netDialer.(type) {
|
||||
case nil:
|
||||
netDialer = &net.Dialer{}
|
||||
case *net.Dialer:
|
||||
netDialer = &*netDialer.(*net.Dialer) // make a copy
|
||||
default:
|
||||
return netDialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
dialer := netDialer.(*net.Dialer)
|
||||
if opt.interfaceName != "" {
|
||||
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
|
||||
return nil, err
|
||||
|
@ -118,8 +130,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
|
|||
if opt.routingMark != 0 {
|
||||
bindMarkToDialer(opt.routingMark, dialer, network, destination)
|
||||
}
|
||||
|
||||
address := net.JoinHostPort(destination.String(), port)
|
||||
if opt.tfo {
|
||||
return dialTFO(ctx, *dialer, network, address)
|
||||
}
|
||||
|
@ -307,15 +317,15 @@ func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) {
|
|||
}
|
||||
|
||||
type Dialer struct {
|
||||
Opt option
|
||||
opt option
|
||||
}
|
||||
|
||||
func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return DialContext(ctx, network, address, WithOption(d.Opt))
|
||||
return DialContext(ctx, network, address, WithOption(d.opt))
|
||||
}
|
||||
|
||||
func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
|
||||
opt := WithOption(d.Opt)
|
||||
opt := WithOption(d.opt)
|
||||
if rAddrPort.Addr().Unmap().IsLoopback() {
|
||||
// avoid "The requested address is not valid in its context."
|
||||
opt = WithInterface("")
|
||||
|
@ -325,5 +335,5 @@ func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddr
|
|||
|
||||
func NewDialer(options ...Option) Dialer {
|
||||
opt := applyOptions(options...)
|
||||
return Dialer{Opt: *opt}
|
||||
return Dialer{opt: *opt}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/Dreamacro/clash/component/resolver"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
@ -12,6 +15,10 @@ var (
|
|||
DefaultRoutingMark = atomic.NewInt32(0)
|
||||
)
|
||||
|
||||
type NetDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type option struct {
|
||||
interfaceName string
|
||||
addrReuse bool
|
||||
|
@ -20,6 +27,7 @@ type option struct {
|
|||
prefer int
|
||||
tfo bool
|
||||
resolver resolver.Resolver
|
||||
netDialer NetDialer
|
||||
}
|
||||
|
||||
type Option func(opt *option)
|
||||
|
@ -76,6 +84,12 @@ func WithTFO(tfo bool) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithNetDialer(netDialer NetDialer) Option {
|
||||
return func(opt *option) {
|
||||
opt.netDialer = netDialer
|
||||
}
|
||||
}
|
||||
|
||||
func WithOption(o option) Option {
|
||||
return func(opt *option) {
|
||||
*opt = o
|
||||
|
|
Loading…
Reference in a new issue