feat: sniffer support

sniffer:
  enable: true
  force: false # Overwrite domain
  sniffing:
    - tls
This commit is contained in:
gVisor bot 2022-04-09 22:30:36 +08:00
parent 277dd2dc68
commit b2becaffe3
9 changed files with 256 additions and 62 deletions

View file

@ -0,0 +1,104 @@
package sniffer
import (
"errors"
"net"
CN "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)
var (
ErrorUnsupportedSniffer = errors.New("unsupported sniffer")
)
var Dispatcher SnifferDispatcher
type SnifferDispatcher struct {
enable bool
force bool
sniffers []C.Sniffer
}
func (sd *SnifferDispatcher) Tcp(conn net.Conn, metadata *C.Metadata) {
bufConn, ok := conn.(*CN.BufferedConn)
if !ok {
return
}
if sd.force {
sd.cover(bufConn, metadata)
} else {
if metadata.Host != "" {
return
}
sd.cover(bufConn, metadata)
}
}
func (sd *SnifferDispatcher) Enable() bool {
return sd.enable
}
func (sd *SnifferDispatcher) cover(conn *CN.BufferedConn, metadata *C.Metadata) {
for _, sniffer := range sd.sniffers {
if sniffer.SupportNetwork() == C.TCP {
conn.Peek(1)
len := conn.Buffered()
bytes, err := conn.Peek(len)
if err != nil {
log.Warnln("the data lenght not enough")
continue
}
host, err := sniffer.SniffTCP(bytes)
if err != nil {
log.Warnln("Sniff data failed on Sniffer[%s]", sniffer.Protocol())
continue
}
metadata.Host = host
metadata.DstIP = nil
metadata.AddrType = C.AtypDomainName
if resolver.FakeIPEnabled() {
metadata.DNSMode = C.DNSFakeIP
} else {
metadata.DNSMode = C.DNSMapping
}
resolver.InsertHostByIP(metadata.DstIP, host)
break
}
}
}
func NewSnifferDispatcher(needSniffer []C.SnifferType, force bool) (SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: true,
force: force,
}
for _, snifferName := range needSniffer {
sniffer, err := NewSniffer(snifferName)
if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName)
return SnifferDispatcher{enable: false}, err
}
dispatcher.sniffers = append(dispatcher.sniffers, sniffer)
}
return dispatcher, nil
}
func NewSniffer(name C.SnifferType) (C.Sniffer, error) {
switch name {
case C.TLS:
return &TLSSniffer{}, nil
default:
return nil, ErrorUnsupportedSniffer
}
}

View file

