feat: better config for sniffer

This commit is contained in:
Skyxim 2023-01-23 13:16:25 +08:00
parent d1f5bef25d
commit df1f6e2b99
7 changed files with 179 additions and 64 deletions

View file

@ -0,0 +1,52 @@
package sniffer
import (
"errors"
"github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer"
)
type SnifferConfig struct {
Ports []utils.Range[uint16]
}
type BaseSniffer struct {
ports []utils.Range[uint16]
supportNetworkType constant.NetWork
}
// Protocol implements sniffer.Sniffer
func (*BaseSniffer) Protocol() string {
return "unknown"
}
// SniffTCP implements sniffer.Sniffer
func (*BaseSniffer) SniffTCP(bytes []byte) (string, error) {
return "", errors.New("TODO")
}
// SupportNetwork implements sniffer.Sniffer
func (bs *BaseSniffer) SupportNetwork() constant.NetWork {
return bs.supportNetworkType
}
// SupportPort implements sniffer.Sniffer
func (bs *BaseSniffer) SupportPort(port uint16) bool {
for _, portRange := range bs.ports {
if portRange.Contains(port) {
return true
}
}
return false
}
func NewBaseSniffer(ports []utils.Range[uint16], networkType constant.NetWork) *BaseSniffer {
return &BaseSniffer{
ports: ports,
supportNetworkType: networkType,
}
}
var _ sniffer.Sniffer = (*BaseSniffer)(nil)

View file

