Fix: some UDP issues (#265)

This commit is contained in:
Jason Lyu 2019-10-11 20:11:18 +08:00 committed by Dreamacro
parent 0f63682bdf
commit 4cd8b6f24f
15 changed files with 228 additions and 243 deletions

View file

@ -30,7 +30,7 @@ func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) {
return nil, nil, err return nil, nil, err
} }
addr, err := resolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) addr, err := resolveUDPAddr("udp", metadata.RemoteAddress())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -58,7 +58,7 @@ func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error {
var buf bytes.Buffer var buf bytes.Buffer
var err error var err error
addr := net.JoinHostPort(metadata.String(), metadata.DstPort) addr := metadata.RemoteAddress()
buf.WriteString("CONNECT " + addr + " HTTP/1.1\r\n") buf.WriteString("CONNECT " + addr + " HTTP/1.1\r\n")
buf.WriteString("Host: " + metadata.String() + "\r\n") buf.WriteString("Host: " + metadata.String() + "\r\n")
buf.WriteString("Proxy-Connection: Keep-Alive\r\n") buf.WriteString("Proxy-Connection: Keep-Alive\r\n")

View file

@ -7,7 +7,6 @@ import (
"net" "net"
"strconv" "strconv"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/common/structure" "github.com/Dreamacro/clash/common/structure"
obfs "github.com/Dreamacro/clash/component/simple-obfs" obfs "github.com/Dreamacro/clash/component/simple-obfs"
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
@ -93,9 +92,9 @@ func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, er
return nil, nil, err return nil, nil, err
} }
targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) targetAddr := socks5.ParseAddr(metadata.RemoteAddress())
if err != nil { if targetAddr == nil {
return nil, nil, err return nil, nil, fmt.Errorf("parse address error: %v:%v", metadata.String(), metadata.DstPort)
} }
pc = ss.cipher.PacketConn(pc) pc = ss.cipher.PacketConn(pc)
@ -189,16 +188,15 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
type ssUDPConn struct { type ssUDPConn struct {
net.PacketConn net.PacketConn
rAddr net.Addr rAddr socks5.Addr
} }
func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
buf := pool.BufPool.Get().([]byte) packet, err := socks5.EncodeUDPPacket(uc.rAddr, b)
defer pool.BufPool.Put(buf[:cap(buf)]) if err != nil {
rAddr := socks5.ParseAddr(uc.rAddr.String()) return
copy(buf[len(rAddr):], b) }
copy(buf, rAddr) return uc.PacketConn.WriteTo(packet[3:], addr)
return uc.PacketConn.WriteTo(buf[:len(rAddr)+len(b)], addr)
} }
func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {

View file

@ -98,9 +98,9 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err
return return
} }
targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) targetAddr := socks5.ParseAddr(metadata.RemoteAddress())
if err != nil { if targetAddr == nil {
return return nil, nil, fmt.Errorf("parse address error: %v:%v", metadata.String(), metadata.DstPort)
} }
pc, err := net.ListenPacket("udp", "") pc, err := net.ListenPacket("udp", "")
@ -146,12 +146,12 @@ func NewSocks5(option Socks5Option) *Socks5 {
type socksUDPConn struct { type socksUDPConn struct {
net.PacketConn net.PacketConn
rAddr net.Addr rAddr socks5.Addr
tcpConn net.Conn tcpConn net.Conn
} }
func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
packet, err := socks5.EncodeUDPPacket(uc.rAddr.String(), b) packet, err := socks5.EncodeUDPPacket(uc.rAddr, b)
if err != nil { if err != nil {
return return
} }
@ -160,12 +160,17 @@ func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
func (uc *socksUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (uc *socksUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, a, e := uc.PacketConn.ReadFrom(b) n, a, e := uc.PacketConn.ReadFrom(b)
if e != nil {
return 0, nil, e
}
addr, payload, err := socks5.DecodeUDPPacket(b) addr, payload, err := socks5.DecodeUDPPacket(b)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
// due to DecodeUDPPacket is mutable, record addr length
addrLength := len(addr)
copy(b, payload) copy(b, payload)
return n - len(addr) - 3, a, e return n - addrLength - 3, a, nil
} }
func (uc *socksUDPConn) Close() error { func (uc *socksUDPConn) Close() error {

View file

@ -86,19 +86,6 @@ func serializesSocksAddr(metadata *C.Metadata) []byte {
return bytes.Join(buf, nil) return bytes.Join(buf, nil)
} }
type fakeUDPConn struct {
net.Conn
}
func (fuc *fakeUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return fuc.Conn.Write(b)
}
func (fuc *fakeUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, err := fuc.Conn.Read(b)
return n, fuc.RemoteAddr(), err
}
func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {

View file

@ -51,7 +51,7 @@ func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) {
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("new vmess client error: %v", err) return nil, nil, fmt.Errorf("new vmess client error: %v", err)
} }
return newPacketConn(&fakeUDPConn{Conn: c}, v), c.RemoteAddr(), nil return newPacketConn(&vmessUDPConn{Conn: c}, v), c.RemoteAddr(), nil
} }
func NewVmess(option VmessOption) (*Vmess, error) { func NewVmess(option VmessOption) (*Vmess, error) {
@ -111,3 +111,16 @@ func parseVmessAddr(metadata *C.Metadata) *vmess.DstAddr {
Port: uint(port), Port: uint(port),
} }
} }
type vmessUDPConn struct {
net.Conn
}
func (uc *vmessUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return uc.Conn.Write(b)
}
func (uc *vmessUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, err := uc.Conn.Read(b)
return n, uc.RemoteAddr(), err
}

