chore: tuic server can handle V4 and V5 in same port

This commit is contained in:
gVisor bot 2023-06-21 13:53:37 +08:00
parent 7ff7a4745d
commit 7c04b3a096
11 changed files with 474 additions and 432 deletions

View file

@ -905,9 +905,9 @@ listeners:
listen: 0.0.0.0 listen: 0.0.0.0
# rule: sub-rule-name1 # 默认使用 rules如果未找到 sub-rule 则直接使用 rules # rule: sub-rule-name1 # 默认使用 rules如果未找到 sub-rule 则直接使用 rules
# proxy: proxy # 如果不为空则直接将该入站流量交由指定proxy处理(当proxy不为空时这里的proxy名称必须合法否则会出错) # proxy: proxy # 如果不为空则直接将该入站流量交由指定proxy处理(当proxy不为空时这里的proxy名称必须合法否则会出错)
# token: # tuicV4填写可同时填写users # token: # tuicV4填写同时填写users
# - TOKEN # - TOKEN
# users: # tuicV5填写可同时填写token # users: # tuicV5填写同时填写token
# 00000000-0000-0000-0000-000000000000: PASSWORD_0 # 00000000-0000-0000-0000-000000000000: PASSWORD_0
# 00000000-0000-0000-0000-000000000001: PASSWORD_1 # 00000000-0000-0000-0000-000000000001: PASSWORD_1
# certificate: ./server.crt # certificate: ./server.crt
@ -978,9 +978,9 @@ listeners:
# tuic-server: # tuic-server:
# enable: true # enable: true
# listen: 127.0.0.1:10443 # listen: 127.0.0.1:10443
# token: # tuicV4填写可同时填写users # token: # tuicV4填写同时填写users
# - TOKEN # - TOKEN
# users: # tuicV5填写可同时填写token # users: # tuicV5填写同时填写token
# 00000000-0000-0000-0000-000000000000: PASSWORD_0 # 00000000-0000-0000-0000-000000000000: PASSWORD_0
# 00000000-0000-0000-0000-000000000001: PASSWORD_1 # 00000000-0000-0000-0000-000000000001: PASSWORD_1
# certificate: ./server.crt # certificate: ./server.crt

View file

