feat: sniffer support

sniffer:
  enable: true
  force: false # Overwrite domain
  sniffing:
    - tls
This commit is contained in:
Skyxim 2022-04-09 22:30:36 +08:00
parent 07906c0aa5
commit 544e0f137d
9 changed files with 256 additions and 62 deletions

View file

@ -0,0 +1,104 @@
package sniffer
import (
"errors"
"net"
CN "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)
var (
ErrorUnsupportedSniffer = errors.New("unsupported sniffer")
)
var Dispatcher SnifferDispatcher
type SnifferDispatcher struct {
enable bool
force bool
sniffers []C.Sniffer
}
func (sd *SnifferDispatcher) Tcp(conn net.Conn, metadata *C.Metadata) {
bufConn, ok := conn.(*CN.BufferedConn)
if !ok {
return
}
if sd.force {
sd.cover(bufConn, metadata)
} else {
if metadata.Host != "" {
return
}
sd.cover(bufConn, metadata)
}
}
func (sd *SnifferDispatcher) Enable() bool {
return sd.enable
}
func (sd *SnifferDispatcher) cover(conn *CN.BufferedConn, metadata *C.Metadata) {
for _, sniffer := range sd.sniffers {
if sniffer.SupportNetwork() == C.TCP {
conn.Peek(1)
len := conn.Buffered()
bytes, err := conn.Peek(len)
if err != nil {
log.Warnln("the data lenght not enough")
continue
}
host, err := sniffer.SniffTCP(bytes)
if err != nil {
log.Warnln("Sniff data failed on Sniffer[%s]", sniffer.Protocol())
continue
}
metadata.Host = host
metadata.DstIP = nil
metadata.AddrType = C.AtypDomainName
if resolver.FakeIPEnabled() {
metadata.DNSMode = C.DNSFakeIP
} else {
metadata.DNSMode = C.DNSMapping
}
resolver.InsertHostByIP(metadata.DstIP, host)
break
}
}
}
func NewSnifferDispatcher(needSniffer []C.SnifferType, force bool) (SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: true,
force: force,
}
for _, snifferName := range needSniffer {
sniffer, err := NewSniffer(snifferName)
if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName)
return SnifferDispatcher{enable: false}, err
}
dispatcher.sniffers = append(dispatcher.sniffers, sniffer)
}
return dispatcher, nil
}
func NewSniffer(name C.SnifferType) (C.Sniffer, error) {
switch name {
case C.TLS:
return &TLSSniffer{}, nil
default:
return nil, ErrorUnsupportedSniffer
}
}

View file

