From f4de055aa174000024d439e00a1301b99d476478 Mon Sep 17 00:00:00 2001 From: Dreamacro <8615343+Dreamacro@users.noreply.github.com> Date: Sat, 23 Jan 2021 14:49:46 +0800 Subject: [PATCH] Refactor: make inbound request contextual --- adapters/inbound/http.go | 23 +++------------ adapters/inbound/https.go | 10 +++---- adapters/inbound/socket.go | 21 +++----------- constant/adapters.go | 14 +++++---- constant/context.go | 23 +++++++++++++++ context/conn.go | 39 +++++++++++++++++++++++++ context/dns.go | 41 +++++++++++++++++++++++++++ context/http.go | 47 +++++++++++++++++++++++++++++++ context/packetconn.go | 43 ++++++++++++++++++++++++++++ dns/middleware.go | 28 ++++++++++-------- dns/resolver.go | 19 ++----------- dns/server.go | 20 ++++++++----- dns/util.go | 15 ++++++++++ hub/route/connections.go | 10 +++---- hub/route/server.go | 4 +-- tunnel/connection.go | 18 ++++++------ tunnel/{ => statistic}/manager.go | 2 +- tunnel/{ => statistic}/tracker.go | 6 ++-- tunnel/tunnel.go | 44 ++++++++++++++--------------- 19 files changed, 302 insertions(+), 125 deletions(-) create mode 100644 constant/context.go create mode 100644 context/conn.go create mode 100644 context/dns.go create mode 100644 context/http.go create mode 100644 context/packetconn.go rename tunnel/{ => statistic}/manager.go (99%) rename tunnel/{ => statistic}/tracker.go (95%) diff --git a/adapters/inbound/http.go b/adapters/inbound/http.go index 867dab22..ee26ac7c 100644 --- a/adapters/inbound/http.go +++ b/adapters/inbound/http.go @@ -6,33 +6,18 @@ import ( "strings" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/context" ) -// HTTPAdapter is a adapter for HTTP connection -type HTTPAdapter struct { - net.Conn - metadata *C.Metadata - R *http.Request -} - -// Metadata return destination metadata -func (h *HTTPAdapter) Metadata() *C.Metadata { - return h.metadata -} - -// NewHTTP is HTTPAdapter generator -func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter { +// NewHTTP recieve normal http request and return HTTPContext +func NewHTTP(request *http.Request, conn net.Conn) *context.HTTPContext { metadata := parseHTTPAddr(request) metadata.Type = C.HTTP if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { metadata.SrcIP = ip metadata.SrcPort = port } - return &HTTPAdapter{ - metadata: metadata, - R: request, - Conn: conn, - } + return context.NewHTTPContext(conn, request, metadata) } // RemoveHopByHopHeaders remove hop-by-hop header diff --git a/adapters/inbound/https.go b/adapters/inbound/https.go index 7a219980..bb2bd97d 100644 --- a/adapters/inbound/https.go +++ b/adapters/inbound/https.go @@ -5,18 +5,16 @@ import ( "net/http" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/context" ) -// NewHTTPS is HTTPAdapter generator -func NewHTTPS(request *http.Request, conn net.Conn) *SocketAdapter { +// NewHTTPS recieve CONNECT request and return ConnContext +func NewHTTPS(request *http.Request, conn net.Conn) *context.ConnContext { metadata := parseHTTPAddr(request) metadata.Type = C.HTTPCONNECT if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { metadata.SrcIP = ip metadata.SrcPort = port } - return &SocketAdapter{ - metadata: metadata, - Conn: conn, - } + return context.NewConnContext(conn, metadata) } diff --git a/adapters/inbound/socket.go b/adapters/inbound/socket.go index 134be489..1370b701 100644 --- a/adapters/inbound/socket.go +++ b/adapters/inbound/socket.go @@ -5,21 +5,11 @@ import ( "github.com/Dreamacro/clash/component/socks5" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/context" ) -// SocketAdapter is a adapter for socks and redir connection -type SocketAdapter struct { - net.Conn - metadata *C.Metadata -} - -// Metadata return destination metadata -func (s *SocketAdapter) Metadata() *C.Metadata { - return s.metadata -} - -// NewSocket is SocketAdapter generator -func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *SocketAdapter { +// NewSocket recieve TCP inbound and return ConnContext +func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *context.ConnContext { metadata := parseSocksAddr(target) metadata.NetWork = C.TCP metadata.Type = source @@ -28,8 +18,5 @@ func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *SocketAdapter metadata.SrcPort = port } - return &SocketAdapter{ - Conn: conn, - metadata: metadata, - } + return context.NewConnContext(conn, metadata) } diff --git a/constant/adapters.go b/constant/adapters.go index 4ba891de..7456c304 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -27,11 +27,6 @@ const ( LoadBalance ) -type ServerAdapter interface { - net.Conn - Metadata() *Metadata -} - type Connection interface { Chains() Chain AppendToChains(adapter ProxyAdapter) @@ -50,6 +45,15 @@ func (c Chain) String() string { } } +func (c Chain) Last() string { + switch len(c) { + case 0: + return "" + default: + return c[0] + } +} + type Conn interface { net.Conn Connection diff --git a/constant/context.go b/constant/context.go new file mode 100644 index 00000000..e641ed14 --- /dev/null +++ b/constant/context.go @@ -0,0 +1,23 @@ +package constant + +import ( + "net" + + "github.com/gofrs/uuid" +) + +type PlainContext interface { + ID() uuid.UUID +} + +type ConnContext interface { + PlainContext + Metadata() *Metadata + Conn() net.Conn +} + +type PacketConnContext interface { + PlainContext + Metadata() *Metadata + PacketConn() net.PacketConn +} diff --git a/context/conn.go b/context/conn.go new file mode 100644 index 00000000..ee0f3a9d --- /dev/null +++ b/context/conn.go @@ -0,0 +1,39 @@ +package context + +import ( + "net" + + C "github.com/Dreamacro/clash/constant" + + "github.com/gofrs/uuid" +) + +type ConnContext struct { + id uuid.UUID + metadata *C.Metadata + conn net.Conn +} + +func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext { + id, _ := uuid.NewV4() + return &ConnContext{ + id: id, + metadata: metadata, + conn: conn, + } +} + +// ID implement C.ConnContext ID +func (c *ConnContext) ID() uuid.UUID { + return c.id +} + +// Metadata implement C.ConnContext Metadata +func (c *ConnContext) Metadata() *C.Metadata { + return c.metadata +} + +// Conn implement C.ConnContext Conn +func (c *ConnContext) Conn() net.Conn { + return c.conn +} diff --git a/context/dns.go b/context/dns.go new file mode 100644 index 00000000..0be4a1fc --- /dev/null +++ b/context/dns.go @@ -0,0 +1,41 @@ +package context + +import ( + "github.com/gofrs/uuid" + "github.com/miekg/dns" +) + +const ( + DNSTypeHost = "host" + DNSTypeFakeIP = "fakeip" + DNSTypeRaw = "raw" +) + +type DNSContext struct { + id uuid.UUID + msg *dns.Msg + tp string +} + +func NewDNSContext(msg *dns.Msg) *DNSContext { + id, _ := uuid.NewV4() + return &DNSContext{ + id: id, + msg: msg, + } +} + +// ID implement C.PlainContext ID +func (c *DNSContext) ID() uuid.UUID { + return c.id +} + +// SetType set type of response +func (c *DNSContext) SetType(tp string) { + c.tp = tp +} + +// Type return type of response +func (c *DNSContext) Type() string { + return c.tp +} diff --git a/context/http.go b/context/http.go new file mode 100644 index 00000000..292f7d97 --- /dev/null +++ b/context/http.go @@ -0,0 +1,47 @@ +package context + +import ( + "net" + "net/http" + + C "github.com/Dreamacro/clash/constant" + + "github.com/gofrs/uuid" +) + +type HTTPContext struct { + id uuid.UUID + metadata *C.Metadata + conn net.Conn + req *http.Request +} + +func NewHTTPContext(conn net.Conn, req *http.Request, metadata *C.Metadata) *HTTPContext { + id, _ := uuid.NewV4() + return &HTTPContext{ + id: id, + metadata: metadata, + conn: conn, + req: req, + } +} + +// ID implement C.ConnContext ID +func (hc *HTTPContext) ID() uuid.UUID { + return hc.id +} + +// Metadata implement C.ConnContext Metadata +func (hc *HTTPContext) Metadata() *C.Metadata { + return hc.metadata +} + +// Conn implement C.ConnContext Conn +func (hc *HTTPContext) Conn() net.Conn { + return hc.conn +} + +// Request return the http request struct +func (hc *HTTPContext) Request() *http.Request { + return hc.req +} diff --git a/context/packetconn.go b/context/packetconn.go new file mode 100644 index 00000000..3b005141 --- /dev/null +++ b/context/packetconn.go @@ -0,0 +1,43 @@ +package context + +import ( + "net" + + C "github.com/Dreamacro/clash/constant" + + "github.com/gofrs/uuid" +) + +type PacketConnContext struct { + id uuid.UUID + metadata *C.Metadata + packetConn net.PacketConn +} + +func NewPacketConnContext(metadata *C.Metadata) *PacketConnContext { + id, _ := uuid.NewV4() + return &PacketConnContext{ + id: id, + metadata: metadata, + } +} + +// ID implement C.PacketConnContext ID +func (pc *PacketConnContext) ID() uuid.UUID { + return pc.id +} + +// Metadata implement C.PacketConnContext Metadata +func (pc *PacketConnContext) Metadata() *C.Metadata { + return pc.metadata +} + +// PacketConn implement C.PacketConnContext PacketConn +func (pc *PacketConnContext) PacketConn() net.PacketConn { + return pc.packetConn +} + +// InjectPacketConn injectPacketConn manually +func (pc *PacketConnContext) InjectPacketConn(pconn C.PacketConn) { + pc.packetConn = pconn +} diff --git a/dns/middleware.go b/dns/middleware.go index 9a119e78..782c0ef0 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -8,26 +8,27 @@ import ( "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/trie" + "github.com/Dreamacro/clash/context" "github.com/Dreamacro/clash/log" D "github.com/miekg/dns" ) -type handler func(r *D.Msg) (*D.Msg, error) +type handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) type middleware func(next handler) handler func withHosts(hosts *trie.DomainTrie) middleware { return func(next handler) handler { - return func(r *D.Msg) (*D.Msg, error) { + return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] if !isIPRequest(q) { - return next(r) + return next(ctx, r) } record := hosts.Search(strings.TrimRight(q.Name, ".")) if record == nil { - return next(r) + return next(ctx, r) } ip := record.Data.(net.IP) @@ -46,9 +47,10 @@ func withHosts(hosts *trie.DomainTrie) middleware { msg.Answer = []D.RR{rr} } else { - return next(r) + return next(ctx, r) } + ctx.SetType(context.DNSTypeHost) msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true msg.RecursionAvailable = true @@ -60,14 +62,14 @@ func withHosts(hosts *trie.DomainTrie) middleware { func withMapping(mapping *cache.LruCache) middleware { return func(next handler) handler { - return func(r *D.Msg) (*D.Msg, error) { + return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] if !isIPRequest(q) { - return next(r) + return next(ctx, r) } - msg, err := next(r) + msg, err := next(ctx, r) if err != nil { return nil, err } @@ -99,12 +101,12 @@ func withMapping(mapping *cache.LruCache) middleware { func withFakeIP(fakePool *fakeip.Pool) middleware { return func(next handler) handler { - return func(r *D.Msg) (*D.Msg, error) { + return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] host := strings.TrimRight(q.Name, ".") if fakePool.LookupHost(host) { - return next(r) + return next(ctx, r) } switch q.Qtype { @@ -113,7 +115,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { } if q.Qtype != D.TypeA { - return next(r) + return next(ctx, r) } rr := &D.A{} @@ -123,6 +125,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { msg := r.Copy() msg.Answer = []D.RR{rr} + ctx.SetType(context.DNSTypeFakeIP) setMsgTTL(msg, 1) msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true @@ -134,7 +137,8 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { } func withResolver(resolver *Resolver) handler { - return func(r *D.Msg) (*D.Msg, error) { + return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { + ctx.SetType(context.DNSTypeRaw) q := r.Question[0] // return a empty AAAA msg when ipv6 disabled diff --git a/dns/resolver.go b/dns/resolver.go index 93e3ca6e..d110aa34 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -212,7 +212,7 @@ func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) { fallbackMsg := r.asyncExchange(r.fallback, m) res := <-msgCh if res.Error == nil { - if ips := r.msgToIP(res.Msg); len(ips) != 0 { + if ips := msgToIP(res.Msg); len(ips) != 0 { if !r.shouldIPFallback(ips[0]) { msg = res.Msg // no need to wait for fallback result err = res.Error @@ -247,7 +247,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) return nil, err } - ips := r.msgToIP(msg) + ips := msgToIP(msg) ipLength := len(ips) if ipLength == 0 { return nil, resolver.ErrIPNotFound @@ -257,21 +257,6 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) return } -func (r *Resolver) msgToIP(msg *D.Msg) []net.IP { - ips := []net.IP{} - - for _, answer := range msg.Answer { - switch ans := answer.(type) { - case *D.AAAA: - ips = append(ips, ans.AAAA) - case *D.A: - ips = append(ips, ans.A) - } - } - - return ips -} - func (r *Resolver) msgToDomain(msg *D.Msg) string { if len(msg.Question) > 0 { return strings.TrimRight(msg.Question[0].Name, ".") diff --git a/dns/server.go b/dns/server.go index 23718c4f..84ff0bac 100644 --- a/dns/server.go +++ b/dns/server.go @@ -1,9 +1,11 @@ package dns import ( + "errors" "net" "github.com/Dreamacro/clash/common/sockopt" + "github.com/Dreamacro/clash/context" "github.com/Dreamacro/clash/log" D "github.com/miekg/dns" @@ -21,21 +23,25 @@ type Server struct { handler handler } +// ServeDNS implement D.Handler ServeDNS func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { - if len(r.Question) == 0 { - D.HandleFailed(w, r) - return - } - - msg, err := s.handler(r) + msg, err := handlerWithContext(s.handler, r) if err != nil { D.HandleFailed(w, r) return } - w.WriteMsg(msg) } +func handlerWithContext(handler handler, msg *D.Msg) (*D.Msg, error) { + if len(msg.Question) == 0 { + return nil, errors.New("at least one question is required") + } + + ctx := context.NewDNSContext(msg) + return handler(ctx, msg) +} + func (s *Server) setHandler(handler handler) { s.handler = handler } diff --git a/dns/util.go b/dns/util.go index 55a280db..c2bb11d8 100644 --- a/dns/util.go +++ b/dns/util.go @@ -153,3 +153,18 @@ func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg { return msg } + +func msgToIP(msg *D.Msg) []net.IP { + ips := []net.IP{} + + for _, answer := range msg.Answer { + switch ans := answer.(type) { + case *D.AAAA: + ips = append(ips, ans.AAAA) + case *D.A: + ips = append(ips, ans.A) + } + } + + return ips +} diff --git a/hub/route/connections.go b/hub/route/connections.go index 21792625..edec6d6a 100644 --- a/hub/route/connections.go +++ b/hub/route/connections.go @@ -7,7 +7,7 @@ import ( "strconv" "time" - T "github.com/Dreamacro/clash/tunnel" + "github.com/Dreamacro/clash/tunnel/statistic" "github.com/gorilla/websocket" "github.com/go-chi/chi" @@ -24,7 +24,7 @@ func connectionRouter() http.Handler { func getConnections(w http.ResponseWriter, r *http.Request) { if !websocket.IsWebSocketUpgrade(r) { - snapshot := T.DefaultManager.Snapshot() + snapshot := statistic.DefaultManager.Snapshot() render.JSON(w, r, snapshot) return } @@ -50,7 +50,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) { buf := &bytes.Buffer{} sendSnapshot := func() error { buf.Reset() - snapshot := T.DefaultManager.Snapshot() + snapshot := statistic.DefaultManager.Snapshot() if err := json.NewEncoder(buf).Encode(snapshot); err != nil { return err } @@ -73,7 +73,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) { func closeConnection(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - snapshot := T.DefaultManager.Snapshot() + snapshot := statistic.DefaultManager.Snapshot() for _, c := range snapshot.Connections { if id == c.ID() { c.Close() @@ -84,7 +84,7 @@ func closeConnection(w http.ResponseWriter, r *http.Request) { } func closeAllConnections(w http.ResponseWriter, r *http.Request) { - snapshot := T.DefaultManager.Snapshot() + snapshot := statistic.DefaultManager.Snapshot() for _, c := range snapshot.Connections { c.Close() } diff --git a/hub/route/server.go b/hub/route/server.go index 5948f121..ed0133cb 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -9,7 +9,7 @@ import ( C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" - T "github.com/Dreamacro/clash/tunnel" + "github.com/Dreamacro/clash/tunnel/statistic" "github.com/go-chi/chi" "github.com/go-chi/cors" @@ -143,7 +143,7 @@ func traffic(w http.ResponseWriter, r *http.Request) { tick := time.NewTicker(time.Second) defer tick.Stop() - t := T.DefaultManager + t := statistic.DefaultManager buf := &bytes.Buffer{} var err error for range tick.C { diff --git a/tunnel/connection.go b/tunnel/connection.go index 15a62555..77671a4b 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -13,13 +13,15 @@ import ( "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/context" ) -func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) { - req := request.R +func handleHTTP(ctx *context.HTTPContext, outbound net.Conn) { + req := ctx.Request() + conn := ctx.Conn() host := req.Host - inboundReader := bufio.NewReader(request) + inboundReader := bufio.NewReader(conn) outboundReader := bufio.NewReader(outbound) for { @@ -43,7 +45,7 @@ func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) { inbound.RemoveHopByHopHeaders(resp.Header) if resp.StatusCode == http.StatusContinue { - err = resp.Write(request) + err = resp.Write(conn) if err != nil { break } @@ -58,14 +60,14 @@ func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) { } else { resp.Close = true } - err = resp.Write(request) + err = resp.Write(conn) if err != nil || resp.Close { break } // even if resp.Write write body to the connection, but some http request have to Copy to close it buf := pool.Get(pool.RelayBufferSize) - _, err = io.CopyBuffer(request, resp.Body, buf) + _, err = io.CopyBuffer(conn, resp.Body, buf) pool.Put(buf) if err != nil && err != io.EOF { break @@ -129,8 +131,8 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr n } } -func handleSocket(request C.ServerAdapter, outbound net.Conn) { - relay(request, outbound) +func handleSocket(ctx C.ConnContext, outbound net.Conn) { + relay(ctx.Conn(), outbound) } // relay copies between left and right bidirectionally. diff --git a/tunnel/manager.go b/tunnel/statistic/manager.go similarity index 99% rename from tunnel/manager.go rename to tunnel/statistic/manager.go index 784d57d9..462da674 100644 --- a/tunnel/manager.go +++ b/tunnel/statistic/manager.go @@ -1,4 +1,4 @@ -package tunnel +package statistic import ( "sync" diff --git a/tunnel/tracker.go b/tunnel/statistic/tracker.go similarity index 95% rename from tunnel/tracker.go rename to tunnel/statistic/tracker.go index dcb81e7f..1f5f1f9c 100644 --- a/tunnel/tracker.go +++ b/tunnel/statistic/tracker.go @@ -1,4 +1,4 @@ -package tunnel +package statistic import ( "net" @@ -57,7 +57,7 @@ func (tt *tcpTracker) Close() error { return tt.Conn.Close() } -func newTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { +func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { uuid, _ := uuid.NewV4() t := &tcpTracker{ @@ -114,7 +114,7 @@ func (ut *udpTracker) Close() error { return ut.PacketConn.Close() } -func newUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker { +func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker { uuid, _ := uuid.NewV4() ut := &udpTracker{ diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 1339701a..d7ca4c80 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -12,11 +12,13 @@ import ( "github.com/Dreamacro/clash/component/nat" "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/context" "github.com/Dreamacro/clash/log" + "github.com/Dreamacro/clash/tunnel/statistic" ) var ( - tcpQueue = make(chan C.ServerAdapter, 200) + tcpQueue = make(chan C.ConnContext, 200) udpQueue = make(chan *inbound.PacketAdapter, 200) natTable = nat.New() rules []C.Rule @@ -36,8 +38,8 @@ func init() { } // Add request to queue -func Add(req C.ServerAdapter) { - tcpQueue <- req +func Add(ctx C.ConnContext) { + tcpQueue <- ctx } // AddPacket add udp Packet to queue @@ -141,9 +143,7 @@ func preHandleMetadata(metadata *C.Metadata) error { return nil } -func resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { - var proxy C.Proxy - var rule C.Rule +func resolveMetadata(ctx C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) { switch mode { case Direct: proxy = proxies["DIRECT"] @@ -151,13 +151,9 @@ func resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { proxy = proxies["GLOBAL"] // Rule default: - var err error proxy, rule, err = match(metadata) - if err != nil { - return nil, nil, err - } } - return proxy, rule, nil + return } func handleUDPConn(packet *inbound.PacketAdapter) { @@ -210,7 +206,8 @@ func handleUDPConn(packet *inbound.PacketAdapter) { cond.Broadcast() }() - proxy, rule, err := resolveMetadata(metadata) + ctx := context.NewPacketConnContext(metadata) + proxy, rule, err := resolveMetadata(ctx, metadata) if err != nil { log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) return @@ -225,7 +222,8 @@ func handleUDPConn(packet *inbound.PacketAdapter) { } return } - pc := newUDPTracker(rawPc, DefaultManager, metadata, rule) + ctx.InjectPacketConn(rawPc) + pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule) switch true { case rule != nil: @@ -245,10 +243,10 @@ func handleUDPConn(packet *inbound.PacketAdapter) { }() } -func handleTCPConn(localConn C.ServerAdapter) { - defer localConn.Close() +func handleTCPConn(ctx C.ConnContext) { + defer ctx.Conn().Close() - metadata := localConn.Metadata() + metadata := ctx.Metadata() if !metadata.Valid() { log.Warnln("[Metadata] not valid: %#v", metadata) return @@ -259,7 +257,7 @@ func handleTCPConn(localConn C.ServerAdapter) { return } - proxy, rule, err := resolveMetadata(metadata) + proxy, rule, err := resolveMetadata(ctx, metadata) if err != nil { log.Warnln("[Metadata] parse failed: %s", err.Error()) return @@ -274,7 +272,7 @@ func handleTCPConn(localConn C.ServerAdapter) { } return } - remoteConn = newTCPTracker(remoteConn, DefaultManager, metadata, rule) + remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule) defer remoteConn.Close() switch true { @@ -288,11 +286,11 @@ func handleTCPConn(localConn C.ServerAdapter) { log.Infoln("[TCP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) } - switch adapter := localConn.(type) { - case *inbound.HTTPAdapter: - handleHTTP(adapter, remoteConn) - case *inbound.SocketAdapter: - handleSocket(adapter, remoteConn) + switch c := ctx.(type) { + case *context.HTTPContext: + handleHTTP(c, remoteConn) + default: + handleSocket(ctx, remoteConn) } }