Fix: some UDP issues (#265)
This commit is contained in:
parent
0f63682bdf
commit
4cd8b6f24f
15 changed files with 228 additions and 243 deletions
|
@ -30,7 +30,7 @@ func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) {
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
addr, err := resolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort))
|
||||
addr, err := resolveUDPAddr("udp", metadata.RemoteAddress())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error {
|
|||
var buf bytes.Buffer
|
||||
var err error
|
||||
|
||||
addr := net.JoinHostPort(metadata.String(), metadata.DstPort)
|
||||
addr := metadata.RemoteAddress()
|
||||
buf.WriteString("CONNECT " + addr + " HTTP/1.1\r\n")
|
||||
buf.WriteString("Host: " + metadata.String() + "\r\n")
|
||||
buf.WriteString("Proxy-Connection: Keep-Alive\r\n")
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/common/structure"
|
||||
obfs "github.com/Dreamacro/clash/component/simple-obfs"
|
||||
"github.com/Dreamacro/clash/component/socks5"
|
||||
|
@ -93,9 +92,9 @@ func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, er
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
targetAddr := socks5.ParseAddr(metadata.RemoteAddress())
|
||||
if targetAddr == nil {
|
||||
return nil, nil, fmt.Errorf("parse address error: %v:%v", metadata.String(), metadata.DstPort)
|
||||
}
|
||||
|
||||
pc = ss.cipher.PacketConn(pc)
|
||||
|
@ -189,16 +188,15 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
|
|||
|
||||
type ssUDPConn struct {
|
||||
net.PacketConn
|
||||
rAddr net.Addr
|
||||
rAddr socks5.Addr
|
||||
}
|
||||
|
||||
func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
defer pool.BufPool.Put(buf[:cap(buf)])
|
||||
rAddr := socks5.ParseAddr(uc.rAddr.String())
|
||||
copy(buf[len(rAddr):], b)
|
||||
copy(buf, rAddr)
|
||||
return uc.PacketConn.WriteTo(buf[:len(rAddr)+len(b)], addr)
|
||||
func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
packet, err := socks5.EncodeUDPPacket(uc.rAddr, b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return uc.PacketConn.WriteTo(packet[3:], addr)
|
||||
}
|
||||
|
||||
func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
|
|
|
@ -98,9 +98,9 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err
|
|||
return
|
||||
}
|
||||
|
||||
targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort))
|
||||
if err != nil {
|
||||
return
|
||||
targetAddr := socks5.ParseAddr(metadata.RemoteAddress())
|
||||
if targetAddr == nil {
|
||||
return nil, nil, fmt.Errorf("parse address error: %v:%v", metadata.String(), metadata.DstPort)
|
||||
}
|
||||
|
||||
pc, err := net.ListenPacket("udp", "")
|
||||
|
@ -146,12 +146,12 @@ func NewSocks5(option Socks5Option) *Socks5 {
|
|||
|
||||
type socksUDPConn struct {
|
||||
net.PacketConn
|
||||
rAddr net.Addr
|
||||
rAddr socks5.Addr
|
||||
tcpConn net.Conn
|
||||
}
|
||||
|
||||
func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
packet, err := socks5.EncodeUDPPacket(uc.rAddr.String(), b)
|
||||
packet, err := socks5.EncodeUDPPacket(uc.rAddr, b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -160,12 +160,17 @@ func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|||
|
||||
func (uc *socksUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, a, e := uc.PacketConn.ReadFrom(b)
|
||||
if e != nil {
|
||||
return 0, nil, e
|
||||
}
|
||||
addr, payload, err := socks5.DecodeUDPPacket(b)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
// due to DecodeUDPPacket is mutable, record addr length
|
||||
addrLength := len(addr)
|
||||
copy(b, payload)
|
||||
return n - len(addr) - 3, a, e
|
||||
return n - addrLength - 3, a, nil
|
||||
}
|
||||
|
||||
func (uc *socksUDPConn) Close() error {
|
||||
|
|
|
@ -86,19 +86,6 @@ func serializesSocksAddr(metadata *C.Metadata) []byte {
|
|||
return bytes.Join(buf, nil)
|
||||
}
|
||||
|
||||
type fakeUDPConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (fuc *fakeUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
return fuc.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (fuc *fakeUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, err := fuc.Conn.Read(b)
|
||||
return n, fuc.RemoteAddr(), err
|
||||
}
|
||||
|
||||
func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
|
|
|
@ -51,7 +51,7 @@ func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) {
|
|||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("new vmess client error: %v", err)
|
||||
}
|
||||
return newPacketConn(&fakeUDPConn{Conn: c}, v), c.RemoteAddr(), nil
|
||||
return newPacketConn(&vmessUDPConn{Conn: c}, v), c.RemoteAddr(), nil
|
||||
}
|
||||
|
||||
func NewVmess(option VmessOption) (*Vmess, error) {
|
||||
|
@ -111,3 +111,16 @@ func parseVmessAddr(metadata *C.Metadata) *vmess.DstAddr {
|
|||
Port: uint(port),
|
||||
}
|
||||
}
|
||||
|
||||
type vmessUDPConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (uc *vmessUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
return uc.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (uc *vmessUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, err := uc.Conn.Read(b)
|
||||
return n, uc.RemoteAddr(), err
|
||||
}
|
||||
|
|
|
@ -1,98 +0,0 @@
|
|||
package nat
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
*table
|
||||
}
|
||||
|
||||
type table struct {
|
||||
mapping sync.Map
|
||||
janitor *janitor
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
type element struct {
|
||||
Expired time.Time
|
||||
RemoteAddr net.Addr
|
||||
RemoteConn net.PacketConn
|
||||
}
|
||||
|
||||
func (t *table) Set(key net.Addr, rConn net.PacketConn, rAddr net.Addr) {
|
||||
// set conn read timeout
|
||||
rConn.SetReadDeadline(time.Now().Add(t.timeout))
|
||||
t.mapping.Store(key, &element{
|
||||
RemoteAddr: rAddr,
|
||||
RemoteConn: rConn,
|
||||
Expired: time.Now().Add(t.timeout),
|
||||
})
|
||||
}
|
||||
|
||||
func (t *table) Get(key net.Addr) (rConn net.PacketConn, rAddr net.Addr) {
|
||||
item, exist := t.mapping.Load(key)
|
||||
if !exist {
|
||||
return
|
||||
}
|
||||
elm := item.(*element)
|
||||
// expired
|
||||
if time.Since(elm.Expired) > 0 {
|
||||
t.mapping.Delete(key)
|
||||
elm.RemoteConn.Close()
|
||||
return
|
||||
}
|
||||
// reset expired time
|
||||
elm.Expired = time.Now().Add(t.timeout)
|
||||
return elm.RemoteConn, elm.RemoteAddr
|
||||
}
|
||||
|
||||
func (t *table) cleanup() {
|
||||
t.mapping.Range(func(k, v interface{}) bool {
|
||||
key := k.(net.Addr)
|
||||
elm := v.(*element)
|
||||
if time.Since(elm.Expired) > 0 {
|
||||
t.mapping.Delete(key)
|
||||
elm.RemoteConn.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
type janitor struct {
|
||||
interval time.Duration
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
func (j *janitor) process(t *table) {
|
||||
ticker := time.NewTicker(j.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.cleanup()
|
||||
case <-j.stop:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stopJanitor(t *Table) {
|
||||
t.janitor.stop <- struct{}{}
|
||||
}
|
||||
|
||||
// New return *Cache
|
||||
func New(interval time.Duration) *Table {
|
||||
j := &janitor{
|
||||
interval: interval,
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
t := &table{janitor: j, timeout: interval}
|
||||
go j.process(t)
|
||||
T := &Table{t}
|
||||
runtime.SetFinalizer(T, stopJanitor)
|
||||
return T
|
||||
}
|
46
component/nat/table.go
Normal file
46
component/nat/table.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package nat
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
mapping sync.Map
|
||||
}
|
||||
|
||||
type element struct {
|
||||
RemoteAddr net.Addr
|
||||
RemoteConn net.PacketConn
|
||||
}
|
||||
|
||||
func (t *Table) Set(key string, pc net.PacketConn, addr net.Addr) {
|
||||
// set conn read timeout
|
||||
t.mapping.Store(key, &element{
|
||||
RemoteConn: pc,
|
||||
RemoteAddr: addr,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Table) Get(key string) (net.PacketConn, net.Addr) {
|
||||
item, exist := t.mapping.Load(key)
|
||||
if !exist {
|
||||
return nil, nil
|
||||
}
|
||||
elm := item.(*element)
|
||||
return elm.RemoteConn, elm.RemoteAddr
|
||||
}
|
||||
|
||||
func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) {
|
||||
item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{})
|
||||
return item.(*sync.WaitGroup), loaded
|
||||
}
|
||||
|
||||
func (t *Table) Delete(key string) {
|
||||
t.mapping.Delete(key)
|
||||
}
|
||||
|
||||
// New return *Cache
|
||||
func New() *Table {
|
||||
return &Table{}
|
||||
}
|
|
@ -338,6 +338,7 @@ func ParseAddr(s string) Addr {
|
|||
return addr
|
||||
}
|
||||
|
||||
// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
|
||||
func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
|
||||
if len(packet) < 5 {
|
||||
err = errors.New("insufficient length of packet")
|
||||
|
@ -360,16 +361,15 @@ func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
|
|||
err = errors.New("failed to read UDP header")
|
||||
}
|
||||
|
||||
payload = bytes.Join([][]byte{packet[3+len(addr):]}, []byte{})
|
||||
payload = packet[3+len(addr):]
|
||||
return
|
||||
}
|
||||
|
||||
func EncodeUDPPacket(addr string, payload []byte) (packet []byte, err error) {
|
||||
rAddr := ParseAddr(addr)
|
||||
if rAddr == nil {
|
||||
err = errors.New("cannot parse addr")
|
||||
func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
|
||||
if addr == nil {
|
||||
err = errors.New("address is invalid")
|
||||
return
|
||||
}
|
||||
packet = bytes.Join([][]byte{{0, 0, 0}, rAddr, payload}, []byte{})
|
||||
packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -41,11 +41,18 @@ type Metadata struct {
|
|||
Host string
|
||||
}
|
||||
|
||||
func (m *Metadata) RemoteAddress() string {
|
||||
return net.JoinHostPort(m.String(), m.DstPort)
|
||||
}
|
||||
|
||||
func (m *Metadata) String() string {
|
||||
if m.Host == "" {
|
||||
if m.Host != "" {
|
||||
return m.Host
|
||||
} else if m.DstIP != nil {
|
||||
return m.DstIP.String()
|
||||
} else {
|
||||
return "<nil>"
|
||||
}
|
||||
return m.Host
|
||||
}
|
||||
|
||||
func (m *Metadata) Valid() bool {
|
||||
|
|
|
@ -1,17 +1,13 @@
|
|||
package socks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
|
||||
adapters "github.com/Dreamacro/clash/adapters/inbound"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/component/socks5"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/tunnel"
|
||||
)
|
||||
|
||||
var (
|
||||
_ = tunnel.NATInstance()
|
||||
)
|
||||
|
||||
type SockUDPListener struct {
|
||||
|
@ -28,17 +24,17 @@ func NewSocksUDPProxy(addr string) (*SockUDPListener, error) {
|
|||
|
||||
sl := &SockUDPListener{l, addr, false}
|
||||
go func() {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
defer pool.BufPool.Put(buf[:cap(buf)])
|
||||
for {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
n, remoteAddr, err := l.ReadFrom(buf)
|
||||
if err != nil {
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
if sl.closed {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
go handleSocksUDP(l, buf[:n], remoteAddr)
|
||||
handleSocksUDP(l, buf[:n], remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -54,12 +50,19 @@ func (l *SockUDPListener) Address() string {
|
|||
return l.address
|
||||
}
|
||||
|
||||
func handleSocksUDP(c net.PacketConn, packet []byte, remoteAddr net.Addr) {
|
||||
target, payload, err := socks5.DecodeUDPPacket(packet)
|
||||
func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) {
|
||||
target, payload, err := socks5.DecodeUDPPacket(buf)
|
||||
if err != nil {
|
||||
// Unresolved UDP packet, do nothing
|
||||
// Unresolved UDP packet, return buffer to the pool
|
||||
pool.BufPool.Put(buf[:cap(buf)])
|
||||
return
|
||||
}
|
||||
conn := newfakeConn(c, target.String(), remoteAddr, payload)
|
||||
conn := &fakeConn{
|
||||
PacketConn: pc,
|
||||
remoteAddr: addr,
|
||||
targetAddr: target,
|
||||
buffer: bytes.NewBuffer(payload),
|
||||
bufRef: buf,
|
||||
}
|
||||
tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP))
|
||||
}
|
||||
|
|
|
@ -4,24 +4,16 @@ import (
|
|||
"bytes"
|
||||
"net"
|
||||
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/component/socks5"
|
||||
)
|
||||
|
||||
type fakeConn struct {
|
||||
net.PacketConn
|
||||
target string
|
||||
remoteAddr net.Addr
|
||||
targetAddr socks5.Addr
|
||||
buffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func newfakeConn(conn net.PacketConn, target string, remoteAddr net.Addr, buf []byte) *fakeConn {
|
||||
buffer := bytes.NewBuffer(buf)
|
||||
return &fakeConn{
|
||||
PacketConn: conn,
|
||||
target: target,
|
||||
buffer: buffer,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
bufRef []byte
|
||||
}
|
||||
|
||||
func (c *fakeConn) Read(b []byte) (n int, err error) {
|
||||
|
@ -29,7 +21,7 @@ func (c *fakeConn) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func (c *fakeConn) Write(b []byte) (n int, err error) {
|
||||
packet, err := socks5.EncodeUDPPacket(c.target, b)
|
||||
packet, err := socks5.EncodeUDPPacket(c.targetAddr, b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -39,3 +31,9 @@ func (c *fakeConn) Write(b []byte) (n int, err error) {
|
|||
func (c *fakeConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *fakeConn) Close() error {
|
||||
err := c.PacketConn.Close()
|
||||
pool.BufPool.Put(c.bufRef[:cap(c.bufRef)])
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -86,11 +86,14 @@ func (t *Tunnel) handleUDPToRemote(conn net.Conn, pc net.PacketConn, addr net.Ad
|
|||
t.traffic.Up() <- int64(n)
|
||||
}
|
||||
|
||||
func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn) {
|
||||
func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, timeout time.Duration) {
|
||||
buf := pool.BufPool.Get().([]byte)
|
||||
defer pool.BufPool.Put(buf[:cap(buf)])
|
||||
defer t.natTable.Delete(key)
|
||||
defer pc.Close()
|
||||
|
||||
for {
|
||||
pc.SetReadDeadline(time.Now().Add(timeout))
|
||||
n, _, err := pc.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nat "github.com/Dreamacro/clash/component/nat-table"
|
||||
)
|
||||
|
||||
var (
|
||||
natTable *nat.Table
|
||||
natOnce sync.Once
|
||||
|
||||
natTimeout = 120 * time.Second
|
||||
)
|
||||
|
||||
func NATInstance() *nat.Table {
|
||||
natOnce.Do(func() {
|
||||
natTable = nat.New(natTimeout)
|
||||
})
|
||||
return natTable
|
||||
}
|
157
tunnel/tunnel.go
157
tunnel/tunnel.go
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
InboundAdapter "github.com/Dreamacro/clash/adapters/inbound"
|
||||
"github.com/Dreamacro/clash/component/nat"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/dns"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
|
@ -17,11 +18,16 @@ import (
|
|||
var (
|
||||
tunnel *Tunnel
|
||||
once sync.Once
|
||||
|
||||
// default timeout for UDP session
|
||||
udpTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
// Tunnel handle relay inbound proxy and outbound proxy
|
||||
type Tunnel struct {
|
||||
queue *channels.InfiniteChannel
|
||||
tcpQueue *channels.InfiniteChannel
|
||||
udpQueue *channels.InfiniteChannel
|
||||
natTable *nat.Table
|
||||
rules []C.Rule
|
||||
proxies map[string]C.Proxy
|
||||
configMux *sync.RWMutex
|
||||
|
@ -36,7 +42,12 @@ type Tunnel struct {
|
|||
|
||||
// Add request to queue
|
||||
func (t *Tunnel) Add(req C.ServerAdapter) {
|
||||
t.queue.In() <- req
|
||||
switch req.Metadata().NetWork {
|
||||
case C.TCP:
|
||||
t.tcpQueue.In() <- req
|
||||
case C.UDP:
|
||||
t.udpQueue.In() <- req
|
||||
}
|
||||
}
|
||||
|
||||
// Traffic return traffic of all connections
|
||||
|
@ -86,11 +97,18 @@ func (t *Tunnel) SetMode(mode Mode) {
|
|||
}
|
||||
|
||||
func (t *Tunnel) process() {
|
||||
queue := t.queue.Out()
|
||||
for {
|
||||
elm := <-queue
|
||||
go func() {
|
||||
queue := t.udpQueue.Out()
|
||||
for elm := range queue {
|
||||
conn := elm.(C.ServerAdapter)
|
||||
t.handleUDPConn(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
queue := t.tcpQueue.Out()
|
||||
for elm := range queue {
|
||||
conn := elm.(C.ServerAdapter)
|
||||
go t.handleConn(conn)
|
||||
go t.handleTCPConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,26 +120,7 @@ func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool {
|
|||
return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) handleConn(localConn C.ServerAdapter) {
|
||||
defer func() {
|
||||
var conn net.Conn
|
||||
switch adapter := localConn.(type) {
|
||||
case *InboundAdapter.HTTPAdapter:
|
||||
conn = adapter.Conn
|
||||
case *InboundAdapter.SocketAdapter:
|
||||
conn = adapter.Conn
|
||||
}
|
||||
if _, ok := conn.(*net.TCPConn); ok {
|
||||
localConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
metadata := localConn.Metadata()
|
||||
if !metadata.Valid() {
|
||||
log.Warnln("[Metadata] not valid: %#v", metadata)
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
|
||||
// preprocess enhanced-mode metadata
|
||||
if t.needLookupIP(metadata) {
|
||||
host, exist := dns.DefaultResolver.IPToHost(*metadata.DstIP)
|
||||
|
@ -146,43 +145,87 @@ func (t *Tunnel) handleConn(localConn C.ServerAdapter) {
|
|||
var err error
|
||||
proxy, rule, err = t.match(metadata)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch metadata.NetWork {
|
||||
case C.TCP:
|
||||
t.handleTCPConn(localConn, metadata, proxy, rule)
|
||||
case C.UDP:
|
||||
t.handleUDPConn(localConn, metadata, proxy, rule)
|
||||
}
|
||||
return proxy, rule, nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter, metadata *C.Metadata, proxy C.Proxy, rule C.Rule) {
|
||||
pc, addr := natTable.Get(localConn.RemoteAddr())
|
||||
if pc == nil {
|
||||
rawPc, nAddr, err := proxy.DialUDP(metadata)
|
||||
addr = nAddr
|
||||
pc = rawPc
|
||||
if err != nil {
|
||||
log.Warnln("dial %s error: %s", proxy.Name(), err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if rule != nil {
|
||||
log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String())
|
||||
} else {
|
||||
log.Infoln("%s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String())
|
||||
}
|
||||
|
||||
natTable.Set(localConn.RemoteAddr(), pc, addr)
|
||||
go t.handleUDPToLocal(localConn, pc)
|
||||
func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) {
|
||||
metadata := localConn.Metadata()
|
||||
if !metadata.Valid() {
|
||||
log.Warnln("[Metadata] not valid: %#v", metadata)
|
||||
return
|
||||
}
|
||||
|
||||
t.handleUDPToRemote(localConn, pc, addr)
|
||||
src := localConn.RemoteAddr().String()
|
||||
dst := metadata.RemoteAddress()
|
||||
key := src + "-" + dst
|
||||
|
||||
pc, addr := t.natTable.Get(key)
|
||||
if pc != nil {
|
||||
t.handleUDPToRemote(localConn, pc, addr)
|
||||
return
|
||||
}
|
||||
|
||||
lockKey := key + "-lock"
|
||||
wg, loaded := t.natTable.GetOrCreateLock(lockKey)
|
||||
go func() {
|
||||
if !loaded {
|
||||
wg.Add(1)
|
||||
proxy, rule, err := t.resolveMetadata(metadata)
|
||||
if err != nil {
|
||||
log.Warnln("Parse metadata failed: %s", err.Error())
|
||||
t.natTable.Delete(lockKey)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
|
||||
rawPc, nAddr, err := proxy.DialUDP(metadata)
|
||||
if err != nil {
|
||||
log.Warnln("dial %s error: %s", proxy.Name(), err.Error())
|
||||
t.natTable.Delete(lockKey)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
pc = rawPc
|
||||
addr = nAddr
|
||||
|
||||
if rule != nil {
|
||||
log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String())
|
||||
} else {
|
||||
log.Infoln("%s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String())
|
||||
}
|
||||
|
||||
t.natTable.Set(key, pc, addr)
|
||||
t.natTable.Delete(lockKey)
|
||||
wg.Done()
|
||||
go t.handleUDPToLocal(localConn, pc, key, udpTimeout)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
pc, addr := t.natTable.Get(key)
|
||||
if pc != nil {
|
||||
t.handleUDPToRemote(localConn, pc, addr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter, metadata *C.Metadata, proxy C.Proxy, rule C.Rule) {
|
||||
func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) {
|
||||
defer localConn.Close()
|
||||
|
||||
metadata := localConn.Metadata()
|
||||
if !metadata.Valid() {
|
||||
log.Warnln("[Metadata] not valid: %#v", metadata)
|
||||
return
|
||||
}
|
||||
|
||||
proxy, rule, err := t.resolveMetadata(metadata)
|
||||
if err != nil {
|
||||
log.Warnln("Parse metadata failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
remoteConn, err := proxy.Dial(metadata)
|
||||
if err != nil {
|
||||
log.Warnln("dial %s error: %s", proxy.Name(), err.Error())
|
||||
|
@ -253,7 +296,9 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
|
|||
|
||||
func newTunnel() *Tunnel {
|
||||
return &Tunnel{
|
||||
queue: channels.NewInfiniteChannel(),
|
||||
tcpQueue: channels.NewInfiniteChannel(),
|
||||
udpQueue: channels.NewInfiniteChannel(),
|
||||
natTable: nat.New(),
|
||||
proxies: make(map[string]C.Proxy),
|
||||
configMux: &sync.RWMutex{},
|
||||
traffic: C.NewTraffic(time.Second),
|
||||
|
|
Loading…
Reference in a new issue