diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index b05f5830..a77c1a9a 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -92,13 +92,13 @@ func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, return nil, nil, err } - remoteAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) + targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) if err != nil { return nil, nil, err } pc = ss.cipher.PacketConn(pc) - return &ssUDPConn{PacketConn: pc, rAddr: remoteAddr}, addr, nil + return &ssUDPConn{PacketConn: pc, rAddr: targetAddr}, addr, nil } func (ss *ShadowSocks) MarshalJSON() ([]byte, error) { diff --git a/adapters/outbound/socks5.go b/adapters/outbound/socks5.go index 6ec611c6..1b08bf87 100644 --- a/adapters/outbound/socks5.go +++ b/adapters/outbound/socks5.go @@ -3,6 +3,8 @@ package adapters import ( "crypto/tls" "fmt" + "io" + "io/ioutil" "net" "strconv" @@ -51,24 +53,31 @@ func (ss *Socks5) Dial(metadata *C.Metadata) (net.Conn, error) { Password: ss.pass, } } - if err := socks5.ClientHandshake(c, serializesSocksAddr(metadata), socks5.CmdConnect, user); err != nil { + if _, err := socks5.ClientHandshake(c, serializesSocksAddr(metadata), socks5.CmdConnect, user); err != nil { return nil, err } return c, nil } -func (ss *Socks5) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { +func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ net.PacketConn, _ net.Addr, err error) { c, err := dialTimeout("tcp", ss.addr, tcpTimeout) + if err != nil { + err = fmt.Errorf("%s connect error", ss.addr) + return + } - if err == nil && ss.tls { + if ss.tls { cc := tls.Client(c, ss.tlsConfig) err = cc.Handshake() c = cc } - if err != nil { - return nil, nil, fmt.Errorf("%s connect error", ss.addr) - } + defer func() { + if err != nil { + c.Close() + } + }() + tcpKeepAlive(c) var user *socks5.User if ss.user != "" { @@ -78,10 +87,36 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error } } - if err := socks5.ClientHandshake(c, serializesSocksAddr(metadata), socks5.CmdUDPAssociate, user); err != nil { - return nil, nil, err + bindAddr, err := socks5.ClientHandshake(c, serializesSocksAddr(metadata), socks5.CmdUDPAssociate, user) + if err != nil { + err = fmt.Errorf("%v client hanshake error", err) + return } - return &fakeUDPConn{Conn: c}, c.LocalAddr(), nil + + addr, err := net.ResolveUDPAddr("udp", bindAddr.String()) + if err != nil { + return + } + + targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) + if err != nil { + return + } + + pc, err := net.ListenPacket("udp", "") + if err != nil { + return + } + + go func() { + io.Copy(ioutil.Discard, c) + c.Close() + // A UDP association terminates when the TCP connection that the UDP + // ASSOCIATE request arrived on terminates. RFC1928 + pc.Close() + }() + + return &socksUDPConn{PacketConn: pc, rAddr: targetAddr}, addr, nil } func NewSocks5(option Socks5Option) *Socks5 { @@ -108,3 +143,26 @@ func NewSocks5(option Socks5Option) *Socks5 { tlsConfig: tlsConfig, } } + +type socksUDPConn struct { + net.PacketConn + rAddr net.Addr +} + +func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + packet, err := socks5.EncodeUDPPacket(uc.rAddr.String(), b) + if err != nil { + return + } + return uc.PacketConn.WriteTo(packet, addr) +} + +func (uc *socksUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, a, e := uc.PacketConn.ReadFrom(b) + addr, payload, err := socks5.DecodeUDPPacket(b) + if err != nil { + return 0, nil, err + } + copy(b, payload) + return n - len(addr) - 3, a, e +} diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index 7d74936d..825af42c 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -48,7 +48,10 @@ func (v *Vmess) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) } tcpKeepAlive(c) c, err = v.client.New(c, parseVmessAddr(metadata)) - return &fakeUDPConn{Conn: c}, c.LocalAddr(), err + if err != nil { + return nil, nil, fmt.Errorf("new vmess client error: %v", err) + } + return &fakeUDPConn{Conn: c}, c.RemoteAddr(), nil } func NewVmess(option VmessOption) (*Vmess, error) { @@ -74,7 +77,7 @@ func NewVmess(option VmessOption) (*Vmess, error) { Base: &Base{ name: option.Name, tp: C.Vmess, - udp: option.UDP, + udp: true, }, server: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), client: client, diff --git a/component/nat-table/nat.go b/component/nat-table/nat.go new file mode 100644 index 00000000..06b38671 --- /dev/null +++ b/component/nat-table/nat.go @@ -0,0 +1,98 @@ +package nat + +import ( + "net" + "runtime" + "sync" + "time" +) + +type Table struct { + *table +} + +type table struct { + mapping sync.Map + janitor *janitor + timeout time.Duration +} + +type element struct { + Expired time.Time + RemoteAddr net.Addr + RemoteConn net.PacketConn +} + +func (t *table) Set(key net.Addr, rConn net.PacketConn, rAddr net.Addr) { + // set conn read timeout + rConn.SetReadDeadline(time.Now().Add(t.timeout)) + t.mapping.Store(key, &element{ + RemoteAddr: rAddr, + RemoteConn: rConn, + Expired: time.Now().Add(t.timeout), + }) +} + +func (t *table) Get(key net.Addr) (rConn net.PacketConn, rAddr net.Addr) { + item, exist := t.mapping.Load(key) + if !exist { + return + } + elm := item.(*element) + // expired + if time.Since(elm.Expired) > 0 { + t.mapping.Delete(key) + elm.RemoteConn.Close() + return + } + // reset expired time + elm.Expired = time.Now().Add(t.timeout) + return elm.RemoteConn, elm.RemoteAddr +} + +func (t *table) cleanup() { + t.mapping.Range(func(k, v interface{}) bool { + key := k.(net.Addr) + elm := v.(*element) + if time.Since(elm.Expired) > 0 { + t.mapping.Delete(key) + elm.RemoteConn.Close() + } + return true + }) +} + +type janitor struct { + interval time.Duration + stop chan struct{} +} + +func (j *janitor) process(t *table) { + ticker := time.NewTicker(j.interval) + for { + select { + case <-ticker.C: + t.cleanup() + case <-j.stop: + ticker.Stop() + return + } + } +} + +func stopJanitor(t *Table) { + t.janitor.stop <- struct{}{} +} + +// New return *Cache +func New(interval time.Duration) *Table { + j := &janitor{ + interval: interval, + stop: make(chan struct{}), + } + t := &table{janitor: j, timeout: interval} + go j.process(t) + T := &Table{t} + runtime.SetFinalizer(T, stopJanitor) + return T +} diff --git a/component/socks5/socks5.go b/component/socks5/socks5.go index e945fe2c..b16c3b69 100644 --- a/component/socks5/socks5.go +++ b/component/socks5/socks5.go @@ -41,7 +41,25 @@ const MaxAddrLen = 1 + 1 + 255 + 2 const MaxAuthLen = 255 // Addr represents a SOCKS address as defined in RFC 1928 section 5. -type Addr = []byte +type Addr []byte + +func (a Addr) String() string { + var host, port string + + switch a[0] { + case AtypDomainName: + host = string(a[2 : 2+int(a[1])]) + port = strconv.Itoa((int(a[2+int(a[1])]) << 8) | int(a[2+int(a[1])+1])) + case AtypIPv4: + host = net.IP(a[1 : 1+net.IPv4len]).String() + port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1])) + case AtypIPv6: + host = net.IP(a[1 : 1+net.IPv6len]).String() + port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1])) + } + + return net.JoinHostPort(host, port) +} // SOCKS errors as defined in RFC 1928 section 6. const ( @@ -138,23 +156,33 @@ func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, return } - if buf[1] != CmdConnect && buf[1] != CmdUDPAssociate { - err = ErrCommandNotSupported - return - } - command = buf[1] addr, err = readAddr(rw, buf) if err != nil { return } - // write VER REP RSV ATYP BND.ADDR BND.PORT - _, err = rw.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) + + switch command { + case CmdConnect, CmdUDPAssociate: + // Acquire server listened address info + localAddr := ParseAddr(rw.LocalAddr().String()) + if localAddr == nil { + err = ErrAddressNotSupported + } else { + // write VER REP RSV ATYP BND.ADDR BND.PORT + _, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{})) + } + case CmdBind: + fallthrough + default: + err = ErrCommandNotSupported + } + return } // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side. -func ClientHandshake(rw io.ReadWriter, addr Addr, cammand Command, user *User) error { +func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) { buf := make([]byte, MaxAddrLen) var err error @@ -165,16 +193,16 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, cammand Command, user *User) e _, err = rw.Write([]byte{5, 1, 0}) } if err != nil { - return err + return nil, err } // VER, METHOD if _, err := io.ReadFull(rw, buf[:2]); err != nil { - return err + return nil, err } if buf[0] != 5 { - return errors.New("SOCKS version error") + return nil, errors.New("SOCKS version error") } if buf[1] == 2 { @@ -187,30 +215,31 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, cammand Command, user *User) e authMsg.WriteString(user.Password) if _, err := rw.Write(authMsg.Bytes()); err != nil { - return err + return nil, err } if _, err := io.ReadFull(rw, buf[:2]); err != nil { - return err + return nil, err } if buf[1] != 0 { - return errors.New("rejected username/password") + return nil, errors.New("rejected username/password") } } else if buf[1] != 0 { - return errors.New("SOCKS need auth") + return nil, errors.New("SOCKS need auth") } // VER, CMD, RSV, ADDR - if _, err := rw.Write(bytes.Join([][]byte{{5, cammand, 0}, addr}, []byte(""))); err != nil { - return err + if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil { + return nil, err } - if _, err := io.ReadFull(rw, buf[:10]); err != nil { - return err + // VER, REP, RSV + if _, err := io.ReadFull(rw, buf[:3]); err != nil { + return nil, err } - return nil + return readAddr(rw, buf) } func readAddr(r io.Reader, b []byte) (Addr, error) { @@ -307,3 +336,39 @@ func ParseAddr(s string) Addr { return addr } + +func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { + if len(packet) < 5 { + err = errors.New("insufficient length of packet") + return + } + + // packet[0] and packet[1] are reserved + if !bytes.Equal(packet[:2], []byte{0, 0}) { + err = errors.New("reserved fields should be zero") + return + } + + if packet[2] != 0 /* fragments */ { + err = errors.New("discarding fragmented payload") + return + } + + addr = SplitAddr(packet[3:]) + if addr == nil { + err = errors.New("failed to read UDP header") + } + + payload = bytes.Join([][]byte{packet[3+len(addr):]}, []byte{}) + return +} + +func EncodeUDPPacket(addr string, payload []byte) (packet []byte, err error) { + rAddr := ParseAddr(addr) + if rAddr == nil { + err = errors.New("cannot parse addr") + return + } + packet = bytes.Join([][]byte{{0, 0, 0}, rAddr, payload}, []byte{}) + return +} diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 82d5f6e0..361afe1e 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -103,7 +103,6 @@ func updateGeneral(general *config.General) { T.Instance().SetMode(general.Mode) allowLan := general.AllowLan - P.SetAllowLan(allowLan) if err := P.ReCreateHTTP(general.Port); err != nil { diff --git a/proxy/listener.go b/proxy/listener.go index 7fbcfd77..98ef4ec8 100644 --- a/proxy/listener.go +++ b/proxy/listener.go @@ -13,9 +13,10 @@ import ( var ( allowLan = false - socksListener *socks.SockListener - httpListener *http.HttpListener - redirListener *redir.RedirListener + socksListener *socks.SockListener + socksUDPListener *socks.SockUDPListener + httpListener *http.HttpListener + redirListener *redir.RedirListener ) type listener interface { @@ -82,6 +83,30 @@ func ReCreateSocks(port int) error { return err } + return reCreateSocksUDP(port) +} + +func reCreateSocksUDP(port int) error { + addr := genAddr(port, allowLan) + + if socksUDPListener != nil { + if socksUDPListener.Address() == addr { + return nil + } + socksUDPListener.Close() + socksUDPListener = nil + } + + if portIsZero(addr) { + return nil + } + + var err error + socksUDPListener, err = socks.NewSocksUDPProxy(addr) + if err != nil { + return err + } + return nil } diff --git a/proxy/socks/tcp.go b/proxy/socks/tcp.go index 47bfa037..1d080dd6 100644 --- a/proxy/socks/tcp.go +++ b/proxy/socks/tcp.go @@ -1,6 +1,8 @@ package socks import ( + "io" + "io/ioutil" "net" adapters "github.com/Dreamacro/clash/adapters/inbound" @@ -62,7 +64,8 @@ func handleSocks(conn net.Conn) { } conn.(*net.TCPConn).SetKeepAlive(true) if command == socks5.CmdUDPAssociate { - tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP)) + defer conn.Close() + io.Copy(ioutil.Discard, conn) return } tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.TCP)) diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go new file mode 100644 index 00000000..63b63ba5 --- /dev/null +++ b/proxy/socks/udp.go @@ -0,0 +1,65 @@ +package socks + +import ( + "net" + + adapters "github.com/Dreamacro/clash/adapters/inbound" + "github.com/Dreamacro/clash/common/pool" + "github.com/Dreamacro/clash/component/socks5" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/tunnel" +) + +var ( + _ = tunnel.NATInstance() +) + +type SockUDPListener struct { + net.PacketConn + address string + closed bool +} + +func NewSocksUDPProxy(addr string) (*SockUDPListener, error) { + l, err := net.ListenPacket("udp", addr) + if err != nil { + return nil, err + } + + sl := &SockUDPListener{l, addr, false} + go func() { + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf[:cap(buf)]) + for { + n, remoteAddr, err := l.ReadFrom(buf) + if err != nil { + if sl.closed { + break + } + continue + } + go handleSocksUDP(l, buf[:n], remoteAddr) + } + }() + + return sl, nil +} + +func (l *SockUDPListener) Close() error { + l.closed = true + return l.PacketConn.Close() +} + +func (l *SockUDPListener) Address() string { + return l.address +} + +func handleSocksUDP(c net.PacketConn, packet []byte, remoteAddr net.Addr) { + target, payload, err := socks5.DecodeUDPPacket(packet) + if err != nil { + // Unresolved UDP packet, do nothing + return + } + conn := newfakeConn(c, target.String(), remoteAddr, payload) + tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP)) +} diff --git a/proxy/socks/utils.go b/proxy/socks/utils.go new file mode 100644 index 00000000..2ecee3c3 --- /dev/null +++ b/proxy/socks/utils.go @@ -0,0 +1,41 @@ +package socks + +import ( + "bytes" + "net" + + "github.com/Dreamacro/clash/component/socks5" +) + +type fakeConn struct { + net.PacketConn + target string + remoteAddr net.Addr + buffer *bytes.Buffer +} + +func newfakeConn(conn net.PacketConn, target string, remoteAddr net.Addr, buf []byte) *fakeConn { + buffer := bytes.NewBuffer(buf) + return &fakeConn{ + PacketConn: conn, + target: target, + buffer: buffer, + remoteAddr: remoteAddr, + } +} + +func (c *fakeConn) Read(b []byte) (n int, err error) { + return c.buffer.Read(b) +} + +func (c *fakeConn) Write(b []byte) (n int, err error) { + packet, err := socks5.EncodeUDPPacket(c.target, b) + if err != nil { + return + } + return c.PacketConn.WriteTo(packet, c.remoteAddr) +} + +func (c *fakeConn) RemoteAddr() net.Addr { + return c.remoteAddr +} diff --git a/tunnel/connection.go b/tunnel/connection.go index a8928a62..36c698d2 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -63,54 +63,43 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { } } +func (t *Tunnel) handleUDPToRemote(conn net.Conn, pc net.PacketConn, addr net.Addr) { + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf[:cap(buf)]) + + n, err := conn.Read(buf) + if err != nil { + return + } + if _, err = pc.WriteTo(buf[:n], addr); err != nil { + return + } + t.traffic.Up() <- int64(n) +} + +func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn) { + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf[:cap(buf)]) + + for { + n, _, err := pc.ReadFrom(buf) + if err != nil { + return + } + + n, err = conn.Write(buf[:n]) + if err != nil { + return + } + t.traffic.Down() <- int64(n) + } +} + func (t *Tunnel) handleSocket(request *adapters.SocketAdapter, outbound net.Conn) { conn := newTrafficTrack(outbound, t.traffic) relay(request, conn) } -func (t *Tunnel) handleUDPOverTCP(conn net.Conn, pc net.PacketConn, addr net.Addr) error { - ch := make(chan error, 1) - - go func() { - buf := pool.BufPool.Get().([]byte) - defer pool.BufPool.Put(buf) - for { - n, err := conn.Read(buf) - if err != nil { - ch <- err - return - } - pc.SetReadDeadline(time.Now().Add(120 * time.Second)) - if _, err = pc.WriteTo(buf[:n], addr); err != nil { - ch <- err - return - } - t.traffic.Up() <- int64(n) - ch <- nil - } - }() - - buf := pool.BufPool.Get().([]byte) - defer pool.BufPool.Put(buf) - - for { - pc.SetReadDeadline(time.Now().Add(120 * time.Second)) - n, _, err := pc.ReadFrom(buf) - if err != nil { - break - } - - if _, err := conn.Write(buf[:n]); err != nil { - break - } - - t.traffic.Down() <- int64(n) - } - - <-ch - return nil -} - // relay copies between left and right bidirectionally. func relay(leftConn, rightConn net.Conn) { ch := make(chan error) diff --git a/tunnel/session.go b/tunnel/session.go new file mode 100644 index 00000000..4433deae --- /dev/null +++ b/tunnel/session.go @@ -0,0 +1,22 @@ +package tunnel + +import ( + "sync" + "time" + + nat "github.com/Dreamacro/clash/component/nat-table" +) + +var ( + natTable *nat.Table + natOnce sync.Once + + natTimeout = 120 * time.Second +) + +func NATInstance() *nat.Table { + natOnce.Do(func() { + natTable = nat.New(natTimeout) + }) + return natTable +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index fedde7f2..984af7b9 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -103,9 +103,20 @@ func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { } func (t *Tunnel) handleConn(localConn C.ServerAdapter) { - defer localConn.Close() - metadata := localConn.Metadata() + defer func() { + var conn net.Conn + switch adapter := localConn.(type) { + case *InboundAdapter.HTTPAdapter: + conn = adapter.Conn + case *InboundAdapter.SocketAdapter: + conn = adapter.Conn + } + if _, ok := conn.(*net.TCPConn); ok { + localConn.Close() + } + }() + metadata := localConn.Metadata() if !metadata.Valid() { log.Warnln("[Metadata] not valid: %#v", metadata) return @@ -138,18 +149,32 @@ func (t *Tunnel) handleConn(localConn C.ServerAdapter) { } } - if metadata.NetWork == C.UDP { - pc, addr, err := proxy.DialUDP(metadata) + switch metadata.NetWork { + case C.TCP: + t.handleTCPConn(localConn, metadata, proxy) + case C.UDP: + t.handleUDPConn(localConn, metadata, proxy) + } +} + +func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter, metadata *C.Metadata, proxy C.Proxy) { + pc, addr := natTable.Get(localConn.RemoteAddr()) + if pc == nil { + var err error + pc, addr, err = proxy.DialUDP(metadata) if err != nil { log.Warnln("Proxy[%s] connect [%s --> %s] error: %s", proxy.Name(), metadata.SrcIP.String(), metadata.String(), err.Error()) return } - defer pc.Close() - t.handleUDPOverTCP(localConn, pc, addr) - return + natTable.Set(localConn.RemoteAddr(), pc, addr) + go t.handleUDPToLocal(localConn, pc) } + t.handleUDPToRemote(localConn, pc, addr) +} + +func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter, metadata *C.Metadata, proxy C.Proxy) { remoConn, err := proxy.Dial(metadata) if err != nil { log.Warnln("Proxy[%s] connect [%s --> %s] error: %s", proxy.Name(), metadata.SrcIP.String(), metadata.String(), err.Error()) @@ -196,6 +221,7 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, error) { } if metadata.NetWork == C.UDP && !adapter.SupportUDP() { + log.Debugln("%v UDP is not supported", adapter.Name()) continue }