@ -26,7 +26,7 @@ type Listener struct {
closed bool closed bool
config LC.TuicServer config LC.TuicServer
udpListeners []net.PacketConn udpListeners []net.PacketConn
servers []tuic.Server servers []*tuic.Server
} }
func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.PacketAdapter, additions ...inbound.Addition) (*Listener, error) { func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.PacketAdapter, additions ...inbound.Addition) (*Listener, error) {
@ -102,42 +102,29 @@ func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.Packet
return nil return nil
} }
var optionV4 *tuic.ServerOptionV4 option := &tuic.ServerOption{
var optionV5 *tuic.ServerOptionV5 HandleTcpFn: handleTcpFn,
HandleUdpFn: handleUdpFn,
TlsConfig: tlsConfig,
QuicConfig: quicConfig,
CongestionController: config.CongestionController,
AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond,
MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize,
CWND: config.CWND,
}
if len(config.Token) > 0 { if len(config.Token) > 0 {
tokens := make([][32]byte, len(config.Token)) tokens := make([][32]byte, len(config.Token))
for i, token := range config.Token { for i, token := range config.Token {
tokens[i] = tuic.GenTKN(token) tokens[i] = tuic.GenTKN(token)
} }
option.Tokens = tokens
optionV4 = &tuic.ServerOptionV4{
HandleTcpFn: handleTcpFn,
HandleUdpFn: handleUdpFn,
TlsConfig: tlsConfig,
QuicConfig: quicConfig,
Tokens: tokens,
CongestionController: config.CongestionController,
AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond,
MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize,
CWND: config.CWND,
} }
} else { if len(config.Users) > 0 {
users := make(map[[16]byte]string) users := make(map[[16]byte]string)
for _uuid, password := range config.Users { for _uuid, password := range config.Users {
users[uuid.FromStringOrNil(_uuid)] = password users[uuid.FromStringOrNil(_uuid)] = password
} }
option.Users = users
optionV5 = &tuic.ServerOptionV5{
HandleTcpFn: handleTcpFn,
HandleUdpFn: handleUdpFn,
TlsConfig: tlsConfig,
QuicConfig: quicConfig,
Users: users,
CongestionController: config.CongestionController,
AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond,
MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize,
CWND: config.CWND,
}
} }
sl := &Listener{false, config, nil, nil} sl := &Listener{false, config, nil, nil}
@ -157,12 +144,8 @@ func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.Packet
sl.udpListeners = append(sl.udpListeners, ul) sl.udpListeners = append(sl.udpListeners, ul)
var server tuic.Server var server *tuic.Server
if optionV4 != nil { server, err = tuic.NewServer(option, ul)
server, err = tuic.NewServerV4(optionV4, ul)
} else {
server, err = tuic.NewServerV5(optionV5, ul)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,11 +1,13 @@
package common package common
import ( import (
"bufio"
"context" "context"
"errors" "errors"
"net" "net"
"time" "time"
N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
@ -28,9 +30,12 @@ type Client interface {
Close() Close()
} }
type Server interface { type ServerHandler interface {
Serve() error AuthOk() bool
Close() error HandleTimeout()
HandleStream(conn *N.BufferedConn) (err error)
HandleMessage(message []byte) (err error)
HandleUniStream(reader *bufio.Reader) (err error)
} }
type UdpRelayMode uint8 type UdpRelayMode uint8

234
transport/tuic/server.go Normal file
View file

@ -0,0 +1,234 @@
package tuic
import (
"bufio"
"context"
"crypto/tls"
"net"
"time"
"github.com/Dreamacro/clash/adapter/inbound"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
"github.com/Dreamacro/clash/transport/tuic/common"
v4 "github.com/Dreamacro/clash/transport/tuic/v4"
v5 "github.com/Dreamacro/clash/transport/tuic/v5"
"github.com/gofrs/uuid/v5"
"github.com/metacubex/quic-go"
)
type ServerOption struct {
HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error
HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error
TlsConfig *tls.Config
QuicConfig *quic.Config
Tokens [][32]byte // V4 special
Users map[[16]byte]string // V5 special
CongestionController string
AuthenticationTimeout time.Duration
MaxUdpRelayPacketSize int
CWND int
}
type Server struct {
*ServerOption
optionV4 *v4.ServerOption
optionV5 *v5.ServerOption
listener *quic.EarlyListener
}
func (s *Server) Serve() error {
for {
conn, err := s.listener.Accept(context.Background())
if err != nil {
return err
}
common.SetCongestionController(conn, s.CongestionController, s.CWND)
h := &serverHandler{
Server: s,
quicConn: conn,
uuid: utils.NewUUIDV4(),
}
if h.optionV4 != nil {
h.v4Handler = v4.NewServerHandler(h.optionV4, conn, h.uuid)
}
if h.optionV5 != nil {
h.v5Handler = v5.NewServerHandler(h.optionV5, conn, h.uuid)
}
go h.handle()
}
}
func (s *Server) Close() error {
return s.listener.Close()
}
type serverHandler struct {
*Server
quicConn quic.EarlyConnection
uuid uuid.UUID
v4Handler common.ServerHandler
v5Handler common.ServerHandler
}
func (s *serverHandler) handle() {
go func() {
_ = s.handleUniStream()
}()
go func() {
_ = s.handleStream()
}()
go func() {
_ = s.handleMessage()
}()
<-s.quicConn.HandshakeComplete()
time.AfterFunc(s.AuthenticationTimeout, func() {
if s.v4Handler != nil {
if s.v4Handler.AuthOk() {
return
}
}
if s.v5Handler != nil {
if s.v5Handler.AuthOk() {
return
}
}
if s.v4Handler != nil {
s.v4Handler.HandleTimeout()
}
if s.v5Handler != nil {
s.v5Handler.HandleTimeout()
}
})
}
func (s *serverHandler) handleMessage() (err error) {
for {
var message []byte
message, err = s.quicConn.ReceiveMessage()
if err != nil {
return err
}
go func() (err error) {
if len(message) > 0 {
switch message[0] {
case v4.VER:
if s.v4Handler != nil {
return s.v4Handler.HandleMessage(message)
}
case v5.VER:
if s.v5Handler != nil {
return s.v5Handler.HandleMessage(message)
}
}
}
return
}()
}
}
func (s *serverHandler) handleStream() (err error) {
for {
var quicStream quic.Stream
quicStream, err = s.quicConn.AcceptStream(context.Background())
if err != nil {
return err
}
go func() (err error) {
stream := common.NewQuicStreamConn(
quicStream,
s.quicConn.LocalAddr(),
s.quicConn.RemoteAddr(),
nil,
)
conn := N.NewBufferedConn(stream)
verBytes, err := conn.Peek(1)
if err != nil {
_ = conn.Close()
return err
}
switch verBytes[0] {
case v4.VER:
if s.v4Handler != nil {
return s.v4Handler.HandleStream(conn)
}
case v5.VER:
if s.v5Handler != nil {
return s.v5Handler.HandleStream(conn)
}
}
return
}()
}
}
func (s *serverHandler) handleUniStream() (err error) {
for {
var stream quic.ReceiveStream
stream, err = s.quicConn.AcceptUniStream(context.Background())
if err != nil {
return err
}
go func() (err error) {
defer func() {
stream.CancelRead(0)
}()
reader := bufio.NewReader(stream)
verBytes, err := reader.Peek(1)
if err != nil {
return err
}
switch verBytes[0] {
case v4.VER:
if s.v4Handler != nil {
return s.v4Handler.HandleUniStream(reader)
}
case v5.VER:
if s.v5Handler != nil {
return s.v5Handler.HandleUniStream(reader)
}
}
return
}()
}
}
func NewServer(option *ServerOption, pc net.PacketConn) (*Server, error) {
listener, err := quic.ListenEarly(pc, option.TlsConfig, option.QuicConfig)
if err != nil {
return nil, err
}
server := &Server{
ServerOption: option,
listener: listener,
}
if len(option.Tokens) > 0 {
server.optionV4 = &v4.ServerOption{
HandleTcpFn: option.HandleTcpFn,
HandleUdpFn: option.HandleUdpFn,
Tokens: option.Tokens,
MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize,
}
}
if len(option.Users) > 0 {
server.optionV5 = &v5.ServerOption{
HandleTcpFn: option.HandleTcpFn,
HandleUdpFn: option.HandleUdpFn,
Users: option.Users,
MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize,
}
}
return server, nil
}

View file

@ -1,8 +1,6 @@
package tuic package tuic
import ( import (
"net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/tuic/common" "github.com/Dreamacro/clash/transport/tuic/common"
v4 "github.com/Dreamacro/clash/transport/tuic/v4" v4 "github.com/Dreamacro/clash/transport/tuic/v4"
@ -26,19 +24,6 @@ type DialFunc = common.DialFunc
var TooManyOpenStreams = common.TooManyOpenStreams var TooManyOpenStreams = common.TooManyOpenStreams
type ServerOptionV4 = v4.ServerOption
type ServerOptionV5 = v5.ServerOption
type Server = common.Server
func NewServerV4(option *ServerOptionV4, pc net.PacketConn) (Server, error) {
return v4.NewServer(option, pc)
}
func NewServerV5(option *ServerOptionV5, pc net.PacketConn) (Server, error) {
return v5.NewServer(option, pc)
}
const DefaultStreamReceiveWindow = common.DefaultStreamReceiveWindow const DefaultStreamReceiveWindow = common.DefaultStreamReceiveWindow
const DefaultConnectionReceiveWindow = common.DefaultConnectionReceiveWindow const DefaultConnectionReceiveWindow = common.DefaultConnectionReceiveWindow

View file

@ -3,9 +3,9 @@ package v4
import ( import (
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/Dreamacro/clash/common/atomic"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/transport/tuic/common" "github.com/Dreamacro/clash/transport/tuic/common"

View file

@ -36,6 +36,8 @@ const (
ResponseType = CommandType(0xff) ResponseType = CommandType(0xff)
) )
const VER byte = 0x04
func (c CommandType) String() string { func (c CommandType) String() string {
switch c { switch c {
case AuthenticateType: case AuthenticateType:
@ -66,7 +68,7 @@ type CommandHead struct {
func NewCommandHead(TYPE CommandType) CommandHead { func NewCommandHead(TYPE CommandType) CommandHead {
return CommandHead{ return CommandHead{
VER: 0x04, VER: VER,
TYPE: TYPE, TYPE: TYPE,
} }
} }

View file

@ -3,18 +3,14 @@ package v4
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"sync" "sync"
"sync/atomic"
"time"
"github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/atomic"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
"github.com/Dreamacro/clash/transport/tuic/common" "github.com/Dreamacro/clash/transport/tuic/common"
@ -27,106 +23,55 @@ type ServerOption struct {
HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error
HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error
TlsConfig *tls.Config
QuicConfig *quic.Config
Tokens [][32]byte Tokens [][32]byte
CongestionController string
AuthenticationTimeout time.Duration
MaxUdpRelayPacketSize int MaxUdpRelayPacketSize int
CWND int
} }
type Server struct { func NewServerHandler(option *ServerOption, quicConn quic.EarlyConnection, uuid uuid.UUID) common.ServerHandler {
*ServerOption return &serverHandler{
listener *quic.EarlyListener
}
func NewServer(option *ServerOption, pc net.PacketConn) (*Server, error) {
listener, err := quic.ListenEarly(pc, option.TlsConfig, option.QuicConfig)
if err != nil {
return nil, err
}
return &Server{
ServerOption: option, ServerOption: option,
listener: listener, quicConn: quicConn,
}, err uuid: uuid,
}
func (s *Server) Serve() error {
for {
conn, err := s.listener.Accept(context.Background())
if err != nil {
return err
}
common.SetCongestionController(conn, s.CongestionController, s.CWND)
h := &serverHandler{
Server: s,
quicConn: conn,
uuid: utils.NewUUIDV4(),
authCh: make(chan struct{}), authCh: make(chan struct{}),
} }
go h.handle()
}
}
func (s *Server) Close() error {
return s.listener.Close()
} }
type serverHandler struct { type serverHandler struct {
*Server *ServerOption
quicConn quic.EarlyConnection quicConn quic.EarlyConnection
uuid uuid.UUID uuid uuid.UUID
authCh chan struct{} authCh chan struct{}
authOk bool authOk atomic.Bool
authOnce sync.Once authOnce sync.Once
udpInputMap sync.Map udpInputMap sync.Map
} }
func (s *serverHandler) handle() { func (s *serverHandler) AuthOk() bool {
go func() { return s.authOk.Load()
_ = s.handleUniStream() }
}()
go func() {
_ = s.handleStream()
}()
go func() {
_ = s.handleMessage()
}()
<-s.quicConn.HandshakeComplete() func (s *serverHandler) HandleTimeout() {
time.AfterFunc(s.AuthenticationTimeout, func() {
s.authOnce.Do(func() { s.authOnce.Do(func() {
_ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout") _ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout")
s.authOk = false s.authOk.Store(false)
close(s.authCh) close(s.authCh)
}) })
})
} }
func (s *serverHandler) handleMessage() (err error) { func (s *serverHandler) HandleMessage(message []byte) (err error) {
for {
var message []byte
message, err = s.quicConn.ReceiveMessage()
if err != nil {
return err
}
go func() (err error) {
buffer := bytes.NewBuffer(message) buffer := bytes.NewBuffer(message)
packet, err := ReadPacket(buffer) packet, err := ReadPacket(buffer)
if err != nil { if err != nil {
return return
} }
return s.parsePacket(packet, common.NATIVE) return s.parsePacket(packet, common.NATIVE)
}()
}
} }
func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) { func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) {
<-s.authCh <-s.authCh
if !s.authOk { if !s.authOk.Load() {
return return
} }
var assocId uint32 var assocId uint32
@ -157,27 +102,13 @@ func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayM
}) })
} }
func (s *serverHandler) handleStream() (err error) { func (s *serverHandler) HandleStream(conn *N.BufferedConn) (err error) {
for {
var quicStream quic.Stream
quicStream, err = s.quicConn.AcceptStream(context.Background())
if err != nil {
return err
}
go func() (err error) {
stream := common.NewQuicStreamConn(
quicStream,
s.quicConn.LocalAddr(),
s.quicConn.RemoteAddr(),
nil,
)
conn := N.NewBufferedConn(stream)
connect, err := ReadConnect(conn) connect, err := ReadConnect(conn)
if err != nil { if err != nil {
return err return err
} }
<-s.authCh <-s.authCh
if !s.authOk { if !s.authOk.Load() {
return conn.Close() return conn.Close()
} }
@ -194,29 +125,16 @@ func (s *serverHandler) handleStream() (err error) {
_ = conn.Close() _ = conn.Close()
return err return err
} }
_, err = buf.WriteTo(stream) _, err = buf.WriteTo(conn)
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return err return err
} }
return return
}()
}
} }
func (s *serverHandler) handleUniStream() (err error) { func (s *serverHandler) HandleUniStream(reader *bufio.Reader) (err error) {
for {
var stream quic.ReceiveStream
stream, err = s.quicConn.AcceptUniStream(context.Background())
if err != nil {
return err
}
go func() (err error) {
defer func() {
stream.CancelRead(0)
}()
reader := bufio.NewReader(stream)
commandHead, err := ReadCommandHead(reader) commandHead, err := ReadCommandHead(reader)
if err != nil { if err != nil {
return return
@ -239,7 +157,7 @@ func (s *serverHandler) handleUniStream() (err error) {
if !authOk { if !authOk {
_ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed") _ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed")
} }
s.authOk = authOk s.authOk.Store(authOk)
close(s.authCh) close(s.authCh)
}) })
case PacketType: case PacketType:
@ -268,8 +186,6 @@ func (s *serverHandler) handleUniStream() (err error) {
heartbeat.BytesLen() heartbeat.BytesLen()
} }
return return
}()
}
} }
type serverUDPPacket struct { type serverUDPPacket struct {

View file

@ -4,9 +4,9 @@ import (
"errors" "errors"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/Dreamacro/clash/common/atomic"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/transport/tuic/common" "github.com/Dreamacro/clash/transport/tuic/common"

View file

@ -35,6 +35,8 @@ const (
HeartbeatType = CommandType(0x04) HeartbeatType = CommandType(0x04)
) )
const VER byte = 0x05
func (c CommandType) String() string { func (c CommandType) String() string {
switch c { switch c {
case AuthenticateType: case AuthenticateType:
@ -63,7 +65,7 @@ type CommandHead struct {
func NewCommandHead(TYPE CommandType) CommandHead { func NewCommandHead(TYPE CommandType) CommandHead {
return CommandHead{ return CommandHead{
VER: 0x05, VER: VER,
TYPE: TYPE, TYPE: TYPE,
} }
} }

View file

@ -3,18 +3,13 @@ package v5
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"sync" "sync"
"sync/atomic"
"time"
"github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/atomic"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
"github.com/Dreamacro/clash/transport/tuic/common" "github.com/Dreamacro/clash/transport/tuic/common"
@ -27,94 +22,45 @@ type ServerOption struct {
HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error
HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error
TlsConfig *tls.Config
QuicConfig *quic.Config
Users map[[16]byte]string Users map[[16]byte]string
CongestionController string
AuthenticationTimeout time.Duration
MaxUdpRelayPacketSize int MaxUdpRelayPacketSize int
CWND int
} }
type Server struct { func NewServerHandler(option *ServerOption, quicConn quic.EarlyConnection, uuid uuid.UUID) common.ServerHandler {
*ServerOption return &serverHandler{
listener *quic.EarlyListener
}
func NewServer(option *ServerOption, pc net.PacketConn) (*Server, error) {
listener, err := quic.ListenEarly(pc, option.TlsConfig, option.QuicConfig)
if err != nil {
return nil, err
}
return &Server{
ServerOption: option, ServerOption: option,
listener: listener, quicConn: quicConn,
}, err uuid: uuid,
}
func (s *Server) Serve() error {
for {
conn, err := s.listener.Accept(context.Background())
if err != nil {
return err
}
common.SetCongestionController(conn, s.CongestionController, s.CWND)
h := &serverHandler{
Server: s,
quicConn: conn,
uuid: utils.NewUUIDV4(),
authCh: make(chan struct{}), authCh: make(chan struct{}),
} }
go h.handle()
}
}
func (s *Server) Close() error {
return s.listener.Close()
} }
type serverHandler struct { type serverHandler struct {
*Server *ServerOption
quicConn quic.EarlyConnection quicConn quic.EarlyConnection
uuid uuid.UUID uuid uuid.UUID
authCh chan struct{} authCh chan struct{}
authOk bool authOk atomic.Bool
authUUID string authUUID atomic.TypedValue[string]
authOnce sync.Once authOnce sync.Once
udpInputMap sync.Map udpInputMap sync.Map
} }
func (s *serverHandler) handle() { func (s *serverHandler) AuthOk() bool {
go func() { return s.authOk.Load()
_ = s.handleUniStream() }
}()
go func() {
_ = s.handleStream()
}()
go func() {
_ = s.handleMessage()
}()
<-s.quicConn.HandshakeComplete() func (s *serverHandler) HandleTimeout() {
time.AfterFunc(s.AuthenticationTimeout, func() {
s.authOnce.Do(func() { s.authOnce.Do(func() {
_ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout") _ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout")
s.authOk = false s.authOk.Store(false)
close(s.authCh) close(s.authCh)
}) })
})
} }
func (s *serverHandler) handleMessage() (err error) { func (s *serverHandler) HandleMessage(message []byte) (err error) {
for {
var message []byte
message, err = s.quicConn.ReceiveMessage()
if err != nil {
return err
}
go func() (err error) {
reader := bytes.NewBuffer(message) reader := bytes.NewBuffer(message)
commandHead, err := ReadCommandHead(reader) commandHead, err := ReadCommandHead(reader)
if err != nil { if err != nil {
@ -137,13 +83,11 @@ func (s *serverHandler) handleMessage() (err error) {
heartbeat.BytesLen() heartbeat.BytesLen()
} }
return return
}()
}
} }
func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) { func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) {
<-s.authCh <-s.authCh
if !s.authOk { if !s.authOk.Load() {
return return
} }
var assocId uint16 var assocId uint16
@ -175,55 +119,28 @@ func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayM
pc: pc, pc: pc,
packet: packetPtr, packet: packetPtr,
rAddr: N.NewCustomAddr("tuic", fmt.Sprintf("tuic-%s-%d", s.uuid, assocId), s.quicConn.RemoteAddr()), // for tunnel's handleUDPConn rAddr: N.NewCustomAddr("tuic", fmt.Sprintf("tuic-%s-%d", s.uuid, assocId), s.quicConn.RemoteAddr()), // for tunnel's handleUDPConn
}, inbound.WithInUser(s.authUUID)) }, inbound.WithInUser(s.authUUID.Load()))
} }
func (s *serverHandler) handleStream() (err error) { func (s *serverHandler) HandleStream(conn *N.BufferedConn) (err error) {
for {
var quicStream quic.Stream
quicStream, err = s.quicConn.AcceptStream(context.Background())
if err != nil {
return err
}
go func() (err error) {
stream := common.NewQuicStreamConn(
quicStream,
s.quicConn.LocalAddr(),
s.quicConn.RemoteAddr(),
nil,
)
conn := N.NewBufferedConn(stream)
connect, err := ReadConnect(conn) connect, err := ReadConnect(conn)
if err != nil { if err != nil {
return err return err
} }
<-s.authCh <-s.authCh
if !s.authOk { if !s.authOk.Load() {
return conn.Close() return conn.Close()
} }
err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr(), inbound.WithInUser(s.authUUID)) err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr(), inbound.WithInUser(s.authUUID.Load()))
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return err return err
} }
return return
}()
}
} }
func (s *serverHandler) handleUniStream() (err error) { func (s *serverHandler) HandleUniStream(reader *bufio.Reader) (err error) {
for {
var stream quic.ReceiveStream
stream, err = s.quicConn.AcceptUniStream(context.Background())
if err != nil {
return err
}
go func() (err error) {
defer func() {
stream.CancelRead(0)
}()
reader := bufio.NewReader(stream)
commandHead, err := ReadCommandHead(reader) commandHead, err := ReadCommandHead(reader)
if err != nil { if err != nil {
return return
@ -252,8 +169,8 @@ func (s *serverHandler) handleUniStream() (err error) {
if !authOk { if !authOk {
_ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed") _ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed")
} }
s.authOk = authOk s.authOk.Store(authOk)
s.authUUID = authUUID.String() s.authUUID.Store(authUUID.String())
close(s.authCh) close(s.authCh)
}) })
case PacketType: case PacketType:
@ -275,8 +192,6 @@ func (s *serverHandler) handleUniStream() (err error) {
} }
} }
return return
}()
}
} }
type serverUDPInput struct { type serverUDPInput struct {