fix: wireguard auto close not working
This commit is contained in:
parent
b0c56eee28
commit
30c0359666
1 changed files with 159 additions and 10 deletions
|
@ -40,6 +40,7 @@ type WireGuard struct {
|
||||||
startOnce sync.Once
|
startOnce sync.Once
|
||||||
startErr error
|
startErr error
|
||||||
resolver *dns.Resolver
|
resolver *dns.Resolver
|
||||||
|
refP *refProxyAdapter
|
||||||
}
|
}
|
||||||
|
|
||||||
type WireGuardOption struct {
|
type WireGuardOption struct {
|
||||||
|
@ -100,6 +101,20 @@ func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr
|
||||||
return cDialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
|
return cDialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type wgSingErrorHandler struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ E.Handler = (*wgSingErrorHandler)(nil)
|
||||||
|
|
||||||
|
func (w wgSingErrorHandler) NewError(ctx context.Context, err error) {
|
||||||
|
if E.IsClosedOrCanceled(err) {
|
||||||
|
log.SingLogger.Debug(fmt.Sprintf("[WG](%s) connection closed: %s", w.name, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", w.name, err))
|
||||||
|
}
|
||||||
|
|
||||||
type wgNetDialer struct {
|
type wgNetDialer struct {
|
||||||
tunDevice wireguard.Device
|
tunDevice wireguard.Device
|
||||||
}
|
}
|
||||||
|
@ -174,7 +189,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
|
||||||
connectAddr = option.Addr()
|
connectAddr = option.Addr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outbound.bind = wireguard.NewClientBind(context.Background(), outbound, outbound.dialer, isConnect, connectAddr, reserved)
|
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, connectAddr, reserved)
|
||||||
|
|
||||||
var localPrefixes []netip.Prefix
|
var localPrefixes []netip.Prefix
|
||||||
|
|
||||||
|
@ -312,13 +327,15 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
refP := &refProxyAdapter{}
|
||||||
|
outbound.refP = refP
|
||||||
if option.RemoteDnsResolve && len(option.Dns) > 0 {
|
if option.RemoteDnsResolve && len(option.Dns) > 0 {
|
||||||
nss, err := dns.ParseNameServer(option.Dns)
|
nss, err := dns.ParseNameServer(option.Dns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for i := range nss {
|
for i := range nss {
|
||||||
nss[i].ProxyAdapter = outbound
|
nss[i].ProxyAdapter = refP
|
||||||
}
|
}
|
||||||
outbound.resolver = dns.NewResolver(dns.Config{
|
outbound.resolver = dns.NewResolver(dns.Config{
|
||||||
Main: nss,
|
Main: nss,
|
||||||
|
@ -329,14 +346,6 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
|
||||||
return outbound, nil
|
return outbound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WireGuard) NewError(ctx context.Context, err error) {
|
|
||||||
if E.IsClosedOrCanceled(err) {
|
|
||||||
log.SingLogger.Debug(fmt.Sprintf("[WG](%s) connection closed: %s", w.Name(), err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", w.Name(), err))
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeWireGuard(w *WireGuard) {
|
func closeWireGuard(w *WireGuard) {
|
||||||
if w.device != nil {
|
if w.device != nil {
|
||||||
w.device.Close()
|
w.device.Close()
|
||||||
|
@ -357,6 +366,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
|
||||||
if !metadata.Resolved() || w.resolver != nil {
|
if !metadata.Resolved() || w.resolver != nil {
|
||||||
r := resolver.DefaultResolver
|
r := resolver.DefaultResolver
|
||||||
if w.resolver != nil {
|
if w.resolver != nil {
|
||||||
|
w.refP.SetProxyAdapter(w)
|
||||||
|
defer w.refP.ClearProxyAdapter()
|
||||||
r = w.resolver
|
r = w.resolver
|
||||||
}
|
}
|
||||||
options = append(options, dialer.WithResolver(r))
|
options = append(options, dialer.WithResolver(r))
|
||||||
|
@ -391,6 +402,8 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat
|
||||||
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
|
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
|
||||||
r := resolver.DefaultResolver
|
r := resolver.DefaultResolver
|
||||||
if w.resolver != nil {
|
if w.resolver != nil {
|
||||||
|
w.refP.SetProxyAdapter(w)
|
||||||
|
defer w.refP.ClearProxyAdapter()
|
||||||
r = w.resolver
|
r = w.resolver
|
||||||
}
|
}
|
||||||
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
|
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
|
||||||
|
@ -414,3 +427,139 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat
|
||||||
func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool {
|
func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type refProxyAdapter struct {
|
||||||
|
proxyAdapter C.ProxyAdapter
|
||||||
|
count int
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) SetProxyAdapter(proxyAdapter C.ProxyAdapter) {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.proxyAdapter = proxyAdapter
|
||||||
|
r.count++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) ClearProxyAdapter() {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.count--
|
||||||
|
if r.count == 0 {
|
||||||
|
r.proxyAdapter = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) Name() string {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.Name()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) Type() C.AdapterType {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.Type()
|
||||||
|
}
|
||||||
|
return C.AdapterType(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) Addr() string {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.Addr()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) SupportUDP() bool {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.SupportUDP()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) SupportXUDP() bool {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.SupportXUDP()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) SupportTFO() bool {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.SupportTFO()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) MarshalJSON() ([]byte, error) {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.MarshalJSON()
|
||||||
|
}
|
||||||
|
return nil, C.ErrNotSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.StreamConn(c, metadata)
|
||||||
|
}
|
||||||
|
return nil, C.ErrNotSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.DialContext(ctx, metadata, opts...)
|
||||||
|
}
|
||||||
|
return nil, C.ErrNotSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.ListenPacketContext(ctx, metadata, opts...)
|
||||||
|
}
|
||||||
|
return nil, C.ErrNotSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) SupportUOT() bool {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.SupportUOT()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) SupportWithDialer() C.NetWork {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.SupportWithDialer()
|
||||||
|
}
|
||||||
|
return C.InvalidNet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.DialContextWithDialer(ctx, dialer, metadata)
|
||||||
|
}
|
||||||
|
return nil, C.ErrNotSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata)
|
||||||
|
}
|
||||||
|
return nil, C.ErrNotSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) IsL3Protocol(metadata *C.Metadata) bool {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.IsL3Protocol(metadata)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refProxyAdapter) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
|
||||||
|
if r.proxyAdapter != nil {
|
||||||
|
return r.proxyAdapter.Unwrap(metadata, touch)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ C.ProxyAdapter = (*refProxyAdapter)(nil)
|
||||||
|
|
Loading…
Reference in a new issue