diff --git a/component/trie/ipcidr_node.go b/component/trie/ipcidr_node.go new file mode 100644 index 00000000..acaf9a8f --- /dev/null +++ b/component/trie/ipcidr_node.go @@ -0,0 +1,44 @@ +package trie + +import "errors" + +var ( + ErrorOverMaxValue = errors.New("the value don't over max value") +) + +type IpCidrNode struct { + Mark bool + child map[uint32]*IpCidrNode + maxValue uint32 +} + +func NewIpCidrNode(mark bool, maxValue uint32) *IpCidrNode { + ipCidrNode := &IpCidrNode{ + Mark: mark, + child: map[uint32]*IpCidrNode{}, + maxValue: maxValue, + } + + return ipCidrNode +} + +func (n *IpCidrNode) addChild(value uint32) error { + if value > n.maxValue { + return ErrorOverMaxValue + } + + n.child[value] = NewIpCidrNode(false, n.maxValue) + return nil +} + +func (n *IpCidrNode) hasChild(value uint32) bool { + return n.getChild(value) != nil +} + +func (n *IpCidrNode) getChild(value uint32) *IpCidrNode { + if value <= n.maxValue { + return n.child[value] + } + + return nil +} diff --git a/component/trie/ipcidr_trie.go b/component/trie/ipcidr_trie.go new file mode 100644 index 00000000..8b931c2d --- /dev/null +++ b/component/trie/ipcidr_trie.go @@ -0,0 +1,255 @@ +package trie + +import ( + "github.com/Dreamacro/clash/log" + "net" +) + +type IPV6 bool + +const ( + ipv4GroupMaxValue = 0xFF + ipv6GroupMaxValue = 0xFFFF +) + +type IpCidrTrie struct { + ipv4Trie *IpCidrNode + ipv6Trie *IpCidrNode +} + +func NewIpCidrTrie() *IpCidrTrie { + return &IpCidrTrie{ + ipv4Trie: NewIpCidrNode(false, ipv4GroupMaxValue), + ipv6Trie: NewIpCidrNode(false, ipv6GroupMaxValue), + } +} + +func (trie *IpCidrTrie) AddIpCidr(ipCidr *net.IPNet) error { + subIpCidr, subCidr, isIpv4, err := ipCidrToSubIpCidr(ipCidr) + if err != nil { + return err + } + + for _, sub := range subIpCidr { + addIpCidr(trie, isIpv4, sub, subCidr/8) + } + + return nil +} + +func (trie *IpCidrTrie) AddIpCidrForString(ipCidr string) error { + _, ipNet, err := net.ParseCIDR(ipCidr) + if err != nil { + return err + } + + return trie.AddIpCidr(ipNet) +} + +func (trie *IpCidrTrie) IsContain(ip net.IP) bool { + ip, isIpv4 := checkAndConverterIp(ip) + if ip == nil { + return false + } + + var groupValues []uint32 + var ipCidrNode *IpCidrNode + + if isIpv4 { + ipCidrNode = trie.ipv4Trie + for _, group := range ip { + groupValues = append(groupValues, uint32(group)) + } + } else { + ipCidrNode = trie.ipv6Trie + for i := 0; i < len(ip); i += 2 { + groupValues = append(groupValues, getIpv6GroupValue(ip[i], ip[i+1])) + } + } + + return search(ipCidrNode, groupValues) != nil +} + +func (trie *IpCidrTrie) IsContainForString(ipString string) bool { + return trie.IsContain(net.ParseIP(ipString)) +} + +func ipCidrToSubIpCidr(ipNet *net.IPNet) ([]net.IP, int, bool, error) { + maskSize, _ := ipNet.Mask.Size() + var ( + ipList []net.IP + newMaskSize int + isIpv4 bool + err error + ) + + ip, isIpv4 := checkAndConverterIp(ipNet.IP) + ipList, newMaskSize, err = subIpCidr(ip, maskSize, isIpv4) + + return ipList, newMaskSize, isIpv4, err +} + +func subIpCidr(ip net.IP, maskSize int, isIpv4 bool) ([]net.IP, int, error) { + var subIpCidrList []net.IP + groupSize := 8 + if !isIpv4 { + groupSize = 16 + } + + if maskSize%groupSize == 0 { + return append(subIpCidrList, ip), maskSize, nil + } + + lastByteMaskSize := maskSize % 8 + lastByteMaskIndex := maskSize / 8 + subIpCidrNum := 0xFF >> lastByteMaskSize + for i := 0; i < subIpCidrNum; i++ { + subIpCidr := make([]byte, len(ip)) + copy(subIpCidr, ip) + subIpCidr[lastByteMaskIndex] += byte(i) + subIpCidrList = append(subIpCidrList, subIpCidr) + } + + newMaskSize := (lastByteMaskIndex + 1) * 8 + if !isIpv4 { + newMaskSize = (lastByteMaskIndex/2 + 1) * 16 + } + + return subIpCidrList, newMaskSize, nil +} + +func addIpCidr(trie *IpCidrTrie, isIpv4 bool, ip net.IP, groupSize int) { + if isIpv4 { + addIpv4Cidr(trie, ip, groupSize) + } else { + addIpv6Cidr(trie, ip, groupSize) + } +} + +func addIpv4Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) { + preNode := trie.ipv4Trie + node := preNode.getChild(uint32(ip[0])) + if node == nil { + err := preNode.addChild(uint32(ip[0])) + if err != nil { + return + } + + node = preNode.getChild(uint32(ip[0])) + } + + for i := 1; i < groupSize; i++ { + if node.Mark { + return + } + + groupValue := uint32(ip[i]) + if !node.hasChild(groupValue) { + err := node.addChild(groupValue) + if err != nil { + log.Errorln(err.Error()) + } + } + + preNode = node + node = node.getChild(groupValue) + if node == nil { + err := preNode.addChild(uint32(ip[i-1])) + if err != nil { + return + } + + node = preNode.getChild(uint32(ip[i-1])) + } + } + + node.Mark = true + cleanChild(node) +} + +func addIpv6Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) { + preNode := trie.ipv6Trie + node := preNode.getChild(getIpv6GroupValue(ip[0], ip[1])) + if node == nil { + err := preNode.addChild(getIpv6GroupValue(ip[0], ip[1])) + if err != nil { + return + } + + node = preNode.getChild(getIpv6GroupValue(ip[0], ip[1])) + } + + for i := 2; i < groupSize; i += 2 { + if node.Mark { + return + } + + groupValue := getIpv6GroupValue(ip[i], ip[i+1]) + if !node.hasChild(groupValue) { + err := node.addChild(groupValue) + if err != nil { + log.Errorln(err.Error()) + } + } + + preNode = node + node = node.getChild(groupValue) + if node == nil { + err := preNode.addChild(getIpv6GroupValue(ip[i-2], ip[i-1])) + if err != nil { + return + } + + node = preNode.getChild(getIpv6GroupValue(ip[i-2], ip[i-1])) + } + } + + node.Mark = true + cleanChild(node) +} + +func getIpv6GroupValue(high, low byte) uint32 { + return (uint32(high) << 8) | uint32(low) +} + +func cleanChild(node *IpCidrNode) { + for i := uint32(0); i < uint32(len(node.child)); i++ { + delete(node.child, i) + } +} + +func search(root *IpCidrNode, groupValues []uint32) *IpCidrNode { + node := root.getChild(groupValues[0]) + if node == nil || node.Mark { + return node + } + + for _, value := range groupValues[1:] { + if !node.hasChild(value) { + return nil + } + + node = node.getChild(value) + + if node == nil || node.Mark { + return node + } + } + + return nil +} + +// return net.IP To4 or To16 and is ipv4 +func checkAndConverterIp(ip net.IP) (net.IP, bool) { + ipResult := ip.To4() + if ipResult == nil { + ipResult = ip.To16() + if ipResult == nil { + return nil, false + } + + return ipResult, false + } + + return ipResult, true +} diff --git a/component/trie/trie_test.go b/component/trie/trie_test.go new file mode 100644 index 00000000..ec4f1fd2 --- /dev/null +++ b/component/trie/trie_test.go @@ -0,0 +1,82 @@ +package trie + +import ( + "net" + "testing" +) +import "github.com/stretchr/testify/assert" + +func TestIpv4AddSuccess(t *testing.T) { + trie := NewIpCidrTrie() + err := trie.AddIpCidrForString("10.0.0.2/16") + assert.Equal(t, nil, err) +} + +func TestIpv4AddFail(t *testing.T) { + trie := NewIpCidrTrie() + err := trie.AddIpCidrForString("333.00.23.2/23") + assert.IsType(t, new(net.ParseError), err) + + err = trie.AddIpCidrForString("22.3.34.2/222") + assert.IsType(t, new(net.ParseError), err) + + err = trie.AddIpCidrForString("2.2.2.2") + assert.IsType(t, new(net.ParseError), err) +} + +func TestIpv4Search(t *testing.T) { + trie := NewIpCidrTrie() + assert.NoError(t, trie.AddIpCidrForString("129.2.36.0/16")) + assert.NoError(t, trie.AddIpCidrForString("10.2.36.0/18")) + assert.NoError(t, trie.AddIpCidrForString("16.2.23.0/24")) + assert.NoError(t, trie.AddIpCidrForString("11.2.13.2/26")) + assert.NoError(t, trie.AddIpCidrForString("55.5.6.3/8")) + assert.NoError(t, trie.AddIpCidrForString("66.23.25.4/6")) + assert.Equal(t, true, trie.IsContainForString("129.2.3.65")) + assert.Equal(t, false, trie.IsContainForString("15.2.3.1")) + assert.Equal(t, true, trie.IsContainForString("11.2.13.1")) + assert.Equal(t, true, trie.IsContainForString("55.0.0.0")) + assert.Equal(t, true, trie.IsContainForString("64.0.0.0")) + assert.Equal(t, false, trie.IsContainForString("128.0.0.0")) + + assert.Equal(t, false, trie.IsContain(net.ParseIP("22"))) + assert.Equal(t, false, trie.IsContain(net.ParseIP(""))) +} + +func TestIpv6AddSuccess(t *testing.T) { + trie := NewIpCidrTrie() + err := trie.AddIpCidrForString("2001:0db8:02de:0000:0000:0000:0000:0e13/32") + assert.Equal(t, nil, err) + + err = trie.AddIpCidrForString("2001:1db8:f2de::0e13/18") + assert.Equal(t, nil, err) +} + +func TestIpv6AddFail(t *testing.T) { + trie := NewIpCidrTrie() + err := trie.AddIpCidrForString("2001::25de::cade/23") + assert.IsType(t, new(net.ParseError), err) + + err = trie.AddIpCidrForString("2001:0fa3:25de::cade/222") + assert.IsType(t, new(net.ParseError), err) + + err = trie.AddIpCidrForString("2001:0fa3:25de::cade") + assert.IsType(t, new(net.ParseError), err) +} + +func TestIpv6Search(t *testing.T) { + trie := NewIpCidrTrie() + assert.NoError(t, trie.AddIpCidrForString("2001:b28:f23d:f001::e/128")) + assert.NoError(t, trie.AddIpCidrForString("2001:67c:4e8:f002::e/12")) + assert.NoError(t, trie.AddIpCidrForString("2001:b28:f23d:f003::e/96")) + assert.NoError(t, trie.AddIpCidrForString("2001:67c:4e8:f002::a/32")) + assert.NoError(t, trie.AddIpCidrForString("2001:67c:4e8:f004::a/60")) + assert.NoError(t, trie.AddIpCidrForString("2001:b28:f23f:f005::a/64")) + assert.Equal(t, true, trie.IsContainForString("2001:b28:f23d:f001::e")) + assert.Equal(t, false, trie.IsContainForString("2222::fff2")) + assert.Equal(t, true, trie.IsContainForString("2000::ffa0")) + assert.Equal(t, true, trie.IsContainForString("2001:b28:f23f:f005:5662::")) + assert.Equal(t, true, trie.IsContainForString("2001:67c:4e8:9666::1213")) + + assert.Equal(t, false, trie.IsContain(net.ParseIP("22233:22"))) +} diff --git a/config/config.go b/config/config.go index 37482879..461db58e 100644 --- a/config/config.go +++ b/config/config.go @@ -122,10 +122,10 @@ type Config struct { Hosts *trie.DomainTrie Profile *Profile Rules []C.Rule - RuleProviders map[string]C.Rule Users []auth.AuthUser Proxies map[string]C.Proxy Providers map[string]providerTypes.ProxyProvider + RuleProviders map[string]*providerTypes.RuleProvider } type RawDNS struct { @@ -482,9 +482,28 @@ time = ClashTime() return nil } -func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[string]C.Rule, error) { +func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[string]*providerTypes.RuleProvider, error) { + ruleProviders := map[string]*providerTypes.RuleProvider{} + + // parse rule provider + for name, mapping := range cfg.RuleProvider { + rp, err := R.ParseRuleProvider(name, mapping) + if err != nil { + return nil, nil, err + } + + ruleProviders[name] = &rp + R.SetRuleProvider(&rp) + } + + for _, provider := range ruleProviders { + log.Infoln("Start initial provider %s", (*provider).Name()) + if err := (*provider).Initial(); err != nil { + return nil, nil, fmt.Errorf("initial rule provider %s error: %w", (*provider).Name(), err) + } + } + rules := []C.Rule{} - ruleProviders := map[string]C.Rule{} rulesConfig := cfg.Rule mode := cfg.Mode diff --git a/constant/rule.go b/constant/rule.go index e2087604..df7a5d19 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -13,6 +13,7 @@ const ( DstPort Process Script + RuleSet MATCH ) @@ -44,6 +45,8 @@ func (rt RuleType) String() string { return "Script" case MATCH: return "Match" + case RuleSet: + return "RuleSet" default: return "Unknown" } diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 05de2a10..27d36a82 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -75,7 +75,7 @@ func ApplyConfig(cfg *config.Config, force bool) { updateUsers(cfg.Users) updateProxies(cfg.Proxies, cfg.Providers) - updateRules(cfg.Rules) + updateRules(cfg.Rules, cfg.RuleProviders) updateHosts(cfg.Hosts) updateProfile(cfg) updateIPTables(cfg.DNS, cfg.General) @@ -176,8 +176,8 @@ func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.Pro tunnel.UpdateProxies(proxies, providers) } -func updateRules(rules []C.Rule) { - tunnel.UpdateRules(rules) +func updateRules(rules []C.Rule, ruleProviders map[string]*provider.RuleProvider) { + tunnel.UpdateRules(rules, ruleProviders) } func updateGeneral(general *config.General, force bool) { diff --git a/hub/route/provider.go b/hub/route/provider.go index 0b599445..163ee26a 100644 --- a/hub/route/provider.go +++ b/hub/route/provider.go @@ -75,3 +75,54 @@ func findProviderByName(next http.Handler) http.Handler { next.ServeHTTP(w, r.WithContext(ctx)) }) } + +func ruleProviderRouter() http.Handler { + r := chi.NewRouter() + r.Get("/", getRuleProviders) + r.Route("/{name}", func(r chi.Router) { + r.Use(parseRuleProviderName, findRuleProviderByName) + r.Put("/", updateRuleProvider) + }) + return r +} + +func getRuleProviders(w http.ResponseWriter, r *http.Request) { + ruleProviders := tunnel.RuleProviders() + render.JSON(w, r, render.M{ + "providers": ruleProviders, + }) +} + +func updateRuleProvider(w http.ResponseWriter, r *http.Request) { + provider := r.Context().Value(CtxKeyProvider).(*provider.RuleProvider) + if err := (*provider).Update(); err != nil { + render.Status(r, http.StatusServiceUnavailable) + render.JSON(w, r, newError(err.Error())) + } + + render.NoContent(w, r) +} + +func parseRuleProviderName(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + name := getEscapeParam(r, "name") + ctx := context.WithValue(r.Context(), CtxKeyProviderName, name) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func findRuleProviderByName(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + name := r.Context().Value(CtxKeyProviderName).(string) + providers := tunnel.RuleProviders() + provider, exist := providers[name] + if !exist { + render.Status(r, http.StatusNotFound) + render.JSON(w, r, ErrNotFound) + return + } + + ctx := context.WithValue(r.Context(), CtxKeyProvider, provider) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/hub/route/server.go b/hub/route/server.go index e01696be..783996ba 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -70,6 +70,7 @@ func Start(addr string, secret string) { r.Mount("/rules", ruleRouter()) r.Mount("/connections", connectionRouter()) r.Mount("/providers/proxies", proxyProviderRouter()) + r.Mount("/providers/rules", ruleProviderRouter()) }) if uiPath != "" { diff --git a/rule/fetcher.go b/rule/fetcher.go new file mode 100644 index 00000000..89f45ab1 --- /dev/null +++ b/rule/fetcher.go @@ -0,0 +1,186 @@ +package rules + +import ( + "bytes" + "crypto/md5" + P "github.com/Dreamacro/clash/constant/provider" + "github.com/Dreamacro/clash/log" + "io/ioutil" + "os" + "path/filepath" + "time" +) + +var ( + fileMode os.FileMode = 0666 + dirMode os.FileMode = 0755 +) + +type parser = func([]byte) (interface{}, error) + +type fetcher struct { + name string + vehicle P.Vehicle + updatedAt *time.Time + ticker *time.Ticker + done chan struct{} + hash [16]byte + parser parser + onUpdate func(interface{}) error +} + +func (f *fetcher) Name() string { + return f.name +} + +func (f *fetcher) VehicleType() P.VehicleType { + return f.vehicle.Type() +} + +func (f *fetcher) Initial() (interface{}, error) { + var ( + buf []byte + hasLocal bool + err error + ) + + if stat, fErr := os.Stat(f.vehicle.Path()); fErr == nil { + buf, err = ioutil.ReadFile(f.vehicle.Path()) + modTime := stat.ModTime() + f.updatedAt = &modTime + hasLocal = true + } else { + buf, err = f.vehicle.Read() + } + + if err != nil { + return nil, err + } + + rules, err := f.parser(buf) + if err != nil { + if !hasLocal { + return nil, err + } + + buf, err = f.vehicle.Read() + if err != nil { + return nil, err + } + + rules, err = f.parser(buf) + if err != nil { + return nil, err + } + + hasLocal = false + } + + if f.vehicle.Type() != P.File && !hasLocal { + if err := safeWrite(f.vehicle.Path(), buf); err != nil { + return nil, err + } + } + + f.hash = md5.Sum(buf) + if f.ticker != nil { + go f.pullLoop() + } + + return rules, nil +} + +func (f *fetcher) Update() (interface{}, bool, error) { + buf, err := f.vehicle.Read() + if err != nil { + return nil, false, err + } + + now := time.Now() + hash := md5.Sum(buf) + if bytes.Equal(f.hash[:], hash[:]) { + f.updatedAt = &now + return nil, true, nil + } + + rules, err := f.parser(buf) + if err != nil { + return nil, false, err + } + + if f.vehicle.Type() != P.File { + if err := safeWrite(f.vehicle.Path(), buf); err != nil { + return nil, false, err + } + } + + f.updatedAt = &now + f.hash = hash + + return rules, false, nil +} + +func (f *fetcher) Destroy() error { + if f.ticker != nil { + f.done <- struct{}{} + } + return nil +} + +func newFetcher(name string, interval time.Duration, vehicle P.Vehicle, parser parser, onUpdate func(interface{}) error) *fetcher { + var ticker *time.Ticker + if interval != 0 { + ticker = time.NewTicker(interval) + } + + return &fetcher{ + name: name, + ticker: ticker, + vehicle: vehicle, + parser: parser, + done: make(chan struct{}, 1), + onUpdate: onUpdate, + } +} + +func safeWrite(path string, buf []byte) error { + dir := filepath.Dir(path) + + if _, err := os.Stat(dir); os.IsNotExist(err) { + if err := os.MkdirAll(dir, dirMode); err != nil { + return err + } + } + + return ioutil.WriteFile(path, buf, fileMode) +} + +func (f *fetcher) pullLoop() { + for { + select { + case <-f.ticker.C: + elm, same, err := f.Update() + if err != nil { + log.Warnln("[Provider] %s pull error: %s", f.Name(), err.Error()) + continue + } + + if same { + log.Debugln("[Provider] %s's rules doesn't change", f.Name()) + continue + } + + log.Infoln("[Provider] %s's rules update", f.Name()) + if f.onUpdate != nil { + err := f.onUpdate(elm) + if err != nil { + log.Infoln("[Provider] %s update failed", f.Name()) + } + } + + case <-f.done: + f.ticker.Stop() + return + } + } +} diff --git a/rule/parser.go b/rule/parser.go index fa73c2ae..6a72f52f 100644 --- a/rule/parser.go +++ b/rule/parser.go @@ -2,8 +2,11 @@ package rules import ( "fmt" - + "github.com/Dreamacro/clash/adapter/provider" + "github.com/Dreamacro/clash/common/structure" C "github.com/Dreamacro/clash/constant" + P "github.com/Dreamacro/clash/constant/provider" + "time" ) func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { @@ -42,9 +45,52 @@ func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { parsed, parseErr = NewProcess(payload, target, ruleExtra) case "MATCH": parsed = NewMatch(target, ruleExtra) + case "RULE-SET": + parsed, parseErr = NewRuleSet(payload, target) default: parseErr = fmt.Errorf("unsupported rule type %s", tp) } return parsed, parseErr } + +type ruleProviderSchema struct { + Type string `provider:"type"` + Behavior string `provider:"behavior"` + Path string `provider:"path"` + URL string `provider:"url,omitempty"` + Interval int `provider:"interval,omitempty"` +} + +func ParseRuleProvider(name string, mapping map[string]interface{}) (P.RuleProvider, error) { + schema := &ruleProviderSchema{} + decoder := structure.NewDecoder(structure.Option{TagName: "provider", WeaklyTypedInput: true}) + if err := decoder.Decode(mapping, schema); err != nil { + return nil, err + } + var behavior P.RuleType + + switch schema.Behavior { + case "domain": + behavior = P.Domain + case "ipcidr": + behavior = P.IPCIDR + case "classical": + behavior = P.Classical + default: + return nil, fmt.Errorf("unsupported behavior type: %s", schema.Behavior) + } + + path := C.Path.Resolve(schema.Path) + var vehicle P.Vehicle + switch schema.Type { + case "file": + vehicle = provider.NewFileVehicle(path) + case "http": + vehicle = provider.NewHTTPVehicle(schema.URL, path) + default: + return nil, fmt.Errorf("unsupported vehicle type: %s", schema.Type) + } + + return NewRuleSetProvider(name, behavior, time.Duration(uint(schema.Interval))*time.Second, vehicle), nil +} diff --git a/rule/provider.go b/rule/provider.go new file mode 100644 index 00000000..48022fb2 --- /dev/null +++ b/rule/provider.go @@ -0,0 +1,246 @@ +package rules + +import ( + "encoding/json" + "errors" + "github.com/Dreamacro/clash/component/trie" + C "github.com/Dreamacro/clash/constant" + P "github.com/Dreamacro/clash/constant/provider" + "gopkg.in/yaml.v2" + "runtime" + "strings" + "time" +) + +var ( + ruleProviders = map[string]*P.RuleProvider{} +) + +type ruleSetProvider struct { + *fetcher + behavior P.RuleType + shouldResolveIP bool + count int + DomainRules *trie.DomainTrie + IPCIDRRules *trie.IpCidrTrie + ClassicalRules []C.Rule +} + +type RuleSetProvider struct { + *ruleSetProvider +} + +type RulePayload struct { + /** + key: Domain or IP Cidr + value: Rule type or is empty + */ + Rules []string `yaml:"payload"` +} + +func RuleProviders() map[string]*P.RuleProvider { + return ruleProviders +} + +func SetRuleProvider(ruleProvider *P.RuleProvider) { + if ruleProvider != nil { + ruleProviders[(*ruleProvider).Name()] = ruleProvider + } +} + +func (rp *ruleSetProvider) Type() P.ProviderType { + return P.Rule +} + +func (rp *ruleSetProvider) Initial() error { + elm, err := rp.fetcher.Initial() + if err != nil { + return err + } + + return rp.fetcher.onUpdate(elm) +} + +func (rp *ruleSetProvider) Update() error { + elm, same, err := rp.fetcher.Update() + if err == nil && !same { + return rp.fetcher.onUpdate(elm) + } + + return err +} + +func (rp *ruleSetProvider) Behavior() P.RuleType { + return rp.behavior +} + +func (rp *ruleSetProvider) Match(metadata *C.Metadata) bool { + switch rp.behavior { + case P.Domain: + return rp.DomainRules.Search(metadata.Host) != nil + case P.IPCIDR: + return rp.IPCIDRRules.IsContain(metadata.DstIP) + case P.Classical: + for _, rule := range rp.ClassicalRules { + if rule.Match(metadata) { + return true + } + } + + return false + default: + return false + } +} + +func (rp *ruleSetProvider) ShouldResolveIP() bool { + return rp.shouldResolveIP +} + +func (rp *ruleSetProvider) AsRule(adaptor string) C.Rule { + panic("implement me") +} + +func (rp *ruleSetProvider) MarshalJSON() ([]byte, error) { + return json.Marshal( + map[string]interface{}{ + "behavior": rp.behavior.String(), + "name": rp.Name(), + "ruleCount": rp.count, + "type": rp.Type().String(), + "updatedAt": rp.updatedAt, + "vehicleType": rp.VehicleType().String(), + }) +} + +func NewRuleSetProvider(name string, behavior P.RuleType, interval time.Duration, vehicle P.Vehicle) P.RuleProvider { + rp := &ruleSetProvider{ + behavior: behavior, + } + + onUpdate := func(elm interface{}) error { + rulesRaw := elm.([]string) + rp.count = len(rulesRaw) + rules, err := constructRules(rp.behavior, rulesRaw) + if err != nil { + return err + } + + rp.shouldResolveIP = false + rp.setRules(rules) + return nil + } + + fetcher := newFetcher(name, interval, vehicle, rulesParse, onUpdate) + rp.fetcher = fetcher + wrapper := &RuleSetProvider{ + rp, + } + + runtime.SetFinalizer(wrapper, rp.fetcher.Destroy()) + return wrapper +} + +func rulesParse(buf []byte) (interface{}, error) { + rulePayload := RulePayload{} + err := yaml.Unmarshal(buf, &rulePayload) + if err != nil { + return nil, err + } + + return rulePayload.Rules, nil +} + +func constructRules(behavior P.RuleType, rules []string) (interface{}, error) { + switch behavior { + case P.Domain: + return handleDomainRules(rules) + case P.IPCIDR: + return handleIpCidrRules(rules) + case P.Classical: + return handleClassicalRules(rules) + default: + return nil, errors.New("unknown behavior type") + } +} + +func handleDomainRules(rules []string) (interface{}, error) { + domainRules := trie.New() + for _, rawRule := range rules { + ruleType, rule, _ := ruleParse(rawRule) + if ruleType != "" { + return nil, errors.New("error format of domain") + } + + if err := domainRules.Insert(rule, ""); err != nil { + return nil, err + } + } + return domainRules, nil +} + +func handleIpCidrRules(rules []string) (interface{}, error) { + ipCidrRules := trie.NewIpCidrTrie() + for _, rawRule := range rules { + ruleType, rule, _ := ruleParse(rawRule) + if ruleType != "" { + return nil, errors.New("error format of ip-cidr") + } + + if err := ipCidrRules.AddIpCidrForString(rule); err != nil { + return nil, err + } + } + return ipCidrRules, nil +} + +func handleClassicalRules(rules []string) (interface{}, error) { + var classicalRules []C.Rule + for _, rawRule := range rules { + ruleType, rule, params := ruleParse(rawRule) + if ruleType == "RULE-SET" { + return nil, errors.New("error rule type") + } + + r, err := ParseRule(ruleType, rule, "", params) + if err != nil { + return nil, err + } + + classicalRules = append(classicalRules, r) + } + return classicalRules, nil +} + +func ruleParse(ruleRaw string) (string, string, []string) { + item := strings.Split(ruleRaw, ",") + if len(item) == 1 { + return "", item[0], nil + } else if len(item) == 2 { + return item[0], item[1], nil + } else if len(item) > 2 { + return item[0], item[1], item[2:] + } + + return "", "", nil +} + +func (rp *ruleSetProvider) setRules(rules interface{}) { + switch rp.behavior { + case P.Domain: + rp.DomainRules = rules.(*trie.DomainTrie) + rp.shouldResolveIP = false + case P.Classical: + rp.ClassicalRules = rules.([]C.Rule) + for i := range rp.ClassicalRules { + if rp.ClassicalRules[i].ShouldResolveIP() { + rp.shouldResolveIP = true + break + } + } + case P.IPCIDR: + rp.IPCIDRRules = rules.(*trie.IpCidrTrie) + rp.shouldResolveIP = true + default: + } +} diff --git a/rule/rule_set.go b/rule/rule_set.go new file mode 100644 index 00000000..82b03175 --- /dev/null +++ b/rule/rule_set.go @@ -0,0 +1,57 @@ +package rules + +import ( + "fmt" + C "github.com/Dreamacro/clash/constant" + P "github.com/Dreamacro/clash/constant/provider" +) + +type RuleSet struct { + ruleProviderName string + adapter string + ruleProvider *P.RuleProvider +} + +func (rs *RuleSet) RuleType() C.RuleType { + return C.RuleSet +} + +func (rs *RuleSet) Match(metadata *C.Metadata) bool { + return rs.getProviders().Match(metadata) +} + +func (rs *RuleSet) Adapter() string { + return rs.adapter +} + +func (rs *RuleSet) Payload() string { + return rs.getProviders().Name() +} + +func (rs *RuleSet) ShouldResolveIP() bool { + return rs.getProviders().Behavior() != P.Domain +} +func (rs *RuleSet) getProviders() P.RuleProvider { + if rs.ruleProvider == nil { + rp := RuleProviders()[rs.ruleProviderName] + rs.ruleProvider = rp + } + + return *rs.ruleProvider +} + +func (rs *RuleSet) RuleExtra() *C.RuleExtra { + return nil +} + +func NewRuleSet(ruleProviderName string, adapter string) (*RuleSet, error) { + rp, ok := RuleProviders()[ruleProviderName] + if !ok { + return nil, fmt.Errorf("rule set %s not found", ruleProviderName) + } + return &RuleSet{ + ruleProviderName: ruleProviderName, + adapter: adapter, + ruleProvider: rp, + }, nil +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index ba5ef9c0..4551bf09 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -20,14 +20,14 @@ import ( ) var ( - tcpQueue = make(chan C.ConnContext, 200) - udpQueue = make(chan *inbound.PacketAdapter, 200) - natTable = nat.New() - rules []C.Rule - proxies = make(map[string]C.Proxy) - providers map[string]provider.ProxyProvider - configMux sync.RWMutex - + tcpQueue = make(chan C.ConnContext, 200) + udpQueue = make(chan *inbound.PacketAdapter, 200) + natTable = nat.New() + rules []C.Rule + proxies = make(map[string]C.Proxy) + providers map[string]provider.ProxyProvider + configMux sync.RWMutex + ruleProviders map[string]*provider.RuleProvider // Outbound Rule mode = Rule @@ -57,9 +57,10 @@ func Rules() []C.Rule { } // UpdateRules handle update rules -func UpdateRules(newRules []C.Rule) { +func UpdateRules(newRules []C.Rule, rp map[string]*provider.RuleProvider) { configMux.Lock() rules = newRules + ruleProviders = rp configMux.Unlock() } @@ -73,6 +74,11 @@ func Providers() map[string]provider.ProxyProvider { return providers } +// RuleProviders return all loaded rule providers +func RuleProviders() map[string]*provider.RuleProvider { + return ruleProviders +} + // UpdateProxies handle update proxies func UpdateProxies(newProxies map[string]C.Proxy, newProviders map[string]provider.ProxyProvider) { configMux.Lock()