View file

@ -1,98 +0,0 @@
package nat
import (
"net"
"runtime"
"sync"
"time"
)
type Table struct {
*table
}
type table struct {
mapping sync.Map
janitor *janitor
timeout time.Duration
}
type element struct {
Expired time.Time
RemoteAddr net.Addr
RemoteConn net.PacketConn
}
func (t *table) Set(key net.Addr, rConn net.PacketConn, rAddr net.Addr) {
// set conn read timeout
rConn.SetReadDeadline(time.Now().Add(t.timeout))
t.mapping.Store(key, &element{
RemoteAddr: rAddr,
RemoteConn: rConn,
Expired: time.Now().Add(t.timeout),
})
}
func (t *table) Get(key net.Addr) (rConn net.PacketConn, rAddr net.Addr) {
item, exist := t.mapping.Load(key)
if !exist {
return
}
elm := item.(*element)
// expired
if time.Since(elm.Expired) > 0 {
t.mapping.Delete(key)
elm.RemoteConn.Close()
return
}
// reset expired time
elm.Expired = time.Now().Add(t.timeout)
return elm.RemoteConn, elm.RemoteAddr
}
func (t *table) cleanup() {
t.mapping.Range(func(k, v interface{}) bool {
key := k.(net.Addr)
elm := v.(*element)
if time.Since(elm.Expired) > 0 {
t.mapping.Delete(key)
elm.RemoteConn.Close()
}
return true
})
}
type janitor struct {
interval time.Duration
stop chan struct{}
}
func (j *janitor) process(t *table) {
ticker := time.NewTicker(j.interval)
for {
select {
case <-ticker.C:
t.cleanup()
case <-j.stop:
ticker.Stop()
return
}
}
}
func stopJanitor(t *Table) {
t.janitor.stop <- struct{}{}
}
// New return *Cache
func New(interval time.Duration) *Table {
j := &janitor{
interval: interval,
stop: make(chan struct{}),
}
t := &table{janitor: j, timeout: interval}
go j.process(t)
T := &Table{t}
runtime.SetFinalizer(T, stopJanitor)
return T
}

46
component/nat/table.go Normal file
View file

@ -0,0 +1,46 @@
package nat
import (
"net"
"sync"
)
type Table struct {
mapping sync.Map
}
type element struct {
RemoteAddr net.Addr
RemoteConn net.PacketConn
}
func (t *Table) Set(key string, pc net.PacketConn, addr net.Addr) {
// set conn read timeout
t.mapping.Store(key, &element{
RemoteConn: pc,
RemoteAddr: addr,
})
}
func (t *Table) Get(key string) (net.PacketConn, net.Addr) {
item, exist := t.mapping.Load(key)
if !exist {
return nil, nil
}
elm := item.(*element)
return elm.RemoteConn, elm.RemoteAddr
}
func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) {
item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{})
return item.(*sync.WaitGroup), loaded
}
func (t *Table) Delete(key string) {
t.mapping.Delete(key)
}
// New return *Cache
func New() *Table {
return &Table{}
}

View file

