feat: better config for sniffer

This commit is contained in:
gVisor bot 2023-01-23 13:16:25 +08:00
parent 643979800c
commit 86ad74a0ae
7 changed files with 179 additions and 64 deletions

View file

@ -0,0 +1,52 @@
package sniffer
import (
"errors"
"github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer"
)
type SnifferConfig struct {
Ports []utils.Range[uint16]
}
type BaseSniffer struct {
ports []utils.Range[uint16]
supportNetworkType constant.NetWork
}
// Protocol implements sniffer.Sniffer
func (*BaseSniffer) Protocol() string {
return "unknown"
}
// SniffTCP implements sniffer.Sniffer
func (*BaseSniffer) SniffTCP(bytes []byte) (string, error) {
return "", errors.New("TODO")
}
// SupportNetwork implements sniffer.Sniffer
func (bs *BaseSniffer) SupportNetwork() constant.NetWork {
return bs.supportNetworkType
}
// SupportPort implements sniffer.Sniffer
func (bs *BaseSniffer) SupportPort(port uint16) bool {
for _, portRange := range bs.ports {
if portRange.Contains(port) {
return true
}
}
return false
}
func NewBaseSniffer(ports []utils.Range[uint16], networkType constant.NetWork) *BaseSniffer {
return &BaseSniffer{
ports: ports,
supportNetworkType: networkType,
}
}
var _ sniffer.Sniffer = (*BaseSniffer)(nil)

View file