@ -1,4 +1,4 @@
package tls package sniffer
import ( import (
"testing" "testing"
@ -142,7 +142,7 @@ func TestTLSHeaders(t *testing.T) {
} }
for _, test := range cases { for _, test := range cases {
header, err := SniffTLS(test.input) domain, err := SniffTLS(test.input)
if test.err { if test.err {
if err == nil { if err == nil {
t.Errorf("Exepct error but nil in test %v", test) t.Errorf("Exepct error but nil in test %v", test)
@ -151,8 +151,8 @@ func TestTLSHeaders(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Expect no error but actually %s in test %v", err.Error(), test) t.Errorf("Expect no error but actually %s in test %v", err.Error(), test)
} }
if header.Domain() != test.domain { if *domain != test.domain {
t.Error("expect domain ", test.domain, " but got ", header.Domain()) t.Error("expect domain ", test.domain, " but got ", domain)
} }
} }
} }

View file

@ -1,107 +1,116 @@
package tls package sniffer
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"strings" "strings"
C "github.com/Dreamacro/clash/constant"
) )
var ErrNoClue = errors.New("not enough information for making a decision")
type SniffHeader struct {
domain string
}
func (h *SniffHeader) Protocol() string {
return "tls"
}
func (h *SniffHeader) Domain() string {
return h.domain
}
var ( var (
errNotTLS = errors.New("not TLS header") errNotTLS = errors.New("not TLS header")
errNotClientHello = errors.New("not client hello") errNotClientHello = errors.New("not client hello")
ErrNoClue = errors.New("not enough information for making a decision")
) )
type TLSSniffer struct {
}
func (tls *TLSSniffer) Protocol() string {
return "tls"
}
func (tls *TLSSniffer) SupportNetwork() C.NetWork {
return C.TCP
}
func (tls *TLSSniffer) SniffTCP(bytes []byte) (string, error) {
domain, err := SniffTLS(bytes)
if err == nil {
return *domain, nil
} else {
return "", err
}
}
func IsValidTLSVersion(major, minor byte) bool { func IsValidTLSVersion(major, minor byte) bool {
return major == 3 return major == 3
} }
// ReadClientHello returns server name (if any) from TLS client hello message. // ReadClientHello returns server name (if any) from TLS client hello message.
// https://github.com/golang/go/blob/master/src/crypto/tls/handshake_messages.go#L300 // https://github.com/golang/go/blob/master/src/crypto/tls/handshake_messages.go#L300
func ReadClientHello(data []byte, h *SniffHeader) error { func ReadClientHello(data []byte) (*string, error) {
if len(data) < 42 { if len(data) < 42 {
return ErrNoClue return nil, ErrNoClue
} }
sessionIDLen := int(data[38]) sessionIDLen := int(data[38])
if sessionIDLen > 32 || len(data) < 39+sessionIDLen { if sessionIDLen > 32 || len(data) < 39+sessionIDLen {
return ErrNoClue return nil, ErrNoClue
} }
data = data[39+sessionIDLen:] data = data[39+sessionIDLen:]
if len(data) < 2 { if len(data) < 2 {
return ErrNoClue return nil, ErrNoClue
} }
// cipherSuiteLen is the number of bytes of cipher suite numbers. Since // cipherSuiteLen is the number of bytes of cipher suite numbers. Since
// they are uint16s, the number must be even. // they are uint16s, the number must be even.
cipherSuiteLen := int(data[0])<<8 | int(data[1]) cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
return errNotClientHello return nil, errNotClientHello
} }
data = data[2+cipherSuiteLen:] data = data[2+cipherSuiteLen:]
if len(data) < 1 { if len(data) < 1 {
return ErrNoClue return nil, ErrNoClue
} }
compressionMethodsLen := int(data[0]) compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen { if len(data) < 1+compressionMethodsLen {
return ErrNoClue return nil, ErrNoClue
} }
data = data[1+compressionMethodsLen:] data = data[1+compressionMethodsLen:]
if len(data) == 0 { if len(data) == 0 {
return errNotClientHello return nil, errNotClientHello
} }
if len(data) < 2 { if len(data) < 2 {
return errNotClientHello return nil, errNotClientHello
} }
extensionsLength := int(data[0])<<8 | int(data[1]) extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:] data = data[2:]
if extensionsLength != len(data) { if extensionsLength != len(data) {
return errNotClientHello return nil, errNotClientHello
} }
for len(data) != 0 { for len(data) != 0 {
if len(data) < 4 { if len(data) < 4 {
return errNotClientHello return nil, errNotClientHello
} }
extension := uint16(data[0])<<8 | uint16(data[1]) extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3]) length := int(data[2])<<8 | int(data[3])
data = data[4:] data = data[4:]
if len(data) < length { if len(data) < length {
return errNotClientHello return nil, errNotClientHello
} }
if extension == 0x00 { /* extensionServerName */ if extension == 0x00 { /* extensionServerName */
d := data[:length] d := data[:length]
if len(d) < 2 { if len(d) < 2 {
return errNotClientHello return nil, errNotClientHello
} }
namesLen := int(d[0])<<8 | int(d[1]) namesLen := int(d[0])<<8 | int(d[1])
d = d[2:] d = d[2:]
if len(d) != namesLen { if len(d) != namesLen {
return errNotClientHello return nil, errNotClientHello
} }
for len(d) > 0 { for len(d) > 0 {
if len(d) < 3 { if len(d) < 3 {
return errNotClientHello return nil, errNotClientHello
} }
nameType := d[0] nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2]) nameLen := int(d[1])<<8 | int(d[2])
d = d[3:] d = d[3:]
if len(d) < nameLen { if len(d) < nameLen {
return errNotClientHello return nil, errNotClientHello
} }
if nameType == 0 { if nameType == 0 {
serverName := string(d[:nameLen]) serverName := string(d[:nameLen])
@ -109,21 +118,22 @@ func ReadClientHello(data []byte, h *SniffHeader) error {
// trailing dot. See // trailing dot. See
// https://tools.ietf.org/html/rfc6066#section-3. // https://tools.ietf.org/html/rfc6066#section-3.
if strings.HasSuffix(serverName, ".") { if strings.HasSuffix(serverName, ".") {
return errNotClientHello return nil, errNotClientHello
} }
h.domain = serverName
return nil return &serverName, nil
} }
d = d[nameLen:] d = d[nameLen:]
} }
} }
data = data[length:] data = data[length:]
} }
return errNotTLS return nil, errNotTLS
} }
func SniffTLS(b []byte) (*SniffHeader, error) { func SniffTLS(b []byte) (*string, error) {
if len(b) < 5 { if len(b) < 5 {
return nil, ErrNoClue return nil, ErrNoClue
} }
@ -139,10 +149,9 @@ func SniffTLS(b []byte) (*SniffHeader, error) {
return nil, ErrNoClue return nil, ErrNoClue
} }
h := &SniffHeader{} domain, err := ReadClientHello(b[5 : 5+headerLen])
err := ReadClientHello(b[5:5+headerLen], h)
if err == nil { if err == nil {
return h, nil return domain, nil
} }
return nil, err return nil, err
} }

