package sing import ( "context" "errors" "net" "net/netip" "sync" "time" "github.com/Dreamacro/clash/adapter/inbound" N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" vmess "github.com/metacubex/sing-vmess" mux "github.com/sagernet/sing-mux" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio/deadline" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/uot" ) const UDPTimeout = 5 * time.Minute type ListenerHandler struct { Tunnel C.Tunnel Type C.Type Additions []inbound.Addition UDPTimeout time.Duration } func UpstreamMetadata(metadata M.Metadata) M.Metadata { return M.Metadata{ Source: metadata.Source, Destination: metadata.Destination, } } func ConvertMetadata(metadata *C.Metadata) M.Metadata { return M.Metadata{ Protocol: metadata.Type.String(), Source: M.SocksaddrFrom(metadata.SrcIP, metadata.SrcPort), Destination: M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort), } } func (h *ListenerHandler) IsSpecialFqdn(fqdn string) bool { switch fqdn { case mux.Destination.Fqdn: case vmess.MuxDestination.Fqdn: case uot.MagicAddress: case uot.LegacyMagicAddress: default: return false } return true } func (h *ListenerHandler) ParseSpecialFqdn(ctx context.Context, conn net.Conn, metadata M.Metadata) error { switch metadata.Destination.Fqdn { case mux.Destination.Fqdn: return mux.HandleConnection(ctx, h, log.SingLogger, conn, UpstreamMetadata(metadata)) case vmess.MuxDestination.Fqdn: return vmess.HandleMuxConnection(ctx, conn, h) case uot.MagicAddress: request, err := uot.ReadRequest(conn) if err != nil { return E.Cause(err, "read UoT request") } metadata.Destination = request.Destination return h.NewPacketConnection(ctx, uot.NewConn(conn, *request), metadata) case uot.LegacyMagicAddress: metadata.Destination = M.Socksaddr{Addr: netip.IPv4Unspecified()} return h.NewPacketConnection(ctx, uot.NewConn(conn, uot.Request{}), metadata) } return errors.New("not special fqdn") } func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { if h.IsSpecialFqdn(metadata.Destination.Fqdn) { return h.ParseSpecialFqdn(ctx, conn, metadata) } if deadline.NeedAdditionalReadDeadline(conn) { conn = N.NewDeadlineConn(conn) // conn from sing should check NeedAdditionalReadDeadline } cMetadata := &C.Metadata{ NetWork: C.TCP, Type: h.Type, Host: metadata.Destination.Fqdn, DstIP: metadata.Destination.Addr, DstPort: metadata.Destination.Port, SrcIP: metadata.Source.Addr, SrcPort: metadata.Source.Port, } additions := combineAdditions(ctx, h.Additions, inbound.WithInAddr(conn.LocalAddr())) for _, addition := range additions { addition.Apply(cMetadata) } h.Tunnel.HandleTCPConn(conn, cMetadata) // this goroutine must exit after conn unused return nil } func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata M.Metadata) error { if deadline.NeedAdditionalReadDeadline(conn) { conn = deadline.NewFallbackPacketConn(bufio.NewNetPacketConn(conn)) // conn from sing should check NeedAdditionalReadDeadline } defer func() { _ = conn.Close() }() mutex := sync.Mutex{} conn2 := conn // a new interface to set nil in defer defer func() { mutex.Lock() // this goroutine must exit after all conn.WritePacket() is not running defer mutex.Unlock() conn2 = nil }() var buff *buf.Buffer newBuffer := func() *buf.Buffer { buff = buf.NewPacket() // do not use stack buffer return buff } readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) if isReadWaiter { readWaiter.InitializeReadWaiter(newBuffer) } for { var ( dest M.Socksaddr err error ) buff = nil // clear last loop status, avoid repeat release if isReadWaiter { dest, err = readWaiter.WaitReadPacket() } else { dest, err = conn.ReadPacket(newBuffer()) } if err != nil { if buff != nil { buff.Release() } if ShouldIgnorePacketError(err) { break } return err } cPacket := &packet{ conn: &conn2, mutex: &mutex, rAddr: metadata.Source.UDPAddr(), lAddr: conn.LocalAddr(), buff: buff, } cMetadata := &C.Metadata{ NetWork: C.UDP, Type: h.Type, Host: dest.Fqdn, DstIP: dest.Addr, DstPort: dest.Port, SrcIP: metadata.Source.Addr, SrcPort: metadata.Source.Port, } additions := combineAdditions(ctx, h.Additions, inbound.WithInAddr(conn.LocalAddr())) for _, addition := range additions { addition.Apply(cMetadata) } h.Tunnel.HandleUDPPacket(cPacket, cMetadata) } return nil } func (h *ListenerHandler) NewError(ctx context.Context, err error) { log.Warnln("%s listener get error: %+v", h.Type.String(), err) } func ShouldIgnorePacketError(err error) bool { // ignore simple error if E.IsTimeout(err) || E.IsClosed(err) || E.IsCanceled(err) { return true } return false } type packet struct { conn *network.PacketConn mutex *sync.Mutex rAddr net.Addr lAddr net.Addr buff *buf.Buffer } func (c *packet) Data() []byte { return c.buff.Bytes() } // WriteBack wirtes UDP packet with source(ip, port) = `addr` func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { if addr == nil { err = errors.New("address is invalid") return } buff := buf.NewPacket() defer buff.Release() n, err = buff.Write(b) if err != nil { return } c.mutex.Lock() defer c.mutex.Unlock() conn := *c.conn if conn == nil { err = errors.New("writeBack to closed connection") return } err = conn.WritePacket(buff, M.SocksaddrFromNet(addr)) if err != nil { return } return } // LocalAddr returns the source IP/Port of UDP Packet func (c *packet) LocalAddr() net.Addr { return c.rAddr } func (c *packet) Drop() { c.buff.Release() } func (c *packet) InAddr() net.Addr { return c.lAddr }