From 26ce3e8814db8249a17f8bfda6ac50572d7ca6ac Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Fri, 31 Jan 2020 14:43:54 +0800 Subject: [PATCH] Improve: udp NAT type --- adapters/inbound/packet.go | 4 ++-- adapters/outbound/base.go | 4 ++-- adapters/outbound/direct.go | 11 +++-------- adapters/outbound/reject.go | 4 ++-- adapters/outbound/shadowsocks.go | 19 +++++++------------ adapters/outbound/socks5.go | 20 +++++--------------- adapters/outbound/vmess.go | 8 ++++---- adapters/outboundgroup/fallback.go | 7 +++---- adapters/outboundgroup/loadbalance.go | 2 +- adapters/outboundgroup/selector.go | 7 +++---- adapters/outboundgroup/urltest.go | 7 +++---- component/nat/table.go | 20 +++++--------------- constant/adapters.go | 2 +- constant/metadata.go | 12 ++++++++++++ proxy/socks/udp.go | 5 ++--- proxy/socks/utils.go | 21 +++++++-------------- tunnel/connection.go | 7 ++----- tunnel/tunnel.go | 19 +++++++------------ 18 files changed, 71 insertions(+), 108 deletions(-) diff --git a/adapters/inbound/packet.go b/adapters/inbound/packet.go index 59ccce85..001a579b 100644 --- a/adapters/inbound/packet.go +++ b/adapters/inbound/packet.go @@ -17,9 +17,9 @@ func (s *PacketAdapter) Metadata() *C.Metadata { } // NewPacket is PacketAdapter generator -func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type, netType C.NetWork) *PacketAdapter { +func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type) *PacketAdapter { metadata := parseSocksAddr(target) - metadata.NetWork = netType + metadata.NetWork = C.UDP metadata.Type = source if ip, port, err := parseAddr(packet.LocalAddr().String()); err == nil { metadata.SrcIP = ip diff --git a/adapters/outbound/base.go b/adapters/outbound/base.go index 3dc97825..df1c61bc 100644 --- a/adapters/outbound/base.go +++ b/adapters/outbound/base.go @@ -30,8 +30,8 @@ func (b *Base) Type() C.AdapterType { return b.tp } -func (b *Base) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { - return nil, nil, errors.New("no support") +func (b *Base) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { + return nil, errors.New("no support") } func (b *Base) SupportUDP() bool { diff --git a/adapters/outbound/direct.go b/adapters/outbound/direct.go index 061d0361..d2b7a9ab 100644 --- a/adapters/outbound/direct.go +++ b/adapters/outbound/direct.go @@ -25,17 +25,12 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, return newConn(c, d), nil } -func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { +func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { pc, err := net.ListenPacket("udp", "") if err != nil { - return nil, nil, err + return nil, err } - - addr, err := resolveUDPAddr("udp", metadata.RemoteAddress()) - if err != nil { - return nil, nil, err - } - return newPacketConn(pc, d), addr, nil + return newPacketConn(pc, d), nil } func NewDirect() *Direct { diff --git a/adapters/outbound/reject.go b/adapters/outbound/reject.go index 5ef61540..7daba1a2 100644 --- a/adapters/outbound/reject.go +++ b/adapters/outbound/reject.go @@ -18,8 +18,8 @@ func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, return newConn(&NopConn{}, r), nil } -func (r *Reject) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { - return nil, nil, errors.New("match reject rule") +func (r *Reject) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { + return nil, errors.New("match reject rule") } func NewReject() *Reject { diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index 83e1a8b8..217c8e2e 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -82,24 +82,19 @@ func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C return newConn(c, ss), err } -func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { +func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { pc, err := net.ListenPacket("udp", "") if err != nil { - return nil, nil, err + return nil, err } addr, err := resolveUDPAddr("udp", ss.server) if err != nil { - return nil, nil, err - } - - targetAddr := socks5.ParseAddr(metadata.RemoteAddress()) - if targetAddr == nil { - return nil, nil, fmt.Errorf("parse address %s error: %s", metadata.String(), metadata.DstPort) + return nil, err } pc = ss.cipher.PacketConn(pc) - return newPacketConn(&ssUDPConn{PacketConn: pc, rAddr: targetAddr}, ss), addr, nil + return newPacketConn(&ssUDPConn{PacketConn: pc, rAddr: addr}, ss), nil } func (ss *ShadowSocks) MarshalJSON() ([]byte, error) { @@ -189,15 +184,15 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) { type ssUDPConn struct { net.PacketConn - rAddr socks5.Addr + rAddr net.Addr } func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - packet, err := socks5.EncodeUDPPacket(uc.rAddr, b) + packet, err := socks5.EncodeUDPPacket(socks5.ParseAddrToSocksAddr(addr), b) if err != nil { return } - return uc.PacketConn.WriteTo(packet[3:], addr) + return uc.PacketConn.WriteTo(packet[3:], uc.rAddr) } func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { diff --git a/adapters/outbound/socks5.go b/adapters/outbound/socks5.go index 7632be30..79cf05b7 100644 --- a/adapters/outbound/socks5.go +++ b/adapters/outbound/socks5.go @@ -60,7 +60,7 @@ func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn return newConn(c, ss), nil } -func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err error) { +func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) defer cancel() c, err := dialContext(ctx, "tcp", ss.addr) @@ -96,16 +96,6 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err return } - addr, err := net.ResolveUDPAddr("udp", bindAddr.String()) - if err != nil { - return - } - - targetAddr := socks5.ParseAddr(metadata.RemoteAddress()) - if targetAddr == nil { - return nil, nil, fmt.Errorf("parse address %s error: %s", metadata.String(), metadata.DstPort) - } - pc, err := net.ListenPacket("udp", "") if err != nil { return @@ -119,7 +109,7 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err pc.Close() }() - return newPacketConn(&socksUDPConn{PacketConn: pc, rAddr: targetAddr, tcpConn: c}, ss), addr, nil + return newPacketConn(&socksUDPConn{PacketConn: pc, rAddr: bindAddr.UDPAddr(), tcpConn: c}, ss), nil } func NewSocks5(option Socks5Option) *Socks5 { @@ -149,16 +139,16 @@ func NewSocks5(option Socks5Option) *Socks5 { type socksUDPConn struct { net.PacketConn - rAddr socks5.Addr + rAddr net.Addr tcpConn net.Conn } func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - packet, err := socks5.EncodeUDPPacket(uc.rAddr, b) + packet, err := socks5.EncodeUDPPacket(socks5.ParseAddrToSocksAddr(addr), b) if err != nil { return } - return uc.PacketConn.WriteTo(packet, addr) + return uc.PacketConn.WriteTo(packet, uc.rAddr) } func (uc *socksUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index c653b106..a7a5016d 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -42,19 +42,19 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, return newConn(c, v), err } -func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { +func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) defer cancel() c, err := dialContext(ctx, "tcp", v.server) if err != nil { - return nil, nil, fmt.Errorf("%s connect error", v.server) + return nil, fmt.Errorf("%s connect error", v.server) } tcpKeepAlive(c) c, err = v.client.New(c, parseVmessAddr(metadata)) if err != nil { - return nil, nil, fmt.Errorf("new vmess client error: %v", err) + return nil, fmt.Errorf("new vmess client error: %v", err) } - return newPacketConn(&vmessUDPConn{Conn: c}, v), c.RemoteAddr(), nil + return newPacketConn(&vmessUDPConn{Conn: c}, v), nil } func NewVmess(option VmessOption) (*Vmess, error) { diff --git a/adapters/outboundgroup/fallback.go b/adapters/outboundgroup/fallback.go index 2104c39e..7a344d09 100644 --- a/adapters/outboundgroup/fallback.go +++ b/adapters/outboundgroup/fallback.go @@ -3,7 +3,6 @@ package outboundgroup import ( "context" "encoding/json" - "net" "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/provider" @@ -31,13 +30,13 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con return c, err } -func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { +func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { proxy := f.findAliveProxy() - pc, addr, err := proxy.DialUDP(metadata) + pc, err := proxy.DialUDP(metadata) if err == nil { pc.AppendToChains(f) } - return pc, addr, err + return pc, err } func (f *Fallback) SupportUDP() bool { diff --git a/adapters/outboundgroup/loadbalance.go b/adapters/outboundgroup/loadbalance.go index 78a942e0..1495cda2 100644 --- a/adapters/outboundgroup/loadbalance.go +++ b/adapters/outboundgroup/loadbalance.go @@ -74,7 +74,7 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c return } -func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, addr net.Addr, err error) { +func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, err error) { defer func() { if err == nil { pc.AppendToChains(lb) diff --git a/adapters/outboundgroup/selector.go b/adapters/outboundgroup/selector.go index fd9ef041..533a6126 100644 --- a/adapters/outboundgroup/selector.go +++ b/adapters/outboundgroup/selector.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "net" "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/provider" @@ -27,12 +26,12 @@ func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con return c, err } -func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { - pc, addr, err := s.selected.DialUDP(metadata) +func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { + pc, err := s.selected.DialUDP(metadata) if err == nil { pc.AppendToChains(s) } - return pc, addr, err + return pc, err } func (s *Selector) SupportUDP() bool { diff --git a/adapters/outboundgroup/urltest.go b/adapters/outboundgroup/urltest.go index cf1ad138..c985eea8 100644 --- a/adapters/outboundgroup/urltest.go +++ b/adapters/outboundgroup/urltest.go @@ -3,7 +3,6 @@ package outboundgroup import ( "context" "encoding/json" - "net" "time" "github.com/Dreamacro/clash/adapters/outbound" @@ -31,12 +30,12 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Co return c, err } -func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { - pc, addr, err := u.fast().DialUDP(metadata) +func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { + pc, err := u.fast().DialUDP(metadata) if err == nil { pc.AppendToChains(u) } - return pc, addr, err + return pc, err } func (u *URLTest) proxies() []C.Proxy { diff --git a/component/nat/table.go b/component/nat/table.go index eac98467..a88a00f0 100644 --- a/component/nat/table.go +++ b/component/nat/table.go @@ -9,26 +9,16 @@ type Table struct { mapping sync.Map } -type element struct { - RemoteAddr net.Addr - RemoteConn net.PacketConn +func (t *Table) Set(key string, pc net.PacketConn) { + t.mapping.Store(key, pc) } -func (t *Table) Set(key string, pc net.PacketConn, addr net.Addr) { - // set conn read timeout - t.mapping.Store(key, &element{ - RemoteConn: pc, - RemoteAddr: addr, - }) -} - -func (t *Table) Get(key string) (net.PacketConn, net.Addr) { +func (t *Table) Get(key string) net.PacketConn { item, exist := t.mapping.Load(key) if !exist { - return nil, nil + return nil } - elm := item.(*element) - return elm.RemoteConn, elm.RemoteAddr + return item.(net.PacketConn) } func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) { diff --git a/constant/adapters.go b/constant/adapters.go index f05a23fa..f7614c3a 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -59,7 +59,7 @@ type ProxyAdapter interface { Name() string Type() AdapterType DialContext(ctx context.Context, metadata *Metadata) (Conn, error) - DialUDP(metadata *Metadata) (PacketConn, net.Addr, error) + DialUDP(metadata *Metadata) (PacketConn, error) SupportUDP() bool MarshalJSON() ([]byte, error) } diff --git a/constant/metadata.go b/constant/metadata.go index afcd3443..8bdbd7bc 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -3,6 +3,7 @@ package constant import ( "encoding/json" "net" + "strconv" ) // Socks addr type @@ -70,6 +71,17 @@ func (m *Metadata) RemoteAddress() string { return net.JoinHostPort(m.String(), m.DstPort) } +func (m *Metadata) UDPAddr() *net.UDPAddr { + if m.NetWork != UDP || m.DstIP == nil { + return nil + } + port, _ := strconv.Atoi(m.DstPort) + return &net.UDPAddr{ + IP: m.DstIP, + Port: port, + } +} + func (m *Metadata) String() string { if m.Host != "" { return m.Host diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go index 0e26cc14..2239af3c 100644 --- a/proxy/socks/udp.go +++ b/proxy/socks/udp.go @@ -58,10 +58,9 @@ func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) { } packet := &fakeConn{ PacketConn: pc, - remoteAddr: addr, - targetAddr: target, + rAddr: addr, payload: payload, bufRef: buf, } - tun.AddPacket(adapters.NewPacket(target, packet, C.SOCKS, C.UDP)) + tun.AddPacket(adapters.NewPacket(target, packet, C.SOCKS)) } diff --git a/proxy/socks/utils.go b/proxy/socks/utils.go index 8cbe57e9..51a1911e 100644 --- a/proxy/socks/utils.go +++ b/proxy/socks/utils.go @@ -9,10 +9,9 @@ import ( type fakeConn struct { net.PacketConn - remoteAddr net.Addr - targetAddr socks5.Addr - payload []byte - bufRef []byte + rAddr net.Addr + payload []byte + bufRef []byte } func (c *fakeConn) Data() []byte { @@ -21,25 +20,19 @@ func (c *fakeConn) Data() []byte { // WriteBack wirtes UDP packet with source(ip, port) = `addr` func (c *fakeConn) WriteBack(b []byte, addr net.Addr) (n int, err error) { - from := c.targetAddr - if addr != nil { - // if addr is provided, use the parsed addr - from = socks5.ParseAddrToSocksAddr(addr) - } - packet, err := socks5.EncodeUDPPacket(from, b) + packet, err := socks5.EncodeUDPPacket(socks5.ParseAddrToSocksAddr(addr), b) if err != nil { return } - return c.PacketConn.WriteTo(packet, c.remoteAddr) + return c.PacketConn.WriteTo(packet, c.rAddr) } // LocalAddr returns the source IP/Port of UDP Packet func (c *fakeConn) LocalAddr() net.Addr { - return c.remoteAddr + return c.PacketConn.LocalAddr() } func (c *fakeConn) Close() error { - err := c.PacketConn.Close() pool.BufPool.Put(c.bufRef[:cap(c.bufRef)]) - return err + return nil } diff --git a/tunnel/connection.go b/tunnel/connection.go index 825d0b34..6771aa40 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -88,21 +88,18 @@ func (t *Tunnel) handleUDPToRemote(packet C.UDPPacket, pc net.PacketConn, addr n DefaultManager.Upload() <- int64(len(packet.Data())) } -func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, omitSrcAddr bool, timeout time.Duration) { +func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string) { buf := pool.BufPool.Get().([]byte) defer pool.BufPool.Put(buf[:cap(buf)]) defer t.natTable.Delete(key) defer pc.Close() for { - pc.SetReadDeadline(time.Now().Add(timeout)) + pc.SetReadDeadline(time.Now().Add(udpTimeout)) n, from, err := pc.ReadFrom(buf) if err != nil { return } - if from != nil && omitSrcAddr { - from = nil - } n, err = packet.WriteBack(buf[:n], from) if err != nil { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 61f08971..50ba1cd6 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -175,11 +175,10 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { return } - src := packet.LocalAddr().String() - dst := metadata.RemoteAddress() - key := src + "-" + dst + key := packet.LocalAddr().String() - pc, addr := t.natTable.Get(key) + pc := t.natTable.Get(key) + addr := metadata.UDPAddr() if pc != nil { t.handleUDPToRemote(packet, pc, addr) return @@ -188,8 +187,6 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { lockKey := key + "-lock" wg, loaded := t.natTable.GetOrCreateLock(lockKey) - isFakeIP := dns.DefaultResolver != nil && dns.DefaultResolver.IsFakeIP(metadata.DstIP) - go func() { if !loaded { wg.Add(1) @@ -201,14 +198,13 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { return } - rawPc, nAddr, err := proxy.DialUDP(metadata) + rawPc, err := proxy.DialUDP(metadata) if err != nil { log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error()) t.natTable.Delete(lockKey) wg.Done() return } - addr = nAddr pc = newUDPTracker(rawPc, DefaultManager, metadata, rule) if rule != nil { @@ -217,15 +213,14 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String()) } - t.natTable.Set(key, pc, addr) + t.natTable.Set(key, pc) t.natTable.Delete(lockKey) wg.Done() - // in fake-ip mode, Full-Cone NAT can never achieve, fallback to omitting src Addr - go t.handleUDPToLocal(packet.UDPPacket, pc, key, isFakeIP, udpTimeout) + go t.handleUDPToLocal(packet.UDPPacket, pc, key) } wg.Wait() - pc, addr := t.natTable.Get(key) + pc := t.natTable.Get(key) if pc != nil { t.handleUDPToRemote(packet, pc, addr) }