chore: wireguard using internal dialer
This commit is contained in:
parent
545a79d406
commit
9cc7fdaca9
3 changed files with 57 additions and 23 deletions
|
@ -34,7 +34,7 @@ type WireGuard struct {
|
||||||
bind *wireguard.ClientBind
|
bind *wireguard.ClientBind
|
||||||
device *device.Device
|
device *device.Device
|
||||||
tunDevice wireguard.Device
|
tunDevice wireguard.Device
|
||||||
dialer *wgDialer
|
dialer *wgSingDialer
|
||||||
startOnce sync.Once
|
startOnce sync.Once
|
||||||
startErr error
|
startErr error
|
||||||
}
|
}
|
||||||
|
@ -56,16 +56,28 @@ type WireGuardOption struct {
|
||||||
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
|
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type wgDialer struct {
|
type wgSingDialer struct {
|
||||||
options []dialer.Option
|
dialer dialer.Dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *wgDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
var _ N.Dialer = &wgSingDialer{}
|
||||||
return dialer.DialContext(ctx, network, destination.String(), d.options...)
|
|
||||||
|
func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
return d.dialer.DialContext(ctx, network, destination.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *wgDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
return dialer.ListenPacket(ctx, dialer.ParseNetwork("udp", destination.Addr), "", d.options...)
|
return d.dialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
|
||||||
|
}
|
||||||
|
|
||||||
|
type wgNetDialer struct {
|
||||||
|
tunDevice wireguard.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ dialer.NetDialer = &wgNetDialer{}
|
||||||
|
|
||||||
|
func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address))
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
|
func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
|
||||||
|
@ -79,7 +91,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
|
||||||
rmark: option.RoutingMark,
|
rmark: option.RoutingMark,
|
||||||
prefer: C.NewDNSPrefer(option.IPVersion),
|
prefer: C.NewDNSPrefer(option.IPVersion),
|
||||||
},
|
},
|
||||||
dialer: &wgDialer{},
|
dialer: &wgSingDialer{dialer: dialer.NewDialer()},
|
||||||
}
|
}
|
||||||
runtime.SetFinalizer(outbound, closeWireGuard)
|
runtime.SetFinalizer(outbound, closeWireGuard)
|
||||||
|
|
||||||
|
@ -199,7 +211,8 @@ func closeWireGuard(w *WireGuard) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
|
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
|
||||||
w.dialer.options = opts
|
options := w.Base.DialOptions(opts...)
|
||||||
|
w.dialer.dialer = dialer.NewDialer(options...)
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
w.startOnce.Do(func() {
|
w.startOnce.Do(func() {
|
||||||
w.startErr = w.tunDevice.Start()
|
w.startErr = w.tunDevice.Start()
|
||||||
|
@ -208,12 +221,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
|
||||||
return nil, w.startErr
|
return nil, w.startErr
|
||||||
}
|
}
|
||||||
if !metadata.Resolved() {
|
if !metadata.Resolved() {
|
||||||
var addrs []netip.Addr
|
options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
|
||||||
addrs, err = resolver.LookupIP(ctx, metadata.Host)
|
conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn, err = N.DialSerial(ctx, w.tunDevice, "tcp", M.ParseSocksaddr(metadata.RemoteAddress()), addrs)
|
|
||||||
} else {
|
} else {
|
||||||
port, _ := strconv.Atoi(metadata.DstPort)
|
port, _ := strconv.Atoi(metadata.DstPort)
|
||||||
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, uint16(port)))
|
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, uint16(port)))
|
||||||
|
@ -228,7 +237,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
|
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
|
||||||
w.dialer.options = opts
|
options := w.Base.DialOptions(opts...)
|
||||||
|
w.dialer.dialer = dialer.NewDialer(options...)
|
||||||
var pc net.PacketConn
|
var pc net.PacketConn
|
||||||
w.startOnce.Do(func() {
|
w.startOnce.Do(func() {
|
||||||
w.startErr = w.tunDevice.Start()
|
w.startErr = w.tunDevice.Start()
|
||||||
|
|
|
@ -109,7 +109,19 @@ func GetTcpConcurrent() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
|
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
|
||||||
dialer := &net.Dialer{}
|
address := net.JoinHostPort(destination.String(), port)
|
||||||
|
|
||||||
|
netDialer := opt.netDialer
|
||||||
|
switch netDialer.(type) {
|
||||||
|
case nil:
|
||||||
|
netDialer = &net.Dialer{}
|
||||||
|
case *net.Dialer:
|
||||||
|
netDialer = &*netDialer.(*net.Dialer) // make a copy
|
||||||
|
default:
|
||||||
|
return netDialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := netDialer.(*net.Dialer)
|
||||||
if opt.interfaceName != "" {
|
if opt.interfaceName != "" {
|
||||||
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
|
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -118,8 +130,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
|
||||||
if opt.routingMark != 0 {
|
if opt.routingMark != 0 {
|
||||||
bindMarkToDialer(opt.routingMark, dialer, network, destination)
|
bindMarkToDialer(opt.routingMark, dialer, network, destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
address := net.JoinHostPort(destination.String(), port)
|
|
||||||
if opt.tfo {
|
if opt.tfo {
|
||||||
return dialTFO(ctx, *dialer, network, address)
|
return dialTFO(ctx, *dialer, network, address)
|
||||||
}
|
}
|
||||||
|
@ -307,15 +317,15 @@ func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Dialer struct {
|
type Dialer struct {
|
||||||
Opt option
|
opt option
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return DialContext(ctx, network, address, WithOption(d.Opt))
|
return DialContext(ctx, network, address, WithOption(d.opt))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
|
func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
|
||||||
opt := WithOption(d.Opt)
|
opt := WithOption(d.opt)
|
||||||
if rAddrPort.Addr().Unmap().IsLoopback() {
|
if rAddrPort.Addr().Unmap().IsLoopback() {
|
||||||
// avoid "The requested address is not valid in its context."
|
// avoid "The requested address is not valid in its context."
|
||||||
opt = WithInterface("")
|
opt = WithInterface("")
|
||||||
|
@ -325,5 +335,5 @@ func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddr
|
||||||
|
|
||||||
func NewDialer(options ...Option) Dialer {
|
func NewDialer(options ...Option) Dialer {
|
||||||
opt := applyOptions(options...)
|
opt := applyOptions(options...)
|
||||||
return Dialer{Opt: *opt}
|
return Dialer{opt: *opt}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package dialer
|
package dialer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/component/resolver"
|
"github.com/Dreamacro/clash/component/resolver"
|
||||||
|
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
|
@ -12,6 +15,10 @@ var (
|
||||||
DefaultRoutingMark = atomic.NewInt32(0)
|
DefaultRoutingMark = atomic.NewInt32(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type NetDialer interface {
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
type option struct {
|
type option struct {
|
||||||
interfaceName string
|
interfaceName string
|
||||||
addrReuse bool
|
addrReuse bool
|
||||||
|
@ -20,6 +27,7 @@ type option struct {
|
||||||
prefer int
|
prefer int
|
||||||
tfo bool
|
tfo bool
|
||||||
resolver resolver.Resolver
|
resolver resolver.Resolver
|
||||||
|
netDialer NetDialer
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(opt *option)
|
type Option func(opt *option)
|
||||||
|
@ -76,6 +84,12 @@ func WithTFO(tfo bool) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithNetDialer(netDialer NetDialer) Option {
|
||||||
|
return func(opt *option) {
|
||||||
|
opt.netDialer = netDialer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithOption(o option) Option {
|
func WithOption(o option) Option {
|
||||||
return func(opt *option) {
|
return func(opt *option) {
|
||||||
*opt = o
|
*opt = o
|
||||||
|
|
Loading…
Reference in a new issue