@ -11,7 +11,6 @@ import (
"github.com/Dreamacro/clash/common/cache"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer"
@ -33,7 +32,6 @@ type SnifferDispatcher struct {
forceDomain *trie.DomainTrie[struct{}]
skipSNI *trie.DomainTrie[struct{}]
portRanges *[]utils.Range[uint16]
skipList *cache.LruCache[string, uint8]
rwMux sync.RWMutex
@ -55,10 +53,12 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
}
inWhitelist := false
for _, portRange := range *sd.portRanges {
if portRange.Contains(uint16(port)) {
inWhitelist = true
break
for _, sniffer := range sd.sniffers {
if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet {
inWhitelist = sniffer.SupportPort(uint16(port))
if inWhitelist {
break
}
}
}
@ -182,21 +182,20 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
return &dispatcher, nil
}
func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[struct{}],
skipSNI *trie.DomainTrie[struct{}], ports *[]utils.Range[uint16],
func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, forceDomain *trie.DomainTrie[struct{}],
skipSNI *trie.DomainTrie[struct{}],
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: true,
forceDomain: forceDomain,
skipSNI: skipSNI,
portRanges: ports,
skipList: cache.New[string, uint8](cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)),
forceDnsMapping: forceDnsMapping,
parsePureIp: parsePureIp,
}
for _, snifferName := range needSniffer {
s, err := NewSniffer(snifferName)
for snifferName, config := range snifferConfig {
s, err := NewSniffer(snifferName, config)
if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName)
return &SnifferDispatcher{enable: false}, err
@ -208,12 +207,12 @@ func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTr
return &dispatcher, nil
}
func NewSniffer(name sniffer.Type) (sniffer.Sniffer, error) {
func NewSniffer(name sniffer.Type, snifferConfig SnifferConfig) (sniffer.Sniffer, error) {
switch name {
case sniffer.TLS:
return &TLSSniffer{}, nil
return NewTLSSniffer(snifferConfig)
case sniffer.HTTP:
return &HTTPSniffer{}, nil
return NewHTTPSniffer(snifferConfig)
default:
return nil, ErrorUnsupportedSniffer
}

View file

@ -7,7 +7,9 @@ import (
"net"
"strings"
"github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer"
)
var (
@ -24,10 +26,25 @@ const (
)
type HTTPSniffer struct {
*BaseSniffer
version version
host string
}
var _ sniffer.Sniffer = (*HTTPSniffer)(nil)
func NewHTTPSniffer(snifferConfig SnifferConfig) (*HTTPSniffer, error) {
ports := make([]utils.Range[uint16], 0)
if len(snifferConfig.Ports) == 0 {
ports = append(ports, *utils.NewRange[uint16](80, 80))
} else {
ports = append(ports, snifferConfig.Ports...)
}
return &HTTPSniffer{
BaseSniffer: NewBaseSniffer(ports, C.TCP),
}, nil
}
func (http *HTTPSniffer) Protocol() string {
switch http.version {
case HTTP1:

View file

@ -5,7 +5,9 @@ import (
"errors"
"strings"
"github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer"
)
var (
@ -13,7 +15,22 @@ var (
errNotClientHello = errors.New("not client hello")
)
var _ sniffer.Sniffer = (*TLSSniffer)(nil)
type TLSSniffer struct {
*BaseSniffer
}
func NewTLSSniffer(snifferConfig SnifferConfig) (*TLSSniffer, error) {
ports := make([]utils.Range[uint16], 0)
if len(snifferConfig.Ports) == 0 {
ports = append(ports, *utils.NewRange[uint16](443, 443))
} else {
ports = append(ports, snifferConfig.Ports...)
}
return &TLSSniffer{
BaseSniffer: NewBaseSniffer(ports, C.TCP),
}, nil
}
func (tls *TLSSniffer) Protocol() string {

View file

@ -14,6 +14,7 @@ import (
"time"
P "github.com/Dreamacro/clash/component/process"
SNIFF "github.com/Dreamacro/clash/component/sniffer"
"github.com/Dreamacro/clash/adapter"
"github.com/Dreamacro/clash/adapter/outbound"
@ -134,11 +135,10 @@ type IPTables struct {
type Sniffer struct {
Enable bool
Sniffers []snifferTypes.Type
Sniffers map[snifferTypes.Type]SNIFF.SnifferConfig
Reverses *trie.DomainTrie[struct{}]
ForceDomain *trie.DomainTrie[struct{}]
SkipDomain *trie.DomainTrie[struct{}]
Ports *[]utils.Range[uint16]
ForceDnsMapping bool
ParsePureIp bool
}
@ -287,13 +287,18 @@ type RawGeoXUrl struct {
}
type RawSniffer struct {
Enable bool `yaml:"enable" json:"enable"`
Sniffing []string `yaml:"sniffing" json:"sniffing"`
ForceDomain []string `yaml:"force-domain" json:"force-domain"`
SkipDomain []string `yaml:"skip-domain" json:"skip-domain"`
Ports []string `yaml:"port-whitelist" json:"port-whitelist"`
ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"`
ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"`
Enable bool `yaml:"enable" json:"enable"`
Sniffing []string `yaml:"sniffing" json:"sniffing"`
ForceDomain []string `yaml:"force-domain" json:"force-domain"`
SkipDomain []string `yaml:"skip-domain" json:"skip-domain"`
Ports []string `yaml:"port-whitelist" json:"port-whitelist"`
ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"`
ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"`
Sniff map[string]RawSniffingConfig `yaml:"sniff" json:"sniff"`
}
type RawSniffingConfig struct {
Ports []string `yaml:"ports" json:"ports"`
}
// EBpf config
@ -1187,55 +1192,54 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
ForceDnsMapping: snifferRaw.ForceDnsMapping,
ParsePureIp: snifferRaw.ParsePureIp,
}
loadSniffer := make(map[snifferTypes.Type]SNIFF.SnifferConfig)
var ports []utils.Range[uint16]
if len(snifferRaw.Ports) == 0 {
ports = append(ports, *utils.NewRange[uint16](80, 80))
ports = append(ports, *utils.NewRange[uint16](443, 443))
} else {
for _, portRange := range snifferRaw.Ports {
portRaws := strings.Split(portRange, "-")
p, err := strconv.ParseUint(portRaws[0], 10, 16)
if len(snifferRaw.Sniff) != 0 {
for sniffType, sniffConfig := range snifferRaw.Sniff {
find := false
ports, err := parsePortRange(sniffConfig.Ports)
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
return nil, err
}
start := uint16(p)
if len(portRaws) > 1 {
p, err = strconv.ParseUint(portRaws[1], 10, 16)
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
for _, snifferType := range snifferTypes.List {
if snifferType.String() == strings.ToUpper(sniffType) {
find = true
loadSniffer[snifferType] = SNIFF.SnifferConfig{
Ports: ports,
}
}
}
end := uint16(p)
ports = append(ports, *utils.NewRange(start, end))
} else {
ports = append(ports, *utils.NewRange(start, start))
if !find {
return nil, fmt.Errorf("not find the sniffer[%s]", sniffType)
}
}
} else {
// Deprecated: Use Sniff instead
log.Warnln("Deprecated: Use Sniff instead")
globalPorts, err := parsePortRange(snifferRaw.Ports)
if err != nil {
return nil, err
}
for _, snifferName := range snifferRaw.Sniffing {
find := false
for _, snifferType := range snifferTypes.List {
if snifferType.String() == strings.ToUpper(snifferName) {
find = true
loadSniffer[snifferType] = SNIFF.SnifferConfig{
Ports: globalPorts,
}
}
}
if !find {
return nil, fmt.Errorf("not find the sniffer[%s]", snifferName)
}
}
}
sniffer.Ports = &ports
loadSniffer := make(map[snifferTypes.Type]struct{})
for _, snifferName := range snifferRaw.Sniffing {
find := false
for _, snifferType := range snifferTypes.List {
if snifferType.String() == strings.ToUpper(snifferName) {
find = true
loadSniffer[snifferType] = struct{}{}
}
}
if !find {
return nil, fmt.Errorf("not find the sniffer[%s]", snifferName)
}
}
for st := range loadSniffer {
sniffer.Sniffers = append(sniffer.Sniffers, st)
}
sniffer.Sniffers = loadSniffer
sniffer.ForceDomain = trie.New[struct{}]()
for _, domain := range snifferRaw.ForceDomain {
err := sniffer.ForceDomain.Insert(domain, struct{}{})
@ -1256,3 +1260,28 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
return sniffer, nil
}
func parsePortRange(portRanges []string) ([]utils.Range[uint16], error) {
ports := make([]utils.Range[uint16], 0)
for _, portRange := range portRanges {
portRaws := strings.Split(portRange, "-")
p, err := strconv.ParseUint(portRaws[0], 10, 16)
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
}
start := uint16(p)
if len(portRaws) > 1 {
p, err = strconv.ParseUint(portRaws[1], 10, 16)
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
}
end := uint16(p)
ports = append(ports, *utils.NewRange(start, end))
} else {
ports = append(ports, *utils.NewRange(start, start))
}
}
return ports, nil
}

View file

@ -6,6 +6,7 @@ type Sniffer interface {
SupportNetwork() constant.NetWork
SniffTCP(bytes []byte) (string, error)
Protocol() string
SupportPort(port uint16) bool
}
const (

View file

@ -279,7 +279,7 @@ func updateTun(general *config.General) {
func updateSniffer(sniffer *config.Sniffer) {
if sniffer.Enable {
dispatcher, err := SNI.NewSnifferDispatcher(
sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain, sniffer.Ports,
sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain,
sniffer.ForceDnsMapping, sniffer.ParsePureIp,
)
if err != nil {