From 544e0f137d1951f6d4891cf1bd7a7e0ec36a527b Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sat, 9 Apr 2022 22:30:36 +0800 Subject: [PATCH] feat: sniffer support sniffer: enable: true force: false # Overwrite domain sniffing: - tls --- component/sniffer/dispatcher.go | 104 ++++++++++++++++++ .../tls => component/sniffer}/sniff_test.go | 8 +- .../sniffer/tls_sniffer.go | 87 ++++++++------- config/config.go | 53 +++++++++ constant/sniffer.go | 26 +++++ context/conn.go | 4 +- hub/executor/executor.go | 12 ++ tunnel/statistic/tracker.go | 17 --- tunnel/tunnel.go | 7 ++ 9 files changed, 256 insertions(+), 62 deletions(-) create mode 100644 component/sniffer/dispatcher.go rename {common/snifer/tls => component/sniffer}/sniff_test.go (97%) rename common/snifer/tls/sniff.go => component/sniffer/tls_sniffer.go (66%) create mode 100644 constant/sniffer.go diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go new file mode 100644 index 00000000..60c52015 --- /dev/null +++ b/component/sniffer/dispatcher.go @@ -0,0 +1,104 @@ +package sniffer + +import ( + "errors" + "net" + + CN "github.com/Dreamacro/clash/common/net" + "github.com/Dreamacro/clash/component/resolver" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +var ( + ErrorUnsupportedSniffer = errors.New("unsupported sniffer") +) + +var Dispatcher SnifferDispatcher + +type SnifferDispatcher struct { + enable bool + force bool + sniffers []C.Sniffer +} + +func (sd *SnifferDispatcher) Tcp(conn net.Conn, metadata *C.Metadata) { + bufConn, ok := conn.(*CN.BufferedConn) + if !ok { + return + } + + if sd.force { + sd.cover(bufConn, metadata) + } else { + if metadata.Host != "" { + return + } + + sd.cover(bufConn, metadata) + } +} + +func (sd *SnifferDispatcher) Enable() bool { + return sd.enable +} + +func (sd *SnifferDispatcher) cover(conn *CN.BufferedConn, metadata *C.Metadata) { + for _, sniffer := range sd.sniffers { + if sniffer.SupportNetwork() == C.TCP { + conn.Peek(1) + len := conn.Buffered() + bytes, err := conn.Peek(len) + if err != nil { + log.Warnln("the data lenght not enough") + continue + } + + host, err := sniffer.SniffTCP(bytes) + if err != nil { + log.Warnln("Sniff data failed on Sniffer[%s]", sniffer.Protocol()) + continue + } + + metadata.Host = host + metadata.DstIP = nil + metadata.AddrType = C.AtypDomainName + if resolver.FakeIPEnabled() { + metadata.DNSMode = C.DNSFakeIP + } else { + metadata.DNSMode = C.DNSMapping + } + + resolver.InsertHostByIP(metadata.DstIP, host) + break + } + } +} + +func NewSnifferDispatcher(needSniffer []C.SnifferType, force bool) (SnifferDispatcher, error) { + dispatcher := SnifferDispatcher{ + enable: true, + force: force, + } + + for _, snifferName := range needSniffer { + sniffer, err := NewSniffer(snifferName) + if err != nil { + log.Errorln("Sniffer name[%s] is error", snifferName) + return SnifferDispatcher{enable: false}, err + } + + dispatcher.sniffers = append(dispatcher.sniffers, sniffer) + } + + return dispatcher, nil +} + +func NewSniffer(name C.SnifferType) (C.Sniffer, error) { + switch name { + case C.TLS: + return &TLSSniffer{}, nil + default: + return nil, ErrorUnsupportedSniffer + } +} diff --git a/common/snifer/tls/sniff_test.go b/component/sniffer/sniff_test.go similarity index 97% rename from common/snifer/tls/sniff_test.go rename to component/sniffer/sniff_test.go index 26f5f1ee..e7ced43c 100644 --- a/common/snifer/tls/sniff_test.go +++ b/component/sniffer/sniff_test.go @@ -1,4 +1,4 @@ -package tls +package sniffer import ( "testing" @@ -142,7 +142,7 @@ func TestTLSHeaders(t *testing.T) { } for _, test := range cases { - header, err := SniffTLS(test.input) + domain, err := SniffTLS(test.input) if test.err { if err == nil { t.Errorf("Exepct error but nil in test %v", test) @@ -151,8 +151,8 @@ func TestTLSHeaders(t *testing.T) { if err != nil { t.Errorf("Expect no error but actually %s in test %v", err.Error(), test) } - if header.Domain() != test.domain { - t.Error("expect domain ", test.domain, " but got ", header.Domain()) + if *domain != test.domain { + t.Error("expect domain ", test.domain, " but got ", domain) } } } diff --git a/common/snifer/tls/sniff.go b/component/sniffer/tls_sniffer.go similarity index 66% rename from common/snifer/tls/sniff.go rename to component/sniffer/tls_sniffer.go index 1471fc68..6af0b97b 100644 --- a/common/snifer/tls/sniff.go +++ b/component/sniffer/tls_sniffer.go @@ -1,107 +1,116 @@ -package tls +package sniffer import ( "encoding/binary" "errors" "strings" + + C "github.com/Dreamacro/clash/constant" ) -var ErrNoClue = errors.New("not enough information for making a decision") - -type SniffHeader struct { - domain string -} - -func (h *SniffHeader) Protocol() string { - return "tls" -} - -func (h *SniffHeader) Domain() string { - return h.domain -} - var ( errNotTLS = errors.New("not TLS header") errNotClientHello = errors.New("not client hello") + ErrNoClue = errors.New("not enough information for making a decision") ) +type TLSSniffer struct { +} + +func (tls *TLSSniffer) Protocol() string { + return "tls" +} + +func (tls *TLSSniffer) SupportNetwork() C.NetWork { + return C.TCP +} + +func (tls *TLSSniffer) SniffTCP(bytes []byte) (string, error) { + domain, err := SniffTLS(bytes) + if err == nil { + return *domain, nil + } else { + return "", err + } +} + func IsValidTLSVersion(major, minor byte) bool { return major == 3 } // ReadClientHello returns server name (if any) from TLS client hello message. // https://github.com/golang/go/blob/master/src/crypto/tls/handshake_messages.go#L300 -func ReadClientHello(data []byte, h *SniffHeader) error { +func ReadClientHello(data []byte) (*string, error) { if len(data) < 42 { - return ErrNoClue + return nil, ErrNoClue } sessionIDLen := int(data[38]) if sessionIDLen > 32 || len(data) < 39+sessionIDLen { - return ErrNoClue + return nil, ErrNoClue } data = data[39+sessionIDLen:] if len(data) < 2 { - return ErrNoClue + return nil, ErrNoClue } // cipherSuiteLen is the number of bytes of cipher suite numbers. Since // they are uint16s, the number must be even. cipherSuiteLen := int(data[0])<<8 | int(data[1]) if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { - return errNotClientHello + return nil, errNotClientHello } data = data[2+cipherSuiteLen:] if len(data) < 1 { - return ErrNoClue + return nil, ErrNoClue } compressionMethodsLen := int(data[0]) if len(data) < 1+compressionMethodsLen { - return ErrNoClue + return nil, ErrNoClue } data = data[1+compressionMethodsLen:] if len(data) == 0 { - return errNotClientHello + return nil, errNotClientHello } if len(data) < 2 { - return errNotClientHello + return nil, errNotClientHello } extensionsLength := int(data[0])<<8 | int(data[1]) data = data[2:] if extensionsLength != len(data) { - return errNotClientHello + return nil, errNotClientHello } for len(data) != 0 { if len(data) < 4 { - return errNotClientHello + return nil, errNotClientHello } extension := uint16(data[0])<<8 | uint16(data[1]) length := int(data[2])<<8 | int(data[3]) data = data[4:] if len(data) < length { - return errNotClientHello + return nil, errNotClientHello } if extension == 0x00 { /* extensionServerName */ d := data[:length] if len(d) < 2 { - return errNotClientHello + return nil, errNotClientHello } namesLen := int(d[0])<<8 | int(d[1]) d = d[2:] if len(d) != namesLen { - return errNotClientHello + return nil, errNotClientHello } for len(d) > 0 { if len(d) < 3 { - return errNotClientHello + return nil, errNotClientHello } nameType := d[0] nameLen := int(d[1])<<8 | int(d[2]) d = d[3:] if len(d) < nameLen { - return errNotClientHello + return nil, errNotClientHello } if nameType == 0 { serverName := string(d[:nameLen]) @@ -109,21 +118,22 @@ func ReadClientHello(data []byte, h *SniffHeader) error { // trailing dot. See // https://tools.ietf.org/html/rfc6066#section-3. if strings.HasSuffix(serverName, ".") { - return errNotClientHello + return nil, errNotClientHello } - h.domain = serverName - return nil + + return &serverName, nil } + d = d[nameLen:] } } data = data[length:] } - return errNotTLS + return nil, errNotTLS } -func SniffTLS(b []byte) (*SniffHeader, error) { +func SniffTLS(b []byte) (*string, error) { if len(b) < 5 { return nil, ErrNoClue } @@ -139,10 +149,9 @@ func SniffTLS(b []byte) (*SniffHeader, error) { return nil, ErrNoClue } - h := &SniffHeader{} - err := ReadClientHello(b[5:5+headerLen], h) + domain, err := ReadClientHello(b[5 : 5+headerLen]) if err == nil { - return h, nil + return domain, nil } return nil, err } diff --git a/config/config.go b/config/config.go index 7db86916..8386c0c6 100644 --- a/config/config.go +++ b/config/config.go @@ -126,6 +126,12 @@ type IPTables struct { Bypass []string `yaml:"bypass" json:"bypass"` } +type Sniffer struct { + Enable bool + Force bool + Sniffers []C.SnifferType +} + // Experimental config type Experimental struct{} @@ -143,6 +149,7 @@ type Config struct { Proxies map[string]C.Proxy Providers map[string]providerTypes.ProxyProvider RuleProviders map[string]*providerTypes.RuleProvider + Sniffer *Sniffer } type RawDNS struct { @@ -199,6 +206,7 @@ type RawConfig struct { GeodataMode bool `yaml:"geodata-mode"` GeodataLoader string `yaml:"geodata-loader"` + Sniffer SnifferRaw `yaml:"sniffer"` ProxyProvider map[string]map[string]any `yaml:"proxy-providers"` RuleProvider map[string]map[string]any `yaml:"rule-providers"` Hosts map[string]string `yaml:"hosts"` @@ -212,6 +220,12 @@ type RawConfig struct { Rule []string `yaml:"rules"` } +type SnifferRaw struct { + Enable bool `yaml:"enable" json:"enable"` + Force bool `yaml:"force" json:"force"` + Sniffing []string `yaml:"sniffing" json:"sniffing"` +} + // Parse config func Parse(buf []byte) (*Config, error) { rawCfg, err := UnmarshalRawConfig(buf) @@ -277,6 +291,11 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { "www.msftconnecttest.com", }, }, + Sniffer: SnifferRaw{ + Enable: false, + Force: false, + Sniffing: []string{}, + }, Profile: Profile{ StoreSelected: true, }, @@ -339,6 +358,11 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { config.Users = parseAuthentication(rawCfg.Authentication) + config.Sniffer, err = parseSniffer(rawCfg.Sniffer) + if err != nil { + return nil, err + } + elapsedTime := time.Since(startTime) / time.Millisecond // duration in ms log.Infoln("Initial configuration complete, total time: %dms", elapsedTime) //Segment finished in xxm return config, nil @@ -882,3 +906,32 @@ func parseTun(rawTun RawTun, general *General) (*Tun, error) { AutoRoute: rawTun.AutoRoute, }, nil } + +func parseSniffer(snifferRaw SnifferRaw) (*Sniffer, error) { + sniffer := &Sniffer{ + Enable: snifferRaw.Enable, + Force: snifferRaw.Force, + } + + loadSniffer := make(map[C.SnifferType]struct{}) + + for _, snifferName := range snifferRaw.Sniffing { + find := false + for _, snifferType := range C.SnifferList { + if snifferType.String() == strings.ToUpper(snifferName) { + find = true + loadSniffer[snifferType] = struct{}{} + } + } + + if !find { + return nil, fmt.Errorf("not find the sniffer[%s]", snifferName) + } + } + + for st := range loadSniffer { + sniffer.Sniffers = append(sniffer.Sniffers, st) + } + + return sniffer, nil +} diff --git a/constant/sniffer.go b/constant/sniffer.go new file mode 100644 index 00000000..5454279a --- /dev/null +++ b/constant/sniffer.go @@ -0,0 +1,26 @@ +package constant + +type Sniffer interface { + SupportNetwork() NetWork + SniffTCP(bytes []byte) (string, error) + Protocol() string +} + +const ( + TLS SnifferType = iota +) + +var ( + SnifferList = []SnifferType{TLS} +) + +type SnifferType int + +func (rt SnifferType) String() string { + switch rt { + case TLS: + return "TLS" + default: + return "Unknown" + } +} diff --git a/context/conn.go b/context/conn.go index ee0f3a9d..8ecbf56b 100644 --- a/context/conn.go +++ b/context/conn.go @@ -3,8 +3,8 @@ package context import ( "net" + CN "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" - "github.com/gofrs/uuid" ) @@ -19,7 +19,7 @@ func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext { return &ConnContext{ id: id, metadata: metadata, - conn: conn, + conn: CN.NewBufferedConn(conn), } } diff --git a/hub/executor/executor.go b/hub/executor/executor.go index bcdafb89..f4b57347 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -17,6 +17,7 @@ import ( "github.com/Dreamacro/clash/component/profile" "github.com/Dreamacro/clash/component/profile/cachefile" "github.com/Dreamacro/clash/component/resolver" + SNI "github.com/Dreamacro/clash/component/sniffer" "github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/config" C "github.com/Dreamacro/clash/constant" @@ -78,6 +79,7 @@ func ApplyConfig(cfg *config.Config, force bool) { updateDNS(cfg.DNS, cfg.Tun) updateGeneral(cfg.General, force) updateIPTables(cfg) + updateSniffer(cfg.Sniffer) updateTun(cfg.Tun, cfg.DNS) updateExperimental(cfg) updateHosts(cfg.Hosts) @@ -87,6 +89,16 @@ func ApplyConfig(cfg *config.Config, force bool) { log.SetLevel(cfg.General.LogLevel) } +func updateSniffer(sniffer *config.Sniffer) { + if sniffer.Enable { + var err error + SNI.Dispatcher, err = SNI.NewSnifferDispatcher(sniffer.Sniffers, sniffer.Force) + if err != nil { + log.Errorln("Init Sniffer failed, err:%v", err) + } + } +} + func GetGeneral() *config.General { ports := P.GetPorts() var authenticator []string diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index db018c05..1f5f1f9c 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -1,14 +1,10 @@ package statistic import ( - "errors" "net" "time" - "github.com/Dreamacro/clash/common/snifer/tls" - "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/log" "github.com/gofrs/uuid" "go.uber.org/atomic" @@ -52,20 +48,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) tt.manager.PushUploaded(upload) - if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443" || tt.Metadata.DstPort == "993" || tt.Metadata.DstPort == "465" || tt.Metadata.DstPort == "995") { - header, err := tls.SniffTLS(b) - if err != nil { - // log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort) - } else { - resolver.InsertHostByIP(tt.Metadata.DstIP, header.Domain()) - log.Warnln("use sni update host: %s ip: %s", header.Domain(), tt.Metadata.DstIP.String()) - tt.manager.Leave(tt) - tt.Conn.Close() - return n, errors.New("sni update, break current link to avoid leaks") - } - } tt.UploadTotal.Add(upload) - return n, err } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 0a0a7ce5..c3623957 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -14,6 +14,7 @@ import ( "github.com/Dreamacro/clash/component/nat" P "github.com/Dreamacro/clash/component/process" "github.com/Dreamacro/clash/component/resolver" + "github.com/Dreamacro/clash/component/sniffer" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/constant/provider" icontext "github.com/Dreamacro/clash/context" @@ -36,6 +37,8 @@ var ( // default timeout for UDP session udpTimeout = 60 * time.Second + + snifferDispatcher *sniffer.SnifferDispatcher ) func init() { @@ -294,6 +297,10 @@ func handleTCPConn(connCtx C.ConnContext) { return } + if sniffer.Dispatcher.Enable() { + sniffer.Dispatcher.Tcp(connCtx.Conn(), metadata) + } + proxy, rule, err := resolveMetadata(connCtx, metadata) if err != nil { log.Warnln("[Metadata] parse failed: %s", err.Error())