feat: add override-destination for sniffer

This commit is contained in:
Skyxim 2023-01-23 14:08:11 +08:00
parent df1f6e2b99
commit 096bb8d439
9 changed files with 48 additions and 26 deletions

View file

@ -9,7 +9,8 @@ import (
) )
type SnifferConfig struct { type SnifferConfig struct {
Ports []utils.Range[uint16] OverrideDest bool
Ports []utils.Range[uint16]
} }
type BaseSniffer struct { type BaseSniffer struct {

View file

@ -26,15 +26,12 @@ var (
var Dispatcher *SnifferDispatcher var Dispatcher *SnifferDispatcher
type SnifferDispatcher struct { type SnifferDispatcher struct {
enable bool enable bool
sniffers map[sniffer.Sniffer]SnifferConfig
sniffers []sniffer.Sniffer forceDomain *trie.DomainTrie[struct{}]
skipSNI *trie.DomainTrie[struct{}]
forceDomain *trie.DomainTrie[struct{}] skipList *cache.LruCache[string, uint8]
skipSNI *trie.DomainTrie[struct{}] rwMux sync.RWMutex
skipList *cache.LruCache[string, uint8]
rwMux sync.RWMutex
forceDnsMapping bool forceDnsMapping bool
parsePureIp bool parsePureIp bool
} }
@ -53,10 +50,12 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
} }
inWhitelist := false inWhitelist := false
for _, sniffer := range sd.sniffers { overrideDest := false
for sniffer, config := range sd.sniffers {
if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet { if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet {
inWhitelist = sniffer.SupportPort(uint16(port)) inWhitelist = sniffer.SupportPort(uint16(port))
if inWhitelist { if inWhitelist {
overrideDest = config.OverrideDest
break break
} }
} }
@ -89,12 +88,12 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
sd.skipList.Delete(dst) sd.skipList.Delete(dst)
sd.rwMux.RUnlock() sd.rwMux.RUnlock()
sd.replaceDomain(metadata, host) sd.replaceDomain(metadata, host, overrideDest)
} }
} }
} }
func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string) { func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
dstIP := "" dstIP := ""
if metadata.DstIP.IsValid() { if metadata.DstIP.IsValid() {
dstIP = metadata.DstIP.String() dstIP = metadata.DstIP.String()
@ -112,7 +111,11 @@ func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string) {
metadata.Host, host) metadata.Host, host)
} }
metadata.Host = host if overrideDest {
metadata.Host = host
} else {
metadata.SniffHost = host
}
metadata.DNSMode = C.DNSNormal metadata.DNSMode = C.DNSNormal
} }
@ -121,7 +124,7 @@ func (sd *SnifferDispatcher) Enable() bool {
} }
func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) { func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
for _, s := range sd.sniffers { for s := range sd.sniffers {
if s.SupportNetwork() == C.TCP { if s.SupportNetwork() == C.TCP {
_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(1 * time.Second))
_, err := conn.Peek(1) _, err := conn.Peek(1)
@ -189,9 +192,10 @@ func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, forceDom
enable: true, enable: true,
forceDomain: forceDomain, forceDomain: forceDomain,
skipSNI: skipSNI, skipSNI: skipSNI,
skipList: cache.New[string, uint8](cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)), skipList: cache.New(cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)),
forceDnsMapping: forceDnsMapping, forceDnsMapping: forceDnsMapping,
parsePureIp: parsePureIp, parsePureIp: parsePureIp,
sniffers: make(map[sniffer.Sniffer]SnifferConfig, 0),
} }
for snifferName, config := range snifferConfig { for snifferName, config := range snifferConfig {
@ -200,8 +204,7 @@ func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig, forceDom
log.Errorln("Sniffer name[%s] is error", snifferName) log.Errorln("Sniffer name[%s] is error", snifferName)
return &SnifferDispatcher{enable: false}, err return &SnifferDispatcher{enable: false}, err
} }
dispatcher.sniffers[s] = config
dispatcher.sniffers = append(dispatcher.sniffers, s)
} }
return &dispatcher, nil return &dispatcher, nil

View file

