From 8054749b408bb49fd1aa1d7b1bc8937ce0b16f35 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 22 Apr 2022 16:27:51 +0800 Subject: [PATCH] feat: support uid rule eg. UID,1000/5000-6000,Proxy --- component/process/process.go | 13 ++++- component/process/process_linux.go | 3 +- constant/rule.go | 3 + rule/common/uid.go | 89 ++++++++++++++++++++++++++++++ rule/parser.go | 2 + tunnel/tunnel.go | 2 +- 6 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 rule/common/uid.go diff --git a/component/process/process.go b/component/process/process.go index c8433b94..3ecc4b1e 100644 --- a/component/process/process.go +++ b/component/process/process.go @@ -2,12 +2,11 @@ package process import ( "errors" + "github.com/Dreamacro/clash/common/nnip" + C "github.com/Dreamacro/clash/constant" "net" "net/netip" "runtime" - - "github.com/Dreamacro/clash/common/nnip" - C "github.com/Dreamacro/clash/constant" ) var ( @@ -25,6 +24,14 @@ func FindProcessName(network string, srcIP netip.Addr, srcPort int) (string, err return findProcessName(network, srcIP, srcPort) } +func FindUid(network string, srcIP netip.Addr, srcPort int) (int32, error) { + _, uid, err := resolveSocketByNetlink(network, srcIP, srcPort) + if err != nil { + return -1, err + } + return uid, nil +} + func ShouldFindProcess(metadata *C.Metadata) bool { if runtime.GOOS == "android" { return false diff --git a/component/process/process_linux.go b/component/process/process_linux.go index 2c01f17f..5a98008e 100644 --- a/component/process/process_linux.go +++ b/component/process/process_linux.go @@ -37,7 +37,6 @@ func findProcessName(network string, ip netip.Addr, srcPort int) (string, error) if err != nil { return "", err } - return resolveProcessNameByProcSearch(inode, uid) } @@ -108,7 +107,7 @@ func resolveSocketByNetlink(network string, ip netip.Addr, srcPort int) (int32, return 0, 0, fmt.Errorf("netlink message: NLMSG_ERROR") } - inode, uid := unpackSocketDiagResponse(&messages[0]) + inode, uid := unpackSocketDiagResponse(&message) if inode < 0 || uid < 0 { return 0, 0, fmt.Errorf("invalid inode(%d) or uid(%d)", inode, uid) } diff --git a/constant/rule.go b/constant/rule.go index 68d1b4b1..1617979f 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -16,6 +16,7 @@ const ( Script RuleSet Network + Uid MATCH AND OR @@ -56,6 +57,8 @@ func (rt RuleType) String() string { return "RuleSet" case Network: return "Network" + case Uid: + return "Uid" case AND: return "AND" case OR: diff --git a/rule/common/uid.go b/rule/common/uid.go new file mode 100644 index 00000000..80c7d73a --- /dev/null +++ b/rule/common/uid.go @@ -0,0 +1,89 @@ +package common + +import ( + "github.com/Dreamacro/clash/common/utils" + "github.com/Dreamacro/clash/component/process" + C "github.com/Dreamacro/clash/constant" + "strconv" + "strings" +) + +type Uid struct { + *Base + uids []utils.Range[int32] + oUid string + adapter string +} + +func NewUid(oUid, adapter string) (*Uid, error) { + //if len(_uids) > 28 { + // return nil, fmt.Errorf("%s, too many uid to use, maximum support 28 uid", errPayload.Error()) + //} + + var uidRange []utils.Range[int32] + for _, u := range strings.Split(oUid, "/") { + if u == "" { + continue + } + + subUids := strings.Split(u, "-") + subUidsLen := len(subUids) + if subUidsLen > 2 { + return nil, errPayload + } + + uidStart, err := strconv.ParseUint(strings.Trim(subUids[0], "[ ]"), 10, 32) + if err != nil { + return nil, errPayload + } + + switch subUidsLen { + case 1: + uidRange = append(uidRange, *utils.NewRange(int32(uidStart), int32(uidStart))) + case 2: + uidEnd, err := strconv.ParseUint(strings.Trim(subUids[1], "[ ]"), 10, 32) + if err != nil { + return nil, errPayload + } + + uidRange = append(uidRange, *utils.NewRange(int32(uidStart), int32(uidEnd))) + } + } + + if len(uidRange) == 0 { + return nil, errPayload + } + return &Uid{ + Base: &Base{}, + adapter: adapter, + oUid: oUid, + uids: uidRange, + }, nil +} + +func (u *Uid) RuleType() C.RuleType { + return C.Uid +} + +func (u *Uid) Match(metadata *C.Metadata) bool { + srcPort, err := strconv.Atoi(metadata.SrcPort) + if err != nil { + return false + } + if uid, err := process.FindUid(metadata.NetWork.String(), metadata.SrcIP, srcPort); err == nil { + for _, _uid := range u.uids { + if _uid.Contains(uid) { + return true + } + } + } + return false +} + +func (u *Uid) Adapter() string { + return u.adapter +} + +func (u *Uid) Payload() string { + return u.oUid +} diff --git a/rule/parser.go b/rule/parser.go index 010b2205..a14fe16b 100644 --- a/rule/parser.go +++ b/rule/parser.go @@ -45,6 +45,8 @@ func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { parsed, parseErr = RP.NewRuleSet(payload, target) case "NETWORK": parsed, parseErr = RC.NewNetworkType(payload, target) + case "UID": + parsed, parseErr = RC.NewUid(payload, target) case "AND": parsed, parseErr = logic.NewAND(payload, target) case "OR": diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index e720dff9..6d0265d6 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -328,7 +328,7 @@ func handleTCPConn(connCtx C.ConnContext) { if rule == nil { log.Warnln("[TCP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error()) } else { - log.Warnln("[TCP] dial %s (match %s/%s) to %s error: %s", proxy.Name(), rule.RuleType().String(), rule.Payload(), metadata.RemoteAddress(), err.Error()) + log.Warnln("[TCP] dial %s (match %s(%s)) to %s error: %s", proxy.Name(), rule.RuleType().String(), rule.Payload(), metadata.RemoteAddress(), err.Error()) } return }