From 0a0b8074f4932a5438f0c618c46f0cceaf0c7808 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sat, 26 Mar 2022 18:34:15 +0800 Subject: [PATCH] refactor: rule-set and its provider --- rule/provider/classical_strategy.go | 54 +++++++++ rule/provider/domain_strategy.go | 57 ++++++++++ rule/provider/ipcidr_strategy.go | 44 ++++++++ rule/provider/provider.go | 169 +++++----------------------- 4 files changed, 186 insertions(+), 138 deletions(-) create mode 100644 rule/provider/classical_strategy.go create mode 100644 rule/provider/domain_strategy.go create mode 100644 rule/provider/ipcidr_strategy.go diff --git a/rule/provider/classical_strategy.go b/rule/provider/classical_strategy.go new file mode 100644 index 00000000..45700b81 --- /dev/null +++ b/rule/provider/classical_strategy.go @@ -0,0 +1,54 @@ +package provider + +import ( + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +type classicalStrategy struct { + rules []C.Rule + count int + shouldResolveIP bool +} + +func (c *classicalStrategy) Match(metadata *C.Metadata) bool { + for _, rule := range c.rules { + if rule.Match(metadata) { + return true + } + } + + return false +} + +func (c *classicalStrategy) Count() int { + return c.count +} + +func (c *classicalStrategy) ShouldResolveIP() bool { + return c.shouldResolveIP +} + +func (c *classicalStrategy) OnUpdate(rules []string) { + var classicalRules []C.Rule + shouldResolveIP := false + for _, rawRule := range rules { + ruleType, rule, params := ruleParse(rawRule) + r, err := parseRule(ruleType, rule, "", params) + if err != nil { + log.Warnln("parse rule error:[%s]", err.Error()) + } + + if !shouldResolveIP { + shouldResolveIP = shouldResolveIP || r.ShouldResolveIP() + } + + classicalRules = append(classicalRules, r) + } + + c.rules = classicalRules +} + +func NewClassicalStrategy() *classicalStrategy { + return &classicalStrategy{rules: []C.Rule{}} +} diff --git a/rule/provider/domain_strategy.go b/rule/provider/domain_strategy.go new file mode 100644 index 00000000..31ecb184 --- /dev/null +++ b/rule/provider/domain_strategy.go @@ -0,0 +1,57 @@ +package provider + +import ( + "github.com/Dreamacro/clash/component/trie" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + "strings" +) + +type domainStrategy struct { + shouldResolveIP bool + count int + domainRules *trie.DomainTrie +} + +func (d *domainStrategy) Match(metadata *C.Metadata) bool { + return d.domainRules != nil && d.domainRules.Search(metadata.Host) != nil +} + +func (d *domainStrategy) Count() int { + return d.count +} + +func (d *domainStrategy) ShouldResolveIP() bool { + return d.shouldResolveIP +} + +func (d *domainStrategy) OnUpdate(rules []string) { + domainTrie := trie.New() + for _, rule := range rules { + err := domainTrie.Insert(rule, "") + if err != nil { + log.Warnln("invalid domain:[%s]", rule) + } else { + d.count++ + } + } + + d.domainRules = domainTrie +} + +func ruleParse(ruleRaw string) (string, string, []string) { + item := strings.Split(ruleRaw, ",") + if len(item) == 1 { + return "", item[0], nil + } else if len(item) == 2 { + return item[0], item[1], nil + } else if len(item) > 2 { + return item[0], item[1], item[2:] + } + + return "", "", nil +} + +func NewDomainStrategy() *domainStrategy { + return &domainStrategy{shouldResolveIP: false} +} diff --git a/rule/provider/ipcidr_strategy.go b/rule/provider/ipcidr_strategy.go new file mode 100644 index 00000000..31533ba2 --- /dev/null +++ b/rule/provider/ipcidr_strategy.go @@ -0,0 +1,44 @@ +package provider + +import ( + "github.com/Dreamacro/clash/component/trie" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +type ipcidrStrategy struct { + count int + shouldResolveIP bool + trie *trie.IpCidrTrie +} + +func (i *ipcidrStrategy) Match(metadata *C.Metadata) bool { + return i.trie != nil && i.trie.IsContain(metadata.DstIP) +} + +func (i *ipcidrStrategy) Count() int { + return i.count +} + +func (i *ipcidrStrategy) ShouldResolveIP() bool { + return i.shouldResolveIP +} + +func (i *ipcidrStrategy) OnUpdate(rules []string) { + ipCidrTrie := trie.NewIpCidrTrie() + for _, rule := range rules { + err := ipCidrTrie.AddIpCidrForString(rule) + if err != nil { + log.Warnln("invalid Ipcidr:[%s]", rule) + } else { + i.count++ + } + } + + i.trie = ipCidrTrie + i.shouldResolveIP = i.count > 0 +} + +func NewIPCidrStrategy() *ipcidrStrategy { + return &ipcidrStrategy{} +} diff --git a/rule/provider/provider.go b/rule/provider/provider.go index 9ffc788f..19905bce 100644 --- a/rule/provider/provider.go +++ b/rule/provider/provider.go @@ -2,14 +2,10 @@ package provider import ( "encoding/json" - "errors" - "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" P "github.com/Dreamacro/clash/constant/provider" - "github.com/Dreamacro/clash/log" "gopkg.in/yaml.v2" "runtime" - "strings" "time" ) @@ -19,12 +15,8 @@ var ( type ruleSetProvider struct { *fetcher - behavior P.RuleType - shouldResolveIP bool - count int - DomainRules *trie.DomainTrie - IPCIDRRules *trie.IpCidrTrie - ClassicalRules []C.Rule + behavior P.RuleType + strategy ruleStrategy } type RuleSetProvider struct { @@ -39,6 +31,13 @@ type RulePayload struct { Rules []string `yaml:"payload"` } +type ruleStrategy interface { + Match(metadata *C.Metadata) bool + Count() int + ShouldResolveIP() bool + OnUpdate(rules []string) +} + func RuleProviders() map[string]P.RuleProvider { return ruleProviders } @@ -76,30 +75,11 @@ func (rp *ruleSetProvider) Behavior() P.RuleType { } func (rp *ruleSetProvider) Match(metadata *C.Metadata) bool { - if rp.count == 0 { - return false - } - - switch rp.behavior { - case P.Domain: - return rp.DomainRules != nil && rp.DomainRules.Search(metadata.Host) != nil - case P.IPCIDR: - return rp.IPCIDRRules != nil && rp.IPCIDRRules.IsContain(metadata.DstIP) - case P.Classical: - for _, rule := range rp.ClassicalRules { - if rule.Match(metadata) { - return true - } - } - - return false - default: - return false - } + return rp.strategy != nil && rp.strategy.Match(metadata) } func (rp *ruleSetProvider) ShouldResolveIP() bool { - return rp.shouldResolveIP + return rp.strategy.ShouldResolveIP() } func (rp *ruleSetProvider) AsRule(adaptor string) C.Rule { @@ -111,7 +91,7 @@ func (rp *ruleSetProvider) MarshalJSON() ([]byte, error) { map[string]interface{}{ "behavior": rp.behavior.String(), "name": rp.Name(), - "ruleCount": rp.count, + "ruleCount": rp.strategy.Count(), "type": rp.Type().String(), "updatedAt": rp.updatedAt, "vehicleType": rp.VehicleType().String(), @@ -125,23 +105,14 @@ func NewRuleSetProvider(name string, behavior P.RuleType, interval time.Duration onUpdate := func(elm interface{}) error { rulesRaw := elm.([]string) - rules, err := constructRules(rp.behavior, rulesRaw) - if err != nil { - return err - } - - if rp.behavior == P.Classical { - rp.count = len(rules.([]C.Rule)) - } else { - rp.count = len(rulesRaw) - } - - rp.setRules(rules) + rp.strategy.OnUpdate(rulesRaw) return nil } fetcher := newFetcher(name, interval, vehicle, rulesParse, onUpdate) rp.fetcher = fetcher + rp.strategy = newStrategy(behavior) + wrapper := &RuleSetProvider{ rp, } @@ -150,6 +121,22 @@ func NewRuleSetProvider(name string, behavior P.RuleType, interval time.Duration return wrapper } +func newStrategy(behavior P.RuleType) ruleStrategy { + switch behavior { + case P.Domain: + strategy := NewDomainStrategy() + return strategy + case P.IPCIDR: + strategy := NewIPCidrStrategy() + return strategy + case P.Classical: + strategy := NewClassicalStrategy() + return strategy + default: + return nil + } +} + func rulesParse(buf []byte) (interface{}, error) { rulePayload := RulePayload{} err := yaml.Unmarshal(buf, &rulePayload) @@ -159,97 +146,3 @@ func rulesParse(buf []byte) (interface{}, error) { return rulePayload.Rules, nil } - -func constructRules(behavior P.RuleType, rules []string) (interface{}, error) { - switch behavior { - case P.Domain: - return handleDomainRules(rules) - case P.IPCIDR: - return handleIpCidrRules(rules) - case P.Classical: - return handleClassicalRules(rules) - default: - return nil, errors.New("unknown behavior type") - } -} - -func handleDomainRules(rules []string) (interface{}, error) { - domainRules := trie.New() - for _, rawRule := range rules { - ruleType, rule, _ := ruleParse(rawRule) - if ruleType != "" { - return nil, errors.New("error format of domain") - } - - if err := domainRules.Insert(rule, ""); err != nil { - return nil, err - } - } - return domainRules, nil -} - -func handleIpCidrRules(rules []string) (interface{}, error) { - ipCidrRules := trie.NewIpCidrTrie() - for _, rawRule := range rules { - ruleType, rule, _ := ruleParse(rawRule) - if ruleType != "" { - return nil, errors.New("error format of ip-cidr") - } - - if err := ipCidrRules.AddIpCidrForString(rule); err != nil { - return nil, err - } - } - return ipCidrRules, nil -} - -func handleClassicalRules(rules []string) (interface{}, error) { - var classicalRules []C.Rule - for _, rawRule := range rules { - ruleType, rule, params := ruleParse(rawRule) - - r, err := parseRule(ruleType, rule, "", params) - if err != nil { - //return nil, err - log.Warnln("%s", err) - continue - } - - classicalRules = append(classicalRules, r) - } - return classicalRules, nil -} - -func ruleParse(ruleRaw string) (string, string, []string) { - item := strings.Split(ruleRaw, ",") - if len(item) == 1 { - return "", item[0], nil - } else if len(item) == 2 { - return item[0], item[1], nil - } else if len(item) > 2 { - return item[0], item[1], item[2:] - } - - return "", "", nil -} - -func (rp *ruleSetProvider) setRules(rules interface{}) { - switch rp.behavior { - case P.Domain: - rp.DomainRules = rules.(*trie.DomainTrie) - rp.shouldResolveIP = false - case P.Classical: - rp.ClassicalRules = rules.([]C.Rule) - for i := range rp.ClassicalRules { - if rp.ClassicalRules[i].ShouldResolveIP() { - rp.shouldResolveIP = true - break - } - } - case P.IPCIDR: - rp.IPCIDRRules = rules.(*trie.IpCidrTrie) - rp.shouldResolveIP = true - default: - rp.shouldResolveIP = false - } -}