Improve: udp NAT type

This commit is contained in:
gVisor bot 2020-01-31 14:43:54 +08:00
parent 15dfaba355
commit adfe73b48e
18 changed files with 71 additions and 108 deletions

View file

@ -17,9 +17,9 @@ func (s *PacketAdapter) Metadata() *C.Metadata {
} }
// NewPacket is PacketAdapter generator // NewPacket is PacketAdapter generator
func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type, netType C.NetWork) *PacketAdapter { func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type) *PacketAdapter {
metadata := parseSocksAddr(target) metadata := parseSocksAddr(target)
metadata.NetWork = netType metadata.NetWork = C.UDP
metadata.Type = source metadata.Type = source
if ip, port, err := parseAddr(packet.LocalAddr().String()); err == nil { if ip, port, err := parseAddr(packet.LocalAddr().String()); err == nil {
metadata.SrcIP = ip metadata.SrcIP = ip

View file

@ -30,8 +30,8 @@ func (b *Base) Type() C.AdapterType {
return b.tp return b.tp
} }
func (b *Base) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (b *Base) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
return nil, nil, errors.New("no support") return nil, errors.New("no support")
} }
func (b *Base) SupportUDP() bool { func (b *Base) SupportUDP() bool {

View file

@ -25,17 +25,12 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
return newConn(c, d), nil return newConn(c, d), nil
} }
func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
pc, err := net.ListenPacket("udp", "") pc, err := net.ListenPacket("udp", "")
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
return newPacketConn(pc, d), nil
addr, err := resolveUDPAddr("udp", metadata.RemoteAddress())
if err != nil {
return nil, nil, err
}
return newPacketConn(pc, d), addr, nil
} }
func NewDirect() *Direct { func NewDirect() *Direct {

View file

@ -18,8 +18,8 @@ func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
return newConn(&NopConn{}, r), nil return newConn(&NopConn{}, r), nil
} }
func (r *Reject) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (r *Reject) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
return nil, nil, errors.New("match reject rule") return nil, errors.New("match reject rule")
} }
func NewReject() *Reject { func NewReject() *Reject {

View file

@ -82,24 +82,19 @@ func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C
return newConn(c, ss), err return newConn(c, ss), err
} }
func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
pc, err := net.ListenPacket("udp", "") pc, err := net.ListenPacket("udp", "")
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
addr, err := resolveUDPAddr("udp", ss.server) addr, err := resolveUDPAddr("udp", ss.server)
if err != nil { if err != nil {
return nil, nil, err return nil, err
}
targetAddr := socks5.ParseAddr(metadata.RemoteAddress())
if targetAddr == nil {
return nil, nil, fmt.Errorf("parse address %s error: %s", metadata.String(), metadata.DstPort)
} }
pc = ss.cipher.PacketConn(pc) pc = ss.cipher.PacketConn(pc)
return newPacketConn(&ssUDPConn{PacketConn: pc, rAddr: targetAddr}, ss), addr, nil return newPacketConn(&ssUDPConn{PacketConn: pc, rAddr: addr}, ss), nil
} }
func (ss *ShadowSocks) MarshalJSON() ([]byte, error) { func (ss *ShadowSocks) MarshalJSON() ([]byte, error) {
@ -189,15 +184,15 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
type ssUDPConn struct { type ssUDPConn struct {
net.PacketConn net.PacketConn
rAddr socks5.Addr rAddr net.Addr
} }
func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
packet, err := socks5.EncodeUDPPacket(uc.rAddr, b) packet, err := socks5.EncodeUDPPacket(socks5.ParseAddrToSocksAddr(addr), b)
if err != nil { if err != nil {
return return
} }
return uc.PacketConn.WriteTo(packet[3:], addr) return uc.PacketConn.WriteTo(packet[3:], uc.rAddr)
} }
func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {

View file

@ -60,7 +60,7 @@ func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn
return newConn(c, ss), nil return newConn(c, ss), nil
} }
func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err error) { func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) {
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel() defer cancel()
c, err := dialContext(ctx, "tcp", ss.addr) c, err := dialContext(ctx, "tcp", ss.addr)
@ -96,16 +96,6 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err
return return
} }
addr, err := net.ResolveUDPAddr("udp", bindAddr.String())
if err != nil {
return
}
targetAddr := socks5.ParseAddr(metadata.RemoteAddress())
if targetAddr == nil {
return nil, nil, fmt.Errorf("parse address %s error: %s", metadata.String(), metadata.DstPort)
}
pc, err := net.ListenPacket("udp", "") pc, err := net.ListenPacket("udp", "")
if err != nil { if err != nil {
return return
@ -119,7 +109,7 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err
pc.Close() pc.Close()
}() }()
return newPacketConn(&socksUDPConn{PacketConn: pc, rAddr: targetAddr, tcpConn: c}, ss), addr, nil return newPacketConn(&socksUDPConn{PacketConn: pc, rAddr: bindAddr.UDPAddr(), tcpConn: c}, ss), nil
} }
func NewSocks5(option Socks5Option) *Socks5 { func NewSocks5(option Socks5Option) *Socks5 {
@ -149,16 +139,16 @@ func NewSocks5(option Socks5Option) *Socks5 {
type socksUDPConn struct { type socksUDPConn struct {
net.PacketConn net.PacketConn
rAddr socks5.Addr rAddr net.Addr
tcpConn net.Conn tcpConn net.Conn
} }
func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
packet, err := socks5.EncodeUDPPacket(uc.rAddr, b) packet, err := socks5.EncodeUDPPacket(socks5.ParseAddrToSocksAddr(addr), b)
if err != nil { if err != nil {
return return
} }
return uc.PacketConn.WriteTo(packet, addr) return uc.PacketConn.WriteTo(packet, uc.rAddr)
} }
func (uc *socksUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (uc *socksUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {

View file

@ -42,19 +42,19 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
return newConn(c, v), err return newConn(c, v), err
} }
func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel() defer cancel()
c, err := dialContext(ctx, "tcp", v.server) c, err := dialContext(ctx, "tcp", v.server)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%s connect error", v.server) return nil, fmt.Errorf("%s connect error", v.server)
} }
tcpKeepAlive(c) tcpKeepAlive(c)
c, err = v.client.New(c, parseVmessAddr(metadata)) c, err = v.client.New(c, parseVmessAddr(metadata))
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("new vmess client error: %v", err) return nil, fmt.Errorf("new vmess client error: %v", err)
} }
return newPacketConn(&vmessUDPConn{Conn: c}, v), c.RemoteAddr(), nil return newPacketConn(&vmessUDPConn{Conn: c}, v), nil
} }
func NewVmess(option VmessOption) (*Vmess, error) { func NewVmess(option VmessOption) (*Vmess, error) {

View file

@ -3,7 +3,6 @@ package outboundgroup
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"net"
"github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/outbound"
"github.com/Dreamacro/clash/adapters/provider" "github.com/Dreamacro/clash/adapters/provider"
@ -31,13 +30,13 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
return c, err return c, err
} }
func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
proxy := f.findAliveProxy() proxy := f.findAliveProxy()
pc, addr, err := proxy.DialUDP(metadata) pc, err := proxy.DialUDP(metadata)
if err == nil { if err == nil {
pc.AppendToChains(f) pc.AppendToChains(f)
} }
return pc, addr, err return pc, err
} }
func (f *Fallback) SupportUDP() bool { func (f *Fallback) SupportUDP() bool {

View file

@ -74,7 +74,7 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c
return return
} }
func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, addr net.Addr, err error) { func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, err error) {
defer func() { defer func() {
if err == nil { if err == nil {
pc.AppendToChains(lb) pc.AppendToChains(lb)

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"net"
"github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/outbound"
"github.com/Dreamacro/clash/adapters/provider" "github.com/Dreamacro/clash/adapters/provider"
@ -27,12 +26,12 @@ func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
return c, err return c, err
} }
func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
pc, addr, err := s.selected.DialUDP(metadata) pc, err := s.selected.DialUDP(metadata)
if err == nil { if err == nil {
pc.AppendToChains(s) pc.AppendToChains(s)
} }
return pc, addr, err return pc, err
} }
func (s *Selector) SupportUDP() bool { func (s *Selector) SupportUDP() bool {

View file

@ -3,7 +3,6 @@ package outboundgroup
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"net"
"time" "time"
"github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/outbound"
@ -31,12 +30,12 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Co
return c, err return c, err
} }
func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
pc, addr, err := u.fast().DialUDP(metadata) pc, err := u.fast().DialUDP(metadata)
if err == nil { if err == nil {
pc.AppendToChains(u) pc.AppendToChains(u)
} }
return pc, addr, err return pc, err
} }
func (u *URLTest) proxies() []C.Proxy { func (u *URLTest) proxies() []C.Proxy {

View file

@ -9,26 +9,16 @@ type Table struct {
mapping sync.Map mapping sync.Map
} }
type element struct { func (t *Table) Set(key string, pc net.PacketConn) {
RemoteAddr net.Addr t.mapping.Store(key, pc)
RemoteConn net.PacketConn
} }
func (t *Table) Set(key string, pc net.PacketConn, addr net.Addr) { func (t *Table) Get(key string) net.PacketConn {
// set conn read timeout
t.mapping.Store(key, &element{
RemoteConn: pc,
RemoteAddr: addr,
})
}
func (t *Table) Get(key string) (net.PacketConn, net.Addr) {
item, exist := t.mapping.Load(key) item, exist := t.mapping.Load(key)
if !exist { if !exist {
return nil, nil return nil
} }
elm := item.(*element) return item.(net.PacketConn)
return elm.RemoteConn, elm.RemoteAddr
} }
func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) { func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) {

View file

@ -59,7 +59,7 @@ type ProxyAdapter interface {
Name() string Name() string
Type() AdapterType Type() AdapterType
DialContext(ctx context.Context, metadata *Metadata) (Conn, error) DialContext(ctx context.Context, metadata *Metadata) (Conn, error)
DialUDP(metadata *Metadata) (PacketConn, net.Addr, error) DialUDP(metadata *Metadata) (PacketConn, error)
SupportUDP() bool SupportUDP() bool
MarshalJSON() ([]byte, error) MarshalJSON() ([]byte, error)
} }

