diff --git a/adapter/outbound/snell.go b/adapter/outbound/snell.go index 3b0dd4bc..16791836 100644 --- a/adapter/outbound/snell.go +++ b/adapter/outbound/snell.go @@ -27,6 +27,7 @@ type SnellOption struct { Server string `proxy:"server"` Port int `proxy:"port"` Psk string `proxy:"psk"` + UDP bool `proxy:"udp,omitempty"` Version int `proxy:"version,omitempty"` ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"` } @@ -85,6 +86,24 @@ func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata, opts ...d return NewConn(c, s), err } +// ListenPacketContext implements C.ProxyAdapter +func (s *Snell) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { + c, err := dialer.DialContext(ctx, "tcp", s.addr, s.Base.DialOptions(opts...)...) + if err != nil { + return nil, err + } + tcpKeepAlive(c) + c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption}) + + err = snell.WriteUDPHeader(c, s.version) + if err != nil { + return nil, err + } + + pc := snell.PacketConn(c) + return newPacketConn(pc, s), nil +} + func NewSnell(option SnellOption) (*Snell, error) { addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) psk := []byte(option.Psk) @@ -106,7 +125,13 @@ func NewSnell(option SnellOption) (*Snell, error) { if option.Version == 0 { option.Version = snell.DefaultSnellVersion } - if option.Version != snell.Version1 && option.Version != snell.Version2 { + switch option.Version { + case snell.Version1, snell.Version2: + if option.UDP { + return nil, fmt.Errorf("snell version %d not support UDP", option.Version) + } + case snell.Version3: + default: return nil, fmt.Errorf("snell version error: %d", option.Version) } @@ -115,6 +140,7 @@ func NewSnell(option SnellOption) (*Snell, error) { name: option.Name, addr: addr, tp: C.Snell, + udp: option.UDP, iface: option.Interface, }, psk: psk, diff --git a/test/clash_test.go b/test/clash_test.go index 5eb9d5bd..ade2ed76 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -32,7 +32,7 @@ const ( ImageVmess = "v2fly/v2fly-core:latest" ImageTrojan = "trojangfw/trojan:latest" ImageTrojanGo = "p4gefau1t/trojan-go:latest" - ImageSnell = "icpz/snell-server:latest" + ImageSnell = "ghcr.io/icpz/snell-server:latest" ImageXray = "teddysun/xray:latest" ) diff --git a/test/snell_test.go b/test/snell_test.go index f9cd610c..e03e6dd5 100644 --- a/test/snell_test.go +++ b/test/snell_test.go @@ -120,6 +120,42 @@ func TestClash_Snell(t *testing.T) { testSuit(t, proxy) } +func TestClash_Snellv3(t *testing.T) { + cfg := &container.Config{ + Image: ImageSnell, + ExposedPorts: defaultExposedPorts, + Cmd: []string{"-c", "/config.conf"}, + } + hostCfg := &container.HostConfig{ + PortBindings: defaultPortBindings, + Binds: []string{fmt.Sprintf("%s:/config.conf", C.Path.Resolve("snell.conf"))}, + } + + id, err := startContainer(cfg, hostCfg, "snell") + if err != nil { + assert.FailNow(t, err.Error()) + } + + t.Cleanup(func() { + cleanContainer(id) + }) + + proxy, err := outbound.NewSnell(outbound.SnellOption{ + Name: "snell", + Server: localIP.String(), + Port: 10002, + Psk: "password", + UDP: true, + Version: 3, + }) + if err != nil { + assert.FailNow(t, err.Error()) + } + + time.Sleep(waitTime) + testSuit(t, proxy) +} + func Benchmark_Snell(b *testing.B) { cfg := &container.Config{ Image: ImageSnell, diff --git a/transport/snell/snell.go b/transport/snell/snell.go index 64807b81..4cd5fba8 100644 --- a/transport/snell/snell.go +++ b/transport/snell/snell.go @@ -6,8 +6,10 @@ import ( "fmt" "io" "net" + "sync" "github.com/Dreamacro/clash/common/pool" + "github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/go-shadowsocks2/shadowaead" ) @@ -15,13 +17,19 @@ import ( const ( Version1 = 1 Version2 = 2 + Version3 = 3 DefaultSnellVersion = Version1 + + // max packet length + maxLength = 0x3FFF ) const ( - CommandPing byte = 0 - CommandConnect byte = 1 - CommandConnectV2 byte = 5 + CommandPing byte = 0 + CommandConnect byte = 1 + CommandConnectV2 byte = 5 + CommandUDP byte = 6 + CommondUDPForward byte = 1 CommandTunnel byte = 0 CommandPong byte = 1 @@ -100,6 +108,16 @@ func WriteHeader(conn net.Conn, host string, port uint, version int) error { return nil } +func WriteUDPHeader(conn net.Conn, version int) error { + if version < Version3 { + return errors.New("unsupport UDP version") + } + + // version, command, clientID length + _, err := conn.Write([]byte{Version, CommandUDP, 0x00}) + return err +} + // HalfClose works only on version2 func HalfClose(conn net.Conn) error { if _, err := conn.Write(endSignal); err != nil { @@ -114,10 +132,147 @@ func HalfClose(conn net.Conn) error { func StreamConn(conn net.Conn, psk []byte, version int) *Snell { var cipher shadowaead.Cipher - if version == Version2 { + if version != Version1 { cipher = NewAES128GCM(psk) } else { cipher = NewChacha20Poly1305(psk) } return &Snell{Conn: shadowaead.NewConn(conn, cipher)} } + +func PacketConn(conn net.Conn) net.PacketConn { + return &packetConn{ + Conn: conn, + } +} + +func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { + buf := pool.GetBuffer() + defer pool.PutBuffer(buf) + + // compose snell UDP address format (refer: icpz/snell-server-reversed) + // a brand new wheel to replace socks5 address format, well done Yachen + buf.WriteByte(CommondUDPForward) + switch socks5Addr[0] { + case socks5.AtypDomainName: + hostLen := socks5Addr[1] + buf.Write(socks5Addr[1 : 1+1+hostLen+2]) + case socks5.AtypIPv4: + buf.Write([]byte{0x00, 0x04}) + buf.Write(socks5Addr[1 : 1+net.IPv4len+2]) + case socks5.AtypIPv6: + buf.Write([]byte{0x00, 0x06}) + buf.Write(socks5Addr[1 : 1+net.IPv6len+2]) + } + + buf.Write(payload) + _, err := w.Write(buf.Bytes()) + if err != nil { + return 0, err + } + return len(payload), nil +} + +func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { + if len(payload) <= maxLength { + return writePacket(w, socks5Addr, payload) + } + + offset := 0 + total := len(payload) + for { + cursor := offset + maxLength + if cursor > total { + cursor = total + } + + n, err := writePacket(w, socks5Addr, payload[offset:cursor]) + if err != nil { + return offset + n, err + } + + offset = cursor + if offset == total { + break + } + } + + return total, nil +} + +func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) { + buf := pool.Get(pool.UDPBufferSize) + defer pool.Put(buf) + + n, err := r.Read(buf) + headLen := 1 + if err != nil { + return nil, 0, err + } + if n < headLen { + return nil, 0, errors.New("insufficient UDP length") + } + + // parse snell UDP response address format + switch buf[0] { + case 0x04: + headLen += net.IPv4len + 2 + if n < headLen { + err = errors.New("insufficient UDP length") + break + } + buf[0] = socks5.AtypIPv4 + case 0x06: + headLen += net.IPv6len + 2 + if n < headLen { + err = errors.New("insufficient UDP length") + break + } + buf[0] = socks5.AtypIPv6 + default: + err = errors.New("ip version invalid") + } + + if err != nil { + return nil, 0, err + } + + addr := socks5.SplitAddr(buf[0:]) + if addr == nil { + return nil, 0, errors.New("remote address invalid") + } + uAddr := addr.UDPAddr() + + length := len(payload) + if n-headLen < length { + length = n - headLen + } + copy(payload[:], buf[headLen:headLen+length]) + + return uAddr, length, nil +} + +type packetConn struct { + net.Conn + rMux sync.Mutex + wMux sync.Mutex +} + +func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { + pc.wMux.Lock() + defer pc.wMux.Unlock() + + return WritePacket(pc, socks5.ParseAddr(addr.String()), b) +} + +func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { + pc.rMux.Lock() + defer pc.rMux.Unlock() + + addr, n, err := ReadPacket(pc.Conn, b) + if err != nil { + return 0, nil, err + } + + return n, addr, nil +}