View file

@ -126,6 +126,12 @@ type IPTables struct {
Bypass []string `yaml:"bypass" json:"bypass"` Bypass []string `yaml:"bypass" json:"bypass"`
} }
type Sniffer struct {
Enable bool
Force bool
Sniffers []C.SnifferType
}
// Experimental config // Experimental config
type Experimental struct{} type Experimental struct{}
@ -143,6 +149,7 @@ type Config struct {
Proxies map[string]C.Proxy Proxies map[string]C.Proxy
Providers map[string]providerTypes.ProxyProvider Providers map[string]providerTypes.ProxyProvider
RuleProviders map[string]*providerTypes.RuleProvider RuleProviders map[string]*providerTypes.RuleProvider
Sniffer *Sniffer
} }
type RawDNS struct { type RawDNS struct {
@ -199,6 +206,7 @@ type RawConfig struct {
GeodataMode bool `yaml:"geodata-mode"` GeodataMode bool `yaml:"geodata-mode"`
GeodataLoader string `yaml:"geodata-loader"` GeodataLoader string `yaml:"geodata-loader"`
Sniffer SnifferRaw `yaml:"sniffer"`
ProxyProvider map[string]map[string]any `yaml:"proxy-providers"` ProxyProvider map[string]map[string]any `yaml:"proxy-providers"`
RuleProvider map[string]map[string]any `yaml:"rule-providers"` RuleProvider map[string]map[string]any `yaml:"rule-providers"`
Hosts map[string]string `yaml:"hosts"` Hosts map[string]string `yaml:"hosts"`
@ -212,6 +220,12 @@ type RawConfig struct {
Rule []string `yaml:"rules"` Rule []string `yaml:"rules"`
} }
type SnifferRaw struct {
Enable bool `yaml:"enable" json:"enable"`
Force bool `yaml:"force" json:"force"`
Sniffing []string `yaml:"sniffing" json:"sniffing"`
}
// Parse config // Parse config
func Parse(buf []byte) (*Config, error) { func Parse(buf []byte) (*Config, error) {
rawCfg, err := UnmarshalRawConfig(buf) rawCfg, err := UnmarshalRawConfig(buf)
@ -277,6 +291,11 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
"www.msftconnecttest.com", "www.msftconnecttest.com",
}, },
}, },
Sniffer: SnifferRaw{
Enable: false,
Force: false,
Sniffing: []string{},
},
Profile: Profile{ Profile: Profile{
StoreSelected: true, StoreSelected: true,
}, },
@ -339,6 +358,11 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) {
config.Users = parseAuthentication(rawCfg.Authentication) config.Users = parseAuthentication(rawCfg.Authentication)
config.Sniffer, err = parseSniffer(rawCfg.Sniffer)
if err != nil {
return nil, err
}
elapsedTime := time.Since(startTime) / time.Millisecond // duration in ms elapsedTime := time.Since(startTime) / time.Millisecond // duration in ms
log.Infoln("Initial configuration complete, total time: %dms", elapsedTime) //Segment finished in xxm log.Infoln("Initial configuration complete, total time: %dms", elapsedTime) //Segment finished in xxm
return config, nil return config, nil
@ -882,3 +906,32 @@ func parseTun(rawTun RawTun, general *General) (*Tun, error) {
AutoRoute: rawTun.AutoRoute, AutoRoute: rawTun.AutoRoute,
}, nil }, nil
} }
func parseSniffer(snifferRaw SnifferRaw) (*Sniffer, error) {
sniffer := &Sniffer{
Enable: snifferRaw.Enable,
Force: snifferRaw.Force,
}
loadSniffer := make(map[C.SnifferType]struct{})
for _, snifferName := range snifferRaw.Sniffing {
find := false
for _, snifferType := range C.SnifferList {
if snifferType.String() == strings.ToUpper(snifferName) {
find = true
loadSniffer[snifferType] = struct{}{}
}
}
if !find {
return nil, fmt.Errorf("not find the sniffer[%s]", snifferName)
}
}
for st := range loadSniffer {
sniffer.Sniffers = append(sniffer.Sniffers, st)
}
return sniffer, nil
}

26
constant/sniffer.go Normal file
View file