View file

@ -3,6 +3,7 @@ package constant
import ( import (
"encoding/json" "encoding/json"
"net" "net"
"strconv"
) )
// Socks addr type // Socks addr type
@ -70,6 +71,17 @@ func (m *Metadata) RemoteAddress() string {
return net.JoinHostPort(m.String(), m.DstPort) return net.JoinHostPort(m.String(), m.DstPort)
} }
func (m *Metadata) UDPAddr() *net.UDPAddr {
if m.NetWork != UDP || m.DstIP == nil {
return nil
}
port, _ := strconv.Atoi(m.DstPort)
return &net.UDPAddr{
IP: m.DstIP,
Port: port,
}
}
func (m *Metadata) String() string { func (m *Metadata) String() string {
if m.Host != "" { if m.Host != "" {
return m.Host return m.Host

View file

@ -58,10 +58,9 @@ func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) {
} }
packet := &fakeConn{ packet := &fakeConn{
PacketConn: pc, PacketConn: pc,
remoteAddr: addr, rAddr: addr,
targetAddr: target,
payload: payload, payload: payload,
bufRef: buf, bufRef: buf,
} }
tun.AddPacket(adapters.NewPacket(target, packet, C.SOCKS, C.UDP)) tun.AddPacket(adapters.NewPacket(target, packet, C.SOCKS))
} }

