fix: npe when parse rule
This commit is contained in:
parent
36a719e2f8
commit
45fe6e996b
6 changed files with 86 additions and 42 deletions
|
@ -13,22 +13,23 @@ type NetworkType struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetworkType(network, adapter string) (*NetworkType, error) {
|
func NewNetworkType(network, adapter string) (*NetworkType, error) {
|
||||||
var netType C.NetWork
|
ntType := NetworkType{
|
||||||
|
Base: &Base{},
|
||||||
|
}
|
||||||
|
|
||||||
|
ntType.adapter = adapter
|
||||||
switch strings.ToUpper(network) {
|
switch strings.ToUpper(network) {
|
||||||
case "TCP":
|
case "TCP":
|
||||||
netType = C.TCP
|
ntType.network = C.TCP
|
||||||
break
|
break
|
||||||
case "UDP":
|
case "UDP":
|
||||||
netType = C.UDP
|
ntType.network = C.UDP
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported network type, only TCP/UDP")
|
return nil, fmt.Errorf("unsupported network type, only TCP/UDP")
|
||||||
}
|
}
|
||||||
return &NetworkType{
|
|
||||||
Base: &Base{},
|
return &ntType, nil
|
||||||
network: netType,
|
|
||||||
adapter: adapter,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetworkType) RuleType() C.RuleType {
|
func (n *NetworkType) RuleType() C.RuleType {
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package logic
|
package logic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
|
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
"github.com/Dreamacro/clash/rule/common"
|
"github.com/Dreamacro/clash/rule/common"
|
||||||
)
|
)
|
||||||
|
@ -21,16 +19,12 @@ func (A *AND) ShouldFindProcess() bool {
|
||||||
|
|
||||||
func NewAND(payload string, adapter string) (*AND, error) {
|
func NewAND(payload string, adapter string) (*AND, error) {
|
||||||
and := &AND{Base: &common.Base{}, payload: payload, adapter: adapter}
|
and := &AND{Base: &common.Base{}, payload: payload, adapter: adapter}
|
||||||
rules, err := parseRuleByPayload(payload, true)
|
rules, err := parseRuleByPayload(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
and.rules = rules
|
and.rules = rules
|
||||||
if len(and.rules) == 0 {
|
|
||||||
return nil, fmt.Errorf("And rule is error, may be format error or not contain least one rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.ShouldResolveIP() {
|
if rule.ShouldResolveIP() {
|
||||||
and.needIP = true
|
and.needIP = true
|
||||||
|
|
|
@ -2,20 +2,19 @@ package logic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/common/collections"
|
"github.com/Dreamacro/clash/common/collections"
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
"github.com/Dreamacro/clash/log"
|
"github.com/Dreamacro/clash/log"
|
||||||
RC "github.com/Dreamacro/clash/rule/common"
|
RC "github.com/Dreamacro/clash/rule/common"
|
||||||
"github.com/Dreamacro/clash/rule/provider"
|
"github.com/Dreamacro/clash/rule/provider"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseRuleByPayload(payload string, skip bool) ([]C.Rule, error) {
|
func parseRuleByPayload(payload string) ([]C.Rule, error) {
|
||||||
regex, err := regexp.Compile("\\(.*\\)")
|
regex, err := regexp.Compile("\\(.*\\)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -28,7 +27,7 @@ func parseRuleByPayload(payload string, skip bool) ([]C.Rule, error) {
|
||||||
}
|
}
|
||||||
rules := make([]C.Rule, 0, len(subAllRanges))
|
rules := make([]C.Rule, 0, len(subAllRanges))
|
||||||
|
|
||||||
subRanges := findSubRuleRange(payload, subAllRanges, skip)
|
subRanges := findSubRuleRange(payload, subAllRanges)
|
||||||
for _, subRange := range subRanges {
|
for _, subRange := range subRanges {
|
||||||
subPayload := payload[subRange.start+1 : subRange.end]
|
subPayload := payload[subRange.start+1 : subRange.end]
|
||||||
|
|
||||||
|
@ -53,7 +52,7 @@ func containRange(r Range, preStart, preEnd int) bool {
|
||||||
func payloadToRule(subPayload string) (C.Rule, error) {
|
func payloadToRule(subPayload string) (C.Rule, error) {
|
||||||
splitStr := strings.SplitN(subPayload, ",", 2)
|
splitStr := strings.SplitN(subPayload, ",", 2)
|
||||||
if len(splitStr) < 2 {
|
if len(splitStr) < 2 {
|
||||||
return nil, fmt.Errorf("The logic rule contain a rule of error format")
|
return nil, fmt.Errorf("[%s] format is error", subPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
tp := splitStr[0]
|
tp := splitStr[0]
|
||||||
|
@ -91,9 +90,9 @@ func parseRule(tp, payload string, params []string) (C.Rule, error) {
|
||||||
parsed, parseErr = RC.NewGEOIP(payload, "", noResolve)
|
parsed, parseErr = RC.NewGEOIP(payload, "", noResolve)
|
||||||
case "IP-CIDR", "IP-CIDR6":
|
case "IP-CIDR", "IP-CIDR6":
|
||||||
noResolve := RC.HasNoResolve(params)
|
noResolve := RC.HasNoResolve(params)
|
||||||
parsed, parseErr = RC.NewIPCIDR(payload, "", nil, RC.WithIPCIDRNoResolve(noResolve))
|
parsed, parseErr = RC.NewIPCIDR(payload, "", RC.WithIPCIDRNoResolve(noResolve))
|
||||||
case "SRC-IP-CIDR":
|
case "SRC-IP-CIDR":
|
||||||
parsed, parseErr = RC.NewIPCIDR(payload, "", nil, RC.WithIPCIDRSourceIP(true), RC.WithIPCIDRNoResolve(true))
|
parsed, parseErr = RC.NewIPCIDR(payload, "", RC.WithIPCIDRSourceIP(true), RC.WithIPCIDRNoResolve(true))
|
||||||
case "SRC-PORT":
|
case "SRC-PORT":
|
||||||
parsed, parseErr = RC.NewPort(payload, "", true)
|
parsed, parseErr = RC.NewPort(payload, "", true)
|
||||||
case "DST-PORT":
|
case "DST-PORT":
|
||||||
|
@ -113,7 +112,7 @@ func parseRule(tp, payload string, params []string) (C.Rule, error) {
|
||||||
case "NETWORK":
|
case "NETWORK":
|
||||||
parsed, parseErr = RC.NewNetworkType(payload, "")
|
parsed, parseErr = RC.NewNetworkType(payload, "")
|
||||||
default:
|
default:
|
||||||
parseErr = fmt.Errorf("unsupported rule type %s", tp)
|
parsed, parseErr = nil, fmt.Errorf("unsupported rule type %s", tp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
|
@ -151,6 +150,10 @@ func format(payload string) ([]Range, error) {
|
||||||
num++
|
num++
|
||||||
stack.Push(sr)
|
stack.Push(sr)
|
||||||
} else if c == ')' {
|
} else if c == ')' {
|
||||||
|
if stack.Len() == 0 {
|
||||||
|
return nil, fmt.Errorf("missing '('")
|
||||||
|
}
|
||||||
|
|
||||||
sr := stack.Pop().(Range)
|
sr := stack.Pop().(Range)
|
||||||
sr.end = i
|
sr.end = i
|
||||||
subRanges = append(subRanges, sr)
|
subRanges = append(subRanges, sr)
|
||||||
|
@ -169,11 +172,11 @@ func format(payload string) ([]Range, error) {
|
||||||
return sortResult, nil
|
return sortResult, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func findSubRuleRange(payload string, ruleRanges []Range, skip bool) []Range {
|
func findSubRuleRange(payload string, ruleRanges []Range) []Range {
|
||||||
payloadLen := len(payload)
|
payloadLen := len(payload)
|
||||||
subRuleRange := make([]Range, 0)
|
subRuleRange := make([]Range, 0)
|
||||||
for _, rr := range ruleRanges {
|
for _, rr := range ruleRanges {
|
||||||
if rr.start == 0 && rr.end == payloadLen-1 && skip {
|
if rr.start == 0 && rr.end == payloadLen-1 {
|
||||||
// 最大范围跳过
|
// 最大范围跳过
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
49
rule/logic/logic_test.go
Normal file
49
rule/logic/logic_test.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package logic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Dreamacro/clash/constant"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAND(t *testing.T) {
|
||||||
|
and, err := NewAND("((DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "DIRECT", and.adapter)
|
||||||
|
assert.Equal(t, false, and.ShouldResolveIP())
|
||||||
|
assert.Equal(t, true, and.Match(&constant.Metadata{
|
||||||
|
Host: "baidu.com",
|
||||||
|
AddrType: constant.AtypDomainName,
|
||||||
|
NetWork: constant.TCP,
|
||||||
|
DstPort: "20000",
|
||||||
|
}))
|
||||||
|
|
||||||
|
and, err = NewAND("(DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT")
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
|
||||||
|
and, err = NewAND("((AND,(DOMAIN,baidu.com),(NETWORK,TCP)),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNOT(t *testing.T) {
|
||||||
|
not, err := NewNOT("((DST-PORT,6000-6500))", "REJECT")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, false, not.Match(&constant.Metadata{
|
||||||
|
DstPort: "6100",
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err = NewNOT("((DST-PORT,5600-6666),(DOMAIN,baidu.com))", "DIRECT")
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
|
||||||
|
_, err = NewNOT("(())", "DIRECT")
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOR(t *testing.T) {
|
||||||
|
or, err := NewOR("((DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, true, or.Match(&constant.Metadata{
|
||||||
|
NetWork: constant.TCP,
|
||||||
|
}))
|
||||||
|
assert.Equal(t, false, or.ShouldResolveIP())
|
||||||
|
}
|
|
@ -19,16 +19,19 @@ func (not *NOT) ShouldFindProcess() bool {
|
||||||
|
|
||||||
func NewNOT(payload string, adapter string) (*NOT, error) {
|
func NewNOT(payload string, adapter string) (*NOT, error) {
|
||||||
not := &NOT{Base: &common.Base{}, payload: payload, adapter: adapter}
|
not := &NOT{Base: &common.Base{}, payload: payload, adapter: adapter}
|
||||||
rule, err := parseRuleByPayload(payload, false)
|
rule, err := parseRuleByPayload(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(rule) < 1 {
|
if len(rule) > 1 {
|
||||||
return nil, fmt.Errorf("NOT rule have not a rule")
|
return nil, fmt.Errorf("not rule can contain at most one rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rule) > 0 {
|
||||||
|
not.rule = rule[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
not.rule = rule[0]
|
|
||||||
return not, nil
|
return not, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,7 +40,7 @@ func (not *NOT) RuleType() C.RuleType {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (not *NOT) Match(metadata *C.Metadata) bool {
|
func (not *NOT) Match(metadata *C.Metadata) bool {
|
||||||
return !not.rule.Match(metadata)
|
return not.rule == nil || !not.rule.Match(metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (not *NOT) Adapter() string {
|
func (not *NOT) Adapter() string {
|
||||||
|
@ -49,5 +52,5 @@ func (not *NOT) Payload() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (not *NOT) ShouldResolveIP() bool {
|
func (not *NOT) ShouldResolveIP() bool {
|
||||||
return not.rule.ShouldResolveIP()
|
return not.rule != nil && not.rule.ShouldResolveIP()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package logic
|
package logic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
|
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
"github.com/Dreamacro/clash/rule/common"
|
"github.com/Dreamacro/clash/rule/common"
|
||||||
)
|
)
|
||||||
|
@ -47,16 +45,12 @@ func (or *OR) ShouldResolveIP() bool {
|
||||||
|
|
||||||
func NewOR(payload string, adapter string) (*OR, error) {
|
func NewOR(payload string, adapter string) (*OR, error) {
|
||||||
or := &OR{Base: &common.Base{}, payload: payload, adapter: adapter}
|
or := &OR{Base: &common.Base{}, payload: payload, adapter: adapter}
|
||||||
rules, err := parseRuleByPayload(payload, true)
|
rules, err := parseRuleByPayload(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
or.rules = rules
|
or.rules = rules
|
||||||
if len(or.rules) == 0 {
|
|
||||||
return nil, fmt.Errorf("Or rule is error, may be format error or not contain least one rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.ShouldResolveIP() {
|
if rule.ShouldResolveIP() {
|
||||||
or.needIP = true
|
or.needIP = true
|
||||||
|
|
Loading…
Reference in a new issue