Refactor: make inbound request contextual

This commit is contained in:
Dreamacro 2021-01-23 14:49:46 +08:00
parent 35925cb3da
commit f4de055aa1
19 changed files with 302 additions and 125 deletions

View file

@ -6,33 +6,18 @@ import (
"strings" "strings"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
) )
// HTTPAdapter is a adapter for HTTP connection // NewHTTP recieve normal http request and return HTTPContext
type HTTPAdapter struct { func NewHTTP(request *http.Request, conn net.Conn) *context.HTTPContext {
net.Conn
metadata *C.Metadata
R *http.Request
}
// Metadata return destination metadata
func (h *HTTPAdapter) Metadata() *C.Metadata {
return h.metadata
}
// NewHTTP is HTTPAdapter generator
func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter {
metadata := parseHTTPAddr(request) metadata := parseHTTPAddr(request)
metadata.Type = C.HTTP metadata.Type = C.HTTP
if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil {
metadata.SrcIP = ip metadata.SrcIP = ip
metadata.SrcPort = port metadata.SrcPort = port
} }
return &HTTPAdapter{ return context.NewHTTPContext(conn, request, metadata)
metadata: metadata,
R: request,
Conn: conn,
}
} }
// RemoveHopByHopHeaders remove hop-by-hop header // RemoveHopByHopHeaders remove hop-by-hop header

View file

@ -5,18 +5,16 @@ import (
"net/http" "net/http"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
) )
// NewHTTPS is HTTPAdapter generator // NewHTTPS recieve CONNECT request and return ConnContext
func NewHTTPS(request *http.Request, conn net.Conn) *SocketAdapter { func NewHTTPS(request *http.Request, conn net.Conn) *context.ConnContext {
metadata := parseHTTPAddr(request) metadata := parseHTTPAddr(request)
metadata.Type = C.HTTPCONNECT metadata.Type = C.HTTPCONNECT
if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil {
metadata.SrcIP = ip metadata.SrcIP = ip
metadata.SrcPort = port metadata.SrcPort = port
} }
return &SocketAdapter{ return context.NewConnContext(conn, metadata)
metadata: metadata,
Conn: conn,
}
} }

View file

@ -5,21 +5,11 @@ import (
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
) )
// SocketAdapter is a adapter for socks and redir connection // NewSocket recieve TCP inbound and return ConnContext
type SocketAdapter struct { func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *context.ConnContext {
net.Conn
metadata *C.Metadata
}
// Metadata return destination metadata
func (s *SocketAdapter) Metadata() *C.Metadata {
return s.metadata
}
// NewSocket is SocketAdapter generator
func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *SocketAdapter {
metadata := parseSocksAddr(target) metadata := parseSocksAddr(target)
metadata.NetWork = C.TCP metadata.NetWork = C.TCP
metadata.Type = source metadata.Type = source
@ -28,8 +18,5 @@ func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *SocketAdapter
metadata.SrcPort = port metadata.SrcPort = port
} }
return &SocketAdapter{ return context.NewConnContext(conn, metadata)
Conn: conn,
metadata: metadata,
}
} }

View file

