mihomo/listener/sing/sing.go
2023-11-18 00:07:07 +08:00

251 lines
6.3 KiB
Go

package sing
import (
"context"
"errors"
"net"
"net/netip"
"sync"
"time"
"github.com/metacubex/mihomo/adapter/inbound"
N "github.com/metacubex/mihomo/common/net"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
vmess "github.com/metacubex/sing-vmess"
mux "github.com/sagernet/sing-mux"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/bufio/deadline"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
)
const UDPTimeout = 5 * time.Minute
type ListenerHandler struct {
Tunnel C.Tunnel
Type C.Type
Additions []inbound.Addition
UDPTimeout time.Duration
MuxService *mux.Service
}
func UpstreamMetadata(metadata M.Metadata) M.Metadata {
return M.Metadata{
Source: metadata.Source,
Destination: metadata.Destination,
}
}
func ConvertMetadata(metadata *C.Metadata) M.Metadata {
return M.Metadata{
Protocol: metadata.Type.String(),
Source: M.SocksaddrFrom(metadata.SrcIP, metadata.SrcPort),
Destination: M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort),
}
}
func (h *ListenerHandler) ConfigureSingMux() (err error) {
h.MuxService, err = mux.NewService(mux.ServiceOptions{
NewStreamContext: func(ctx context.Context, conn net.Conn) context.Context {
return ctx
},
Logger: log.SingLogger,
Handler: h,
Brutal: mux.BrutalOptions{
// TODO: sing-mux tcp brutal inbound
},
})
return
}
func (h *ListenerHandler) IsSpecialFqdn(fqdn string) bool {
switch fqdn {
case mux.Destination.Fqdn:
case vmess.MuxDestination.Fqdn:
case uot.MagicAddress:
case uot.LegacyMagicAddress:
default:
return false
}
return true
}
func (h *ListenerHandler) ParseSpecialFqdn(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
switch metadata.Destination.Fqdn {
case mux.Destination.Fqdn:
if h.MuxService == nil {
err := h.ConfigureSingMux()
if err != nil {
return err
}
}
return h.MuxService.NewConnection(ctx, conn, UpstreamMetadata(metadata))
case vmess.MuxDestination.Fqdn:
return vmess.HandleMuxConnection(ctx, conn, h)
case uot.MagicAddress:
request, err := uot.ReadRequest(conn)
if err != nil {
return E.Cause(err, "read UoT request")
}
metadata.Destination = request.Destination
return h.NewPacketConnection(ctx, uot.NewConn(conn, *request), metadata)
case uot.LegacyMagicAddress:
metadata.Destination = M.Socksaddr{Addr: netip.IPv4Unspecified()}
return h.NewPacketConnection(ctx, uot.NewConn(conn, uot.Request{}), metadata)
}
return errors.New("not special fqdn")
}
func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if h.IsSpecialFqdn(metadata.Destination.Fqdn) {
return h.ParseSpecialFqdn(ctx, conn, metadata)
}
if deadline.NeedAdditionalReadDeadline(conn) {
conn = N.NewDeadlineConn(conn) // conn from sing should check NeedAdditionalReadDeadline
}
cMetadata := &C.Metadata{
NetWork: C.TCP,
Type: h.Type,
}
inbound.ApplyAdditions(cMetadata, inbound.WithDstAddr(metadata.Destination), inbound.WithSrcAddr(metadata.Source), inbound.WithInAddr(conn.LocalAddr()))
inbound.ApplyAdditions(cMetadata, getAdditions(ctx)...)
inbound.ApplyAdditions(cMetadata, h.Additions...)
h.Tunnel.HandleTCPConn(conn, cMetadata) // this goroutine must exit after conn unused
return nil
}
func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata M.Metadata) error {
if deadline.NeedAdditionalReadDeadline(conn) {
conn = deadline.NewFallbackPacketConn(bufio.NewNetPacketConn(conn)) // conn from sing should check NeedAdditionalReadDeadline
}
defer func() { _ = conn.Close() }()
mutex := sync.Mutex{}
conn2 := conn // a new interface to set nil in defer
defer func() {
mutex.Lock() // this goroutine must exit after all conn.WritePacket() is not running
defer mutex.Unlock()
conn2 = nil
}()
var buff *buf.Buffer
newBuffer := func() *buf.Buffer {
buff = buf.NewPacket() // do not use stack buffer
return buff
}
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn)
if isReadWaiter {
readWaiter.InitializeReadWaiter(newBuffer)
}
for {
var (
dest M.Socksaddr
err error
)
buff = nil // clear last loop status, avoid repeat release
if isReadWaiter {
dest, err = readWaiter.WaitReadPacket()
} else {
dest, err = conn.ReadPacket(newBuffer())
}
if err != nil {
if buff != nil {
buff.Release()
}
if ShouldIgnorePacketError(err) {
break
}
return err
}
cPacket := &packet{
conn: &conn2,
mutex: &mutex,
rAddr: metadata.Source.UDPAddr(),
lAddr: conn.LocalAddr(),
buff: buff,
}
cMetadata := &C.Metadata{
NetWork: C.UDP,
Type: h.Type,
}
inbound.ApplyAdditions(cMetadata, inbound.WithDstAddr(dest), inbound.WithSrcAddr(metadata.Source), inbound.WithInAddr(conn.LocalAddr()))
inbound.ApplyAdditions(cMetadata, getAdditions(ctx)...)
inbound.ApplyAdditions(cMetadata, h.Additions...)
h.Tunnel.HandleUDPPacket(cPacket, cMetadata)
}
return nil
}
func (h *ListenerHandler) NewError(ctx context.Context, err error) {
log.Warnln("%s listener get error: %+v", h.Type.String(), err)
}
func ShouldIgnorePacketError(err error) bool {
// ignore simple error
if E.IsTimeout(err) || E.IsClosed(err) || E.IsCanceled(err) {
return true
}
return false
}
type packet struct {
conn *network.PacketConn
mutex *sync.Mutex
rAddr net.Addr
lAddr net.Addr
buff *buf.Buffer
}
func (c *packet) Data() []byte {
return c.buff.Bytes()
}
// WriteBack wirtes UDP packet with source(ip, port) = `addr`
func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) {
if addr == nil {
err = errors.New("address is invalid")
return
}
c.mutex.Lock()
defer c.mutex.Unlock()
conn := *c.conn
if conn == nil {
err = errors.New("writeBack to closed connection")
return
}
buff := buf.NewPacket()
defer buff.Release()
n, err = buff.Write(b)
if err != nil {
return
}
err = conn.WritePacket(buff, M.SocksaddrFromNet(addr))
if err != nil {
return
}
return
}
// LocalAddr returns the source IP/Port of UDP Packet
func (c *packet) LocalAddr() net.Addr {
return c.rAddr
}
func (c *packet) Drop() {
c.buff.Release()
}
func (c *packet) InAddr() net.Addr {
return c.lAddr
}