Refactor: make inbound request contextual

This commit is contained in:
Dreamacro 2021-01-23 14:49:46 +08:00
parent 35925cb3da
commit f4de055aa1
19 changed files with 302 additions and 125 deletions

View file

@ -6,33 +6,18 @@ import (
"strings"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
)
// HTTPAdapter is a adapter for HTTP connection
type HTTPAdapter struct {
net.Conn
metadata *C.Metadata
R *http.Request
}
// Metadata return destination metadata
func (h *HTTPAdapter) Metadata() *C.Metadata {
return h.metadata
}
// NewHTTP is HTTPAdapter generator
func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter {
// NewHTTP recieve normal http request and return HTTPContext
func NewHTTP(request *http.Request, conn net.Conn) *context.HTTPContext {
metadata := parseHTTPAddr(request)
metadata.Type = C.HTTP
if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil {
metadata.SrcIP = ip
metadata.SrcPort = port
}
return &HTTPAdapter{
metadata: metadata,
R: request,
Conn: conn,
}
return context.NewHTTPContext(conn, request, metadata)
}
// RemoveHopByHopHeaders remove hop-by-hop header

View file

@ -5,18 +5,16 @@ import (
"net/http"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
)
// NewHTTPS is HTTPAdapter generator
func NewHTTPS(request *http.Request, conn net.Conn) *SocketAdapter {
// NewHTTPS recieve CONNECT request and return ConnContext
func NewHTTPS(request *http.Request, conn net.Conn) *context.ConnContext {
metadata := parseHTTPAddr(request)
metadata.Type = C.HTTPCONNECT
if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil {
metadata.SrcIP = ip
metadata.SrcPort = port
}
return &SocketAdapter{
metadata: metadata,
Conn: conn,
}
return context.NewConnContext(conn, metadata)
}

View file

@ -5,21 +5,11 @@ import (
"github.com/Dreamacro/clash/component/socks5"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
)
// SocketAdapter is a adapter for socks and redir connection
type SocketAdapter struct {
net.Conn
metadata *C.Metadata
}
// Metadata return destination metadata
func (s *SocketAdapter) Metadata() *C.Metadata {
return s.metadata
}
// NewSocket is SocketAdapter generator
func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *SocketAdapter {
// NewSocket recieve TCP inbound and return ConnContext
func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *context.ConnContext {
metadata := parseSocksAddr(target)
metadata.NetWork = C.TCP
metadata.Type = source
@ -28,8 +18,5 @@ func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *SocketAdapter
metadata.SrcPort = port
}
return &SocketAdapter{
Conn: conn,
metadata: metadata,
}
return context.NewConnContext(conn, metadata)
}

View file

@ -27,11 +27,6 @@ const (
LoadBalance
)
type ServerAdapter interface {
net.Conn
Metadata() *Metadata
}
type Connection interface {
Chains() Chain
AppendToChains(adapter ProxyAdapter)
@ -50,6 +45,15 @@ func (c Chain) String() string {
}
}
func (c Chain) Last() string {
switch len(c) {
case 0:
return ""
default:
return c[0]
}
}
type Conn interface {
net.Conn
Connection

23
constant/context.go Normal file
View file

@ -0,0 +1,23 @@
package constant
import (
"net"
"github.com/gofrs/uuid"
)
type PlainContext interface {
ID() uuid.UUID
}
type ConnContext interface {
PlainContext
Metadata() *Metadata
Conn() net.Conn
}
type PacketConnContext interface {
PlainContext
Metadata() *Metadata
PacketConn() net.PacketConn
}

39
context/conn.go Normal file
View file

@ -0,0 +1,39 @@
package context
import (
"net"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type ConnContext struct {
id uuid.UUID
metadata *C.Metadata
conn net.Conn
}
func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
id, _ := uuid.NewV4()
return &ConnContext{
id: id,
metadata: metadata,
conn: conn,
}
}
// ID implement C.ConnContext ID
func (c *ConnContext) ID() uuid.UUID {
return c.id
}
// Metadata implement C.ConnContext Metadata
func (c *ConnContext) Metadata() *C.Metadata {
return c.metadata
}
// Conn implement C.ConnContext Conn
func (c *ConnContext) Conn() net.Conn {
return c.conn
}

41
context/dns.go Normal file
View file

@ -0,0 +1,41 @@
package context
import (
"github.com/gofrs/uuid"
"github.com/miekg/dns"
)
const (
DNSTypeHost = "host"
DNSTypeFakeIP = "fakeip"
DNSTypeRaw = "raw"
)
type DNSContext struct {
id uuid.UUID
msg *dns.Msg
tp string
}
func NewDNSContext(msg *dns.Msg) *DNSContext {
id, _ := uuid.NewV4()
return &DNSContext{
id: id,
msg: msg,
}
}
// ID implement C.PlainContext ID
func (c *DNSContext) ID() uuid.UUID {
return c.id
}
// SetType set type of response
func (c *DNSContext) SetType(tp string) {
c.tp = tp
}
// Type return type of response
func (c *DNSContext) Type() string {
return c.tp
}

47
context/http.go Normal file
View file

@ -0,0 +1,47 @@
package context
import (
"net"
"net/http"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type HTTPContext struct {
id uuid.UUID
metadata *C.Metadata
conn net.Conn
req *http.Request
}
func NewHTTPContext(conn net.Conn, req *http.Request, metadata *C.Metadata) *HTTPContext {
id, _ := uuid.NewV4()
return &HTTPContext{
id: id,
metadata: metadata,
conn: conn,
req: req,
}
}
// ID implement C.ConnContext ID
func (hc *HTTPContext) ID() uuid.UUID {
return hc.id
}
// Metadata implement C.ConnContext Metadata
func (hc *HTTPContext) Metadata() *C.Metadata {
return hc.metadata
}
// Conn implement C.ConnContext Conn
func (hc *HTTPContext) Conn() net.Conn {
return hc.conn
}
// Request return the http request struct
func (hc *HTTPContext) Request() *http.Request {
return hc.req
}

43
context/packetconn.go Normal file
View file

@ -0,0 +1,43 @@
package context
import (
"net"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type PacketConnContext struct {
id uuid.UUID
metadata *C.Metadata
packetConn net.PacketConn
}
func NewPacketConnContext(metadata *C.Metadata) *PacketConnContext {
id, _ := uuid.NewV4()
return &PacketConnContext{
id: id,
metadata: metadata,
}
}
// ID implement C.PacketConnContext ID
func (pc *PacketConnContext) ID() uuid.UUID {
return pc.id
}
// Metadata implement C.PacketConnContext Metadata
func (pc *PacketConnContext) Metadata() *C.Metadata {
return pc.metadata
}
// PacketConn implement C.PacketConnContext PacketConn
func (pc *PacketConnContext) PacketConn() net.PacketConn {
return pc.packetConn
}
// InjectPacketConn injectPacketConn manually
func (pc *PacketConnContext) InjectPacketConn(pconn C.PacketConn) {
pc.packetConn = pconn
}

View file

@ -8,26 +8,27 @@ import (
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/component/trie"
"github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
)
type handler func(r *D.Msg) (*D.Msg, error)
type handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
type middleware func(next handler) handler
func withHosts(hosts *trie.DomainTrie) middleware {
return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
if !isIPRequest(q) {
return next(r)
return next(ctx, r)
}
record := hosts.Search(strings.TrimRight(q.Name, "."))
if record == nil {
return next(r)
return next(ctx, r)
}
ip := record.Data.(net.IP)
@ -46,9 +47,10 @@ func withHosts(hosts *trie.DomainTrie) middleware {
msg.Answer = []D.RR{rr}
} else {
return next(r)
return next(ctx, r)
}
ctx.SetType(context.DNSTypeHost)
msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true
msg.RecursionAvailable = true
@ -60,14 +62,14 @@ func withHosts(hosts *trie.DomainTrie) middleware {
func withMapping(mapping *cache.LruCache) middleware {
return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
if !isIPRequest(q) {
return next(r)
return next(ctx, r)
}
msg, err := next(r)
msg, err := next(ctx, r)
if err != nil {
return nil, err
}
@ -99,12 +101,12 @@ func withMapping(mapping *cache.LruCache) middleware {
func withFakeIP(fakePool *fakeip.Pool) middleware {
return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
host := strings.TrimRight(q.Name, ".")
if fakePool.LookupHost(host) {
return next(r)
return next(ctx, r)
}
switch q.Qtype {
@ -113,7 +115,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
}
if q.Qtype != D.TypeA {
return next(r)
return next(ctx, r)
}
rr := &D.A{}
@ -123,6 +125,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
msg := r.Copy()
msg.Answer = []D.RR{rr}
ctx.SetType(context.DNSTypeFakeIP)
setMsgTTL(msg, 1)
msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true
@ -134,7 +137,8 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
}
func withResolver(resolver *Resolver) handler {
return func(r *D.Msg) (*D.Msg, error) {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
ctx.SetType(context.DNSTypeRaw)
q := r.Question[0]
// return a empty AAAA msg when ipv6 disabled

View file

@ -212,7 +212,7 @@ func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
fallbackMsg := r.asyncExchange(r.fallback, m)
res := <-msgCh
if res.Error == nil {
if ips := r.msgToIP(res.Msg); len(ips) != 0 {
if ips := msgToIP(res.Msg); len(ips) != 0 {
if !r.shouldIPFallback(ips[0]) {
msg = res.Msg // no need to wait for fallback result
err = res.Error
@ -247,7 +247,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
return nil, err
}
ips := r.msgToIP(msg)
ips := msgToIP(msg)
ipLength := len(ips)
if ipLength == 0 {
return nil, resolver.ErrIPNotFound
@ -257,21 +257,6 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
return
}
func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
ips := []net.IP{}
for _, answer := range msg.Answer {
switch ans := answer.(type) {
case *D.AAAA:
ips = append(ips, ans.AAAA)
case *D.A:
ips = append(ips, ans.A)
}
}
return ips
}
func (r *Resolver) msgToDomain(msg *D.Msg) string {
if len(msg.Question) > 0 {
return strings.TrimRight(msg.Question[0].Name, ".")

View file

@ -1,9 +1,11 @@
package dns
import (
"errors"
"net"
"github.com/Dreamacro/clash/common/sockopt"
"github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
@ -21,21 +23,25 @@ type Server struct {
handler handler
}
// ServeDNS implement D.Handler ServeDNS
func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
if len(r.Question) == 0 {
D.HandleFailed(w, r)
return
}
msg, err := s.handler(r)
msg, err := handlerWithContext(s.handler, r)
if err != nil {
D.HandleFailed(w, r)
return
}
w.WriteMsg(msg)
}
func handlerWithContext(handler handler, msg *D.Msg) (*D.Msg, error) {
if len(msg.Question) == 0 {
return nil, errors.New("at least one question is required")
}
ctx := context.NewDNSContext(msg)
return handler(ctx, msg)
}
func (s *Server) setHandler(handler handler) {
s.handler = handler
}

View file

@ -153,3 +153,18 @@ func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg {
return msg
}
func msgToIP(msg *D.Msg) []net.IP {
ips := []net.IP{}
for _, answer := range msg.Answer {
switch ans := answer.(type) {
case *D.AAAA:
ips = append(ips, ans.AAAA)
case *D.A:
ips = append(ips, ans.A)
}
}
return ips
}

View file

@ -7,7 +7,7 @@ import (
"strconv"
"time"
T "github.com/Dreamacro/clash/tunnel"
"github.com/Dreamacro/clash/tunnel/statistic"
"github.com/gorilla/websocket"
"github.com/go-chi/chi"
@ -24,7 +24,7 @@ func connectionRouter() http.Handler {
func getConnections(w http.ResponseWriter, r *http.Request) {
if !websocket.IsWebSocketUpgrade(r) {
snapshot := T.DefaultManager.Snapshot()
snapshot := statistic.DefaultManager.Snapshot()
render.JSON(w, r, snapshot)
return
}
@ -50,7 +50,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) {
buf := &bytes.Buffer{}
sendSnapshot := func() error {
buf.Reset()
snapshot := T.DefaultManager.Snapshot()
snapshot := statistic.DefaultManager.Snapshot()
if err := json.NewEncoder(buf).Encode(snapshot); err != nil {
return err
}
@ -73,7 +73,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) {
func closeConnection(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
snapshot := T.DefaultManager.Snapshot()
snapshot := statistic.DefaultManager.Snapshot()
for _, c := range snapshot.Connections {
if id == c.ID() {
c.Close()
@ -84,7 +84,7 @@ func closeConnection(w http.ResponseWriter, r *http.Request) {
}
func closeAllConnections(w http.ResponseWriter, r *http.Request) {
snapshot := T.DefaultManager.Snapshot()
snapshot := statistic.DefaultManager.Snapshot()
for _, c := range snapshot.Connections {
c.Close()
}

View file

@ -9,7 +9,7 @@ import (
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
T "github.com/Dreamacro/clash/tunnel"
"github.com/Dreamacro/clash/tunnel/statistic"
"github.com/go-chi/chi"
"github.com/go-chi/cors"
@ -143,7 +143,7 @@ func traffic(w http.ResponseWriter, r *http.Request) {
tick := time.NewTicker(time.Second)
defer tick.Stop()
t := T.DefaultManager
t := statistic.DefaultManager
buf := &bytes.Buffer{}
var err error
for range tick.C {

View file

@ -13,13 +13,15 @@ import (
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
)
func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) {
req := request.R
func handleHTTP(ctx *context.HTTPContext, outbound net.Conn) {
req := ctx.Request()
conn := ctx.Conn()
host := req.Host
inboundReader := bufio.NewReader(request)
inboundReader := bufio.NewReader(conn)
outboundReader := bufio.NewReader(outbound)
for {
@ -43,7 +45,7 @@ func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) {
inbound.RemoveHopByHopHeaders(resp.Header)
if resp.StatusCode == http.StatusContinue {
err = resp.Write(request)
err = resp.Write(conn)
if err != nil {
break
}
@ -58,14 +60,14 @@ func handleHTTP(request *inbound.HTTPAdapter, outbound net.Conn) {
} else {
resp.Close = true
}
err = resp.Write(request)
err = resp.Write(conn)
if err != nil || resp.Close {
break
}
// even if resp.Write write body to the connection, but some http request have to Copy to close it
buf := pool.Get(pool.RelayBufferSize)
_, err = io.CopyBuffer(request, resp.Body, buf)
_, err = io.CopyBuffer(conn, resp.Body, buf)
pool.Put(buf)
if err != nil && err != io.EOF {
break
@ -129,8 +131,8 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr n
}
}
func handleSocket(request C.ServerAdapter, outbound net.Conn) {
relay(request, outbound)
func handleSocket(ctx C.ConnContext, outbound net.Conn) {
relay(ctx.Conn(), outbound)
}
// relay copies between left and right bidirectionally.

View file

@ -1,4 +1,4 @@
package tunnel
package statistic
import (
"sync"

View file

@ -1,4 +1,4 @@
package tunnel
package statistic
import (
"net"
@ -57,7 +57,7 @@ func (tt *tcpTracker) Close() error {
return tt.Conn.Close()
}
func newTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker {
func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker {
uuid, _ := uuid.NewV4()
t := &tcpTracker{
@ -114,7 +114,7 @@ func (ut *udpTracker) Close() error {
return ut.PacketConn.Close()
}
func newUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker {
func NewUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, rule C.Rule) *udpTracker {
uuid, _ := uuid.NewV4()
ut := &udpTracker{

View file

@ -12,11 +12,13 @@ import (
"github.com/Dreamacro/clash/component/nat"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/tunnel/statistic"
)
var (
tcpQueue = make(chan C.ServerAdapter, 200)
tcpQueue = make(chan C.ConnContext, 200)
udpQueue = make(chan *inbound.PacketAdapter, 200)
natTable = nat.New()
rules []C.Rule
@ -36,8 +38,8 @@ func init() {
}
// Add request to queue
func Add(req C.ServerAdapter) {
tcpQueue <- req
func Add(ctx C.ConnContext) {
tcpQueue <- ctx
}
// AddPacket add udp Packet to queue
@ -141,9 +143,7 @@ func preHandleMetadata(metadata *C.Metadata) error {
return nil
}
func resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
var proxy C.Proxy
var rule C.Rule
func resolveMetadata(ctx C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) {
switch mode {
case Direct:
proxy = proxies["DIRECT"]
@ -151,13 +151,9 @@ func resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
proxy = proxies["GLOBAL"]
// Rule
default:
var err error
proxy, rule, err = match(metadata)
if err != nil {
return nil, nil, err
}
}
return proxy, rule, nil
return
}
func handleUDPConn(packet *inbound.PacketAdapter) {
@ -210,7 +206,8 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
cond.Broadcast()
}()
proxy, rule, err := resolveMetadata(metadata)
ctx := context.NewPacketConnContext(metadata)
proxy, rule, err := resolveMetadata(ctx, metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return
@ -225,7 +222,8 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
}
return
}
pc := newUDPTracker(rawPc, DefaultManager, metadata, rule)
ctx.InjectPacketConn(rawPc)
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule)
switch true {
case rule != nil:
@ -245,10 +243,10 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
}()
}
func handleTCPConn(localConn C.ServerAdapter) {
defer localConn.Close()
func handleTCPConn(ctx C.ConnContext) {
defer ctx.Conn().Close()
metadata := localConn.Metadata()
metadata := ctx.Metadata()
if !metadata.Valid() {
log.Warnln("[Metadata] not valid: %#v", metadata)
return
@ -259,7 +257,7 @@ func handleTCPConn(localConn C.ServerAdapter) {
return
}
proxy, rule, err := resolveMetadata(metadata)
proxy, rule, err := resolveMetadata(ctx, metadata)
if err != nil {
log.Warnln("[Metadata] parse failed: %s", err.Error())
return
@ -274,7 +272,7 @@ func handleTCPConn(localConn C.ServerAdapter) {
}
return
}
remoteConn = newTCPTracker(remoteConn, DefaultManager, metadata, rule)
remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule)
defer remoteConn.Close()
switch true {
@ -288,11 +286,11 @@ func handleTCPConn(localConn C.ServerAdapter) {
log.Infoln("[TCP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
}
switch adapter := localConn.(type) {
case *inbound.HTTPAdapter:
handleHTTP(adapter, remoteConn)
case *inbound.SocketAdapter:
handleSocket(adapter, remoteConn)
switch c := ctx.(type) {
case *context.HTTPContext:
handleHTTP(c, remoteConn)
default:
handleSocket(ctx, remoteConn)
}
}