diff --git a/rules/provider/classical_strategy.go b/rules/provider/classical_strategy.go index 25360ec7..e187e213 100644 --- a/rules/provider/classical_strategy.go +++ b/rules/provider/classical_strategy.go @@ -37,36 +37,38 @@ func (c *classicalStrategy) ShouldFindProcess() bool { return c.shouldFindProcess } -func (c *classicalStrategy) OnUpdate(rules []string) { - var classicalRules []C.Rule - shouldResolveIP := false - for _, rawRule := range rules { - ruleType, rule, params := ruleParse(rawRule) +func (c *classicalStrategy) Reset() { + c.rules = nil + c.count = 0 + c.shouldFindProcess = false + c.shouldResolveIP = false +} - if ruleType == "PROCESS-NAME" { +func (c *classicalStrategy) Insert(rule string) { + ruleType, rule, params := ruleParse(rule) + + if ruleType == "PROCESS-NAME" { + c.shouldFindProcess = true + } + + r, err := c.parse(ruleType, rule, "", params) + if err != nil { + log.Warnln("parse rule error:[%s]", err.Error()) + } else { + if r.ShouldResolveIP() { + c.shouldResolveIP = true + } + if r.ShouldFindProcess() { c.shouldFindProcess = true } - r, err := c.parse(ruleType, rule, "", params) - if err != nil { - log.Warnln("parse rule error:[%s]", err.Error()) - } else { - if !shouldResolveIP { - shouldResolveIP = r.ShouldResolveIP() - } - - if !c.shouldFindProcess { - c.shouldFindProcess = r.ShouldFindProcess() - } - - classicalRules = append(classicalRules, r) - } + c.rules = append(c.rules, r) + c.count++ } - - c.rules = classicalRules - c.count = len(classicalRules) } +func (c *classicalStrategy) FinishInsert() {} + func ruleParse(ruleRaw string) (string, string, []string) { item := strings.Split(ruleRaw, ",") if len(item) == 1 { diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index a2cb795d..d686d598 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -7,8 +7,9 @@ import ( ) type domainStrategy struct { - count int - domainRules *trie.DomainSet + count int + domainTrie *trie.DomainTrie[struct{}] + domainSet *trie.DomainSet } func (d *domainStrategy) ShouldFindProcess() bool { @@ -16,7 +17,7 @@ func (d *domainStrategy) ShouldFindProcess() bool { } func (d *domainStrategy) Match(metadata *C.Metadata) bool { - return d.domainRules != nil && d.domainRules.Has(metadata.RuleHost()) + return d.domainSet != nil && d.domainSet.Has(metadata.RuleHost()) } func (d *domainStrategy) Count() int { @@ -27,16 +28,24 @@ func (d *domainStrategy) ShouldResolveIP() bool { return false } -func (d *domainStrategy) OnUpdate(rules []string) { - domainTrie := trie.New[struct{}]() - for _, rule := range rules { - err := domainTrie.Insert(rule, struct{}{}) - if err != nil { - log.Warnln("invalid domain:[%s]", rule) - } +func (d *domainStrategy) Reset() { + d.domainTrie = trie.New[struct{}]() + d.domainSet = nil + d.count = 0 +} + +func (d *domainStrategy) Insert(rule string) { + err := d.domainTrie.Insert(rule, struct{}{}) + if err != nil { + log.Warnln("invalid domain:[%s]", rule) + } else { + d.count++ } - d.domainRules = domainTrie.NewDomainSet() - d.count = len(rules) +} + +func (d *domainStrategy) FinishInsert() { + d.domainSet = d.domainTrie.NewDomainSet() + d.domainTrie = nil } func NewDomainStrategy() *domainStrategy { diff --git a/rules/provider/ipcidr_strategy.go b/rules/provider/ipcidr_strategy.go index 88228301..f54302f1 100644 --- a/rules/provider/ipcidr_strategy.go +++ b/rules/provider/ipcidr_strategy.go @@ -28,23 +28,24 @@ func (i *ipcidrStrategy) ShouldResolveIP() bool { return i.shouldResolveIP } -func (i *ipcidrStrategy) OnUpdate(rules []string) { - ipCidrTrie := trie.NewIpCidrTrie() - count := 0 - for _, rule := range rules { - err := ipCidrTrie.AddIpCidrForString(rule) - if err != nil { - log.Warnln("invalid Ipcidr:[%s]", rule) - } else { - count++ - } - } - - i.trie = ipCidrTrie - i.count = count - i.shouldResolveIP = i.count > 0 +func (i *ipcidrStrategy) Reset() { + i.trie = trie.NewIpCidrTrie() + i.count = 0 + i.shouldResolveIP = false } +func (i *ipcidrStrategy) Insert(rule string) { + err := i.trie.AddIpCidrForString(rule) + if err != nil { + log.Warnln("invalid Ipcidr:[%s]", rule) + } else { + i.shouldResolveIP = true + i.count++ + } +} + +func (i *ipcidrStrategy) FinishInsert() {} + func NewIPCidrStrategy() *ipcidrStrategy { return &ipcidrStrategy{} } diff --git a/rules/provider/provider.go b/rules/provider/provider.go index 175917c2..c46861e1 100644 --- a/rules/provider/provider.go +++ b/rules/provider/provider.go @@ -1,13 +1,19 @@ package provider import ( + "bufio" + "bytes" "encoding/json" + "errors" + "gopkg.in/yaml.v3" + "io" + "runtime" + "time" + + "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/component/resource" C "github.com/Dreamacro/clash/constant" P "github.com/Dreamacro/clash/constant/provider" - "gopkg.in/yaml.v3" - "runtime" - "time" ) var ( @@ -29,8 +35,8 @@ type RulePayload struct { key: Domain or IP Cidr value: Rule type or is empty */ - Rules []string `yaml:"payload"` - Rules2 []string `yaml:"rules"` + Payload []string `yaml:"payload"` + Rules []string `yaml:"rules"` } type ruleStrategy interface { @@ -38,7 +44,9 @@ type ruleStrategy interface { Count() int ShouldResolveIP() bool ShouldFindProcess() bool - OnUpdate(rules []string) + Reset() + Insert(rule string) + FinishInsert() } func RuleProviders() map[string]P.RuleProvider { @@ -114,13 +122,12 @@ func NewRuleSetProvider(name string, behavior P.RuleType, interval time.Duration } onUpdate := func(elm interface{}) { - rulesRaw := elm.([]string) - rp.strategy.OnUpdate(rulesRaw) + strategy := elm.(ruleStrategy) + rp.strategy = strategy } - fetcher := resource.NewFetcher(name, interval, vehicle, rulesParse, onUpdate) - rp.Fetcher = fetcher rp.strategy = newStrategy(behavior, parse) + rp.Fetcher = resource.NewFetcher(name, interval, vehicle, func(bytes []byte) (any, error) { return rulesParse(bytes, newStrategy(behavior, parse)) }, onUpdate) wrapper := &RuleSetProvider{ rp, @@ -147,12 +154,75 @@ func newStrategy(behavior P.RuleType, parse func(tp, payload, target string, par } } -func rulesParse(buf []byte) (any, error) { - rulePayload := RulePayload{} - err := yaml.Unmarshal(buf, &rulePayload) - if err != nil { - return nil, err +var ErrNoPayload = errors.New("file must have a `payload` field") + +func rulesParse(buf []byte, strategy ruleStrategy) (any, error) { + strategy.Reset() + + schema := &RulePayload{} + + reader := bufio.NewReader(bytes.NewReader(buf)) + + firstLineBuffer := pool.GetBuffer() + defer pool.PutBuffer(firstLineBuffer) + firstLineLength := 0 + + for { + line, isPrefix, err := reader.ReadLine() + if err != nil { + if err == io.EOF { + if firstLineLength == 0 { // find payload head + return nil, ErrNoPayload + } + break + } + return nil, err + } + firstLineBuffer.Write(line) // need a copy because the returned buffer is only valid until the next call to ReadLine + if isPrefix { + // If the line was too long for the buffer then isPrefix is set and the + // beginning of the line is returned. The rest of the line will be returned + // from future calls. + continue + } + if firstLineLength == 0 { // find payload head + firstLineBuffer.WriteByte('\n') + firstLineLength = firstLineBuffer.Len() + firstLineBuffer.WriteString(" - ''") // a test line + + err = yaml.Unmarshal(firstLineBuffer.Bytes(), schema) + firstLineBuffer.Truncate(firstLineLength) + if err == nil && (len(schema.Rules) > 0 || len(schema.Payload) > 0) { // found + continue + } + + // not found or err!=nil + firstLineBuffer.Truncate(0) + firstLineLength = 0 + continue + } + + // parse payload body + err = yaml.Unmarshal(firstLineBuffer.Bytes(), schema) + firstLineBuffer.Truncate(firstLineLength) + if err != nil { + continue + } + var str string + if len(schema.Rules) > 0 { + str = schema.Rules[0] + } + if len(schema.Payload) > 0 { + str = schema.Payload[0] + } + if str == "" { + continue + } + + strategy.Insert(str) } - return append(rulePayload.Rules, rulePayload.Rules2...), nil + strategy.FinishInsert() + + return strategy, nil }