@ -288,6 +288,7 @@ type RawGeoXUrl struct {
type RawSniffer struct { type RawSniffer struct {
Enable bool `yaml:"enable" json:"enable"` Enable bool `yaml:"enable" json:"enable"`
OverrideDest bool `yaml:"override-destination" json:"override-destination"`
Sniffing []string `yaml:"sniffing" json:"sniffing"` Sniffing []string `yaml:"sniffing" json:"sniffing"`
ForceDomain []string `yaml:"force-domain" json:"force-domain"` ForceDomain []string `yaml:"force-domain" json:"force-domain"`
SkipDomain []string `yaml:"skip-domain" json:"skip-domain"` SkipDomain []string `yaml:"skip-domain" json:"skip-domain"`
@ -298,7 +299,8 @@ type RawSniffer struct {
} }
type RawSniffingConfig struct { type RawSniffingConfig struct {
Ports []string `yaml:"ports" json:"ports"` Ports []string `yaml:"ports" json:"ports"`
OverrideDest *bool `yaml:"override-destination" json:"override-destination"`
} }
// EBpf config // EBpf config
@ -1201,11 +1203,16 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
overrideDest := snifferRaw.OverrideDest
if sniffConfig.OverrideDest != nil {
overrideDest = *sniffConfig.OverrideDest
}
for _, snifferType := range snifferTypes.List { for _, snifferType := range snifferTypes.List {
if snifferType.String() == strings.ToUpper(sniffType) { if snifferType.String() == strings.ToUpper(sniffType) {
find = true find = true
loadSniffer[snifferType] = SNIFF.SnifferConfig{ loadSniffer[snifferType] = SNIFF.SnifferConfig{
Ports: ports, Ports: ports,
OverrideDest: overrideDest,
} }
} }
} }
@ -1228,7 +1235,8 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
if snifferType.String() == strings.ToUpper(snifferName) { if snifferType.String() == strings.ToUpper(snifferName) {
find = true find = true
loadSniffer[snifferType] = SNIFF.SnifferConfig{ loadSniffer[snifferType] = SNIFF.SnifferConfig{
Ports: globalPorts, Ports: globalPorts,
OverrideDest: snifferRaw.OverrideDest,
} }
} }
} }

View file

@ -134,6 +134,8 @@ type Metadata struct {
SpecialProxy string `json:"specialProxy"` SpecialProxy string `json:"specialProxy"`
SpecialRules string `json:"specialRules"` SpecialRules string `json:"specialRules"`
RemoteDst string `json:"remoteDestination"` RemoteDst string `json:"remoteDestination"`
// Only domain rule
SniffHost string
} }
func (m *Metadata) RemoteAddress() string { func (m *Metadata) RemoteAddress() string {
@ -176,6 +178,14 @@ func (m *Metadata) Resolved() bool {
return m.DstIP.IsValid() return m.DstIP.IsValid()
} }
func (m *Metadata) RuleHost() string {
if len(m.SniffHost) == 0 {
return m.Host
} else {
return m.SniffHost
}
}
// Pure is used to solve unexpected behavior // Pure is used to solve unexpected behavior
// when dialing proxy connection in DNSMapping mode. // when dialing proxy connection in DNSMapping mode.
func (m *Metadata) Pure() *Metadata { func (m *Metadata) Pure() *Metadata {

View file

@ -19,7 +19,7 @@ func (d *Domain) RuleType() C.RuleType {
} }
func (d *Domain) Match(metadata *C.Metadata) (bool, string) { func (d *Domain) Match(metadata *C.Metadata) (bool, string) {
return metadata.Host == d.domain, d.adapter return metadata.RuleHost() == d.domain, d.adapter
} }
func (d *Domain) Adapter() string { func (d *Domain) Adapter() string {

View file

@ -19,7 +19,7 @@ func (dk *DomainKeyword) RuleType() C.RuleType {
} }
func (dk *DomainKeyword) Match(metadata *C.Metadata) (bool, string) { func (dk *DomainKeyword) Match(metadata *C.Metadata) (bool, string) {
domain := metadata.Host domain := metadata.RuleHost()
return strings.Contains(domain, dk.keyword), dk.adapter return strings.Contains(domain, dk.keyword), dk.adapter
} }

View file

@ -19,7 +19,7 @@ func (ds *DomainSuffix) RuleType() C.RuleType {
} }
func (ds *DomainSuffix) Match(metadata *C.Metadata) (bool, string) { func (ds *DomainSuffix) Match(metadata *C.Metadata) (bool, string) {
domain := metadata.Host domain := metadata.RuleHost()
return strings.HasSuffix(domain, "."+ds.suffix) || domain == ds.suffix, ds.adapter return strings.HasSuffix(domain, "."+ds.suffix) || domain == ds.suffix, ds.adapter
} }

View file

@ -29,7 +29,7 @@ func (gs *GEOSITE) Match(metadata *C.Metadata) (bool, string) {
return false, "" return false, ""
} }
domain := metadata.Host domain := metadata.RuleHost()
return gs.matcher.ApplyDomain(domain), gs.adapter return gs.matcher.ApplyDomain(domain), gs.adapter
} }

View file

@ -17,7 +17,7 @@ func (d *domainStrategy) ShouldFindProcess() bool {
} }
func (d *domainStrategy) Match(metadata *C.Metadata) bool { func (d *domainStrategy) Match(metadata *C.Metadata) bool {
return d.domainRules != nil && d.domainRules.Search(metadata.Host) != nil return d.domainRules != nil && d.domainRules.Search(metadata.RuleHost()) != nil
} }
func (d *domainStrategy) Count() int { func (d *domainStrategy) Count() int {