@ -27,11 +27,6 @@ const (
LoadBalance LoadBalance
) )
type ServerAdapter interface {
net.Conn
Metadata() *Metadata
}
type Connection interface { type Connection interface {
Chains() Chain Chains() Chain
AppendToChains(adapter ProxyAdapter) AppendToChains(adapter ProxyAdapter)
@ -50,6 +45,15 @@ func (c Chain) String() string {
} }
} }
func (c Chain) Last() string {
switch len(c) {
case 0:
return ""
default:
return c[0]
}
}
type Conn interface { type Conn interface {
net.Conn net.Conn
Connection Connection

23
constant/context.go Normal file
View file

@ -0,0 +1,23 @@
package constant
import (
"net"
"github.com/gofrs/uuid"
)
type PlainContext interface {
ID() uuid.UUID
}
type ConnContext interface {
PlainContext
Metadata() *Metadata
Conn() net.Conn
}
type PacketConnContext interface {
PlainContext
Metadata() *Metadata
PacketConn() net.PacketConn
}

39
context/conn.go Normal file
View file

@ -0,0 +1,39 @@
package context
import (
"net"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type ConnContext struct {
id uuid.UUID
metadata *C.Metadata
conn net.Conn
}
func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
id, _ := uuid.NewV4()
return &ConnContext{
id: id,
metadata: metadata,
conn: conn,
}
}
// ID implement C.ConnContext ID
func (c *ConnContext) ID() uuid.UUID {
return c.id
}
// Metadata implement C.ConnContext Metadata
func (c *ConnContext) Metadata() *C.Metadata {
return c.metadata
}
// Conn implement C.ConnContext Conn
func (c *ConnContext) Conn() net.Conn {
return c.conn
}

41
context/dns.go Normal file
View file

@ -0,0 +1,41 @@
package context
import (
"github.com/gofrs/uuid"
"github.com/miekg/dns"
)
const (
DNSTypeHost = "host"
DNSTypeFakeIP = "fakeip"
DNSTypeRaw = "raw"
)
type DNSContext struct {
id uuid.UUID
msg *dns.Msg
tp string
}
func NewDNSContext(msg *dns.Msg) *DNSContext {
id, _ := uuid.NewV4()
return &DNSContext{
id: id,
msg: msg,
}
}
// ID implement C.PlainContext ID
func (c *DNSContext) ID() uuid.UUID {
return c.id
}
// SetType set type of response
func (c *DNSContext) SetType(tp string) {
c.tp = tp
}
// Type return type of response
func (c *DNSContext) Type() string {
return c.tp
}

47
context/http.go Normal file
View file

@ -0,0 +1,47 @@
package context
import (
"net"
"net/http"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type HTTPContext struct {
id uuid.UUID
metadata *C.Metadata
conn net.Conn
req *http.Request
}
func NewHTTPContext(conn net.Conn, req *http.Request, metadata *C.Metadata) *HTTPContext {
id, _ := uuid.NewV4()
return &HTTPContext{
id: id,
metadata: metadata,
conn: conn,
req: req,
}
}
// ID implement C.ConnContext ID
func (hc *HTTPContext) ID() uuid.UUID {
return hc.id
}
// Metadata implement C.ConnContext Metadata
func (hc *HTTPContext) Metadata() *C.Metadata {
return hc.metadata
}
// Conn implement C.ConnContext Conn
func (hc *HTTPContext) Conn() net.Conn {
return hc.conn
}
// Request return the http request struct
func (hc *HTTPContext) Request() *http.Request {
return hc.req
}

43
context/packetconn.go Normal file
View file

@ -0,0 +1,43 @@
package context
import (
"net"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type PacketConnContext struct {
id uuid.UUID
metadata *C.Metadata
packetConn net.PacketConn
}
func NewPacketConnContext(metadata *C.Metadata) *PacketConnContext {
id, _ := uuid.NewV4()
return &PacketConnContext{
id: id,
metadata: metadata,
}
}
// ID implement C.PacketConnContext ID
func (pc *PacketConnContext) ID() uuid.UUID {
return pc.id
}
// Metadata implement C.PacketConnContext Metadata
func (pc *PacketConnContext) Metadata() *C.Metadata {
return pc.metadata
}
// PacketConn implement C.PacketConnContext PacketConn
func (pc *PacketConnContext) PacketConn() net.PacketConn {
return pc.packetConn
}
// InjectPacketConn injectPacketConn manually
func (pc *PacketConnContext) InjectPacketConn(pconn C.PacketConn) {
pc.packetConn = pconn
}

View file

@ -8,26 +8,27 @@ import (
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/component/trie"
"github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
D "github.com/miekg/dns" D "github.com/miekg/dns"
) )
type handler func(r *D.Msg) (*D.Msg, error) type handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
type middleware func(next handler) handler type middleware func(next handler) handler
func withHosts(hosts *trie.DomainTrie) middleware { func withHosts(hosts *trie.DomainTrie) middleware {
return func(next handler) handler { return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
if !isIPRequest(q) { if !isIPRequest(q) {
return next(r) return next(ctx, r)
} }
record := hosts.Search(strings.TrimRight(q.Name, ".")) record := hosts.Search(strings.TrimRight(q.Name, "."))
if record == nil { if record == nil {
return next(r) return next(ctx, r)
} }
ip := record.Data.(net.IP) ip := record.Data.(net.IP)
@ -46,9 +47,10 @@ func withHosts(hosts *trie.DomainTrie) middleware {
msg.Answer = []D.RR{rr} msg.Answer = []D.RR{rr}
} else { } else {
return next(r) return next(ctx, r)
} }
ctx.SetType(context.DNSTypeHost)
msg.SetRcode(r, D.RcodeSuccess) msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true msg.Authoritative = true
msg.RecursionAvailable = true msg.RecursionAvailable = true
@ -60,14 +62,14 @@ func withHosts(hosts *trie.DomainTrie) middleware {
func withMapping(mapping *cache.LruCache) middleware { func withMapping(mapping *cache.LruCache) middleware {
return func(next handler) handler { return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
if !isIPRequest(q) { if !isIPRequest(q) {
return next(r) return next(ctx, r)
} }
msg, err := next(r) msg, err := next(ctx, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -99,12 +101,12 @@ func withMapping(mapping *cache.LruCache) middleware {
func withFakeIP(fakePool *fakeip.Pool) middleware { func withFakeIP(fakePool *fakeip.Pool) middleware {
return func(next handler) handler { return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
host := strings.TrimRight(q.Name, ".") host := strings.TrimRight(q.Name, ".")
if fakePool.LookupHost(host) { if fakePool.LookupHost(host) {
return next(r) return next(ctx, r)
} }
switch q.Qtype { switch q.Qtype {
@ -113,7 +115,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
} }
if q.Qtype != D.TypeA { if q.Qtype != D.TypeA {
return next(r) return next(ctx, r)
} }
rr := &D.A{} rr := &D.A{}
@ -123,6 +125,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
msg := r.Copy() msg := r.Copy()
msg.Answer = []D.RR{rr} msg.Answer = []D.RR{rr}
ctx.SetType(context.DNSTypeFakeIP)
setMsgTTL(msg, 1) setMsgTTL(msg, 1)
msg.SetRcode(r, D.RcodeSuccess) msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true msg.Authoritative = true
@ -134,7 +137,8 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
} }
func withResolver(resolver *Resolver) handler { func withResolver(resolver *Resolver) handler {
return func(r *D.Msg) (*D.Msg, error) { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
ctx.SetType(context.DNSTypeRaw)
q := r.Question[0] q := r.Question[0]
// return a empty AAAA msg when ipv6 disabled // return a empty AAAA msg when ipv6 disabled

View file

@ -212,7 +212,7 @@ func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
fallbackMsg := r.asyncExchange(r.fallback, m) fallbackMsg := r.asyncExchange(r.fallback, m)
res := <-msgCh res := <-msgCh
if res.Error == nil { if res.Error == nil {
if ips := r.msgToIP(res.Msg); len(ips) != 0 { if ips := msgToIP(res.Msg); len(ips) != 0 {
if !r.shouldIPFallback(ips[0]) { if !r.shouldIPFallback(ips[0]) {
msg = res.Msg // no need to wait for fallback result msg = res.Msg // no need to wait for fallback result
err = res.Error err = res.Error
@ -247,7 +247,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
return nil, err return nil, err
} }
ips := r.msgToIP(msg) ips := msgToIP(msg)
ipLength := len(ips) ipLength := len(ips)
if ipLength == 0 { if ipLength == 0 {
return nil, resolver.ErrIPNotFound return nil, resolver.ErrIPNotFound
@ -257,21 +257,6 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
return return
} }
func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
ips := []net.IP{}
for _, answer := range msg.Answer {
switch ans := answer.(type) {
case *D.AAAA:
ips = append(ips, ans.AAAA)
case *D.A:
ips = append(ips, ans.A)
}
}
return ips
}
func (r *Resolver) msgToDomain(msg *D.Msg) string { func (r *Resolver) msgToDomain(msg *D.Msg) string {
if len(msg.Question) > 0 { if len(msg.Question) > 0 {
return strings.TrimRight(msg.Question[0].Name, ".") return strings.TrimRight(msg.Question[0].Name, ".")

View file

@ -1,9 +1,11 @@
package dns package dns
import ( import (
"errors"
"net" "net"
"github.com/Dreamacro/clash/common/sockopt" "github.com/Dreamacro/clash/common/sockopt"
"github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
D "github.com/miekg/dns" D "github.com/miekg/dns"
@ -21,21 +23,25 @@ type Server struct {
handler handler handler handler
} }
// ServeDNS implement D.Handler ServeDNS
func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
if len(r.Question) == 0 { msg, err := handlerWithContext(s.handler, r)
D.HandleFailed(w, r)
return
}
msg, err := s.handler(r)
if err != nil { if err != nil {
D.HandleFailed(w, r) D.HandleFailed(w, r)
return return
} }
w.WriteMsg(msg) w.WriteMsg(msg)
} }
func handlerWithContext(handler handler, msg *D.Msg) (*D.Msg, error) {
if len(msg.Question) == 0 {
return nil, errors.New("at least one question is required")
}
ctx := context.NewDNSContext(msg)
return handler(ctx, msg)
}
func (s *Server) setHandler(handler handler) { func (s *Server) setHandler(handler handler) {
s.handler = handler s.handler = handler
} }

View file

@ -153,3 +153,18 @@ func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg {
return msg return msg
} }
func msgToIP(msg *D.Msg) []net.IP {
ips := []net.IP{}
for _, answer := range msg.Answer {
switch ans := answer.(type) {
case *D.AAAA:
ips = append(ips, ans.AAAA)
case *D.A:
ips = append(ips, ans.A)
}
}
return ips
}

View file

@ -7,7 +7,7 @@ import (
"strconv" "strconv"
"time" "time"
T "github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel/statistic"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/go-chi/chi" "github.com/go-chi/chi"
@ -24,7 +24,7 @@ func connectionRouter() http.Handler {
func getConnections(w http.ResponseWriter, r *http.Request) { func getConnections(w http.ResponseWriter, r *http.Request) {
if !websocket.IsWebSocketUpgrade(r) { if !websocket.IsWebSocketUpgrade(r) {
snapshot := T.DefaultManager.Snapshot() snapshot := statistic.DefaultManager.Snapshot()
render.JSON(w, r, snapshot) render.JSON(w, r, snapshot)
return return
} }
@ -50,7 +50,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
sendSnapshot := func() error { sendSnapshot := func() error {
buf.Reset() buf.Reset()
snapshot := T.DefaultManager.Snapshot() snapshot := statistic.DefaultManager.Snapshot()
if err := json.NewEncoder(buf).Encode(snapshot); err != nil { if err := json.NewEncoder(buf).Encode(snapshot); err != nil {
return err return err
} }
@ -73,7 +73,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) {
func closeConnection(w http.ResponseWriter, r *http.Request) { func closeConnection(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
snapshot := T.DefaultManager.Snapshot() snapshot := statistic.DefaultManager.Snapshot()
for _, c := range snapshot.Connections { for _, c := range snapshot.Connections {
if id == c.ID() { if id == c.ID() {
c.Close() c.Close()
@ -84,7 +84,7 @@ func closeConnection(w http.ResponseWriter, r *http.Request) {
} }
func closeAllConnections(w http.ResponseWriter, r *http.Request) { func closeAllConnections(w http.ResponseWriter, r *http.Request) {
snapshot := T.DefaultManager.Snapshot() snapshot := statistic.DefaultManager.Snapshot()
for _, c := range snapshot.Connections { for _, c := range snapshot.Connections {
c.Close() c.Close()
} }

View file

@ -9,7 +9,7 @@ import (
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
T "github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel/statistic"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/cors" "github.com/go-chi/cors"
@ -143,7 +143,7 @@ func traffic(w http.ResponseWriter, r *http.Request) {
tick := time.NewTicker(time.Second) tick := time.NewTicker(time.Second)
defer tick.Stop() defer tick.Stop()
t := T.DefaultManager t := statistic.DefaultManager
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
var err error var err error
for range tick.C { for range tick.C {

View file

@ -13,13 +13,15 @@ import (
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
) )
func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) { func handleHTTP(ctx *context.HTTPContext, outbound net.Conn) {
req := request.R req := ctx.Request()
conn := ctx.Conn()
host := req.Host host := req.Host
inboundReader := bufio.NewReader(request) inboundReader := bufio.NewReader(conn)
outboundReader := bufio.NewReader(outbound) outboundReader := bufio.NewReader(outbound)
for { for {
@ -43,7 +45,7 @@ func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) {
inbound.RemoveHopByHopHeaders(resp.Header) inbound.RemoveHopByHopHeaders(resp.Header)
if resp.StatusCode == http.StatusContinue { if resp.StatusCode == http.StatusContinue {
err = resp.Write(request) err = resp.Write(conn)
if err != nil { if err != nil {
break break
} }
@ -58,14 +60,14 @@ func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) {
} else { } else {
resp.Close = true resp.Close = true
} }
err = resp.Write(request) err = resp.Write(conn)
if err != nil || resp.Close { if err != nil || resp.Close {
break break
} }
// even if resp.Write write body to the connection, but some http request have to Copy to close it // even if resp.Write write body to the connection, but some http request have to Copy to close it
buf := pool.Get(pool.RelayBufferSize) buf := pool.Get(pool.RelayBufferSize)
_, err = io.CopyBuffer(request, resp.Body, buf) _, err = io.CopyBuffer(conn, resp.Body, buf)
pool.Put(buf) pool.Put(buf)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
break break
@ -129,8 +131,8 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr n
} }
} }
func handleSocket(request C.ServerAdapter, outbound net.Conn) { func handleSocket(ctx C.ConnContext, outbound net.Conn) {
relay(request, outbound) relay(ctx.Conn(), outbound)
} }
// relay copies between left and right bidirectionally. // relay copies between left and right bidirectionally.

View file

@ -1,4 +1,4 @@
package tunnel package statistic
import ( import (
"sync" "sync"

View file

@ -1,4 +1,4 @@
package tunnel package statistic
import ( import (
"net" "net"
@ -57,7 +57,7 @@ func (tt *tcpTracker) Close() error {
return tt.Conn.Close() return tt.Conn.Close()
} }
func newTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker {
uuid, _ := uuid.NewV4() uuid, _ := uuid.NewV4()
t := &tcpTracker{ t := &tcpTracker{
@ -114,7 +114,7 @@ func (ut *udpTracker) Close() error {
return ut.PacketConn.Close() return ut.PacketConn.Close()
} }
func newUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker { func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker {
uuid, _ := uuid.NewV4() uuid, _ := uuid.NewV4()
ut := &udpTracker{ ut := &udpTracker{

View file

@ -12,11 +12,13 @@ import (
"github.com/Dreamacro/clash/component/nat" "github.com/Dreamacro/clash/component/nat"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/tunnel/statistic"
) )
var ( var (
tcpQueue = make(chan C.ServerAdapter, 200) tcpQueue = make(chan C.ConnContext, 200)
udpQueue = make(chan *inbound.PacketAdapter, 200) udpQueue = make(chan *inbound.PacketAdapter, 200)
natTable = nat.New() natTable = nat.New()
rules []C.Rule rules []C.Rule
@ -36,8 +38,8 @@ func init() {
} }
// Add request to queue // Add request to queue
func Add(req C.ServerAdapter) { func Add(ctx C.ConnContext) {
tcpQueue <- req tcpQueue <- ctx
} }
// AddPacket add udp Packet to queue // AddPacket add udp Packet to queue
@ -141,9 +143,7 @@ func preHandleMetadata(metadata *C.Metadata) error {
return nil return nil
} }
func resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { func resolveMetadata(ctx C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) {
var proxy C.Proxy
var rule C.Rule
switch mode { switch mode {
case Direct: case Direct:
proxy = proxies["DIRECT"] proxy = proxies["DIRECT"]
@ -151,13 +151,9 @@ func resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
proxy = proxies["GLOBAL"] proxy = proxies["GLOBAL"]
// Rule // Rule
default: default:
var err error
proxy, rule, err = match(metadata) proxy, rule, err = match(metadata)
if err != nil {
return nil, nil, err
}
} }
return proxy, rule, nil return
} }
func handleUDPConn(packet *inbound.PacketAdapter) { func handleUDPConn(packet *inbound.PacketAdapter) {
@ -210,7 +206,8 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
cond.Broadcast() cond.Broadcast()
}() }()
proxy, rule, err := resolveMetadata(metadata) ctx := context.NewPacketConnContext(metadata)
proxy, rule, err := resolveMetadata(ctx, metadata)
if err != nil { if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return return
@ -225,7 +222,8 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
} }
return return
} }
pc := newUDPTracker(rawPc, DefaultManager, metadata, rule) ctx.InjectPacketConn(rawPc)
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule)
switch true { switch true {
case rule != nil: case rule != nil:
@ -245,10 +243,10 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
}() }()
} }
func handleTCPConn(localConn C.ServerAdapter) { func handleTCPConn(ctx C.ConnContext) {
defer localConn.Close() defer ctx.Conn().Close()
metadata := localConn.Metadata() metadata := ctx.Metadata()
if !metadata.Valid() { if !metadata.Valid() {
log.Warnln("[Metadata] not valid: %#v", metadata) log.Warnln("[Metadata] not valid: %#v", metadata)
return return
@ -259,7 +257,7 @@ func handleTCPConn(localConn C.ServerAdapter) {
return return
} }
proxy, rule, err := resolveMetadata(metadata) proxy, rule, err := resolveMetadata(ctx, metadata)
if err != nil { if err != nil {
log.Warnln("[Metadata] parse failed: %s", err.Error()) log.Warnln("[Metadata] parse failed: %s", err.Error())
return return
@ -274,7 +272,7 @@ func handleTCPConn(localConn C.ServerAdapter) {
} }
return return
} }
remoteConn = newTCPTracker(remoteConn, DefaultManager, metadata, rule) remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule)
defer remoteConn.Close() defer remoteConn.Close()
switch true { switch true {
@ -288,11 +286,11 @@ func handleTCPConn(localConn C.ServerAdapter) {
log.Infoln("[TCP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) log.Infoln("[TCP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
} }
switch adapter := localConn.(type) { switch c := ctx.(type) {
case *inbound.HTTPAdapter: case *context.HTTPContext:
handleHTTP(adapter, remoteConn) handleHTTP(c, remoteConn)
case *inbound.SocketAdapter: default:
handleSocket(adapter, remoteConn) handleSocket(ctx, remoteConn)
} }
} }