From 86ad74a0aeab89a1f3b86e11e4840b67cd544447 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Mon, 23 Jan 2023 13:16:25 +0800 Subject: [PATCH] feat: better config for sniffer --- component/sniffer/base_sniffer.go | 52 ++++++++++++ component/sniffer/dispatcher.go | 27 +++---- component/sniffer/http_sniffer.go | 17 ++++ component/sniffer/tls_sniffer.go | 17 ++++ config/config.go | 127 ++++++++++++++++++------------ constant/sniffer/sniffer.go | 1 + hub/executor/executor.go | 2 +- 7 files changed, 179 insertions(+), 64 deletions(-) create mode 100644 component/sniffer/base_sniffer.go diff --git a/component/sniffer/base_sniffer.go b/component/sniffer/base_sniffer.go new file mode 100644 index 00000000..6d076b59 --- /dev/null +++ b/component/sniffer/base_sniffer.go @@ -0,0 +1,52 @@ +package sniffer + +import ( + "errors" + + "github.com/Dreamacro/clash/common/utils" + "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/constant/sniffer" +) + +type SnifferConfig struct { + Ports []utils.Range[uint16] +} + +type BaseSniffer struct { + ports []utils.Range[uint16] + supportNetworkType constant.NetWork +} + +// Protocol implements sniffer.Sniffer +func (*BaseSniffer) Protocol() string { + return "unknown" +} + +// SniffTCP implements sniffer.Sniffer +func (*BaseSniffer) SniffTCP(bytes []byte) (string, error) { + return "", errors.New("TODO") +} + +// SupportNetwork implements sniffer.Sniffer +func (bs *BaseSniffer) SupportNetwork() constant.NetWork { + return bs.supportNetworkType +} + +// SupportPort implements sniffer.Sniffer +func (bs *BaseSniffer) SupportPort(port uint16) bool { + for _, portRange := range bs.ports { + if portRange.Contains(port) { + return true + } + } + return false +} + +func NewBaseSniffer(ports []utils.Range[uint16], networkType constant.NetWork) *BaseSniffer { + return &BaseSniffer{ + ports: ports, + supportNetworkType: networkType, + } +} + +var _ sniffer.Sniffer = (*BaseSniffer)(nil) diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index 759e1043..a450693b 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -11,7 +11,6 @@ import ( "github.com/Dreamacro/clash/common/cache" N "github.com/Dreamacro/clash/common/net" - "github.com/Dreamacro/clash/common/utils" "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/constant/sniffer" @@ -33,7 +32,6 @@ type SnifferDispatcher struct { forceDomain *trie.DomainTrie[struct{}] skipSNI *trie.DomainTrie[struct{}] - portRanges *[]utils.Range[uint16] skipList *cache.LruCache[string, uint8] rwMux sync.RWMutex @@ -55,10 +53,12 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) { } inWhitelist := false - for _, portRange := range *sd.portRanges { - if portRange.Contains(uint16(port)) { - inWhitelist = true - break + for _, sniffer := range sd.sniffers { + if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet { + inWhitelist = sniffer.SupportPort(uint16(port)) + if inWhitelist { + break + } } } @@ -182,21 +182,20 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) { return &dispatcher, nil } -func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[struct{}], - skipSNI *trie.DomainTrie[struct{}], ports *[]utils.Range[uint16], +func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, forceDomain *trie.DomainTrie[struct{}], + skipSNI *trie.DomainTrie[struct{}], forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) { dispatcher := SnifferDispatcher{ enable: true, forceDomain: forceDomain, skipSNI: skipSNI, - portRanges: ports, skipList: cache.New[string, uint8](cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)), forceDnsMapping: forceDnsMapping, parsePureIp: parsePureIp, } - for _, snifferName := range needSniffer { - s, err := NewSniffer(snifferName) + for snifferName, config := range snifferConfig { + s, err := NewSniffer(snifferName, config) if err != nil { log.Errorln("Sniffer name[%s] is error", snifferName) return &SnifferDispatcher{enable: false}, err @@ -208,12 +207,12 @@ func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTr return &dispatcher, nil } -func NewSniffer(name sniffer.Type) (sniffer.Sniffer, error) { +func NewSniffer(name sniffer.Type, snifferConfig SnifferConfig) (sniffer.Sniffer, error) { switch name { case sniffer.TLS: - return &TLSSniffer{}, nil + return NewTLSSniffer(snifferConfig) case sniffer.HTTP: - return &HTTPSniffer{}, nil + return NewHTTPSniffer(snifferConfig) default: return nil, ErrorUnsupportedSniffer } diff --git a/component/sniffer/http_sniffer.go b/component/sniffer/http_sniffer.go index 551b20c8..bfa7ca6e 100644 --- a/component/sniffer/http_sniffer.go +++ b/component/sniffer/http_sniffer.go @@ -7,7 +7,9 @@ import ( "net" "strings" + "github.com/Dreamacro/clash/common/utils" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/constant/sniffer" ) var ( @@ -24,10 +26,25 @@ const ( ) type HTTPSniffer struct { + *BaseSniffer version version host string } +var _ sniffer.Sniffer = (*HTTPSniffer)(nil) + +func NewHTTPSniffer(snifferConfig SnifferConfig) (*HTTPSniffer, error) { + ports := make([]utils.Range[uint16], 0) + if len(snifferConfig.Ports) == 0 { + ports = append(ports, *utils.NewRange[uint16](80, 80)) + } else { + ports = append(ports, snifferConfig.Ports...) + } + return &HTTPSniffer{ + BaseSniffer: NewBaseSniffer(ports, C.TCP), + }, nil +} + func (http *HTTPSniffer) Protocol() string { switch http.version { case HTTP1: diff --git a/component/sniffer/tls_sniffer.go b/component/sniffer/tls_sniffer.go index f5a8fd99..0867d0f0 100644 --- a/component/sniffer/tls_sniffer.go +++ b/component/sniffer/tls_sniffer.go @@ -5,7 +5,9 @@ import ( "errors" "strings" + "github.com/Dreamacro/clash/common/utils" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/constant/sniffer" ) var ( @@ -13,7 +15,22 @@ var ( errNotClientHello = errors.New("not client hello") ) +var _ sniffer.Sniffer = (*TLSSniffer)(nil) + type TLSSniffer struct { + *BaseSniffer +} + +func NewTLSSniffer(snifferConfig SnifferConfig) (*TLSSniffer, error) { + ports := make([]utils.Range[uint16], 0) + if len(snifferConfig.Ports) == 0 { + ports = append(ports, *utils.NewRange[uint16](443, 443)) + } else { + ports = append(ports, snifferConfig.Ports...) + } + return &TLSSniffer{ + BaseSniffer: NewBaseSniffer(ports, C.TCP), + }, nil } func (tls *TLSSniffer) Protocol() string { diff --git a/config/config.go b/config/config.go index c5ba80fc..f4a991a9 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,7 @@ import ( "time" P "github.com/Dreamacro/clash/component/process" + SNIFF "github.com/Dreamacro/clash/component/sniffer" "github.com/Dreamacro/clash/adapter" "github.com/Dreamacro/clash/adapter/outbound" @@ -134,11 +135,10 @@ type IPTables struct { type Sniffer struct { Enable bool - Sniffers []snifferTypes.Type + Sniffers map[snifferTypes.Type]SNIFF.SnifferConfig Reverses *trie.DomainTrie[struct{}] ForceDomain *trie.DomainTrie[struct{}] SkipDomain *trie.DomainTrie[struct{}] - Ports *[]utils.Range[uint16] ForceDnsMapping bool ParsePureIp bool } @@ -287,13 +287,18 @@ type RawGeoXUrl struct { } type RawSniffer struct { - Enable bool `yaml:"enable" json:"enable"` - Sniffing []string `yaml:"sniffing" json:"sniffing"` - ForceDomain []string `yaml:"force-domain" json:"force-domain"` - SkipDomain []string `yaml:"skip-domain" json:"skip-domain"` - Ports []string `yaml:"port-whitelist" json:"port-whitelist"` - ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"` - ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"` + Enable bool `yaml:"enable" json:"enable"` + Sniffing []string `yaml:"sniffing" json:"sniffing"` + ForceDomain []string `yaml:"force-domain" json:"force-domain"` + SkipDomain []string `yaml:"skip-domain" json:"skip-domain"` + Ports []string `yaml:"port-whitelist" json:"port-whitelist"` + ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"` + ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"` + Sniff map[string]RawSniffingConfig `yaml:"sniff" json:"sniff"` +} + +type RawSniffingConfig struct { + Ports []string `yaml:"ports" json:"ports"` } // EBpf config @@ -1187,55 +1192,54 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { ForceDnsMapping: snifferRaw.ForceDnsMapping, ParsePureIp: snifferRaw.ParsePureIp, } + loadSniffer := make(map[snifferTypes.Type]SNIFF.SnifferConfig) - var ports []utils.Range[uint16] - if len(snifferRaw.Ports) == 0 { - ports = append(ports, *utils.NewRange[uint16](80, 80)) - ports = append(ports, *utils.NewRange[uint16](443, 443)) - } else { - for _, portRange := range snifferRaw.Ports { - portRaws := strings.Split(portRange, "-") - p, err := strconv.ParseUint(portRaws[0], 10, 16) + if len(snifferRaw.Sniff) != 0 { + for sniffType, sniffConfig := range snifferRaw.Sniff { + find := false + ports, err := parsePortRange(sniffConfig.Ports) if err != nil { - return nil, fmt.Errorf("%s format error", portRange) + return nil, err } - - start := uint16(p) - if len(portRaws) > 1 { - p, err = strconv.ParseUint(portRaws[1], 10, 16) - if err != nil { - return nil, fmt.Errorf("%s format error", portRange) + for _, snifferType := range snifferTypes.List { + if snifferType.String() == strings.ToUpper(sniffType) { + find = true + loadSniffer[snifferType] = SNIFF.SnifferConfig{ + Ports: ports, + } } + } - end := uint16(p) - ports = append(ports, *utils.NewRange(start, end)) - } else { - ports = append(ports, *utils.NewRange(start, start)) + if !find { + return nil, fmt.Errorf("not find the sniffer[%s]", sniffType) + } + } + } else { + // Deprecated: Use Sniff instead + log.Warnln("Deprecated: Use Sniff instead") + globalPorts, err := parsePortRange(snifferRaw.Ports) + if err != nil { + return nil, err + } + + for _, snifferName := range snifferRaw.Sniffing { + find := false + for _, snifferType := range snifferTypes.List { + if snifferType.String() == strings.ToUpper(snifferName) { + find = true + loadSniffer[snifferType] = SNIFF.SnifferConfig{ + Ports: globalPorts, + } + } + } + + if !find { + return nil, fmt.Errorf("not find the sniffer[%s]", snifferName) } } } - sniffer.Ports = &ports - - loadSniffer := make(map[snifferTypes.Type]struct{}) - - for _, snifferName := range snifferRaw.Sniffing { - find := false - for _, snifferType := range snifferTypes.List { - if snifferType.String() == strings.ToUpper(snifferName) { - find = true - loadSniffer[snifferType] = struct{}{} - } - } - - if !find { - return nil, fmt.Errorf("not find the sniffer[%s]", snifferName) - } - } - - for st := range loadSniffer { - sniffer.Sniffers = append(sniffer.Sniffers, st) - } + sniffer.Sniffers = loadSniffer sniffer.ForceDomain = trie.New[struct{}]() for _, domain := range snifferRaw.ForceDomain { err := sniffer.ForceDomain.Insert(domain, struct{}{}) @@ -1256,3 +1260,28 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) { return sniffer, nil } + +func parsePortRange(portRanges []string) ([]utils.Range[uint16], error) { + ports := make([]utils.Range[uint16], 0) + for _, portRange := range portRanges { + portRaws := strings.Split(portRange, "-") + p, err := strconv.ParseUint(portRaws[0], 10, 16) + if err != nil { + return nil, fmt.Errorf("%s format error", portRange) + } + + start := uint16(p) + if len(portRaws) > 1 { + p, err = strconv.ParseUint(portRaws[1], 10, 16) + if err != nil { + return nil, fmt.Errorf("%s format error", portRange) + } + + end := uint16(p) + ports = append(ports, *utils.NewRange(start, end)) + } else { + ports = append(ports, *utils.NewRange(start, start)) + } + } + return ports, nil +} diff --git a/constant/sniffer/sniffer.go b/constant/sniffer/sniffer.go index 8ab2d496..6b20b3f6 100644 --- a/constant/sniffer/sniffer.go +++ b/constant/sniffer/sniffer.go @@ -6,6 +6,7 @@ type Sniffer interface { SupportNetwork() constant.NetWork SniffTCP(bytes []byte) (string, error) Protocol() string + SupportPort(port uint16) bool } const ( diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 9082f856..d88e91dc 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -279,7 +279,7 @@ func updateTun(general *config.General) { func updateSniffer(sniffer *config.Sniffer) { if sniffer.Enable { dispatcher, err := SNI.NewSnifferDispatcher( - sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain, sniffer.Ports, + sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain, sniffer.ForceDnsMapping, sniffer.ParsePureIp, ) if err != nil {