@ -338,6 +338,7 @@ func ParseAddr(s string) Addr {
return addr return addr
} }
// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
if len(packet) < 5 { if len(packet) < 5 {
err = errors.New("insufficient length of packet") err = errors.New("insufficient length of packet")
@ -360,16 +361,15 @@ func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
err = errors.New("failed to read UDP header") err = errors.New("failed to read UDP header")
} }
payload = bytes.Join([][]byte{packet[3+len(addr):]}, []byte{}) payload = packet[3+len(addr):]
return return
} }
func EncodeUDPPacket(addr string, payload []byte) (packet []byte, err error) { func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
rAddr := ParseAddr(addr) if addr == nil {
if rAddr == nil { err = errors.New("address is invalid")
err = errors.New("cannot parse addr")
return return
} }
packet = bytes.Join([][]byte{{0, 0, 0}, rAddr, payload}, []byte{}) packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{})
return return
} }

View file

@ -41,11 +41,18 @@ type Metadata struct {
Host string Host string
} }
func (m *Metadata) RemoteAddress() string {
return net.JoinHostPort(m.String(), m.DstPort)
}
func (m *Metadata) String() string { func (m *Metadata) String() string {
if m.Host == "" { if m.Host != "" {
return m.Host
} else if m.DstIP != nil {
return m.DstIP.String() return m.DstIP.String()
} else {
return "<nil>"
} }
return m.Host
} }
func (m *Metadata) Valid() bool { func (m *Metadata) Valid() bool {

View file

@ -1,17 +1,13 @@
package socks package socks
import ( import (
"bytes"
"net" "net"
adapters "github.com/Dreamacro/clash/adapters/inbound" adapters "github.com/Dreamacro/clash/adapters/inbound"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/tunnel"
)
var (
_ = tunnel.NATInstance()
) )
type SockUDPListener struct { type SockUDPListener struct {
@ -28,17 +24,17 @@ func NewSocksUDPProxy(addr string) (*SockUDPListener, error) {
sl := &SockUDPListener{l, addr, false} sl := &SockUDPListener{l, addr, false}
go func() { go func() {
buf := pool.BufPool.Get().([]byte)
defer pool.BufPool.Put(buf[:cap(buf)])
for { for {
buf := pool.BufPool.Get().([]byte)
n, remoteAddr, err := l.ReadFrom(buf) n, remoteAddr, err := l.ReadFrom(buf)
if err != nil { if err != nil {
pool.BufPool.Put(buf[:cap(buf)])
if sl.closed { if sl.closed {
break break
} }
continue continue
} }
go handleSocksUDP(l, buf[:n], remoteAddr) handleSocksUDP(l, buf[:n], remoteAddr)
} }
}() }()
@ -54,12 +50,19 @@ func (l *SockUDPListener) Address() string {
return l.address return l.address
} }
func handleSocksUDP(c net.PacketConn, packet []byte, remoteAddr net.Addr) { func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) {
target, payload, err := socks5.DecodeUDPPacket(packet) target, payload, err := socks5.DecodeUDPPacket(buf)
if err != nil { if err != nil {
// Unresolved UDP packet, do nothing // Unresolved UDP packet, return buffer to the pool
pool.BufPool.Put(buf[:cap(buf)])
return return
} }
conn := newfakeConn(c, target.String(), remoteAddr, payload) conn := &fakeConn{
PacketConn: pc,
remoteAddr: addr,
targetAddr: target,
buffer: bytes.NewBuffer(payload),
bufRef: buf,
}
tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP)) tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP))
} }

View file