View file

@ -9,10 +9,9 @@ import (
type fakeConn struct { type fakeConn struct {
net.PacketConn net.PacketConn
remoteAddr net.Addr rAddr net.Addr
targetAddr socks5.Addr payload []byte
payload []byte bufRef []byte
bufRef []byte
} }
func (c *fakeConn) Data() []byte { func (c *fakeConn) Data() []byte {
@ -21,25 +20,19 @@ func (c *fakeConn) Data() []byte {
// WriteBack wirtes UDP packet with source(ip, port) = `addr` // WriteBack wirtes UDP packet with source(ip, port) = `addr`
func (c *fakeConn) WriteBack(b []byte, addr net.Addr) (n int, err error) { func (c *fakeConn) WriteBack(b []byte, addr net.Addr) (n int, err error) {
from := c.targetAddr packet, err := socks5.EncodeUDPPacket(socks5.ParseAddrToSocksAddr(addr), b)
if addr != nil {
// if addr is provided, use the parsed addr
from = socks5.ParseAddrToSocksAddr(addr)
}
packet, err := socks5.EncodeUDPPacket(from, b)
if err != nil { if err != nil {
return return
} }
return c.PacketConn.WriteTo(packet, c.remoteAddr) return c.PacketConn.WriteTo(packet, c.rAddr)
} }
// LocalAddr returns the source IP/Port of UDP Packet // LocalAddr returns the source IP/Port of UDP Packet
func (c *fakeConn) LocalAddr() net.Addr { func (c *fakeConn) LocalAddr() net.Addr {
return c.remoteAddr return c.PacketConn.LocalAddr()
} }
func (c *fakeConn) Close() error { func (c *fakeConn) Close() error {
err := c.PacketConn.Close()
pool.BufPool.Put(c.bufRef[:cap(c.bufRef)]) pool.BufPool.Put(c.bufRef[:cap(c.bufRef)])
return err return nil
} }

View file

@ -88,21 +88,18 @@ func (t *Tunnel) handleUDPToRemote(packet C.UDPPacket, pc net.PacketConn, addr n
DefaultManager.Upload() <- int64(len(packet.Data())) DefaultManager.Upload() <- int64(len(packet.Data()))
} }
func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, omitSrcAddr bool, timeout time.Duration) { func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string) {
buf := pool.BufPool.Get().([]byte) buf := pool.BufPool.Get().([]byte)
defer pool.BufPool.Put(buf[:cap(buf)]) defer pool.BufPool.Put(buf[:cap(buf)])
defer t.natTable.Delete(key) defer t.natTable.Delete(key)
defer pc.Close() defer pc.Close()
for { for {
pc.SetReadDeadline(time.Now().Add(timeout)) pc.SetReadDeadline(time.Now().Add(udpTimeout))
n, from, err := pc.ReadFrom(buf) n, from, err := pc.ReadFrom(buf)
if err != nil { if err != nil {
return return
} }
if from != nil && omitSrcAddr {
from = nil
}
n, err = packet.WriteBack(buf[:n], from) n, err = packet.WriteBack(buf[:n], from)
if err != nil { if err != nil {

View file

@ -175,11 +175,10 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) {
return return
} }
src := packet.LocalAddr().String() key := packet.LocalAddr().String()
dst := metadata.RemoteAddress()
key := src + "-" + dst
pc, addr := t.natTable.Get(key) pc := t.natTable.Get(key)
addr := metadata.UDPAddr()
if pc != nil { if pc != nil {
t.handleUDPToRemote(packet, pc, addr) t.handleUDPToRemote(packet, pc, addr)
return return
@ -188,8 +187,6 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) {
lockKey := key + "-lock" lockKey := key + "-lock"
wg, loaded := t.natTable.GetOrCreateLock(lockKey) wg, loaded := t.natTable.GetOrCreateLock(lockKey)
isFakeIP := dns.DefaultResolver != nil && dns.DefaultResolver.IsFakeIP(metadata.DstIP)
go func() { go func() {
if !loaded { if !loaded {
wg.Add(1) wg.Add(1)
@ -201,14 +198,13 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) {
return return
} }
rawPc, nAddr, err := proxy.DialUDP(metadata) rawPc, err := proxy.DialUDP(metadata)
if err != nil { if err != nil {
log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error()) log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error())
t.natTable.Delete(lockKey) t.natTable.Delete(lockKey)
wg.Done() wg.Done()
return return
} }
addr = nAddr
pc = newUDPTracker(rawPc, DefaultManager, metadata, rule) pc = newUDPTracker(rawPc, DefaultManager, metadata, rule)
if rule != nil { if rule != nil {
@ -217,15 +213,14 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) {
log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String()) log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String())
} }
t.natTable.Set(key, pc, addr) t.natTable.Set(key, pc)
t.natTable.Delete(lockKey) t.natTable.Delete(lockKey)
wg.Done() wg.Done()
// in fake-ip mode, Full-Cone NAT can never achieve, fallback to omitting src Addr go t.handleUDPToLocal(packet.UDPPacket, pc, key)
go t.handleUDPToLocal(packet.UDPPacket, pc, key, isFakeIP, udpTimeout)
} }
wg.Wait() wg.Wait()
pc, addr := t.natTable.Get(key) pc := t.natTable.Get(key)
if pc != nil { if pc != nil {
t.handleUDPToRemote(packet, pc, addr) t.handleUDPToRemote(packet, pc, addr)
} }