feat: better config for sniffer
This commit is contained in:
parent
d1f5bef25d
commit
df1f6e2b99
7 changed files with 179 additions and 64 deletions
52
component/sniffer/base_sniffer.go
Normal file
52
component/sniffer/base_sniffer.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package sniffer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/Dreamacro/clash/common/utils"
|
||||
"github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/constant/sniffer"
|
||||
)
|
||||
|
||||
type SnifferConfig struct {
|
||||
Ports []utils.Range[uint16]
|
||||
}
|
||||
|
||||
type BaseSniffer struct {
|
||||
ports []utils.Range[uint16]
|
||||
supportNetworkType constant.NetWork
|
||||
}
|
||||
|
||||
// Protocol implements sniffer.Sniffer
|
||||
func (*BaseSniffer) Protocol() string {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// SniffTCP implements sniffer.Sniffer
|
||||
func (*BaseSniffer) SniffTCP(bytes []byte) (string, error) {
|
||||
return "", errors.New("TODO")
|
||||
}
|
||||
|
||||
// SupportNetwork implements sniffer.Sniffer
|
||||
func (bs *BaseSniffer) SupportNetwork() constant.NetWork {
|
||||
return bs.supportNetworkType
|
||||
}
|
||||
|
||||
// SupportPort implements sniffer.Sniffer
|
||||
func (bs *BaseSniffer) SupportPort(port uint16) bool {
|
||||
for _, portRange := range bs.ports {
|
||||
if portRange.Contains(port) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func NewBaseSniffer(ports []utils.Range[uint16], networkType constant.NetWork) *BaseSniffer {
|
||||
return &BaseSniffer{
|
||||
ports: ports,
|
||||
supportNetworkType: networkType,
|
||||
}
|
||||
}
|
||||
|
||||
var _ sniffer.Sniffer = (*BaseSniffer)(nil)
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"github.com/Dreamacro/clash/common/cache"
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/utils"
|
||||
"github.com/Dreamacro/clash/component/trie"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/constant/sniffer"
|
||||
|
@ -33,7 +32,6 @@ type SnifferDispatcher struct {
|
|||
|
||||
forceDomain *trie.DomainTrie[struct{}]
|
||||
skipSNI *trie.DomainTrie[struct{}]
|
||||
portRanges *[]utils.Range[uint16]
|
||||
skipList *cache.LruCache[string, uint8]
|
||||
rwMux sync.RWMutex
|
||||
|
||||
|
@ -55,12 +53,14 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
|
|||
}
|
||||
|
||||
inWhitelist := false
|
||||
for _, portRange := range *sd.portRanges {
|
||||
if portRange.Contains(uint16(port)) {
|
||||
inWhitelist = true
|
||||
for _, sniffer := range sd.sniffers {
|
||||
if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet {
|
||||
inWhitelist = sniffer.SupportPort(uint16(port))
|
||||
if inWhitelist {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !inWhitelist {
|
||||
return
|
||||
|
@ -182,21 +182,20 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
|
|||
return &dispatcher, nil
|
||||
}
|
||||
|
||||
func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[struct{}],
|
||||
skipSNI *trie.DomainTrie[struct{}], ports *[]utils.Range[uint16],
|
||||
func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, forceDomain *trie.DomainTrie[struct{}],
|
||||
skipSNI *trie.DomainTrie[struct{}],
|
||||
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
|
||||
dispatcher := SnifferDispatcher{
|
||||
enable: true,
|
||||
forceDomain: forceDomain,
|
||||
skipSNI: skipSNI,
|
||||
portRanges: ports,
|
||||
skipList: cache.New[string, uint8](cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)),
|
||||
forceDnsMapping: forceDnsMapping,
|
||||
parsePureIp: parsePureIp,
|
||||
}
|
||||
|
||||
for _, snifferName := range needSniffer {
|
||||
s, err := NewSniffer(snifferName)
|
||||
for snifferName, config := range snifferConfig {
|
||||
s, err := NewSniffer(snifferName, config)
|
||||
if err != nil {
|
||||
log.Errorln("Sniffer name[%s] is error", snifferName)
|
||||
return &SnifferDispatcher{enable: false}, err
|
||||
|
@ -208,12 +207,12 @@ func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTr
|
|||
return &dispatcher, nil
|
||||
}
|
||||
|
||||
func NewSniffer(name sniffer.Type) (sniffer.Sniffer, error) {
|
||||
func NewSniffer(name sniffer.Type, snifferConfig SnifferConfig) (sniffer.Sniffer, error) {
|
||||
switch name {
|
||||
case sniffer.TLS:
|
||||
return &TLSSniffer{}, nil
|
||||
return NewTLSSniffer(snifferConfig)
|
||||
case sniffer.HTTP:
|
||||
return &HTTPSniffer{}, nil
|
||||
return NewHTTPSniffer(snifferConfig)
|
||||
default:
|
||||
return nil, ErrorUnsupportedSniffer
|
||||
}
|
||||
|
|
|
@ -7,7 +7,9 @@ import (
|
|||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/Dreamacro/clash/common/utils"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/constant/sniffer"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -24,10 +26,25 @@ const (
|
|||
)
|
||||
|
||||
type HTTPSniffer struct {
|
||||
*BaseSniffer
|
||||
version version
|
||||
host string
|
||||
}
|
||||
|
||||
var _ sniffer.Sniffer = (*HTTPSniffer)(nil)
|
||||
|
||||
func NewHTTPSniffer(snifferConfig SnifferConfig) (*HTTPSniffer, error) {
|
||||
ports := make([]utils.Range[uint16], 0)
|
||||
if len(snifferConfig.Ports) == 0 {
|
||||
ports = append(ports, *utils.NewRange[uint16](80, 80))
|
||||
} else {
|
||||
ports = append(ports, snifferConfig.Ports...)
|
||||
}
|
||||
return &HTTPSniffer{
|
||||
BaseSniffer: NewBaseSniffer(ports, C.TCP),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (http *HTTPSniffer) Protocol() string {
|
||||
switch http.version {
|
||||
case HTTP1:
|
||||
|
|
|
@ -5,7 +5,9 @@ import (
|
|||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Dreamacro/clash/common/utils"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/constant/sniffer"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -13,7 +15,22 @@ var (
|
|||
errNotClientHello = errors.New("not client hello")
|
||||
)
|
||||
|
||||
var _ sniffer.Sniffer = (*TLSSniffer)(nil)
|
||||
|
||||
type TLSSniffer struct {
|
||||
*BaseSniffer
|
||||
}
|
||||
|
||||
func NewTLSSniffer(snifferConfig SnifferConfig) (*TLSSniffer, error) {
|
||||
ports := make([]utils.Range[uint16], 0)
|
||||
if len(snifferConfig.Ports) == 0 {
|
||||
ports = append(ports, *utils.NewRange[uint16](443, 443))
|
||||
} else {
|
||||
ports = append(ports, snifferConfig.Ports...)
|
||||
}
|
||||
return &TLSSniffer{
|
||||
BaseSniffer: NewBaseSniffer(ports, C.TCP),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tls *TLSSniffer) Protocol() string {
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
P "github.com/Dreamacro/clash/component/process"
|
||||
SNIFF "github.com/Dreamacro/clash/component/sniffer"
|
||||
|
||||
"github.com/Dreamacro/clash/adapter"
|
||||
"github.com/Dreamacro/clash/adapter/outbound"
|
||||
|
@ -134,11 +135,10 @@ type IPTables struct {
|
|||
|
||||
type Sniffer struct {
|
||||
Enable bool
|
||||
Sniffers []snifferTypes.Type
|
||||
Sniffers map[snifferTypes.Type]SNIFF.SnifferConfig
|
||||
Reverses *trie.DomainTrie[struct{}]
|
||||
ForceDomain *trie.DomainTrie[struct{}]
|
||||
SkipDomain *trie.DomainTrie[struct{}]
|
||||
Ports *[]utils.Range[uint16]
|
||||
ForceDnsMapping bool
|
||||
ParsePureIp bool
|
||||
}
|
||||
|
@ -294,6 +294,11 @@ type RawSniffer struct {
|
|||
Ports []string `yaml:"port-whitelist" json:"port-whitelist"`
|
||||
ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"`
|
||||
ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"`
|
||||
Sniff map[string]RawSniffingConfig `yaml:"sniff" json:"sniff"`
|
||||
}
|
||||
|
||||
type RawSniffingConfig struct {
|
||||
Ports []string `yaml:"ports" json:"ports"`
|
||||
}
|
||||
|
||||
// EBpf config
|
||||
|
@ -1187,44 +1192,44 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
|
|||
ForceDnsMapping: snifferRaw.ForceDnsMapping,
|
||||
ParsePureIp: snifferRaw.ParsePureIp,
|
||||
}
|
||||
loadSniffer := make(map[snifferTypes.Type]SNIFF.SnifferConfig)
|
||||
|
||||
var ports []utils.Range[uint16]
|
||||
if len(snifferRaw.Ports) == 0 {
|
||||
ports = append(ports, *utils.NewRange[uint16](80, 80))
|
||||
ports = append(ports, *utils.NewRange[uint16](443, 443))
|
||||
} else {
|
||||
for _, portRange := range snifferRaw.Ports {
|
||||
portRaws := strings.Split(portRange, "-")
|
||||
p, err := strconv.ParseUint(portRaws[0], 10, 16)
|
||||
if len(snifferRaw.Sniff) != 0 {
|
||||
for sniffType, sniffConfig := range snifferRaw.Sniff {
|
||||
find := false
|
||||
ports, err := parsePortRange(sniffConfig.Ports)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s format error", portRange)
|
||||
return nil, err
|
||||
}
|
||||
for _, snifferType := range snifferTypes.List {
|
||||
if snifferType.String() == strings.ToUpper(sniffType) {
|
||||
find = true
|
||||
loadSniffer[snifferType] = SNIFF.SnifferConfig{
|
||||
Ports: ports,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
start := uint16(p)
|
||||
if len(portRaws) > 1 {
|
||||
p, err = strconv.ParseUint(portRaws[1], 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s format error", portRange)
|
||||
if !find {
|
||||
return nil, fmt.Errorf("not find the sniffer[%s]", sniffType)
|
||||
}
|
||||
}
|
||||
|
||||
end := uint16(p)
|
||||
ports = append(ports, *utils.NewRange(start, end))
|
||||
} else {
|
||||
ports = append(ports, *utils.NewRange(start, start))
|
||||
// Deprecated: Use Sniff instead
|
||||
log.Warnln("Deprecated: Use Sniff instead")
|
||||
globalPorts, err := parsePortRange(snifferRaw.Ports)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sniffer.Ports = &ports
|
||||
|
||||
loadSniffer := make(map[snifferTypes.Type]struct{})
|
||||
|
||||
for _, snifferName := range snifferRaw.Sniffing {
|
||||
find := false
|
||||
for _, snifferType := range snifferTypes.List {
|
||||
if snifferType.String() == strings.ToUpper(snifferName) {
|
||||
find = true
|
||||
loadSniffer[snifferType] = struct{}{}
|
||||
loadSniffer[snifferType] = SNIFF.SnifferConfig{
|
||||
Ports: globalPorts,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1232,10 +1237,9 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
|
|||
return nil, fmt.Errorf("not find the sniffer[%s]", snifferName)
|
||||
}
|
||||
}
|
||||
|
||||
for st := range loadSniffer {
|
||||
sniffer.Sniffers = append(sniffer.Sniffers, st)
|
||||
}
|
||||
|
||||
sniffer.Sniffers = loadSniffer
|
||||
sniffer.ForceDomain = trie.New[struct{}]()
|
||||
for _, domain := range snifferRaw.ForceDomain {
|
||||
err := sniffer.ForceDomain.Insert(domain, struct{}{})
|
||||
|
@ -1256,3 +1260,28 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
|
|||
|
||||
return sniffer, nil
|
||||
}
|
||||
|
||||
func parsePortRange(portRanges []string) ([]utils.Range[uint16], error) {
|
||||
ports := make([]utils.Range[uint16], 0)
|
||||
for _, portRange := range portRanges {
|
||||
portRaws := strings.Split(portRange, "-")
|
||||
p, err := strconv.ParseUint(portRaws[0], 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s format error", portRange)
|
||||
}
|
||||
|
||||
start := uint16(p)
|
||||
if len(portRaws) > 1 {
|
||||
p, err = strconv.ParseUint(portRaws[1], 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s format error", portRange)
|
||||
}
|
||||
|
||||
end := uint16(p)
|
||||
ports = append(ports, *utils.NewRange(start, end))
|
||||
} else {
|
||||
ports = append(ports, *utils.NewRange(start, start))
|
||||
}
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ type Sniffer interface {
|
|||
SupportNetwork() constant.NetWork
|
||||
SniffTCP(bytes []byte) (string, error)
|
||||
Protocol() string
|
||||
SupportPort(port uint16) bool
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
|
@ -279,7 +279,7 @@ func updateTun(general *config.General) {
|
|||
func updateSniffer(sniffer *config.Sniffer) {
|
||||
if sniffer.Enable {
|
||||
dispatcher, err := SNI.NewSnifferDispatcher(
|
||||
sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain, sniffer.Ports,
|
||||
sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain,
|
||||
sniffer.ForceDnsMapping, sniffer.ParsePureIp,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue