diff --git a/adapter/provider/fetcher.go b/adapter/provider/fetcher.go index 6c1e96b4..acccf087 100644 --- a/adapter/provider/fetcher.go +++ b/adapter/provider/fetcher.go @@ -26,7 +26,7 @@ type fetcher struct { done chan struct{} hash [16]byte parser parser - onUpdate func(interface{}) + onUpdate func(interface{}) error } func (f *fetcher) Name() string { @@ -167,7 +167,7 @@ func safeWrite(path string, buf []byte) error { return os.WriteFile(path, buf, fileMode) } -func newFetcher(name string, interval time.Duration, vehicle types.Vehicle, parser parser, onUpdate func(interface{})) *fetcher { +func newFetcher(name string, interval time.Duration, vehicle types.Vehicle, parser parser, onUpdate func(interface{}) error) *fetcher { var ticker *time.Ticker if interval != 0 { ticker = time.NewTicker(interval) diff --git a/adapter/provider/parser.go b/adapter/provider/parser.go index 8a173966..5a29ee40 100644 --- a/adapter/provider/parser.go +++ b/adapter/provider/parser.go @@ -60,3 +60,44 @@ func ParseProxyProvider(name string, mapping map[string]interface{}) (types.Prox interval := time.Duration(uint(schema.Interval)) * time.Second return NewProxySetProvider(name, interval, vehicle, hc), nil } + +type ruleProviderSchema struct { + Type string `provider:"type"` + Behavior string `provider:"behavior"` + Path string `provider:"path"` + URL string `provider:"url,omitempty"` + Interval int `provider:"interval,omitempty"` +} + +func ParseRuleProvider(name string, mapping map[string]interface{}) (types.RuleProvider, error) { + schema := &ruleProviderSchema{} + decoder := structure.NewDecoder(structure.Option{TagName: "provider", WeaklyTypedInput: true}) + if err := decoder.Decode(mapping, schema); err != nil { + return nil, err + } + var behavior Behavior + + switch schema.Behavior { + case "domain": + behavior = Domain + case "ipcidr": + behavior = IPCIDR + case "classical": + behavior = Classical + default: + return nil, fmt.Errorf("unsupported behavior type: %s", schema.Behavior) + } + + path := C.Path.Resolve(schema.Path) + var vehicle types.Vehicle + switch schema.Type { + case "file": + vehicle = NewFileVehicle(path) + case "http": + vehicle = NewHTTPVehicle(schema.URL, path) + default: + return nil, fmt.Errorf("unsupported vehicle type: %s", schema.Type) + } + interval := time.Duration(uint(schema.Interval)) * time.Second + return NewRuleSetProvider(name, behavior, interval, vehicle), nil +} diff --git a/adapter/provider/provider.go b/adapter/provider/provider.go index 32baaedd..04051362 100644 --- a/adapter/provider/provider.go +++ b/adapter/provider/provider.go @@ -4,7 +4,9 @@ import ( "encoding/json" "errors" "fmt" + "github.com/Dreamacro/clash/component/trie" "runtime" + "strings" "time" "github.com/Dreamacro/clash/adapter" @@ -132,9 +134,10 @@ func NewProxySetProvider(name string, interval time.Duration, vehicle types.Vehi healthCheck: hc, } - onUpdate := func(elm interface{}) { + onUpdate := func(elm interface{}) error { ret := elm.([]C.Proxy) pd.setProxies(ret) + return nil } fetcher := newFetcher(name, interval, vehicle, proxiesParse, onUpdate) @@ -221,3 +224,278 @@ func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*Co runtime.SetFinalizer(wrapper, stopCompatibleProvider) return wrapper, nil } + +// Rule + +type Behavior int + +var ( + parse = func(ruleType, rule string, params []string) (C.Rule, error) { + return nil, errors.New("unimplemented function") + } + + ruleProviders = map[string]types.RuleProvider{} +) + +func SetClassicalRuleParser(function func(ruleType, rule string, params []string) (C.Rule, error)) { + parse = function +} + +func RuleProviders() map[string]types.RuleProvider { + return ruleProviders +} + +func SetRuleProvider(ruleProvider types.RuleProvider) { + if ruleProvider != nil { + ruleProviders[(ruleProvider).Name()] = ruleProvider + } +} + +type ruleSetProvider struct { + *fetcher + behavior Behavior + count int + DomainRules *trie.DomainTrie + IPCIDRRules *trie.IpCidrTrie + ClassicalRules []C.Rule +} + +type RuleSetProvider struct { + *ruleSetProvider +} + +func (r RuleSetProvider) Behavior() types.RuleType { + //TODO implement me + panic("implement me") +} + +func (r RuleSetProvider) ShouldResolveIP() bool { + //TODO implement me + panic("implement me") +} + +func (r RuleSetProvider) AsRule(adaptor string) C.Rule { + //TODO implement me + panic("implement me") +} + +func NewRuleSetProvider(name string, behavior Behavior, interval time.Duration, vehicle types.Vehicle) *RuleSetProvider { + rp := &ruleSetProvider{ + behavior: behavior, + } + + onUpdate := func(elm interface{}) error { + rulesRaw := elm.([]string) + rp.count = len(rulesRaw) + rules, err := constructRules(rp.behavior, rulesRaw) + if err != nil { + return err + } + rp.setRules(rules) + return nil + } + + fetcher := newFetcher(name, interval, vehicle, rulesParse, onUpdate) + rp.fetcher = fetcher + wrapper := &RuleSetProvider{ + rp, + } + + runtime.SetFinalizer(wrapper, stopRuleSetProvider) + return wrapper +} + +func (rp *ruleSetProvider) Name() string { + return rp.name +} + +func (rp *ruleSetProvider) RuleCount() int { + return rp.count +} + +const ( + Domain = iota + IPCIDR + Classical +) + +// RuleType defined + +func (b Behavior) String() string { + switch b { + case Domain: + return "Domain" + case IPCIDR: + return "IPCIDR" + case Classical: + return "Classical" + default: + return "" + } +} + +func (rp *ruleSetProvider) Match(metadata *C.Metadata) bool { + switch rp.behavior { + case Domain: + return rp.DomainRules.Search(metadata.Host) != nil + case IPCIDR: + return rp.IPCIDRRules.IsContain(metadata.DstIP) + case Classical: + for _, rule := range rp.ClassicalRules { + if rule.Match(metadata) { + return true + } + } + return false + default: + return false + } +} + +func (rp *ruleSetProvider) Behavior() Behavior { + return rp.behavior +} + +func (rp *ruleSetProvider) VehicleType() types.VehicleType { + return rp.vehicle.Type() +} + +func (rp *ruleSetProvider) Type() types.ProviderType { + return types.Rule +} + +func (rp *ruleSetProvider) Initial() error { + elm, err := rp.fetcher.Initial() + if err != nil { + return err + } + return rp.fetcher.onUpdate(elm) +} + +func (rp *ruleSetProvider) Update() error { + elm, same, err := rp.fetcher.Update() + if err == nil && !same { + return rp.fetcher.onUpdate(elm) + } + return err +} + +func (rp *ruleSetProvider) setRules(rules interface{}) { + switch rp.behavior { + case Domain: + rp.DomainRules = rules.(*trie.DomainTrie) + case Classical: + rp.ClassicalRules = rules.([]C.Rule) + case IPCIDR: + rp.IPCIDRRules = rules.(*trie.IpCidrTrie) + default: + } +} + +func (rp *ruleSetProvider) MarshalJSON() ([]byte, error) { + return json.Marshal( + map[string]interface{}{ + "behavior": rp.behavior.String(), + "name": rp.Name(), + "ruleCount": rp.RuleCount(), + "type": rp.Type().String(), + "updatedAt": rp.updatedAt, + "vehicleType": rp.VehicleType().String(), + }) +} + +type RulePayload struct { + /** + key: Domain or IP Cidr + value: Rule type or is empty + */ + Rules []string `yaml:"payload"` +} + +func rulesParse(buf []byte) (interface{}, error) { + rulePayload := RulePayload{} + err := yaml.Unmarshal(buf, &rulePayload) + if err != nil { + return nil, err + } + + return rulePayload.Rules, nil +} + +func constructRules(behavior Behavior, rules []string) (interface{}, error) { + switch behavior { + case Domain: + return handleDomainRules(rules) + case IPCIDR: + return handleIpCidrRules(rules) + case Classical: + return handleClassicalRules(rules) + default: + return nil, errors.New("unknown behavior type") + } +} + +func handleDomainRules(rules []string) (interface{}, error) { + domainRules := trie.New() + for _, rawRule := range rules { + ruleType, rule, _ := ruleParse(rawRule) + if ruleType != "" { + return nil, errors.New("error format of domain") + } + + if err := domainRules.Insert(rule, ""); err != nil { + return nil, err + } + } + return domainRules, nil +} + +func handleIpCidrRules(rules []string) (interface{}, error) { + ipCidrRules := trie.NewIpCidrTrie() + for _, rawRule := range rules { + ruleType, rule, _ := ruleParse(rawRule) + if ruleType != "" { + return nil, errors.New("error format of ip-cidr") + } + + if err := ipCidrRules.AddIpCidrForString(rule); err != nil { + return nil, err + } + } + return ipCidrRules, nil +} + +func handleClassicalRules(rules []string) (interface{}, error) { + var classicalRules []C.Rule + for _, rawRule := range rules { + ruleType, rule, params := ruleParse(rawRule) + if ruleType == "RULE-SET" { + return nil, errors.New("error rule type") + } + + r, err := parse(ruleType, rule, params) + if err != nil { + return nil, err + } + + classicalRules = append(classicalRules, r) + } + return classicalRules, nil +} + +func ruleParse(ruleRaw string) (string, string, []string) { + item := strings.Split(ruleRaw, ",") + if len(item) == 1 { + return "", item[0], nil + } else if len(item) == 2 { + return item[0], item[1], nil + } else if len(item) > 2 { + return item[0], item[1], item[2:] + } + + return "", "", nil +} + +func stopRuleSetProvider(rp *RuleSetProvider) { + rp.fetcher.Destroy() +} diff --git a/adapter/provider/provider_test.go b/adapter/provider/provider_test.go new file mode 100644 index 00000000..66dd0abc --- /dev/null +++ b/adapter/provider/provider_test.go @@ -0,0 +1,58 @@ +package provider + +import ( + "github.com/Dreamacro/clash/constant" + rules "github.com/Dreamacro/clash/rule" + + "github.com/stretchr/testify/assert" + "net" + "testing" + "time" +) + +func setup() { + SetClassicalRuleParser(func(ruleType, rule string, params []string) (constant.Rule, error) { + if params == nil { + params = make([]string, 0) + } + + return rules.ParseRule(ruleType, rule, "", params) + }) +} + +func TestDomain(t *testing.T) { + setup() + domainProvider := NewRuleSetProvider("test", Domain, + time.Duration(uint(100000)), NewFileVehicle("./domain.txt")) + assert.Nil(t, domainProvider.Initial()) + assert.True(t, domainProvider.Match(&constant.Metadata{Host: "youtube.com"})) + assert.True(t, domainProvider.Match(&constant.Metadata{Host: "www.youtube.com"})) + assert.True(t, domainProvider.Match(&constant.Metadata{Host: "test.youtube.com"})) + assert.True(t, domainProvider.Match(&constant.Metadata{Host: "bcovlive-a.akamaihd.net"})) + assert.False(t, domainProvider.Match(&constant.Metadata{Host: "baidu.com"})) +} + +func TestClassical(t *testing.T) { + setup() + classicalProvider := NewRuleSetProvider("test", Classical, + time.Duration(uint(100000)), NewFileVehicle("./classical.txt")) + assert.Nil(t, classicalProvider.Initial()) + assert.True(t, classicalProvider.Match(&constant.Metadata{Host: "www.10010.com", AddrType: constant.AtypDomainName})) + assert.False(t, classicalProvider.Match(&constant.Metadata{Host: "google.com", AddrType: constant.AtypDomainName})) + assert.True(t, classicalProvider.Match(&constant.Metadata{Host: "analytics.strava.com", AddrType: constant.AtypDomainName})) + assert.True(t, classicalProvider.Match(&constant.Metadata{DstIP: net.ParseIP("2a0b:b580::1")})) + assert.False(t, classicalProvider.Match(&constant.Metadata{DstIP: net.ParseIP("2a0b:c582::1")})) + assert.True(t, classicalProvider.Match(&constant.Metadata{DstIP: net.ParseIP("1.255.62.34")})) + assert.False(t, classicalProvider.Match(&constant.Metadata{DstIP: net.ParseIP("103.65.41.199")})) +} + +func TestIpCidr(t *testing.T) { + setup() + domainProvider := NewRuleSetProvider("test", IPCIDR, + time.Duration(uint(100000)), NewFileVehicle("./ipcidr.txt")) + assert.Nil(t, domainProvider.Initial()) + assert.True(t, domainProvider.Match(&constant.Metadata{DstIP: net.ParseIP("91.108.22.10")})) + assert.False(t, domainProvider.Match(&constant.Metadata{DstIP: net.ParseIP("149.190.220.251")})) + assert.True(t, domainProvider.Match(&constant.Metadata{DstIP: net.ParseIP("2001:b28:f23f:f005::a")})) + assert.False(t, domainProvider.Match(&constant.Metadata{DstIP: net.ParseIP("2006:b28:f23f:f005::a")})) +} diff --git a/component/trie/ipcidr_node.go b/component/trie/ipcidr_node.go new file mode 100644 index 00000000..def640fe --- /dev/null +++ b/component/trie/ipcidr_node.go @@ -0,0 +1,47 @@ +package trie + +import "errors" + +var ( + ErrorOverMaxValue = errors.New("the value don't over max value") +) + +type IpCidrNode struct { + Mark bool + child map[uint32]*IpCidrNode + maxValue uint32 +} + +func NewIpCidrNode(mark bool, maxValue uint32) *IpCidrNode { + ipCidrNode := &IpCidrNode{ + Mark: mark, + child: map[uint32]*IpCidrNode{}, + maxValue: maxValue, + } + + return ipCidrNode +} + +func (n *IpCidrNode) addChild(value uint32) error { + if value > n.maxValue { + return ErrorOverMaxValue + } + + n.child[value] = NewIpCidrNode(false, n.maxValue) + return nil +} + +func (n *IpCidrNode) hasChild(value uint32) bool { + return n.getChild(value) != nil +} + +func (n *IpCidrNode) getChild(value uint32) *IpCidrNode { + if value <= n.maxValue && !n.Mark { + if n.child[value] == nil { + n.child[value] = NewIpCidrNode(false, n.maxValue) + } + + return n.child[value] + } + return nil +} diff --git a/component/trie/ipcidr_trie.go b/component/trie/ipcidr_trie.go new file mode 100644 index 00000000..d5056a2c --- /dev/null +++ b/component/trie/ipcidr_trie.go @@ -0,0 +1,214 @@ +package trie + +import ( + "github.com/Dreamacro/clash/log" + "net" +) + +type IPV6 bool + +const ( + ipv4GroupMaxValue = 0xFF + ipv6GroupMaxValue = 0xFFFF +) + +type IpCidrTrie struct { + ipv4Trie *IpCidrNode + ipv6Trie *IpCidrNode +} + +func NewIpCidrTrie() *IpCidrTrie { + return &IpCidrTrie{ + ipv4Trie: NewIpCidrNode(false, ipv4GroupMaxValue), + ipv6Trie: NewIpCidrNode(false, ipv6GroupMaxValue), + } +} + +func (trie *IpCidrTrie) AddIpCidr(ipCidr *net.IPNet) error { + subIpCidr, subCidr, isIpv4, err := ipCidrToSubIpCidr(ipCidr) + if err != nil { + return err + } + + for _, sub := range subIpCidr { + addIpCidr(trie, isIpv4, sub, subCidr/8) + } + + return nil +} + +func (trie *IpCidrTrie) AddIpCidrForString(ipCidr string) error { + _, ipNet, err := net.ParseCIDR(ipCidr) + if err != nil { + return err + } + + return trie.AddIpCidr(ipNet) +} + +func (trie *IpCidrTrie) IsContain(ip net.IP) bool { + ip, isIpv4 := checkAndConverterIp(ip) + if ip == nil { + return false + } + + var groupValues []uint32 + var ipCidrNode *IpCidrNode + + if isIpv4 { + ipCidrNode = trie.ipv4Trie + for _, group := range ip { + groupValues = append(groupValues, uint32(group)) + } + } else { + ipCidrNode = trie.ipv6Trie + for i := 0; i < len(ip); i += 2 { + groupValues = append(groupValues, getIpv6GroupValue(ip[i], ip[i+1])) + } + } + + return search(ipCidrNode, groupValues) != nil +} + +func (trie *IpCidrTrie) IsContainForString(ipString string) bool { + return trie.IsContain(net.ParseIP(ipString)) +} + +func ipCidrToSubIpCidr(ipNet *net.IPNet) ([]net.IP, int, bool, error) { + maskSize, _ := ipNet.Mask.Size() + var ( + ipList []net.IP + newMaskSize int + isIpv4 bool + err error + ) + + ip, isIpv4 := checkAndConverterIp(ipNet.IP) + ipList, newMaskSize, err = subIpCidr(ip, maskSize, true) + + return ipList, newMaskSize, isIpv4, err +} + +func subIpCidr(ip net.IP, maskSize int, isIpv4 bool) ([]net.IP, int, error) { + var subIpCidrList []net.IP + groupSize := 8 + if !isIpv4 { + groupSize = 16 + } + + if maskSize%groupSize == 0 { + return append(subIpCidrList, ip), maskSize, nil + } + + lastByteMaskSize := maskSize % 8 + lastByteMaskIndex := maskSize / 8 + subIpCidrNum := 0xFF >> lastByteMaskSize + for i := 0; i < subIpCidrNum; i++ { + subIpCidr := make([]byte, len(ip)) + copy(subIpCidr, ip) + subIpCidr[lastByteMaskIndex] += byte(i) + subIpCidrList = append(subIpCidrList, subIpCidr) + } + + return subIpCidrList, lastByteMaskIndex * 8, nil +} + +func addIpCidr(trie *IpCidrTrie, isIpv4 bool, ip net.IP, groupSize int) { + if isIpv4 { + addIpv4Cidr(trie, ip, groupSize) + } else { + addIpv6Cidr(trie, ip, groupSize) + } +} + +func addIpv4Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) { + node := trie.ipv4Trie.getChild(uint32(ip[0])) + + for i := 1; i < groupSize; i++ { + if node.Mark { + return + } + + groupValue := uint32(ip[i]) + if !node.hasChild(groupValue) { + err := node.addChild(groupValue) + if err != nil { + log.Errorln(err.Error()) + } + } + + node = node.getChild(groupValue) + } + + node.Mark = true + cleanChild(node) +} + +func addIpv6Cidr(trie *IpCidrTrie, ip net.IP, groupSize int) { + node := trie.ipv6Trie.getChild(getIpv6GroupValue(ip[0], ip[1])) + + for i := 2; i < groupSize; i += 2 { + if node.Mark { + return + } + + groupValue := getIpv6GroupValue(ip[i], ip[i+1]) + if !node.hasChild(groupValue) { + err := node.addChild(groupValue) + if err != nil { + log.Errorln(err.Error()) + } + } + + node = node.getChild(groupValue) + } + + node.Mark = true + cleanChild(node) +} + +func getIpv6GroupValue(high, low byte) uint32 { + return (uint32(high) << 8) | uint32(low) +} + +func cleanChild(node *IpCidrNode) { + for i := uint32(0); i < uint32(len(node.child)); i++ { + delete(node.child, i) + } +} + +func search(root *IpCidrNode, groupValues []uint32) *IpCidrNode { + node := root.getChild(groupValues[0]) + if node.Mark { + return node + } + + for _, value := range groupValues[1:] { + if !node.hasChild(value) { + return nil + } + + node = node.getChild(value) + + if node.Mark { + return node + } + } + + return nil +} + +// return net.IP To4 or To16 and is ipv4 +func checkAndConverterIp(ip net.IP) (net.IP, bool) { + ipResult := ip.To4() + if ipResult == nil { + ipResult = ip.To16() + if ipResult == nil { + return nil, false + } + + return ipResult, false + } + + return ipResult, true +} diff --git a/config/config.go b/config/config.go index 37482879..935a99e6 100644 --- a/config/config.go +++ b/config/config.go @@ -122,7 +122,7 @@ type Config struct { Hosts *trie.DomainTrie Profile *Profile Rules []C.Rule - RuleProviders map[string]C.Rule + RuleProviders map[string]providerTypes.RuleProvider Users []auth.AuthUser Proxies map[string]C.Proxy Providers map[string]providerTypes.ProxyProvider @@ -482,9 +482,40 @@ time = ClashTime() return nil } -func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[string]C.Rule, error) { +func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[string]providerTypes.RuleProvider, error) { rules := []C.Rule{} - ruleProviders := map[string]C.Rule{} + ruleProviders := map[string]providerTypes.RuleProvider{} + ruleProviderNameSet := make(map[string]interface{}, len(ruleProviders)) + + // set parse callback for parse rule type + provider.SetClassicalRuleParser(func(ruleType, rule string, params []string) (C.Rule, error) { + if params == nil { + params = make([]string, 0) + } + + return R.ParseRule(ruleType, rule, "", params) + }) + + for name, mapping := range cfg.RuleProvider { + rp, err := provider.ParseRuleProvider(name, mapping) + if err != nil { + return nil, nil, err + } + + ruleProviders[name] = rp + provider.SetRuleProvider(rp) + } + for _, provider := range ruleProviders { + log.Infoln("Start initial provider %s", (provider).Name()) + if err := (provider).Initial(); err != nil { + return nil, nil, fmt.Errorf("initial rule provider %s error: %w", (provider).Name(), err) + } + } + // get all name of rule provider + for k := range ruleProviders { + ruleProviderNameSet[k] = nil + } + rulesConfig := cfg.Rule mode := cfg.Mode @@ -526,6 +557,7 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[strin return nil, nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target) } + rule = trimArr(rule) params = trimArr(params) parsed, parseErr := R.ParseRule(ruleName, payload, target, params) diff --git a/constant/provider/interface.go b/constant/provider/interface.go index 53bda7ea..3d8332e8 100644 --- a/constant/provider/interface.go +++ b/constant/provider/interface.go @@ -1,5 +1,6 @@ package provider +import "C" import ( "github.com/Dreamacro/clash/constant" ) @@ -79,6 +80,8 @@ const ( Classical ) +type Behavior int + // RuleType defined type RuleType int diff --git a/constant/rule.go b/constant/rule.go index e2087604..a373d49e 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -12,6 +12,7 @@ const ( SrcPort DstPort Process + RuleSet Script MATCH ) @@ -40,6 +41,8 @@ func (rt RuleType) String() string { return "DstPort" case Process: return "Process" + case RuleSet: + return "RuleSet" case Script: return "Script" case MATCH: diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 05de2a10..4baf6670 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -75,7 +75,7 @@ func ApplyConfig(cfg *config.Config, force bool) { updateUsers(cfg.Users) updateProxies(cfg.Proxies, cfg.Providers) - updateRules(cfg.Rules) + updateRules(cfg.Rules, cfg.RuleProviders) updateHosts(cfg.Hosts) updateProfile(cfg) updateIPTables(cfg.DNS, cfg.General) @@ -176,8 +176,8 @@ func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.Pro tunnel.UpdateProxies(proxies, providers) } -func updateRules(rules []C.Rule) { - tunnel.UpdateRules(rules) +func updateRules(rules []C.Rule, ruleProviders map[string]provider.RuleProvider) { + tunnel.UpdateRules(rules, ruleProviders) } func updateGeneral(general *config.General, force bool) { diff --git a/hub/route/provider.go b/hub/route/provider.go index 0b599445..0d3f768e 100644 --- a/hub/route/provider.go +++ b/hub/route/provider.go @@ -75,3 +75,40 @@ func findProviderByName(next http.Handler) http.Handler { next.ServeHTTP(w, r.WithContext(ctx)) }) } + +func getRuleProviders(w http.ResponseWriter, r *http.Request) { + ruleProviders := tunnel.RuleProviders() + render.JSON(w, r, render.M{ + "providers": ruleProviders, + }) +} + +func updateRuleProvider(w http.ResponseWriter, r *http.Request) { + provider := r.Context().Value(CtxKeyProvider).(provider.RuleProvider) + if err := (provider).Update(); err != nil { + render.Status(r, http.StatusServiceUnavailable) + render.JSON(w, r, newError(err.Error())) + } + render.NoContent(w, r) +} +func parseRuleProviderName(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + name := getEscapeParam(r, "name") + ctx := context.WithValue(r.Context(), CtxKeyProviderName, name) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} +func findRuleProviderByName(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + name := r.Context().Value(CtxKeyProviderName).(string) + providers := tunnel.RuleProviders() + provider, exist := providers[name] + if !exist { + render.Status(r, http.StatusNotFound) + render.JSON(w, r, ErrNotFound) + return + } + ctx := context.WithValue(r.Context(), CtxKeyProvider, provider) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/rule/parser.go b/rule/parser.go index fa73c2ae..72427135 100644 --- a/rule/parser.go +++ b/rule/parser.go @@ -40,6 +40,8 @@ func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { parsed, parseErr = NewPort(payload, target, false, ruleExtra) case "PROCESS-NAME": parsed, parseErr = NewProcess(payload, target, ruleExtra) + case "RULE-SET": + parsed, parseErr = NewRuleSet(payload, target, ruleExtra) case "MATCH": parsed = NewMatch(target, ruleExtra) default: diff --git a/rule/rule_set.go b/rule/rule_set.go new file mode 100644 index 00000000..f612e675 --- /dev/null +++ b/rule/rule_set.go @@ -0,0 +1,60 @@ +package rules + +import ( + "fmt" + "github.com/Dreamacro/clash/adapter/provider" + C "github.com/Dreamacro/clash/constant" + types "github.com/Dreamacro/clash/constant/provider" +) + +type RuleSet struct { + ruleProviderName string + adapter string + ruleProvider *types.RuleProvider + ruleExtra *C.RuleExtra +} + +func (rs *RuleSet) RuleType() C.RuleType { + return C.RuleSet +} + +func (rs *RuleSet) Match(metadata *C.Metadata) bool { + return rs.getProviders().Match(metadata) +} + +func (rs *RuleSet) Adapter() string { + return rs.adapter +} + +func (rs *RuleSet) Payload() string { + return rs.getProviders().Name() +} + +func (rs *RuleSet) ShouldResolveIP() bool { + return rs.getProviders().Behavior() != provider.Domain +} +func (rs *RuleSet) getProviders() types.RuleProvider { + if rs.ruleProvider == nil { + rp := provider.RuleProviders()[rs.ruleProviderName] + rs.ruleProvider = &rp + } + + return *rs.ruleProvider +} + +func (rs *RuleSet) RuleExtra() *C.RuleExtra { + return rs.ruleExtra +} + +func NewRuleSet(ruleProviderName string, adapter string, ruleExtra *C.RuleExtra) (*RuleSet, error) { + rp, ok := provider.RuleProviders()[ruleProviderName] + if !ok { + return nil, fmt.Errorf("rule set %s not found", ruleProviderName) + } + return &RuleSet{ + ruleProviderName: ruleProviderName, + adapter: adapter, + ruleProvider: &rp, + ruleExtra: ruleExtra, + }, nil +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index ba5ef9c0..22980b94 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -20,13 +20,14 @@ import ( ) var ( - tcpQueue = make(chan C.ConnContext, 200) - udpQueue = make(chan *inbound.PacketAdapter, 200) - natTable = nat.New() - rules []C.Rule - proxies = make(map[string]C.Proxy) - providers map[string]provider.ProxyProvider - configMux sync.RWMutex + tcpQueue = make(chan C.ConnContext, 200) + udpQueue = make(chan *inbound.PacketAdapter, 200) + natTable = nat.New() + rules []C.Rule + ruleProviders map[string]provider.RuleProvider + proxies = make(map[string]C.Proxy) + providers map[string]provider.ProxyProvider + configMux sync.RWMutex // Outbound Rule mode = Rule @@ -57,9 +58,10 @@ func Rules() []C.Rule { } // UpdateRules handle update rules -func UpdateRules(newRules []C.Rule) { +func UpdateRules(newRules []C.Rule, rp map[string]provider.RuleProvider) { configMux.Lock() rules = newRules + ruleProviders = rp configMux.Unlock() } @@ -81,6 +83,10 @@ func UpdateProxies(newProxies map[string]C.Proxy, newProviders map[string]provid configMux.Unlock() } +func RuleProviders() map[string]provider.RuleProvider { + return ruleProviders +} + // Mode return current mode func Mode() TunnelMode { return mode