@ -11,7 +11,6 @@ import (
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer" "github.com/Dreamacro/clash/constant/sniffer"
@ -33,7 +32,6 @@ type SnifferDispatcher struct {
forceDomain *trie.DomainTrie[struct{}] forceDomain *trie.DomainTrie[struct{}]
skipSNI *trie.DomainTrie[struct{}] skipSNI *trie.DomainTrie[struct{}]
portRanges *[]utils.Range[uint16]
skipList *cache.LruCache[string, uint8] skipList *cache.LruCache[string, uint8]
rwMux sync.RWMutex rwMux sync.RWMutex
@ -55,12 +53,14 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
} }
inWhitelist := false inWhitelist := false
for _, portRange := range *sd.portRanges { for _, sniffer := range sd.sniffers {
if portRange.Contains(uint16(port)) { if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet {
inWhitelist = true inWhitelist = sniffer.SupportPort(uint16(port))
if inWhitelist {
break break
} }
} }
}
if !inWhitelist { if !inWhitelist {
return return
@ -182,21 +182,20 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
return &dispatcher, nil return &dispatcher, nil
} }
func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[struct{}], func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, forceDomain *trie.DomainTrie[struct{}],
skipSNI *trie.DomainTrie[struct{}], ports *[]utils.Range[uint16], skipSNI *trie.DomainTrie[struct{}],
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) { forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{ dispatcher := SnifferDispatcher{
enable: true, enable: true,
forceDomain: forceDomain, forceDomain: forceDomain,
skipSNI: skipSNI, skipSNI: skipSNI,
portRanges: ports,
skipList: cache.New[string, uint8](cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)), skipList: cache.New[string, uint8](cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)),
forceDnsMapping: forceDnsMapping, forceDnsMapping: forceDnsMapping,
parsePureIp: parsePureIp, parsePureIp: parsePureIp,
} }
for _, snifferName := range needSniffer { for snifferName, config := range snifferConfig {
s, err := NewSniffer(snifferName) s, err := NewSniffer(snifferName, config)
if err != nil { if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName) log.Errorln("Sniffer name[%s] is error", snifferName)
return &SnifferDispatcher{enable: false}, err return &SnifferDispatcher{enable: false}, err
@ -208,12 +207,12 @@ func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTr
return &dispatcher, nil return &dispatcher, nil
} }
func NewSniffer(name sniffer.Type) (sniffer.Sniffer, error) { func NewSniffer(name sniffer.Type, snifferConfig SnifferConfig) (sniffer.Sniffer, error) {
switch name { switch name {
case sniffer.TLS: case sniffer.TLS:
return &TLSSniffer{}, nil return NewTLSSniffer(snifferConfig)
case sniffer.HTTP: case sniffer.HTTP:
return &HTTPSniffer{}, nil return NewHTTPSniffer(snifferConfig)
default: default:
return nil, ErrorUnsupportedSniffer return nil, ErrorUnsupportedSniffer
} }

View file

@ -7,7 +7,9 @@ import (
"net" "net"
"strings" "strings"
"github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer"
) )
var ( var (
@ -24,10 +26,25 @@ const (
) )
type HTTPSniffer struct { type HTTPSniffer struct {
*BaseSniffer
version version version version
host string host string
} }
var _ sniffer.Sniffer = (*HTTPSniffer)(nil)
func NewHTTPSniffer(snifferConfig SnifferConfig) (*HTTPSniffer, error) {
ports := make([]utils.Range[uint16], 0)
if len(snifferConfig.Ports) == 0 {
ports = append(ports, *utils.NewRange[uint16](80, 80))
} else {
ports = append(ports, snifferConfig.Ports...)
}
return &HTTPSniffer{
BaseSniffer: NewBaseSniffer(ports, C.TCP),
}, nil
}
func (http *HTTPSniffer) Protocol() string { func (http *HTTPSniffer) Protocol() string {
switch http.version { switch http.version {
case HTTP1: case HTTP1:

View file

@ -5,7 +5,9 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer"
) )
var ( var (
@ -13,7 +15,22 @@ var (
errNotClientHello = errors.New("not client hello") errNotClientHello = errors.New("not client hello")
) )
var _ sniffer.Sniffer = (*TLSSniffer)(nil)
type TLSSniffer struct { type TLSSniffer struct {
*BaseSniffer
}
func NewTLSSniffer(snifferConfig SnifferConfig) (*TLSSniffer, error) {
ports := make([]utils.Range[uint16], 0)
if len(snifferConfig.Ports) == 0 {
ports = append(ports, *utils.NewRange[uint16](443, 443))
} else {
ports = append(ports, snifferConfig.Ports...)
}
return &TLSSniffer{
BaseSniffer: NewBaseSniffer(ports, C.TCP),
}, nil
} }
func (tls *TLSSniffer) Protocol() string { func (tls *TLSSniffer) Protocol() string {

View file

@ -14,6 +14,7 @@ import (
"time" "time"
P "github.com/Dreamacro/clash/component/process" P "github.com/Dreamacro/clash/component/process"
SNIFF "github.com/Dreamacro/clash/component/sniffer"
"github.com/Dreamacro/clash/adapter" "github.com/Dreamacro/clash/adapter"
"github.com/Dreamacro/clash/adapter/outbound" "github.com/Dreamacro/clash/adapter/outbound"
@ -134,11 +135,10 @@ type IPTables struct {
type Sniffer struct { type Sniffer struct {
Enable bool Enable bool
Sniffers []snifferTypes.Type Sniffers map[snifferTypes.Type]SNIFF.SnifferConfig
Reverses *trie.DomainTrie[struct{}] Reverses *trie.DomainTrie[struct{}]
ForceDomain *trie.DomainTrie[struct{}] ForceDomain *trie.DomainTrie[struct{}]
SkipDomain *trie.DomainTrie[struct{}] SkipDomain *trie.DomainTrie[struct{}]
Ports *[]utils.Range[uint16]
ForceDnsMapping bool ForceDnsMapping bool
ParsePureIp bool ParsePureIp bool
} }
@ -294,6 +294,11 @@ type RawSniffer struct {
Ports []string `yaml:"port-whitelist" json:"port-whitelist"` Ports []string `yaml:"port-whitelist" json:"port-whitelist"`
ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"` ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"`
ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"` ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"`
Sniff map[string]RawSniffingConfig `yaml:"sniff" json:"sniff"`
}
type RawSniffingConfig struct {
Ports []string `yaml:"ports" json:"ports"`
} }
// EBpf config // EBpf config
@ -1187,44 +1192,44 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
ForceDnsMapping: snifferRaw.ForceDnsMapping, ForceDnsMapping: snifferRaw.ForceDnsMapping,
ParsePureIp: snifferRaw.ParsePureIp, ParsePureIp: snifferRaw.ParsePureIp,
} }
loadSniffer := make(map[snifferTypes.Type]SNIFF.SnifferConfig)
var ports []utils.Range[uint16] if len(snifferRaw.Sniff) != 0 {
if len(snifferRaw.Ports) == 0 { for sniffType, sniffConfig := range snifferRaw.Sniff {
ports = append(ports, *utils.NewRange[uint16](80, 80)) find := false
ports = append(ports, *utils.NewRange[uint16](443, 443)) ports, err := parsePortRange(sniffConfig.Ports)
} else {
for _, portRange := range snifferRaw.Ports {
portRaws := strings.Split(portRange, "-")
p, err := strconv.ParseUint(portRaws[0], 10, 16)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s format error", portRange) return nil, err
}
for _, snifferType := range snifferTypes.List {
if snifferType.String() == strings.ToUpper(sniffType) {
find = true
loadSniffer[snifferType] = SNIFF.SnifferConfig{
Ports: ports,
}
}
} }
start := uint16(p) if !find {
if len(portRaws) > 1 { return nil, fmt.Errorf("not find the sniffer[%s]", sniffType)
p, err = strconv.ParseUint(portRaws[1], 10, 16) }
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
} }
end := uint16(p)
ports = append(ports, *utils.NewRange(start, end))
} else { } else {
ports = append(ports, *utils.NewRange(start, start)) // Deprecated: Use Sniff instead
log.Warnln("Deprecated: Use Sniff instead")
globalPorts, err := parsePortRange(snifferRaw.Ports)
if err != nil {
return nil, err
} }
}
}
sniffer.Ports = &ports
loadSniffer := make(map[snifferTypes.Type]struct{})
for _, snifferName := range snifferRaw.Sniffing { for _, snifferName := range snifferRaw.Sniffing {
find := false find := false
for _, snifferType := range snifferTypes.List { for _, snifferType := range snifferTypes.List {
if snifferType.String() == strings.ToUpper(snifferName) { if snifferType.String() == strings.ToUpper(snifferName) {
find = true find = true
loadSniffer[snifferType] = struct{}{} loadSniffer[snifferType] = SNIFF.SnifferConfig{
Ports: globalPorts,
}
} }
} }
@ -1232,10 +1237,9 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
return nil, fmt.Errorf("not find the sniffer[%s]", snifferName) return nil, fmt.Errorf("not find the sniffer[%s]", snifferName)
} }
} }
for st := range loadSniffer {
sniffer.Sniffers = append(sniffer.Sniffers, st)
} }
sniffer.Sniffers = loadSniffer
sniffer.ForceDomain = trie.New[struct{}]() sniffer.ForceDomain = trie.New[struct{}]()
for _, domain := range snifferRaw.ForceDomain { for _, domain := range snifferRaw.ForceDomain {
err := sniffer.ForceDomain.Insert(domain, struct{}{}) err := sniffer.ForceDomain.Insert(domain, struct{}{})
@ -1256,3 +1260,28 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
return sniffer, nil return sniffer, nil
} }
func parsePortRange(portRanges []string) ([]utils.Range[uint16], error) {
ports := make([]utils.Range[uint16], 0)
for _, portRange := range portRanges {
portRaws := strings.Split(portRange, "-")
p, err := strconv.ParseUint(portRaws[0], 10, 16)
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
}
start := uint16(p)
if len(portRaws) > 1 {
p, err = strconv.ParseUint(portRaws[1], 10, 16)
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
}
end := uint16(p)
ports = append(ports, *utils.NewRange(start, end))
} else {
ports = append(ports, *utils.NewRange(start, start))
}
}
return ports, nil
}

View file

@ -6,6 +6,7 @@ type Sniffer interface {
SupportNetwork() constant.NetWork SupportNetwork() constant.NetWork
SniffTCP(bytes []byte) (string, error) SniffTCP(bytes []byte) (string, error)
Protocol() string Protocol() string
SupportPort(port uint16) bool
} }
const ( const (

View file

@ -279,7 +279,7 @@ func updateTun(general *config.General) {
func updateSniffer(sniffer *config.Sniffer) { func updateSniffer(sniffer *config.Sniffer) {
if sniffer.Enable { if sniffer.Enable {
dispatcher, err := SNI.NewSnifferDispatcher( dispatcher, err := SNI.NewSnifferDispatcher(
sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain, sniffer.Ports, sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain,
sniffer.ForceDnsMapping, sniffer.ParsePureIp, sniffer.ForceDnsMapping, sniffer.ParsePureIp,
) )
if err != nil { if err != nil {