@ -1,4 +1,4 @@
package tls
package sniffer
import (
"testing"
@ -142,7 +142,7 @@ func TestTLSHeaders(t *testing.T) {
}
for _, test := range cases {
header, err := SniffTLS(test.input)
domain, err := SniffTLS(test.input)
if test.err {
if err == nil {
t.Errorf("Exepct error but nil in test %v", test)
@ -151,8 +151,8 @@ func TestTLSHeaders(t *testing.T) {
if err != nil {
t.Errorf("Expect no error but actually %s in test %v", err.Error(), test)
}
if header.Domain() != test.domain {
t.Error("expect domain ", test.domain, " but got ", header.Domain())
if *domain != test.domain {
t.Error("expect domain ", test.domain, " but got ", domain)
}
}
}

View file

@ -1,107 +1,116 @@
package tls
package sniffer
import (
"encoding/binary"
"errors"
"strings"
C "github.com/Dreamacro/clash/constant"
)
var ErrNoClue = errors.New("not enough information for making a decision")
type SniffHeader struct {
domain string
}
func (h *SniffHeader) Protocol() string {
return "tls"
}
func (h *SniffHeader) Domain() string {
return h.domain
}
var (
errNotTLS = errors.New("not TLS header")
errNotClientHello = errors.New("not client hello")
ErrNoClue = errors.New("not enough information for making a decision")
)
type TLSSniffer struct {
}
func (tls *TLSSniffer) Protocol() string {
return "tls"
}
func (tls *TLSSniffer) SupportNetwork() C.NetWork {
return C.TCP
}
func (tls *TLSSniffer) SniffTCP(bytes []byte) (string, error) {
domain, err := SniffTLS(bytes)
if err == nil {
return *domain, nil
} else {
return "", err
}
}
func IsValidTLSVersion(major, minor byte) bool {
return major == 3
}
// ReadClientHello returns server name (if any) from TLS client hello message.
// https://github.com/golang/go/blob/master/src/crypto/tls/handshake_messages.go#L300
func ReadClientHello(data []byte, h *SniffHeader) error {
func ReadClientHello(data []byte) (*string, error) {
if len(data) < 42 {
return ErrNoClue
return nil, ErrNoClue
}
sessionIDLen := int(data[38])
if sessionIDLen > 32 || len(data) < 39+sessionIDLen {
return ErrNoClue
return nil, ErrNoClue
}
data = data[39+sessionIDLen:]
if len(data) < 2 {
return ErrNoClue
return nil, ErrNoClue
}
// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
// they are uint16s, the number must be even.
cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
return errNotClientHello
return nil, errNotClientHello
}
data = data[2+cipherSuiteLen:]
if len(data) < 1 {
return ErrNoClue
return nil, ErrNoClue
}
compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen {
return ErrNoClue
return nil, ErrNoClue
}
data = data[1+compressionMethodsLen:]
if len(data) == 0 {
return errNotClientHello
return nil, errNotClientHello
}
if len(data) < 2 {
return errNotClientHello
return nil, errNotClientHello
}
extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if extensionsLength != len(data) {
return errNotClientHello
return nil, errNotClientHello
}
for len(data) != 0 {
if len(data) < 4 {
return errNotClientHello
return nil, errNotClientHello
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return errNotClientHello
return nil, errNotClientHello
}
if extension == 0x00 { /* extensionServerName */
d := data[:length]
if len(d) < 2 {
return errNotClientHello
return nil, errNotClientHello
}
namesLen := int(d[0])<<8 | int(d[1])
d = d[2:]
if len(d) != namesLen {
return errNotClientHello
return nil, errNotClientHello
}
for len(d) > 0 {
if len(d) < 3 {
return errNotClientHello
return nil, errNotClientHello
}
nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2])
d = d[3:]
if len(d) < nameLen {
return errNotClientHello
return nil, errNotClientHello
}
if nameType == 0 {
serverName := string(d[:nameLen])
@ -109,21 +118,22 @@ func ReadClientHello(data []byte, h *SniffHeader) error {
// trailing dot. See
// https://tools.ietf.org/html/rfc6066#section-3.
if strings.HasSuffix(serverName, ".") {
return errNotClientHello
return nil, errNotClientHello
}
h.domain = serverName
return nil
return &serverName, nil
}
d = d[nameLen:]
}
}
data = data[length:]
}
return errNotTLS
return nil, errNotTLS
}
func SniffTLS(b []byte) (*SniffHeader, error) {
func SniffTLS(b []byte) (*string, error) {
if len(b) < 5 {
return nil, ErrNoClue
}
@ -139,10 +149,9 @@ func SniffTLS(b []byte) (*SniffHeader, error) {
return nil, ErrNoClue
}
h := &SniffHeader{}
err := ReadClientHello(b[5:5+headerLen], h)
domain, err := ReadClientHello(b[5 : 5+headerLen])
if err == nil {
return h, nil
return domain, nil
}
return nil, err
}

View file

@ -126,6 +126,12 @@ type IPTables struct {
Bypass []string `yaml:"bypass" json:"bypass"`
}
type Sniffer struct {
Enable bool
Force bool
Sniffers []C.SnifferType
}
// Experimental config
type Experimental struct{}
@ -143,6 +149,7 @@ type Config struct {
Proxies map[string]C.Proxy
Providers map[string]providerTypes.ProxyProvider
RuleProviders map[string]*providerTypes.RuleProvider
Sniffer *Sniffer
}
type RawDNS struct {
@ -199,6 +206,7 @@ type RawConfig struct {
GeodataMode bool `yaml:"geodata-mode"`
GeodataLoader string `yaml:"geodata-loader"`
Sniffer SnifferRaw `yaml:"sniffer"`
ProxyProvider map[string]map[string]any `yaml:"proxy-providers"`
RuleProvider map[string]map[string]any `yaml:"rule-providers"`
Hosts map[string]string `yaml:"hosts"`
@ -212,6 +220,12 @@ type RawConfig struct {
Rule []string `yaml:"rules"`
}
type SnifferRaw struct {
Enable bool `yaml:"enable" json:"enable"`
Force bool `yaml:"force" json:"force"`
Sniffing []string `yaml:"sniffing" json:"sniffing"`
}
// Parse config
func Parse(buf []byte) (*Config, error) {
rawCfg, err := UnmarshalRawConfig(buf)
@ -277,6 +291,11 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
"www.msftconnecttest.com",
},
},
Sniffer: SnifferRaw{
Enable: false,
Force: false,
Sniffing: []string{},
},
Profile: Profile{
StoreSelected: true,
},
@ -339,6 +358,11 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) {
config.Users = parseAuthentication(rawCfg.Authentication)
config.Sniffer, err = parseSniffer(rawCfg.Sniffer)
if err != nil {
return nil, err
}
elapsedTime := time.Since(startTime) / time.Millisecond // duration in ms
log.Infoln("Initial configuration complete, total time: %dms", elapsedTime) //Segment finished in xxm
return config, nil
@ -882,3 +906,32 @@ func parseTun(rawTun RawTun, general *General) (*Tun, error) {
AutoRoute: rawTun.AutoRoute,
}, nil
}
func parseSniffer(snifferRaw SnifferRaw) (*Sniffer, error) {
sniffer := &Sniffer{
Enable: snifferRaw.Enable,
Force: snifferRaw.Force,
}
loadSniffer := make(map[C.SnifferType]struct{})
for _, snifferName := range snifferRaw.Sniffing {
find := false
for _, snifferType := range C.SnifferList {
if snifferType.String() == strings.ToUpper(snifferName) {
find = true
loadSniffer[snifferType] = struct{}{}
}
}
if !find {
return nil, fmt.Errorf("not find the sniffer[%s]", snifferName)
}
}
for st := range loadSniffer {
sniffer.Sniffers = append(sniffer.Sniffers, st)
}
return sniffer, nil
}

26
constant/sniffer.go Normal file
View file

@ -0,0 +1,26 @@
package constant
type Sniffer interface {
SupportNetwork() NetWork
SniffTCP(bytes []byte) (string, error)
Protocol() string
}
const (
TLS SnifferType = iota
)
var (
SnifferList = []SnifferType{TLS}
)
type SnifferType int
func (rt SnifferType) String() string {
switch rt {
case TLS:
return "TLS"
default:
return "Unknown"
}
}

View file

@ -3,8 +3,8 @@ package context
import (
"net"
CN "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
@ -19,7 +19,7 @@ func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
return &ConnContext{
id: id,
metadata: metadata,
conn: conn,
conn: CN.NewBufferedConn(conn),
}
}

View file

@ -17,6 +17,7 @@ import (
"github.com/Dreamacro/clash/component/profile"
"github.com/Dreamacro/clash/component/profile/cachefile"
"github.com/Dreamacro/clash/component/resolver"
SNI "github.com/Dreamacro/clash/component/sniffer"
"github.com/Dreamacro/clash/component/trie"
"github.com/Dreamacro/clash/config"
C "github.com/Dreamacro/clash/constant"
@ -78,6 +79,7 @@ func ApplyConfig(cfg *config.Config, force bool) {
updateDNS(cfg.DNS, cfg.Tun)
updateGeneral(cfg.General, force)
updateIPTables(cfg)
updateSniffer(cfg.Sniffer)
updateTun(cfg.Tun, cfg.DNS)
updateExperimental(cfg)
updateHosts(cfg.Hosts)
@ -87,6 +89,16 @@ func ApplyConfig(cfg *config.Config, force bool) {
log.SetLevel(cfg.General.LogLevel)
}
func updateSniffer(sniffer *config.Sniffer) {
if sniffer.Enable {
var err error
SNI.Dispatcher, err = SNI.NewSnifferDispatcher(sniffer.Sniffers, sniffer.Force)
if err != nil {
log.Errorln("Init Sniffer failed, err:%v", err)
}
}
}
func GetGeneral() *config.General {
ports := P.GetPorts()
var authenticator []string

View file

@ -1,14 +1,10 @@
package statistic
import (
"errors"
"net"
"time"
"github.com/Dreamacro/clash/common/snifer/tls"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"github.com/gofrs/uuid"
"go.uber.org/atomic"
@ -52,20 +48,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
n, err := tt.Conn.Write(b)
upload := int64(n)
tt.manager.PushUploaded(upload)
if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443" || tt.Metadata.DstPort == "993" || tt.Metadata.DstPort == "465" || tt.Metadata.DstPort == "995") {
header, err := tls.SniffTLS(b)
if err != nil {
// log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort)
} else {
resolver.InsertHostByIP(tt.Metadata.DstIP, header.Domain())
log.Warnln("use sni update host: %s ip: %s", header.Domain(), tt.Metadata.DstIP.String())
tt.manager.Leave(tt)
tt.Conn.Close()
return n, errors.New("sni update, break current link to avoid leaks")
}
}
tt.UploadTotal.Add(upload)
return n, err
}

View file

@ -14,6 +14,7 @@ import (
"github.com/Dreamacro/clash/component/nat"
P "github.com/Dreamacro/clash/component/process"
"github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/component/sniffer"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/provider"
icontext "github.com/Dreamacro/clash/context"
@ -36,6 +37,8 @@ var (
// default timeout for UDP session
udpTimeout = 60 * time.Second
snifferDispatcher *sniffer.SnifferDispatcher
)
func init() {
@ -294,6 +297,10 @@ func handleTCPConn(connCtx C.ConnContext) {
return
}
if sniffer.Dispatcher.Enable() {
sniffer.Dispatcher.Tcp(connCtx.Conn(), metadata)
}
proxy, rule, err := resolveMetadata(connCtx, metadata)
if err != nil {
log.Warnln("[Metadata] parse failed: %s", err.Error())