fix: npe when parse rule

This commit is contained in:
Skyxim 2022-04-16 00:21:08 +08:00
parent 36a719e2f8
commit 45fe6e996b
6 changed files with 86 additions and 42 deletions

View file

@ -13,22 +13,23 @@ type NetworkType struct {
} }
func NewNetworkType(network, adapter string) (*NetworkType, error) { func NewNetworkType(network, adapter string) (*NetworkType, error) {
var netType C.NetWork ntType := NetworkType{
Base: &Base{},
}
ntType.adapter = adapter
switch strings.ToUpper(network) { switch strings.ToUpper(network) {
case "TCP": case "TCP":
netType = C.TCP ntType.network = C.TCP
break break
case "UDP": case "UDP":
netType = C.UDP ntType.network = C.UDP
break break
default: default:
return nil, fmt.Errorf("unsupported network type, only TCP/UDP") return nil, fmt.Errorf("unsupported network type, only TCP/UDP")
} }
return &NetworkType{
Base: &Base{}, return &ntType, nil
network: netType,
adapter: adapter,
}, nil
} }
func (n *NetworkType) RuleType() C.RuleType { func (n *NetworkType) RuleType() C.RuleType {

View file

@ -1,8 +1,6 @@
package logic package logic
import ( import (
"fmt"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/rule/common" "github.com/Dreamacro/clash/rule/common"
) )
@ -21,16 +19,12 @@ func (A *AND) ShouldFindProcess() bool {
func NewAND(payload string, adapter string) (*AND, error) { func NewAND(payload string, adapter string) (*AND, error) {
and := &AND{Base: &common.Base{}, payload: payload, adapter: adapter} and := &AND{Base: &common.Base{}, payload: payload, adapter: adapter}
rules, err := parseRuleByPayload(payload, true) rules, err := parseRuleByPayload(payload)
if err != nil { if err != nil {
return nil, err return nil, err
} }
and.rules = rules and.rules = rules
if len(and.rules) == 0 {
return nil, fmt.Errorf("And rule is error, may be format error or not contain least one rule")
}
for _, rule := range rules { for _, rule := range rules {
if rule.ShouldResolveIP() { if rule.ShouldResolveIP() {
and.needIP = true and.needIP = true

View file

@ -2,20 +2,19 @@ package logic
import ( import (
"fmt" "fmt"
"io"
"net/http"
"os"
"regexp"
"strings"
"github.com/Dreamacro/clash/common/collections" "github.com/Dreamacro/clash/common/collections"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
RC "github.com/Dreamacro/clash/rule/common" RC "github.com/Dreamacro/clash/rule/common"
"github.com/Dreamacro/clash/rule/provider" "github.com/Dreamacro/clash/rule/provider"
"io"
"net/http"
"os"
"regexp"
"strings"
) )
func parseRuleByPayload(payload string, skip bool) ([]C.Rule, error) { func parseRuleByPayload(payload string) ([]C.Rule, error) {
regex, err := regexp.Compile("\\(.*\\)") regex, err := regexp.Compile("\\(.*\\)")
if err != nil { if err != nil {
return nil, err return nil, err
@ -28,7 +27,7 @@ func parseRuleByPayload(payload string, skip bool) ([]C.Rule, error) {
} }
rules := make([]C.Rule, 0, len(subAllRanges)) rules := make([]C.Rule, 0, len(subAllRanges))
subRanges := findSubRuleRange(payload, subAllRanges, skip) subRanges := findSubRuleRange(payload, subAllRanges)
for _, subRange := range subRanges { for _, subRange := range subRanges {
subPayload := payload[subRange.start+1 : subRange.end] subPayload := payload[subRange.start+1 : subRange.end]
@ -53,7 +52,7 @@ func containRange(r Range, preStart, preEnd int) bool {
func payloadToRule(subPayload string) (C.Rule, error) { func payloadToRule(subPayload string) (C.Rule, error) {
splitStr := strings.SplitN(subPayload, ",", 2) splitStr := strings.SplitN(subPayload, ",", 2)
if len(splitStr) < 2 { if len(splitStr) < 2 {
return nil, fmt.Errorf("The logic rule contain a rule of error format") return nil, fmt.Errorf("[%s] format is error", subPayload)
} }
tp := splitStr[0] tp := splitStr[0]
@ -91,9 +90,9 @@ func parseRule(tp, payload string, params []string) (C.Rule, error) {
parsed, parseErr = RC.NewGEOIP(payload, "", noResolve) parsed, parseErr = RC.NewGEOIP(payload, "", noResolve)
case "IP-CIDR", "IP-CIDR6": case "IP-CIDR", "IP-CIDR6":
noResolve := RC.HasNoResolve(params) noResolve := RC.HasNoResolve(params)
parsed, parseErr = RC.NewIPCIDR(payload, "", nil, RC.WithIPCIDRNoResolve(noResolve)) parsed, parseErr = RC.NewIPCIDR(payload, "", RC.WithIPCIDRNoResolve(noResolve))
case "SRC-IP-CIDR": case "SRC-IP-CIDR":
parsed, parseErr = RC.NewIPCIDR(payload, "", nil, RC.WithIPCIDRSourceIP(true), RC.WithIPCIDRNoResolve(true)) parsed, parseErr = RC.NewIPCIDR(payload, "", RC.WithIPCIDRSourceIP(true), RC.WithIPCIDRNoResolve(true))
case "SRC-PORT": case "SRC-PORT":
parsed, parseErr = RC.NewPort(payload, "", true) parsed, parseErr = RC.NewPort(payload, "", true)
case "DST-PORT": case "DST-PORT":
@ -113,7 +112,7 @@ func parseRule(tp, payload string, params []string) (C.Rule, error) {
case "NETWORK": case "NETWORK":
parsed, parseErr = RC.NewNetworkType(payload, "") parsed, parseErr = RC.NewNetworkType(payload, "")
default: default:
parseErr = fmt.Errorf("unsupported rule type %s", tp) parsed, parseErr = nil, fmt.Errorf("unsupported rule type %s", tp)
} }
if parseErr != nil { if parseErr != nil {
@ -151,6 +150,10 @@ func format(payload string) ([]Range, error) {
num++ num++
stack.Push(sr) stack.Push(sr)
} else if c == ')' { } else if c == ')' {
if stack.Len() == 0 {
return nil, fmt.Errorf("missing '('")
}
sr := stack.Pop().(Range) sr := stack.Pop().(Range)
sr.end = i sr.end = i
subRanges = append(subRanges, sr) subRanges = append(subRanges, sr)
@ -169,11 +172,11 @@ func format(payload string) ([]Range, error) {
return sortResult, nil return sortResult, nil
} }
func findSubRuleRange(payload string, ruleRanges []Range, skip bool) []Range { func findSubRuleRange(payload string, ruleRanges []Range) []Range {
payloadLen := len(payload) payloadLen := len(payload)
subRuleRange := make([]Range, 0) subRuleRange := make([]Range, 0)
for _, rr := range ruleRanges { for _, rr := range ruleRanges {
if rr.start == 0 && rr.end == payloadLen-1 && skip { if rr.start == 0 && rr.end == payloadLen-1 {
// 最大范围跳过 // 最大范围跳过
continue continue
} }

49
rule/logic/logic_test.go Normal file
View file

@ -0,0 +1,49 @@
package logic
import (
"github.com/Dreamacro/clash/constant"
"github.com/stretchr/testify/assert"
"testing"
)
func TestAND(t *testing.T) {
and, err := NewAND("((DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT")
assert.Equal(t, nil, err)
assert.Equal(t, "DIRECT", and.adapter)
assert.Equal(t, false, and.ShouldResolveIP())
assert.Equal(t, true, and.Match(&constant.Metadata{
Host: "baidu.com",
AddrType: constant.AtypDomainName,
NetWork: constant.TCP,
DstPort: "20000",
}))
and, err = NewAND("(DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT")
assert.NotEqual(t, nil, err)
and, err = NewAND("((AND,(DOMAIN,baidu.com),(NETWORK,TCP)),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT")
assert.Equal(t, nil, err)
}
func TestNOT(t *testing.T) {
not, err := NewNOT("((DST-PORT,6000-6500))", "REJECT")
assert.Equal(t, nil, err)
assert.Equal(t, false, not.Match(&constant.Metadata{
DstPort: "6100",
}))
_, err = NewNOT("((DST-PORT,5600-6666),(DOMAIN,baidu.com))", "DIRECT")
assert.NotEqual(t, nil, err)
_, err = NewNOT("(())", "DIRECT")
assert.NotEqual(t, nil, err)
}
func TestOR(t *testing.T) {
or, err := NewOR("((DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT")
assert.Equal(t, nil, err)
assert.Equal(t, true, or.Match(&constant.Metadata{
NetWork: constant.TCP,
}))
assert.Equal(t, false, or.ShouldResolveIP())
}

View file

@ -19,16 +19,19 @@ func (not *NOT) ShouldFindProcess() bool {
func NewNOT(payload string, adapter string) (*NOT, error) { func NewNOT(payload string, adapter string) (*NOT, error) {
not := &NOT{Base: &common.Base{}, payload: payload, adapter: adapter} not := &NOT{Base: &common.Base{}, payload: payload, adapter: adapter}
rule, err := parseRuleByPayload(payload, false) rule, err := parseRuleByPayload(payload)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(rule) < 1 { if len(rule) > 1 {
return nil, fmt.Errorf("NOT rule have not a rule") return nil, fmt.Errorf("not rule can contain at most one rule")
} }
if len(rule) > 0 {
not.rule = rule[0] not.rule = rule[0]
}
return not, nil return not, nil
} }
@ -37,7 +40,7 @@ func (not *NOT) RuleType() C.RuleType {
} }
func (not *NOT) Match(metadata *C.Metadata) bool { func (not *NOT) Match(metadata *C.Metadata) bool {
return !not.rule.Match(metadata) return not.rule == nil || !not.rule.Match(metadata)
} }
func (not *NOT) Adapter() string { func (not *NOT) Adapter() string {
@ -49,5 +52,5 @@ func (not *NOT) Payload() string {
} }
func (not *NOT) ShouldResolveIP() bool { func (not *NOT) ShouldResolveIP() bool {
return not.rule.ShouldResolveIP() return not.rule != nil && not.rule.ShouldResolveIP()
} }

View file

@ -1,8 +1,6 @@
package logic package logic
import ( import (
"fmt"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/rule/common" "github.com/Dreamacro/clash/rule/common"
) )
@ -47,16 +45,12 @@ func (or *OR) ShouldResolveIP() bool {
func NewOR(payload string, adapter string) (*OR, error) { func NewOR(payload string, adapter string) (*OR, error) {
or := &OR{Base: &common.Base{}, payload: payload, adapter: adapter} or := &OR{Base: &common.Base{}, payload: payload, adapter: adapter}
rules, err := parseRuleByPayload(payload, true) rules, err := parseRuleByPayload(payload)
if err != nil { if err != nil {
return nil, err return nil, err
} }
or.rules = rules or.rules = rules
if len(or.rules) == 0 {
return nil, fmt.Errorf("Or rule is error, may be format error or not contain least one rule")
}
for _, rule := range rules { for _, rule := range rules {
if rule.ShouldResolveIP() { if rule.ShouldResolveIP() {
or.needIP = true or.needIP = true