@ -0,0 +1,26 @@
package constant
type Sniffer interface {
SupportNetwork() NetWork
SniffTCP(bytes []byte) (string, error)
Protocol() string
}
const (
TLS SnifferType = iota
)
var (
SnifferList = []SnifferType{TLS}
)
type SnifferType int
func (rt SnifferType) String() string {
switch rt {
case TLS:
return "TLS"
default:
return "Unknown"
}
}

View file

@ -3,8 +3,8 @@ package context
import ( import (
"net" "net"
CN "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
) )
@ -19,7 +19,7 @@ func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
return &ConnContext{ return &ConnContext{
id: id, id: id,
metadata: metadata, metadata: metadata,
conn: conn, conn: CN.NewBufferedConn(conn),
} }
} }

View file

@ -17,6 +17,7 @@ import (
"github.com/Dreamacro/clash/component/profile" "github.com/Dreamacro/clash/component/profile"
"github.com/Dreamacro/clash/component/profile/cachefile" "github.com/Dreamacro/clash/component/profile/cachefile"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
SNI "github.com/Dreamacro/clash/component/sniffer"
"github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/component/trie"
"github.com/Dreamacro/clash/config" "github.com/Dreamacro/clash/config"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
@ -78,6 +79,7 @@ func ApplyConfig(cfg *config.Config, force bool) {
updateDNS(cfg.DNS, cfg.Tun) updateDNS(cfg.DNS, cfg.Tun)
updateGeneral(cfg.General, force) updateGeneral(cfg.General, force)
updateIPTables(cfg) updateIPTables(cfg)
updateSniffer(cfg.Sniffer)
updateTun(cfg.Tun, cfg.DNS) updateTun(cfg.Tun, cfg.DNS)
updateExperimental(cfg) updateExperimental(cfg)
updateHosts(cfg.Hosts) updateHosts(cfg.Hosts)
@ -87,6 +89,16 @@ func ApplyConfig(cfg *config.Config, force bool) {
log.SetLevel(cfg.General.LogLevel) log.SetLevel(cfg.General.LogLevel)
} }
func updateSniffer(sniffer *config.Sniffer) {
if sniffer.Enable {
var err error
SNI.Dispatcher, err = SNI.NewSnifferDispatcher(sniffer.Sniffers, sniffer.Force)
if err != nil {
log.Errorln("Init Sniffer failed, err:%v", err)
}
}
}
func GetGeneral() *config.General { func GetGeneral() *config.General {
ports := P.GetPorts() ports := P.GetPorts()
var authenticator []string var authenticator []string

View file

@ -1,14 +1,10 @@
package statistic package statistic
import ( import (
"errors"
"net" "net"
"time" "time"
"github.com/Dreamacro/clash/common/snifer/tls"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"go.uber.org/atomic" "go.uber.org/atomic"
@ -52,20 +48,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
n, err := tt.Conn.Write(b) n, err := tt.Conn.Write(b)
upload := int64(n) upload := int64(n)
tt.manager.PushUploaded(upload) tt.manager.PushUploaded(upload)
if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443" || tt.Metadata.DstPort == "993" || tt.Metadata.DstPort == "465" || tt.Metadata.DstPort == "995") {
header, err := tls.SniffTLS(b)
if err != nil {
// log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort)
} else {
resolver.InsertHostByIP(tt.Metadata.DstIP, header.Domain())
log.Warnln("use sni update host: %s ip: %s", header.Domain(), tt.Metadata.DstIP.String())
tt.manager.Leave(tt)
tt.Conn.Close()
return n, errors.New("sni update, break current link to avoid leaks")
}
}
tt.UploadTotal.Add(upload) tt.UploadTotal.Add(upload)
return n, err return n, err
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/Dreamacro/clash/component/nat" "github.com/Dreamacro/clash/component/nat"
P "github.com/Dreamacro/clash/component/process" P "github.com/Dreamacro/clash/component/process"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/component/sniffer"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/provider" "github.com/Dreamacro/clash/constant/provider"
icontext "github.com/Dreamacro/clash/context" icontext "github.com/Dreamacro/clash/context"
@ -36,6 +37,8 @@ var (
// default timeout for UDP session // default timeout for UDP session
udpTimeout = 60 * time.Second udpTimeout = 60 * time.Second
snifferDispatcher *sniffer.SnifferDispatcher
) )
func init() { func init() {
@ -294,6 +297,10 @@ func handleTCPConn(connCtx C.ConnContext) {
return return
} }
if sniffer.Dispatcher.Enable() {
sniffer.Dispatcher.Tcp(connCtx.Conn(), metadata)
}
proxy, rule, err := resolveMetadata(connCtx, metadata) proxy, rule, err := resolveMetadata(connCtx, metadata)
if err != nil { if err != nil {
log.Warnln("[Metadata] parse failed: %s", err.Error()) log.Warnln("[Metadata] parse failed: %s", err.Error())