From e52c111ae05d0ab279dfef2c250bfae6b8e19e79 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 25 Nov 2021 23:14:31 +0800 Subject: [PATCH] [test] rule providers --- component/trie/domain.go | 292 +++++++++++++++++++++++++++++++++ constant/provider/interface.go | 26 +++ constant/rule.go | 3 + rule/parser.go | 2 + rule/ruleset.go | 57 +++++++ 5 files changed, 380 insertions(+) create mode 100644 rule/ruleset.go diff --git a/component/trie/domain.go b/component/trie/domain.go index ffd0b754..58e76b71 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -2,6 +2,8 @@ package trie import ( "errors" + "github.com/Dreamacro/clash/log" + "net" "strings" ) @@ -127,3 +129,293 @@ func (t *DomainTrie) search(node *Node, parts []string) *Node { func New() *DomainTrie { return &DomainTrie{root: newNode(nil)} } + +type IPV6 bool + +const ( + ipv4GroupMaxValue = 0xFF + ipv6GroupMaxValue = 0xFFFF +) + +type IpCidrTrie struct { + ipv4Trie *IpCidrNode + ipv6Trie *IpCidrNode +} + +func NewIpCidrTrie() *IpCidrTrie { + return &IpCidrTrie{ + ipv4Trie: NewIpCidrNode(false, ipv4GroupMaxValue), + ipv6Trie: NewIpCidrNode(false, ipv6GroupMaxValue), + } +} + +func (trie *IpCidrTrie) AddIpCidr(ipCidr *net.IPNet) error { + subIpCidr, subCidr, isIpv4, err := ipCidrToSubIpCidr(ipCidr) + if err != nil { + return err + } + + for _, sub := range subIpCidr { + addIpCidr(trie, isIpv4, sub, subCidr/8) + } + + return nil +} + +func (trie *IpCidrTrie) AddIpCidrForString(ipCidr string) error { + _, ipNet, err := net.ParseCIDR(ipCidr) + if err != nil { + return err + } + + return trie.AddIpCidr(ipNet) +} + +func (trie *IpCidrTrie) IsContain(ip net.IP) bool { + ip, isIpv4 := checkAndConverterIp(ip) + if ip == nil { + return false + } + + var groupValues []uint32 + var ipCidrNode *IpCidrNode + + if isIpv4 { + ipCidrNode = trie.ipv4Trie + for _, group := range ip { + groupValues = append(groupValues, uint32(group)) + } + } else { + ipCidrNode = trie.ipv6Trie + for i := 0; i < len(ip); i += 2 { + groupValues = append(groupValues, getIpv6GroupValue(ip[i], ip[i+1])) + } + } + + return search(ipCidrNode, groupValues) != nil +} + +func (trie *IpCidrTrie) IsContainForString(ipString string) bool { + return trie.IsContain(net.ParseIP(ipString)) +} + +func ipCidrToSubIpCidr(ipNet *net.IPNet) ([]net.IP, int, bool, error) { + maskSize, _ := ipNet.Mask.Size() + var ( + ipList []net.IP + newMaskSize int + isIpv4 bool + err error + ) + + ip, isIpv4 := checkAndConverterIp(ipNet.IP) + ipList, newMaskSize, err = subIpCidr(ip, maskSize, isIpv4) + + return ipList, newMaskSize, isIpv4, err +} + +func subIpCidr(ip net.IP, maskSize int, isIpv4 bool) ([]net.IP, int, error) { + var subIpCidrList []net.IP + groupSize := 8 + if !isIpv4 { + groupSize = 16 + } + + if maskSize%groupSize == 0 { + return append(subIpCidrList, ip), maskSize, nil + } + + lastByteMaskSize := maskSize % 8 + lastByteMaskIndex := maskSize / 8 + subIpCidrNum := 0xFF >> lastByteMaskSize + for i := 0; i < subIpCidrNum; i++ { + subIpCidr := make([]byte, len(ip)) + copy(subIpCidr, ip) + subIpCidr[lastByteMaskIndex] += byte(i) + subIpCidrList = append(subIpCidrList, subIpCidr) + } + + newMaskSize := (lastByteMaskIndex + 1) * 8 + if !isIpv4 { + newMaskSize = (lastByteMaskIndex/2 + 1) * 16 + } + + return subIpCidrList, newMaskSize, nil +} + +func addIpCidr(trie *IpCidrTrie, isIpv4 bool, ip net.IP, groupSize int) { + if isIpv4 { + addIpv4Cidr(trie, ip, groupSize) + } else { + addIpv6Cidr(trie, ip, groupSize) + } +} + +func addIpv4Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) { + preNode := trie.ipv4Trie + node := preNode.getChild(uint32(ip[0])) + if node == nil { + err := preNode.addChild(uint32(ip[0])) + if err != nil { + return + } + + node = preNode.getChild(uint32(ip[0])) + } + + for i := 1; i < groupSize; i++ { + if node.Mark { + return + } + + groupValue := uint32(ip[i]) + if !node.hasChild(groupValue) { + err := node.addChild(groupValue) + if err != nil { + log.Errorln(err.Error()) + } + } + + preNode = node + node = node.getChild(groupValue) + if node == nil { + err := preNode.addChild(uint32(ip[i-1])) + if err != nil { + return + } + + node = preNode.getChild(uint32(ip[i-1])) + } + } + + node.Mark = true + cleanChild(node) +} + +func addIpv6Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) { + preNode := trie.ipv6Trie + node := preNode.getChild(getIpv6GroupValue(ip[0], ip[1])) + if node == nil { + err := preNode.addChild(getIpv6GroupValue(ip[0], ip[1])) + if err != nil { + return + } + + node = preNode.getChild(getIpv6GroupValue(ip[0], ip[1])) + } + + for i := 2; i < groupSize; i += 2 { + if node.Mark { + return + } + + groupValue := getIpv6GroupValue(ip[i], ip[i+1]) + if !node.hasChild(groupValue) { + err := node.addChild(groupValue) + if err != nil { + log.Errorln(err.Error()) + } + } + + preNode = node + node = node.getChild(groupValue) + if node == nil { + err := preNode.addChild(getIpv6GroupValue(ip[i-2], ip[i-1])) + if err != nil { + return + } + + node = preNode.getChild(getIpv6GroupValue(ip[i-2], ip[i-1])) + } + } + + node.Mark = true + cleanChild(node) +} + +func getIpv6GroupValue(high, low byte) uint32 { + return (uint32(high) << 8) | uint32(low) +} + +func cleanChild(node *IpCidrNode) { + for i := uint32(0); i < uint32(len(node.child)); i++ { + delete(node.child, i) + } +} + +func search(root *IpCidrNode, groupValues []uint32) *IpCidrNode { + node := root.getChild(groupValues[0]) + if node == nil || node.Mark { + return node + } + + for _, value := range groupValues[1:] { + if !node.hasChild(value) { + return nil + } + + node = node.getChild(value) + + if node == nil || node.Mark { + return node + } + } + + return nil +} + +// return net.IP To4 or To16 and is ipv4 +func checkAndConverterIp(ip net.IP) (net.IP, bool) { + ipResult := ip.To4() + if ipResult == nil { + ipResult = ip.To16() + if ipResult == nil { + return nil, false + } + + return ipResult, false + } + + return ipResult, true +} + +var ( + ErrorOverMaxValue = errors.New("the value don't over max value") +) + +type IpCidrNode struct { + Mark bool + child map[uint32]*IpCidrNode + maxValue uint32 +} + +func NewIpCidrNode(mark bool, maxValue uint32) *IpCidrNode { + ipCidrNode := &IpCidrNode{ + Mark: mark, + child: map[uint32]*IpCidrNode{}, + maxValue: maxValue, + } + + return ipCidrNode +} + +func (n *IpCidrNode) addChild(value uint32) error { + if value > n.maxValue { + return ErrorOverMaxValue + } + + n.child[value] = NewIpCidrNode(false, n.maxValue) + return nil +} + +func (n *IpCidrNode) hasChild(value uint32) bool { + return n.getChild(value) != nil +} + +func (n *IpCidrNode) getChild(value uint32) *IpCidrNode { + if value <= n.maxValue { + return n.child[value] + } + + return nil +} diff --git a/constant/provider/interface.go b/constant/provider/interface.go index 53bda7ea..da03033b 100644 --- a/constant/provider/interface.go +++ b/constant/provider/interface.go @@ -1,6 +1,9 @@ package provider +import "C" import ( + "errors" + "github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/constant" ) @@ -103,3 +106,26 @@ type RuleProvider interface { ShouldResolveIP() bool AsRule(adaptor string) constant.Rule } + +var ( + parse = func(ruleType, rule string, params []string) (C.Rule, error) { + return nil, errors.New("unimplemented function") + } + + ruleProviders = map[string]*RuleProvider{} +) + +func RuleProviders() map[string]*RuleProvider { + return ruleProviders +} + +type ruleSetProvider struct { + count int + DomainRules *trie.DomainTrie + IPCIDRRules *trie.IpCidrTrie + ClassicalRules []C.Rule +} + +type RuleSetProvider struct { + *ruleSetProvider +} diff --git a/constant/rule.go b/constant/rule.go index e2087604..a373d49e 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -12,6 +12,7 @@ const ( SrcPort DstPort Process + RuleSet Script MATCH ) @@ -40,6 +41,8 @@ func (rt RuleType) String() string { return "DstPort" case Process: return "Process" + case RuleSet: + return "RuleSet" case Script: return "Script" case MATCH: diff --git a/rule/parser.go b/rule/parser.go index 60bffefb..0942da1c 100644 --- a/rule/parser.go +++ b/rule/parser.go @@ -40,6 +40,8 @@ func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { parsed, parseErr = NewPort(payload, target, false, ruleExtra) case "PROCESS-NAME": parsed, parseErr = NewProcess(payload, target, ruleExtra) + case "RULE-SET": + parsed, parseErr = NewRuleSet(payload, target, ruleExtra) case "SCRIPT": parsed, parseErr = NewScript(payload, target) case "MATCH": diff --git a/rule/ruleset.go b/rule/ruleset.go new file mode 100644 index 00000000..e634dbae --- /dev/null +++ b/rule/ruleset.go @@ -0,0 +1,57 @@ +package rules + +import ( + "fmt" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/constant/provider" +) + +type RuleSet struct { + ruleProviderName string + adapter string + ruleProvider *provider.RuleProvider + ruleExtra *C.RuleExtra +} + +func (rs *RuleSet) RuleType() C.RuleType { + return C.RuleSet +} + +func (rs *RuleSet) Match(metadata *C.Metadata) bool { + return rs.getProviders().Match(metadata) +} + +func (rs *RuleSet) Adapter() string { + return rs.adapter +} + +func (rs *RuleSet) Payload() string { + return rs.getProviders().Name() +} + +func (rs *RuleSet) ShouldResolveIP() bool { + return rs.getProviders().Behavior() != provider.Domain +} +func (rs *RuleSet) getProviders() provider.RuleProvider { + if rs.ruleProvider == nil { + rp := provider.RuleProviders()[rs.ruleProviderName] + rs.ruleProvider = rp + } + return *rs.ruleProvider +} +func (rs *RuleSet) RuleExtra() *C.RuleExtra { + return rs.ruleExtra +} + +func NewRuleSet(ruleProviderName string, adapter string, ruleExtra *C.RuleExtra) (*RuleSet, error) { + rp, ok := provider.RuleProviders()[ruleProviderName] + if !ok { + return nil, fmt.Errorf("rule set %s not found", ruleProviderName) + } + return &RuleSet{ + ruleProviderName: ruleProviderName, + adapter: adapter, + ruleProvider: rp, + ruleExtra: ruleExtra, + }, nil +}