@ -4,24 +4,16 @@ import (
"bytes" "bytes"
"net" "net"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
) )
type fakeConn struct { type fakeConn struct {
net.PacketConn net.PacketConn
target string
remoteAddr net.Addr remoteAddr net.Addr
targetAddr socks5.Addr
buffer *bytes.Buffer buffer *bytes.Buffer
} bufRef []byte
func newfakeConn(conn net.PacketConn, target string, remoteAddr net.Addr, buf []byte) *fakeConn {
buffer := bytes.NewBuffer(buf)
return &fakeConn{
PacketConn: conn,
target: target,
buffer: buffer,
remoteAddr: remoteAddr,
}
} }
func (c *fakeConn) Read(b []byte) (n int, err error) { func (c *fakeConn) Read(b []byte) (n int, err error) {
@ -29,7 +21,7 @@ func (c *fakeConn) Read(b []byte) (n int, err error) {
} }
func (c *fakeConn) Write(b []byte) (n int, err error) { func (c *fakeConn) Write(b []byte) (n int, err error) {
packet, err := socks5.EncodeUDPPacket(c.target, b) packet, err := socks5.EncodeUDPPacket(c.targetAddr, b)
if err != nil { if err != nil {
return return
} }
@ -39,3 +31,9 @@ func (c *fakeConn) Write(b []byte) (n int, err error) {
func (c *fakeConn) RemoteAddr() net.Addr { func (c *fakeConn) RemoteAddr() net.Addr {
return c.remoteAddr return c.remoteAddr
} }
func (c *fakeConn) Close() error {
err := c.PacketConn.Close()
pool.BufPool.Put(c.bufRef[:cap(c.bufRef)])
return err
}

View file

@ -86,11 +86,14 @@ func (t *Tunnel) handleUDPToRemote(conn net.Conn, pc net.PacketConn, addr net.Ad
t.traffic.Up() <- int64(n) t.traffic.Up() <- int64(n)
} }
func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn) { func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, timeout time.Duration) {
buf := pool.BufPool.Get().([]byte) buf := pool.BufPool.Get().([]byte)
defer pool.BufPool.Put(buf[:cap(buf)]) defer pool.BufPool.Put(buf[:cap(buf)])
defer t.natTable.Delete(key)
defer pc.Close()
for { for {
pc.SetReadDeadline(time.Now().Add(timeout))
n, _, err := pc.ReadFrom(buf) n, _, err := pc.ReadFrom(buf)
if err != nil { if err != nil {
return return

View file

@ -1,22 +0,0 @@
package tunnel
import (
"sync"
"time"
nat "github.com/Dreamacro/clash/component/nat-table"
)
var (
natTable *nat.Table
natOnce sync.Once
natTimeout = 120 * time.Second
)
func NATInstance() *nat.Table {
natOnce.Do(func() {
natTable = nat.New(natTimeout)
})
return natTable
}

View file

@ -7,6 +7,7 @@ import (
"time" "time"
InboundAdapter "github.com/Dreamacro/clash/adapters/inbound" InboundAdapter "github.com/Dreamacro/clash/adapters/inbound"
"github.com/Dreamacro/clash/component/nat"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
@ -17,11 +18,16 @@ import (
var ( var (
tunnel *Tunnel tunnel *Tunnel
once sync.Once once sync.Once
// default timeout for UDP session
udpTimeout = 60 * time.Second
) )
// Tunnel handle relay inbound proxy and outbound proxy // Tunnel handle relay inbound proxy and outbound proxy
type Tunnel struct { type Tunnel struct {
queue *channels.InfiniteChannel tcpQueue *channels.InfiniteChannel
udpQueue *channels.InfiniteChannel
natTable *nat.Table
rules []C.Rule rules []C.Rule
proxies map[string]C.Proxy proxies map[string]C.Proxy
configMux *sync.RWMutex configMux *sync.RWMutex
@ -36,7 +42,12 @@ type Tunnel struct {
// Add request to queue // Add request to queue
func (t *Tunnel) Add(req C.ServerAdapter) { func (t *Tunnel) Add(req C.ServerAdapter) {
t.queue.In() <- req switch req.Metadata().NetWork {
case C.TCP:
t.tcpQueue.In() <- req
case C.UDP:
t.udpQueue.In() <- req
}
} }
// Traffic return traffic of all connections // Traffic return traffic of all connections
@ -86,11 +97,18 @@ func (t *Tunnel) SetMode(mode Mode) {
} }
func (t *Tunnel) process() { func (t *Tunnel) process() {
queue := t.queue.Out() go func() {
for { queue := t.udpQueue.Out()
elm := <-queue for elm := range queue {
conn := elm.(C.ServerAdapter)
t.handleUDPConn(conn)
}
}()
queue := t.tcpQueue.Out()
for elm := range queue {
conn := elm.(C.ServerAdapter) conn := elm.(C.ServerAdapter)
go t.handleConn(conn) go t.handleTCPConn(conn)
} }
} }
@ -102,26 +120,7 @@ func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool {
return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil
} }
func (t *Tunnel) handleConn(localConn C.ServerAdapter) { func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
defer func() {
var conn net.Conn
switch adapter := localConn.(type) {
case *InboundAdapter.HTTPAdapter:
conn = adapter.Conn
case *InboundAdapter.SocketAdapter:
conn = adapter.Conn
}
if _, ok := conn.(*net.TCPConn); ok {
localConn.Close()
}
}()
metadata := localConn.Metadata()
if !metadata.Valid() {
log.Warnln("[Metadata] not valid: %#v", metadata)
return
}
// preprocess enhanced-mode metadata // preprocess enhanced-mode metadata
if t.needLookupIP(metadata) { if t.needLookupIP(metadata) {
host, exist := dns.DefaultResolver.IPToHost(*metadata.DstIP) host, exist := dns.DefaultResolver.IPToHost(*metadata.DstIP)
@ -146,43 +145,87 @@ func (t *Tunnel) handleConn(localConn C.ServerAdapter) {
var err error var err error
proxy, rule, err = t.match(metadata) proxy, rule, err = t.match(metadata)
if err != nil { if err != nil {
return return nil, nil, err
} }
} }
return proxy, rule, nil
switch metadata.NetWork {
case C.TCP:
t.handleTCPConn(localConn, metadata, proxy, rule)
case C.UDP:
t.handleUDPConn(localConn, metadata, proxy, rule)
}
} }
func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter, metadata *C.Metadata, proxy C.Proxy, rule C.Rule) { func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) {
pc, addr := natTable.Get(localConn.RemoteAddr()) metadata := localConn.Metadata()
if pc == nil { if !metadata.Valid() {
rawPc, nAddr, err := proxy.DialUDP(metadata) log.Warnln("[Metadata] not valid: %#v", metadata)
addr = nAddr return
pc = rawPc
if err != nil {
log.Warnln("dial %s error: %s", proxy.Name(), err.Error())
return
}
if rule != nil {
log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String())
} else {
log.Infoln("%s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String())
}
natTable.Set(localConn.RemoteAddr(), pc, addr)
go t.handleUDPToLocal(localConn, pc)
} }
t.handleUDPToRemote(localConn, pc, addr) src := localConn.RemoteAddr().String()
dst := metadata.RemoteAddress()
key := src + "-" + dst
pc, addr := t.natTable.Get(key)
if pc != nil {
t.handleUDPToRemote(localConn, pc, addr)
return
}
lockKey := key + "-lock"
wg, loaded := t.natTable.GetOrCreateLock(lockKey)
go func() {
if !loaded {
wg.Add(1)
proxy, rule, err := t.resolveMetadata(metadata)
if err != nil {
log.Warnln("Parse metadata failed: %s", err.Error())
t.natTable.Delete(lockKey)
wg.Done()
return
}
rawPc, nAddr, err := proxy.DialUDP(metadata)
if err != nil {
log.Warnln("dial %s error: %s", proxy.Name(), err.Error())
t.natTable.Delete(lockKey)
wg.Done()
return
}
pc = rawPc
addr = nAddr
if rule != nil {
log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String())
} else {
log.Infoln("%s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String())
}
t.natTable.Set(key, pc, addr)
t.natTable.Delete(lockKey)
wg.Done()
go t.handleUDPToLocal(localConn, pc, key, udpTimeout)
}
wg.Wait()
pc, addr := t.natTable.Get(key)
if pc != nil {
t.handleUDPToRemote(localConn, pc, addr)
}
}()
} }
func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter, metadata *C.Metadata, proxy C.Proxy, rule C.Rule) { func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) {
defer localConn.Close()
metadata := localConn.Metadata()
if !metadata.Valid() {
log.Warnln("[Metadata] not valid: %#v", metadata)
return
}
proxy, rule, err := t.resolveMetadata(metadata)
if err != nil {
log.Warnln("Parse metadata failed: %v", err)
return
}
remoteConn, err := proxy.Dial(metadata) remoteConn, err := proxy.Dial(metadata)
if err != nil { if err != nil {
log.Warnln("dial %s error: %s", proxy.Name(), err.Error()) log.Warnln("dial %s error: %s", proxy.Name(), err.Error())
@ -253,7 +296,9 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
func newTunnel() *Tunnel { func newTunnel() *Tunnel {
return &Tunnel{ return &Tunnel{
queue: channels.NewInfiniteChannel(), tcpQueue: channels.NewInfiniteChannel(),
udpQueue: channels.NewInfiniteChannel(),
natTable: nat.New(),
proxies: make(map[string]C.Proxy), proxies: make(map[string]C.Proxy),
configMux: &sync.RWMutex{}, configMux: &sync.RWMutex{},
traffic: C.NewTraffic(time.Second), traffic: C.NewTraffic(time.Second),