From 54c22a2fceac86a364c7acd7f88d74d9f62590eb Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 15 Oct 2021 14:11:14 +0800 Subject: [PATCH] Feature: add mode script --- Makefile | 8 +- README.md | 113 +++- component/mmdb/mmdb.go | 45 ++ component/script/clash_module.c | 726 ++++++++++++++++++++++++ component/script/clash_module.go | 321 +++++++++++ component/script/clash_module.h | 61 ++ component/script/clash_module_export.go | 124 ++++ component/script/thread.go | 52 ++ config/config.go | 163 +++++- config/initial.go | 80 ++- constant/path.go | 18 + constant/rule.go | 7 + dns/filters.go | 54 +- go.mod | 5 + go.sum | 13 + hub/executor/executor.go | 11 + listener/tun/ipstack/lwip/tcp.go | 29 +- main.go | 2 +- rule/geodata/memconservative/cache.go | 6 +- rule/geodata/router/condition_geoip.go | 217 ++----- rule/geodata/standard/standard.go | 4 +- rule/geoip.go | 88 +-- rule/geosite.go | 6 +- rule/parser.go | 2 + rule/script.go | 72 +++ transport/vless/conn.go | 3 +- tunnel/mode.go | 4 + tunnel/tunnel.go | 40 +- 28 files changed, 1918 insertions(+), 356 deletions(-) create mode 100644 component/mmdb/mmdb.go create mode 100644 component/script/clash_module.c create mode 100644 component/script/clash_module.go create mode 100644 component/script/clash_module.h create mode 100644 component/script/clash_module_export.go create mode 100644 component/script/thread.go create mode 100644 rule/script.go diff --git a/Makefile b/Makefile index e94c67a4..98a10072 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ GOCMD=go -XGOCMD=xgo -go go-1.17.x -GOBUILD=CGO_ENABLED=1 $(GOCMD) build -a -trimpath +XGOCMD=xgo -go=go-1.17.x +GOBUILD=CGO_ENABLED=1 $(GOCMD) build -trimpath GOCLEAN=$(GOCMD) clean NAME=clash BINDIR=$(shell pwd)/bin @@ -72,11 +72,11 @@ all-arch: $(PLATFORM_LIST) $(WINDOWS_ARCH_LIST) releases: $(gz_releases) $(zip_releases) clean: - rm -rf $(BINDIR) + rm -rf $(BINDIR)/ mkdir -p $(BINDIR) cleancache: # go build cache may need to cleanup if changing C source code $(GOCLEAN) -cache - rm -rf $(BINDIR) + rm -rf $(BINDIR)/ mkdir -p $(BINDIR) \ No newline at end of file diff --git a/README.md b/README.md index e7fe4a6f..24e2ecb7 100644 --- a/README.md +++ b/README.md @@ -35,30 +35,37 @@ ## Getting Started Documentations are now moved to [GitHub Wiki](https://github.com/Dreamacro/clash/wiki). -## Advanced usage for this fork branch +## Advanced usage for this branch ### TUN configuration Supports macOS, Linux and Windows. -Support lwIP stack, a lightweight TCP/IP stack, recommend set to tun. +Support lwIP stack, a lightweight TCP/IP stack, it's recommended set to tun. On Windows, you should download the [Wintun](https://www.wintun.net) driver and copy `wintun.dll` into Clash home directory. ```yaml # Enable the TUN listener tun: enable: true - stack: lwip # lwip(recommend), system or gvisor + stack: lwip # lwip(recommended), system or gvisor dns-listen: 0.0.0.0:53 # additional dns server listen on TUN auto-route: true # auto set global route ``` ### Rules configuration - Support rule `GEOSITE`. +- Support rule `SCRIPT`. - Support `multiport` condition for rule `SRC-PORT` and `DST-PORT`. -- Support not match condition for rule `GEOIP`. - Support `network` condition for all rules. - Support source IPCIDR condition for all rules, just append to the end. -The `GEOSITE` and `GEOIP` databases via https://github.com/Loyalsoldier/v2ray-rules-dat. +The `GEOSITE` databases via https://github.com/Loyalsoldier/v2ray-rules-dat. ```yaml +mode: rule + +script: + shortcuts: + quic: 'network == "udp" and dst_port == 443' + privacy: '"analytics" in host or "adservice" in host or "firebase" in host or "safebrowsing" in host or "doubleclick" in host' + rules: # network condition for all rules - DOMAIN-SUFFIX,bilibili.com,DIRECT,tcp @@ -67,6 +74,10 @@ rules: # multiport condition for rules SRC-PORT and DST-PORT - DST-PORT,123/136/137-139,DIRECT,udp + # rule SCRIPT + - SCRIPT,quic,REJECT # Disable QUIC, same as rule "- DST-PORT,443,REJECT,udp" + - SCRIPT,privacy,REJECT + # rule GEOSITE - GEOSITE,category-ads-all,REJECT - GEOSITE,icloud@cn,DIRECT @@ -76,23 +87,92 @@ rules: - GEOSITE,facebook,PROXY - GEOSITE,youtube,PROXY - GEOSITE,geolocation-cn,DIRECT - - GEOSITE,gfw,PROXY - - GEOSITE,greatfire,PROXY - #- GEOSITE,geolocation-!cn,PROXY + - GEOSITE,geolocation-!cn,PROXY + + # source IPCIDR condition for all rules in gateway proxy + #- GEOSITE,apple,PROXY,192.168.1.88/32,192.168.1.99/32 - GEOIP,telegram,PROXY,no-resolve - GEOIP,private,DIRECT,no-resolve - GEOIP,cn,DIRECT - - # Not match condition for rule GEOIP - #- GEOIP,!cn,PROXY - - # source IPCIDR condition for all rules in gateway proxy - #- GEOIP,!cn,PROXY,192.168.1.88/32,192.168.1.99/32 - + - MATCH,PROXY ``` +### Script configuration +Script enables users to programmatically select a policy for the packets with more flexibility. + +```yaml +mode: script + +rules: + # the rule GEOSITE just as a rule provider in script mode + - GEOSITE,category-ads-all,Whatever + - GEOSITE,youtube,Whatever + - GEOSITE,geolocation-cn,Whatever + +script: + code: | + def main(ctx, metadata): + if metadata["process_name"] == 'apsd': + return "DIRECT" + + if metadata["network"] == 'udp' and metadata["dst_port"] == 443: + return "REJECT" + + host = metadata["host"] + for kw in ['analytics', 'adservice', 'firebase', 'bugly', 'safebrowsing', 'doubleclick']: + if kw in host: + return "REJECT" + + now = time.now() + if (now.hour < 8 or now.hour > 17) and metadata["src_ip"] == '192.168.1.99': + return "REJECT" + + if ctx.rule_providers["geosite:category-ads-all"].match(metadata): + return "REJECT" + + if ctx.rule_providers["geosite:youtube"].match(metadata): + ctx.log('[Script] domain %s matched youtube' % host) + return "Proxy" + + if ctx.rule_providers["geosite:geolocation-cn"].match(metadata): + ctx.log('[Script] domain %s matched geolocation-cn' % host) + return "CN" + + ip = metadata["dst_ip"] + if host != "": + ip = ctx.resolve_ip(host) + if ip == "": + return "Proxy" + + code = ctx.geoip(ip) + if code == "LAN" or code == "CN": + return "DIRECT" + + return "Proxy" # default policy for requests which are not matched by any other script +``` +the context and metadata +```python +interface Metadata { +type: string // socks5、http +network: string // tcp +host: string +process_name: string +src_ip: string +src_port: int +dst_ip: string +dst_port: int +} + +interface Context { +resolve_ip: (host: string) => string // ip string +geoip: (ip: string) => string // country code +log: (log: string) => void +rule_providers: Record boolean }> +} +``` + ### Proxies configuration Support outbound transport protocol `VLESS`. @@ -170,9 +250,6 @@ Add field `Process` to `Metadata` and prepare to get process name for Restful AP To display process name in GUI please use https://yaling888.github.io/yacd/. -## Premium Release -[Release](https://github.com/Dreamacro/clash/releases/tag/premium) - ## Development If you want to build an application that uses clash as a library, check out the the [GitHub Wiki](https://github.com/Dreamacro/clash/wiki/use-clash-as-a-library) diff --git a/component/mmdb/mmdb.go b/component/mmdb/mmdb.go new file mode 100644 index 00000000..e120055d --- /dev/null +++ b/component/mmdb/mmdb.go @@ -0,0 +1,45 @@ +package mmdb + +import ( + "sync" + + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + + "github.com/oschwald/geoip2-golang" +) + +var ( + mmdb *geoip2.Reader + once sync.Once +) + +func LoadFromBytes(buffer []byte) { + once.Do(func() { + var err error + mmdb, err = geoip2.FromBytes(buffer) + if err != nil { + log.Fatalln("Can't load mmdb: %s", err.Error()) + } + }) +} + +func Verify() bool { + instance, err := geoip2.Open(C.Path.MMDB()) + if err == nil { + instance.Close() + } + return err == nil +} + +func Instance() *geoip2.Reader { + once.Do(func() { + var err error + mmdb, err = geoip2.Open(C.Path.MMDB()) + if err != nil { + log.Fatalln("Can't load mmdb: %s", err.Error()) + } + }) + + return mmdb +} diff --git a/component/script/clash_module.c b/component/script/clash_module.c new file mode 100644 index 00000000..a7d458e8 --- /dev/null +++ b/component/script/clash_module.c @@ -0,0 +1,726 @@ +#define PY_SSIZE_T_CLEAN + +#include "clash_module.h" +#include + +PyObject *clash_module; +PyObject *main_fn; +PyObject *clash_context; + +// init_python +void init_python(const char *path) { + + append_inittab(); + + Py_Initialize(); + + wchar_t *program = Py_DecodeLocale("clash", NULL); + if (program != NULL) { + Py_SetProgramName(program); + PyMem_RawFree(program); + } + +// wchar_t *newPath = Py_DecodeLocale(path, NULL); +// if (newPath != NULL) { +// Py_SetPath(newPath); +// } + + char *pathPrefix = "import sys; sys.path.append('"; + char *pathSuffix = "')"; + char *newPath = (char *) malloc(strlen(pathPrefix) + strlen(path) + strlen(pathSuffix)); + sprintf(newPath, "%s%s%s", pathPrefix, path, pathSuffix); + + PyRun_SimpleString(newPath); + free(newPath); + + /* Optionally import the module; alternatively, + import can be deferred until the embedded script + imports it. */ + clash_module = PyImport_ImportModule("clash"); + + main_fn = load_func(CLASH_SCRIPT_MODULE_NAME, "main"); +} + +// Load function, same as "import module_name.func_name as obj" in Python +// Returns the function object or NULL if not found +PyObject *load_func(const char *module_name, char *func_name) { + // Import the module + PyObject *py_mod_name = PyUnicode_FromString(module_name); + if (py_mod_name == NULL) { + return NULL; + } + + PyObject *module = PyImport_Import(py_mod_name); + Py_DECREF(py_mod_name); + if (module == NULL) { + return NULL; + } + + // Get function, same as "getattr(module, func_name)" in Python + PyObject *func = PyObject_GetAttrString(module, func_name); + Py_DECREF(module); + return func; +} + +// Return last error as char *, NULL if there was no error +const char *py_last_error() { + PyObject *err = PyErr_Occurred(); + if (err == NULL) { + return NULL; + } + + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + + if (value == NULL) { + return NULL; + } + + PyObject *str = PyObject_Str(value); + const char *utf8 = PyUnicode_AsUTF8(str); + Py_DECREF(str); + PyErr_Clear(); + return utf8; +} + +void py_clear(PyObject *obj) { + Py_CLEAR(obj); +} + +/** callback function, that call go function by python3 script. **/ + +resolve_ip_callback resolve_ip_callback_fn; + +geoip_callback geoip_callback_fn; + +rule_provider_callback rule_provider_callback_fn; + +log_callback log_callback_fn; + +void +set_resolve_ip_callback(resolve_ip_callback cb) +{ + resolve_ip_callback_fn = cb; +} + +void +set_geoip_callback(geoip_callback cb) +{ + geoip_callback_fn = cb; +} + +void +set_rule_provider_callback(rule_provider_callback cb) +{ + rule_provider_callback_fn = cb; +} + +void +set_log_callback(log_callback cb) +{ + log_callback_fn = cb; +} + +/** end callback function **/ + +/* --------------------------------------------------------------------- */ + +/* RuleProvider objects */ + +typedef struct { + PyObject_HEAD + PyObject *name; /* rule provider name */ +} RuleProviderObject; + +static int +RuleProvider_traverse(RuleProviderObject *self, visitproc visit, void *arg) +{ + Py_VISIT(self->name); + return 0; +} + +static int +RuleProvider_clear(RuleProviderObject *self) +{ + Py_CLEAR(self->name); + return 0; +} + +static void +RuleProvider_dealloc(RuleProviderObject *self) +{ + PyObject_GC_UnTrack(self); + RuleProvider_clear(self); + Py_TYPE(self)->tp_free((PyObject *) self); +} + +static PyObject * +RuleProvider_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + RuleProviderObject *self; + self = (RuleProviderObject *) type->tp_alloc(type, 0); + if (self != NULL) { + self->name = PyUnicode_FromString(""); + if (self->name == NULL) { + Py_DECREF(self); + return NULL; + } + } + return (PyObject *) self; +} + +static int +RuleProvider_init(RuleProviderObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"name", NULL}; + PyObject *name = NULL, *tmp; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|Us", kwlist, &name)) + return -1; + + if (name) { + tmp = self->name; + Py_INCREF(name); + self->name = name; + Py_DECREF(tmp); + } + return 0; +} + +//static PyMemberDef RuleProvider_members[] = { +// {"adapter_type", T_STRING, offsetof(RuleProviderObject, adapter_type), 0, +// "adapter type"}, +// {NULL} /* Sentinel */ +//}; + +static PyObject * +RuleProvider_getname(RuleProviderObject *self, void *closure) +{ + Py_INCREF(self->name); + return self->name; +} + +static int +RuleProvider_setname(RuleProviderObject *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, "Cannot delete the name attribute"); + return -1; + } + if (!PyUnicode_Check(value)) { + PyErr_SetString(PyExc_TypeError, + "The name attribute value must be a string"); + return -1; + } + Py_INCREF(value); + Py_CLEAR(self->name); + self->name = value; + return 0; +} + +static PyGetSetDef RuleProvider_getsetters[] = { + {"name", (getter) RuleProvider_getname, (setter) RuleProvider_setname, + "name", NULL}, + {NULL} /* Sentinel */ +}; + +static PyObject * +RuleProvider_name(RuleProviderObject *self, PyObject *Py_UNUSED(ignored)) +{ + Py_INCREF(self->name); + return self->name; +} + +static PyObject * +RuleProvider_match(RuleProviderObject *self, PyObject *args) +{ + PyObject *result; + PyObject *tmp; + const char *provider_name; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &tmp)) //Format "O","O!","O&": Borrowed reference. + return NULL; + + if (tmp == NULL) + Py_RETURN_FALSE; + + Py_INCREF(tmp); +// PyObject *py_src_port = PyDict_GetItemString(tmp, "src_port"); //Return value: Borrowed reference. +// PyObject *py_dst_port = PyDict_GetItemString(tmp, "dst_port"); //Return value: Borrowed reference. +// Py_INCREF(py_src_port); +// Py_INCREF(py_dst_port); +// char *c_src_port = (char *) malloc(PyLong_AsSize_t(py_src_port)); +// char *c_dst_port = (char *) malloc(PyLong_AsSize_t(py_dst_port)); +// sprintf(c_src_port, "%ld", PyLong_AsLong(py_src_port)); +// sprintf(c_dst_port, "%ld", PyLong_AsLong(py_dst_port)); + + struct Metadata metadata = { + .type = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "type")), // PyDict_GetItemString() Return value: Borrowed reference. + .network = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "network")), + .process_name = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "process_name")), + .host = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "host")), + .src_ip = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "src_ip")), + .src_port = (unsigned short)PyLong_AsUnsignedLong(PyDict_GetItemString(tmp, "src_port")), + .dst_ip = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "dst_ip")), + .dst_port = (unsigned short)PyLong_AsUnsignedLong(PyDict_GetItemString(tmp, "dst_port")) + }; + +// Py_DECREF(py_src_port); +// Py_DECREF(py_dst_port); + + Py_INCREF(self->name); + provider_name = PyUnicode_AsUTF8(self->name); + Py_DECREF(self->name); + Py_DECREF(tmp); + + int rs = rule_provider_callback_fn(provider_name, &metadata); + + result = (rs == 1) ? Py_True : Py_False; + Py_INCREF(result); + return result; +} + +static PyMethodDef RuleProvider_methods[] = { + {"name", (PyCFunction) RuleProvider_name, METH_NOARGS, + "Return the RuleProvider name" + }, + {"match", (PyCFunction) RuleProvider_match, METH_VARARGS, + "Match the rule by the RuleProvider, match(metadata) -> boolean" + }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject RuleProviderType = { + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "clash.RuleProvider", + .tp_doc = "Clash RuleProvider objects", + .tp_basicsize = sizeof(RuleProviderObject), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, + .tp_new = RuleProvider_new, + .tp_init = (initproc) RuleProvider_init, + .tp_dealloc = (destructor) RuleProvider_dealloc, + .tp_traverse = (traverseproc) RuleProvider_traverse, + .tp_clear = (inquiry) RuleProvider_clear, +// .tp_members = RuleProvider_members, + .tp_methods = RuleProvider_methods, + .tp_getset = RuleProvider_getsetters, +}; + +/* end RuleProvider objects */ +/* --------------------------------------------------------------------- */ + +/* Context objects */ + +typedef struct { + PyObject_HEAD + PyObject *rule_providers; /* Dict */ +} ContextObject; + +static int +Context_traverse(ContextObject *self, visitproc visit, void *arg) +{ + Py_VISIT(self->rule_providers); + return 0; +} + +static int +Context_clear(ContextObject *self) +{ + Py_CLEAR(self->rule_providers); + return 0; +} + +static void +Context_dealloc(ContextObject *self) +{ + PyObject_GC_UnTrack(self); + Context_clear(self); + Py_TYPE(self)->tp_free((PyObject *) self); +} + +static PyObject * +Context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + ContextObject *self; + self = (ContextObject *) type->tp_alloc(type, 0); + if (self != NULL) { + self->rule_providers = PyDict_New(); + if (self->rule_providers == NULL) { + Py_DECREF(self); + return NULL; + } + } + return (PyObject *) self; +} + +static int +Context_init(ContextObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"rule_providers", NULL}; + PyObject *rule_providers = NULL, *tmp; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O", kwlist, + &rule_providers)) + return -1; + + if (rule_providers) { + tmp = self->rule_providers; + Py_INCREF(rule_providers); + self->rule_providers = rule_providers; + Py_DECREF(tmp); + } + return 0; +} + +static PyObject * +Context_getrule_providers(ContextObject *self, void *closure) +{ + Py_INCREF(self->rule_providers); + return self->rule_providers; +} + +static int +Context_setrule_providers(ContextObject *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, "Cannot delete the rule_providers attribute"); + return -1; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, + "The rule_providers attribute value must be a dict"); + return -1; + } + Py_INCREF(value); + Py_CLEAR(self->rule_providers); + self->rule_providers = value; + return 0; +} + +static PyGetSetDef Context_getsetters[] = { + {"rule_providers", (getter) Context_getrule_providers, (setter) Context_setrule_providers, + "rule_providers", NULL}, + {NULL} /* Sentinel */ +}; + +static PyObject * +Context_resolve_ip(PyObject *self, PyObject *args) +{ + const char *host; + const char *ip; + + if (!PyArg_ParseTuple(args, "s", &host)) + return NULL; + + if (host == NULL) + return PyUnicode_FromString(""); + + ip = resolve_ip_callback_fn(host); + + return PyUnicode_FromString(ip); +} + +static PyObject * +Context_geoip(PyObject *self, PyObject *args) +{ + const char *ip; + const char *countryCode; + + if (!PyArg_ParseTuple(args, "s", &ip)) + return NULL; + + if (ip == NULL) + return PyUnicode_FromString(""); + + countryCode = geoip_callback_fn(ip); + + return PyUnicode_FromString(countryCode); +} + +static PyObject * +Context_log(PyObject *self, PyObject *args) +{ + const char *msg; + + if (!PyArg_ParseTuple(args, "s", &msg)) + return NULL; + + log_callback_fn(msg); + + Py_RETURN_NONE; +} + +static PyMethodDef Context_methods[] = { + {"resolve_ip", (PyCFunction) Context_resolve_ip, METH_VARARGS, + "resolve_ip(host) -> string" + }, + {"geoip", (PyCFunction) Context_geoip, METH_VARARGS, + "geoip(ip) -> string" + }, + {"log", (PyCFunction) Context_log, METH_VARARGS, + "log(msg) -> void" + }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject ContextType = { + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "clash.Context", + .tp_doc = "Clash Context objects", + .tp_basicsize = sizeof(ContextObject), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, + .tp_new = Context_new, + .tp_init = (initproc) Context_init, + .tp_dealloc = (destructor) Context_dealloc, + .tp_traverse = (traverseproc) Context_traverse, + .tp_clear = (inquiry) Context_clear, + .tp_methods = Context_methods, + .tp_getset = Context_getsetters, +}; + +static PyModuleDef clashmodule = { + PyModuleDef_HEAD_INIT, + .m_name = "clash", + .m_doc = "Clash module that creates an extension module for python3.", + .m_size = -1, +}; + +PyMODINIT_FUNC +PyInit_clash(void) +{ + PyObject *m; + + m = PyModule_Create(&clashmodule); + if (m == NULL) + return NULL; + + if (PyType_Ready(&RuleProviderType) < 0) + return NULL; + + Py_INCREF(&RuleProviderType); + if (PyModule_AddObject(m, "RuleProvider", (PyObject *) &RuleProviderType) < 0) { + Py_DECREF(&RuleProviderType); + Py_DECREF(m); + return NULL; + } + + if (PyType_Ready(&ContextType) < 0) + return NULL; + + Py_INCREF(&ContextType); + if (PyModule_AddObject(m, "Context", (PyObject *) &ContextType) < 0) { + Py_DECREF(&ContextType); + Py_DECREF(m); + return NULL; + } + + return m; +} + +/* end Context objects */ + +/* --------------------------------------------------------------------- */ + +void +append_inittab() +{ + /* Add a built-in module, before Py_Initialize */ + PyImport_AppendInittab("clash", PyInit_clash); +} + +int new_clash_py_context(const char *provider_name_arr[], int size) { + PyObject *dict = PyDict_New(); //Return value: New reference. + if (dict == NULL) { + PyErr_SetString(PyExc_TypeError, + "PyDict_New failure"); + return 0; + } + + for (int i = 0; i < size; i++) { + PyObject *rule_provider = RuleProvider_new(&RuleProviderType, NULL, NULL); + if (rule_provider == NULL) { + Py_DECREF(dict); + PyErr_SetString(PyExc_TypeError, + "RuleProvider_new failure"); + return 0; + } + + RuleProviderObject *providerObj = (RuleProviderObject *) rule_provider; + + PyObject *py_name = PyUnicode_FromString(provider_name_arr[i]); //Return value: New reference. + RuleProvider_setname(providerObj, py_name, NULL); + Py_DECREF(py_name); + + PyDict_SetItemString(dict, provider_name_arr[i], rule_provider); //Parameter value: New reference. + Py_DECREF(rule_provider); + } + + clash_context = Context_new(&ContextType, NULL, NULL); + + if (clash_context == NULL) { + Py_DECREF(dict); + PyErr_SetString(PyExc_TypeError, + "Context_new failure"); + return 0; + } + + Context_setrule_providers((ContextObject *) clash_context, dict, NULL); + Py_DECREF(dict); + return 1; +} + +const char *call_main( + const char *type, + const char *network, + const char *process_name, + const char *host, + const char *src_ip, + unsigned short src_port, + const char *dst_ip, + unsigned short dst_port) { + + PyObject *metadataDict; + PyObject *tupleArgs; + PyObject *result; + + metadataDict = PyDict_New(); //Return value: New reference. + + if (metadataDict == NULL) { + PyErr_SetString(PyExc_TypeError, + "PyDict_New failure"); + return "-1"; + } + + PyObject *p_type = PyUnicode_FromString(type); //Return value: New reference. + PyObject *p_network = PyUnicode_FromString(network); //Return value: New reference. + PyObject *p_process_name = PyUnicode_FromString(process_name); //Return value: New reference. + PyObject *p_host = PyUnicode_FromString(host); //Return value: New reference. + PyObject *p_src_ip = PyUnicode_FromString(src_ip); //Return value: New reference. + PyObject *p_src_port = PyLong_FromUnsignedLong((unsigned long)src_port); //Return value: New reference. + PyObject *p_dst_ip = PyUnicode_FromString(dst_ip); //Return value: New reference. + PyObject *p_dst_port = PyLong_FromUnsignedLong((unsigned long)dst_port); //Return value: New reference. + + PyDict_SetItemString(metadataDict, "type", p_type); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "network", p_network); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "process_name", p_process_name); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "host", p_host); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "src_ip", p_src_ip); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "src_port", p_src_port); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "dst_ip", p_dst_ip); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "dst_port", p_dst_port); //Parameter value: New reference. + + Py_DECREF(p_type); + Py_DECREF(p_network); + Py_DECREF(p_process_name); + Py_DECREF(p_host); + Py_DECREF(p_src_ip); + Py_DECREF(p_src_port); + Py_DECREF(p_dst_ip); + Py_DECREF(p_dst_port); + + tupleArgs = PyTuple_New(2); //Return value: New reference. + if (tupleArgs == NULL) { + Py_DECREF(metadataDict); + PyErr_SetString(PyExc_TypeError, + "PyTuple_New failure"); + return "-1"; + } + + Py_INCREF(clash_context); + PyTuple_SetItem(tupleArgs, 0, clash_context); //clash_context Parameter value: Stolen reference. + PyTuple_SetItem(tupleArgs, 1, metadataDict); //metadataDict Parameter value: Stolen reference. + + Py_INCREF(main_fn); + result = PyObject_CallObject(main_fn, tupleArgs); //Return value: New reference. + Py_DECREF(main_fn); + Py_DECREF(tupleArgs); + + if (result == NULL) { + return "-1"; + } + + if (!PyUnicode_Check(result)) { + Py_DECREF(result); + PyErr_SetString(PyExc_TypeError, + "script main function return value must be a string"); + return "-1"; + } + + const char *adapter = PyUnicode_AsUTF8(result); + + Py_DECREF(result); + + return adapter; +} + +int call_shortcut(PyObject *shortcut_fn, + const char *type, + const char *network, + const char *process_name, + const char *host, + const char *src_ip, + unsigned short src_port, + const char *dst_ip, + unsigned short dst_port) { + + PyObject *args; + PyObject *result; + + args = Py_BuildValue("{s:O, s:s, s:s, s:s, s:s, s:H, s:s, s:H}", + "ctx", clash_context, + "network", network, + "process_name", process_name, + "host", host, + "src_ip", src_ip, + "src_port", src_port, + "dst_ip", dst_ip, + "dst_port", dst_port); //Return value: New reference. + + if (args == NULL) { + PyErr_SetString(PyExc_TypeError, + "Py_BuildValue failure"); + return -1; + } + + PyObject *tupleArgs = PyTuple_New(0); //Return value: New reference. + + Py_INCREF(clash_context); + Py_INCREF(shortcut_fn); + result = PyObject_Call(shortcut_fn, tupleArgs, args); //Return value: New reference. + Py_DECREF(shortcut_fn); + Py_DECREF(clash_context); + Py_DECREF(tupleArgs); + Py_DECREF(args); + + if (result == NULL) { + return -1; + } + + if (!PyBool_Check(result)) { + Py_DECREF(result); + PyErr_SetString(PyExc_TypeError, + "script shortcut return value must be as boolean"); + return -1; + } + + int rs = (result == Py_True) ? 1 : 0; + + Py_DECREF(result); + + return rs; +} + +void finalize_Python() { + Py_CLEAR(main_fn); + Py_CLEAR(clash_context); + Py_CLEAR(clash_module); + Py_FinalizeEx(); + +// clash_module = NULL; +// main_fn = NULL; +// clash_context = NULL; +} + +/* --------------------------------------------------------------------- */ \ No newline at end of file diff --git a/component/script/clash_module.go b/component/script/clash_module.go new file mode 100644 index 00000000..5fd8005f --- /dev/null +++ b/component/script/clash_module.go @@ -0,0 +1,321 @@ +package script + +/* +#cgo pkg-config: python3-embed +//#cgo pkg-config: python3 +//#cgo LDFLAGS: -lpython3 + +#include "clash_module.h" + +extern const char *resolveIPCallbackFn(const char *host); + +void +go_set_resolve_ip_callback() { + set_resolve_ip_callback(resolveIPCallbackFn); +} + +extern const char *geoipCallbackFn(const char *ip); + +void +go_set_geoip_callback() { + set_geoip_callback(geoipCallbackFn); +} + +extern const int ruleProviderCallbackFn(const char *provider_name, struct Metadata *metadata); + +void +go_set_rule_provider_callback() { + set_rule_provider_callback(ruleProviderCallbackFn); +} + +extern void logCallbackFn(const char *msg); + +void +go_set_log_callback() { + set_log_callback(logCallbackFn); +} +*/ +import "C" +import ( + "errors" + "fmt" + "os" + "runtime" + "strconv" + "sync" + "syscall" + "unsafe" + + "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +const ClashScriptModuleName = C.CLASH_SCRIPT_MODULE_NAME + +var lock sync.Mutex + +type PyObject C.PyObject + +func togo(cobject *C.PyObject) *PyObject { + return (*PyObject)(cobject) +} + +func toc(object *PyObject) *C.PyObject { + return (*C.PyObject)(object) +} + +func (pyObject *PyObject) IncRef() { + C.Py_IncRef(toc(pyObject)) +} + +func (pyObject *PyObject) DecRef() { + C.Py_DecRef(toc(pyObject)) +} + +func (pyObject *PyObject) Clear() { + C.py_clear(toc(pyObject)) +} + +// Py_Initialize initialize Python3 +func Py_Initialize(path string) error { + lock.Lock() + defer lock.Unlock() + + if C.Py_IsInitialized() != 0 { + if pyThreadState != nil { + PyEval_RestoreThread(pyThreadState) + } + C.finalize_Python() + } + + cPath := C.CString(path) + //defer C.free(unsafe.Pointer(cPath)) + + C.init_python(cPath) + //C.Py_Initialize() + err := PyLastError() + + if err != nil { + if C.Py_IsInitialized() != 0 { + C.finalize_Python() + _ = os.RemoveAll(constant.Path.ScriptDir()) + } + return err + } else if C.Py_IsInitialized() == 0 { + err = errors.New("initialized script module failure") + return err + } + + initPython3Callback() + return nil +} + +func Py_IsInitialized() bool { + lock.Lock() + defer lock.Unlock() + + return C.Py_IsInitialized() != 0 +} + +func Py_Finalize() { + lock.Lock() + defer lock.Unlock() + + if C.Py_IsInitialized() != 0 { + if pyThreadState != nil { + PyEval_RestoreThread(pyThreadState) + } + C.finalize_Python() + _ = os.RemoveAll(constant.Path.ScriptDir()) + log.Warnln("Clash clean up script mode.") + } +} + +func Py_GetVersion() string { + cversion := C.Py_GetVersion() + return C.GoString(cversion) +} + +func PyRun_SimpleString(command string) int { + ccommand := C.CString(command) + defer C.free(unsafe.Pointer(ccommand)) + + // C.PyRun_SimpleString is a macro, using C.PyRun_SimpleStringFlags instead + return int(C.PyRun_SimpleStringFlags(ccommand, nil)) +} + +// loadPyFunc loads a Python function by module and function name +func loadPyFunc(moduleName, funcName string) (*C.PyObject, error) { + // Convert names to C char* + cMod := C.CString(moduleName) + cFunc := C.CString(funcName) + + // Free memory allocated by C.CString + defer func() { + C.free(unsafe.Pointer(cMod)) + C.free(unsafe.Pointer(cFunc)) + }() + + fnc := C.load_func(cMod, cFunc) + if fnc == nil { + return nil, PyLastError() + } + + return fnc, nil +} + +//PyLastError python last error +func PyLastError() error { + cp := C.py_last_error() + if cp == nil { + return nil + } + + return errors.New(C.GoString(cp)) +} + +func LoadShortcutFunction(shortcut string) (*PyObject, error) { + fnc, err := loadPyFunc(ClashScriptModuleName, shortcut) + if err != nil { + return nil, err + } + return togo(fnc), nil +} + +//CallPyMainFunction call python script main function +//return the proxy adapter name. +func CallPyMainFunction(mtd *constant.Metadata) (string, error) { + _type := C.CString(mtd.Type.String()) + network := C.CString(mtd.NetWork.String()) + processName := C.CString(mtd.Process) + host := C.CString(mtd.Host) + + srcPortGo, _ := strconv.ParseUint(mtd.SrcPort, 10, 16) + dstPortGo, _ := strconv.ParseUint(mtd.DstPort, 10, 16) + srcPort := C.ushort(srcPortGo) + dstPort := C.ushort(dstPortGo) + + dstIpGo := "" + srcIpGo := "" + if mtd.SrcIP != nil { + srcIpGo = mtd.SrcIP.String() + } + if mtd.DstIP != nil { + dstIpGo = mtd.DstIP.String() + } + srcIp := C.CString(srcIpGo) + dstIp := C.CString(dstIpGo) + + defer func() { + C.free(unsafe.Pointer(_type)) + C.free(unsafe.Pointer(network)) + C.free(unsafe.Pointer(processName)) + C.free(unsafe.Pointer(host)) + C.free(unsafe.Pointer(srcIp)) + C.free(unsafe.Pointer(dstIp)) + }() + + runtime.LockOSThread() + gilState := PyGILState_Ensure() + defer PyGILState_Release(gilState) + + cRs := C.call_main(_type, network, processName, host, srcIp, srcPort, dstIp, dstPort) + + rs := C.GoString(cRs) + if rs == "-1" { + err := PyLastError() + if err != nil { + log.Errorln("[Script] script code error: %s", err.Error()) + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + return "", fmt.Errorf("script code error: %w", err) + } else { + return "", fmt.Errorf("script code error, result: %v", rs) + } + } + + return rs, nil +} + +//CallPyShortcut call python script shortcuts function +//param: shortcut name +//return the match result. +func CallPyShortcut(fn *PyObject, mtd *constant.Metadata) (bool, error) { + _type := C.CString(mtd.Type.String()) + network := C.CString(mtd.NetWork.String()) + processName := C.CString(mtd.Process) + host := C.CString(mtd.Host) + + srcPortGo, _ := strconv.ParseUint(mtd.SrcPort, 10, 16) + dstPortGo, _ := strconv.ParseUint(mtd.DstPort, 10, 16) + srcPort := C.ushort(srcPortGo) + dstPort := C.ushort(dstPortGo) + + dstIpGo := "" + srcIpGo := "" + if mtd.SrcIP != nil { + srcIpGo = mtd.SrcIP.String() + } + if mtd.DstIP != nil { + dstIpGo = mtd.DstIP.String() + } + srcIp := C.CString(srcIpGo) + dstIp := C.CString(dstIpGo) + + defer func() { + C.free(unsafe.Pointer(_type)) + C.free(unsafe.Pointer(network)) + C.free(unsafe.Pointer(processName)) + C.free(unsafe.Pointer(host)) + C.free(unsafe.Pointer(srcIp)) + C.free(unsafe.Pointer(dstIp)) + }() + + runtime.LockOSThread() + gilState := PyGILState_Ensure() + defer PyGILState_Release(gilState) + + cRs := C.call_shortcut(toc(fn), _type, network, processName, host, srcIp, srcPort, dstIp, dstPort) + + rs := int(cRs) + if rs == -1 { + err := PyLastError() + if err != nil { + log.Errorln("[Script] script shortcut code error: %s", err.Error()) + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + return false, fmt.Errorf("script shortcut code error: %w", err) + } else { + return false, fmt.Errorf("script shortcut code error: result: %d", rs) + } + } + + if rs == 1 { + return true, nil + } else { + return false, nil + } +} + +func initPython3Callback() { + C.go_set_resolve_ip_callback() + C.go_set_geoip_callback() + C.go_set_rule_provider_callback() + C.go_set_log_callback() +} + +//NewClashPyContext new clash context for python +func NewClashPyContext(ruleProvidersName []string) error { + cStringArr := make([]*C.char, len(ruleProvidersName)) + for i, v := range ruleProvidersName { + cStringArr[i] = C.CString(v) + defer C.free(unsafe.Pointer(cStringArr[i])) + } + + rs := int(C.new_clash_py_context((**C.char)(unsafe.Pointer(&cStringArr[0])), C.int(len(ruleProvidersName)))) + + if rs == 0 { + err := PyLastError() + return fmt.Errorf("new script module context failure: %s", err.Error()) + } + + return nil +} diff --git a/component/script/clash_module.h b/component/script/clash_module.h new file mode 100644 index 00000000..4610fecc --- /dev/null +++ b/component/script/clash_module.h @@ -0,0 +1,61 @@ +#ifndef CLASH_CALLBACK_MODULE_H__ +#define CLASH_CALLBACK_MODULE_H__ + +#include + +#define CLASH_SCRIPT_MODULE_NAME "clash_script" + +struct Metadata { + const char *type; /* type socks5/http */ + const char *network; /* network tcp/udp */ + const char *process_name; + const char *host; + const char *src_ip; + unsigned short src_port; + const char *dst_ip; + unsigned short dst_port; +}; + +/** callback function, that call go function by python3 script. **/ +typedef const char *(*resolve_ip_callback)(const char *host); +typedef const char *(*geoip_callback)(const char *ip); +typedef const int (*rule_provider_callback)(const char *provider_name, struct Metadata *metadata); +typedef void (*log_callback)(const char *msg); + +void set_resolve_ip_callback(resolve_ip_callback cb); +void set_geoip_callback(geoip_callback cb); +void set_rule_provider_callback(rule_provider_callback cb); +void set_log_callback(log_callback cb); +/*---------------------------------------------------------------*/ + +void append_inittab(); +void init_python(const char *path); +void finalize_Python(); +void py_clear(PyObject *obj); +const char *py_last_error(); + +PyObject *load_func(const char *module_name, char *func_name); + +int new_clash_py_context(const char *provider_name_arr[], int size); + +const char *call_main( + const char *type, + const char *network, + const char *process_name, + const char *host, + const char *src_ip, + unsigned short src_port, + const char *dst_ip, + unsigned short dst_port); + +int call_shortcut(PyObject *shortcut_fn, + const char *type, + const char *network, + const char *process_name, + const char *host, + const char *src_ip, + unsigned short src_port, + const char *dst_ip, + unsigned short dst_port); + +#endif // CLASH_CALLBACK_MODULE_H__ \ No newline at end of file diff --git a/component/script/clash_module_export.go b/component/script/clash_module_export.go new file mode 100644 index 00000000..0857c4b4 --- /dev/null +++ b/component/script/clash_module_export.go @@ -0,0 +1,124 @@ +package script + +/* +#include "clash_module.h" +*/ +import "C" +import ( + "net" + "strconv" + "strings" + "unsafe" + + "github.com/Dreamacro/clash/component/mmdb" + "github.com/Dreamacro/clash/component/resolver" + "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +var ( + ruleProviders = map[string]constant.Rule{} + pyThreadState *PyThreadState +) + +func UpdateRuleProviders(rpd map[string]constant.Rule) { + ruleProviders = rpd + if Py_IsInitialized() { + pyThreadState = PyEval_SaveThread() + } +} + +//export resolveIPCallbackFn +func resolveIPCallbackFn(cHost *C.char) *C.char { + host := C.GoString(cHost) + if len(host) == 0 { + cip := C.CString("") + defer C.free(unsafe.Pointer(cip)) + return cip + } + if ip, err := resolver.ResolveIP(host); err == nil { + cip := C.CString(ip.String()) + defer C.free(unsafe.Pointer(cip)) + return cip + } else { + log.Errorln("[Script] resolve ip error: %s", err.Error()) + cip := C.CString("") + defer C.free(unsafe.Pointer(cip)) + return cip + } +} + +//export geoipCallbackFn +func geoipCallbackFn(cIP *C.char) *C.char { + dstIP := net.ParseIP(C.GoString(cIP)) + + if dstIP == nil { + emptyC := C.CString("") + defer C.free(unsafe.Pointer(emptyC)) + + return emptyC + } + + if dstIP.IsPrivate() || constant.TunBroadcastAddr.Equal(dstIP) { + lanC := C.CString("LAN") + defer C.free(unsafe.Pointer(lanC)) + + return lanC + } + + record, _ := mmdb.Instance().Country(dstIP) + + rc := C.CString(strings.ToUpper(record.Country.IsoCode)) + defer C.free(unsafe.Pointer(rc)) + + return rc +} + +//export ruleProviderCallbackFn +func ruleProviderCallbackFn(cProviderName *C.char, cMetadata *C.struct_Metadata) C.int { + //_type := C.GoString(cMetadata._type) + //network := C.GoString(cMetadata.network) + processName := C.GoString(cMetadata.process_name) + host := C.GoString(cMetadata.host) + srcIp := C.GoString(cMetadata.src_ip) + srcPort := strconv.Itoa(int(cMetadata.src_port)) + dstIp := C.GoString(cMetadata.dst_ip) + dstPort := strconv.Itoa(int(cMetadata.dst_port)) + + metadata := &constant.Metadata{ + Process: processName, + SrcIP: net.ParseIP(srcIp), + DstIP: net.ParseIP(dstIp), + SrcPort: srcPort, + DstPort: dstPort, + Host: host, + } + + providerName := C.GoString(cProviderName) + + rule, ok := ruleProviders[providerName] + if !ok { + log.Warnln("rule provider [%s] not found", providerName) + return C.int(0) + } + + if strings.HasPrefix(providerName, "geosite:") { + if len(host) == 0 { + return C.int(0) + } + metadata.AddrType = constant.AtypDomainName + } + + rs := rule.Match(metadata) + + if rs { + return C.int(1) + } + return C.int(0) +} + +//export logCallbackFn +func logCallbackFn(msg *C.char) { + + log.Infoln(C.GoString(msg)) +} diff --git a/component/script/thread.go b/component/script/thread.go new file mode 100644 index 00000000..8b7735c0 --- /dev/null +++ b/component/script/thread.go @@ -0,0 +1,52 @@ +package script + +/* +#include "Python.h" +*/ +import "C" + +//PyThreadState : https://docs.python.org/3/c-api/init.html#c.PyThreadState +type PyThreadState C.PyThreadState + +//PyGILState is an opaque “handle” to the thread state when PyGILState_Ensure() was called, and must be passed to PyGILState_Release() to ensure Python is left in the same state +type PyGILState C.PyGILState_STATE + +//PyEval_SaveThread : https://docs.python.org/3/c-api/init.html#c.PyEval_SaveThread +func PyEval_SaveThread() *PyThreadState { + return (*PyThreadState)(C.PyEval_SaveThread()) +} + +//PyEval_RestoreThread : https://docs.python.org/3/c-api/init.html#c.PyEval_RestoreThread +func PyEval_RestoreThread(tstate *PyThreadState) { + C.PyEval_RestoreThread((*C.PyThreadState)(tstate)) +} + +//PyThreadState_Get : https://docs.python.org/3/c-api/init.html#c.PyThreadState_Get +func PyThreadState_Get() *PyThreadState { + return (*PyThreadState)(C.PyThreadState_Get()) +} + +//PyThreadState_Swap : https://docs.python.org/3/c-api/init.html#c.PyThreadState_Swap +func PyThreadState_Swap(tstate *PyThreadState) *PyThreadState { + return (*PyThreadState)(C.PyThreadState_Swap((*C.PyThreadState)(tstate))) +} + +//PyGILState_Ensure : https://docs.python.org/3/c-api/init.html#c.PyGILState_Ensure +func PyGILState_Ensure() PyGILState { + return PyGILState(C.PyGILState_Ensure()) +} + +//PyGILState_Release : https://docs.python.org/3/c-api/init.html#c.PyGILState_Release +func PyGILState_Release(state PyGILState) { + C.PyGILState_Release(C.PyGILState_STATE(state)) +} + +//PyGILState_GetThisThreadState : https://docs.python.org/3/c-api/init.html#c.PyGILState_GetThisThreadState +func PyGILState_GetThisThreadState() *PyThreadState { + return (*PyThreadState)(C.PyGILState_GetThisThreadState()) +} + +//PyGILState_Check : https://docs.python.org/3/c-api/init.html#c.PyGILState_Check +func PyGILState_Check() bool { + return C.PyGILState_Check() == 1 +} diff --git a/config/config.go b/config/config.go index 15baf641..66d64149 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,7 @@ import ( "net" "net/url" "os" + "regexp" "runtime" "strings" @@ -15,6 +16,7 @@ import ( "github.com/Dreamacro/clash/adapter/provider" "github.com/Dreamacro/clash/component/auth" "github.com/Dreamacro/clash/component/fakeip" + S "github.com/Dreamacro/clash/component/script" "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" providerTypes "github.com/Dreamacro/clash/constant/provider" @@ -92,21 +94,28 @@ type Tun struct { AutoRoute bool `yaml:"auto-route" json:"auto-route"` } +// Script config +type Script struct { + MainCode string `yaml:"code" json:"code"` + ShortcutsCode map[string]string `yaml:"shortcuts" json:"shortcuts"` +} + // Experimental config type Experimental struct{} // Config is clash config manager type Config struct { - General *General - Tun *Tun - DNS *DNS - Experimental *Experimental - Hosts *trie.DomainTrie - Profile *Profile - Rules []C.Rule - Users []auth.AuthUser - Proxies map[string]C.Proxy - Providers map[string]providerTypes.ProxyProvider + General *General + Tun *Tun + DNS *DNS + Experimental *Experimental + Hosts *trie.DomainTrie + Profile *Profile + Rules []C.Rule + RuleProviders map[string]C.Rule + Users []auth.AuthUser + Proxies map[string]C.Proxy + Providers map[string]providerTypes.ProxyProvider } type RawDNS struct { @@ -157,6 +166,7 @@ type RawConfig struct { Proxy []map[string]interface{} `yaml:"proxies"` ProxyGroup []map[string]interface{} `yaml:"proxy-groups"` Rule []string `yaml:"rules"` + Script Script `yaml:"script"` } // Parse config @@ -204,6 +214,10 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { Profile: Profile{ StoreSelected: true, }, + Script: Script{ + MainCode: "", + ShortcutsCode: map[string]string{}, + }, } if err := yaml.Unmarshal(buf, rawCfg); err != nil { @@ -232,11 +246,17 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { config.Proxies = proxies config.Providers = providers - rules, err := parseRules(rawCfg, proxies) + err = parseScript(rawCfg) + if err != nil { + return nil, err + } + + rules, ruleProviders, err := parseRules(rawCfg, proxies) if err != nil { return nil, err } config.Rules = rules + config.RuleProviders = ruleProviders hosts, err := parseHosts(rawCfg) if err != nil { @@ -396,9 +416,80 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[ return proxies, providersMap, nil } -func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, error) { +func parseScript(cfg *RawConfig) error { + mode := cfg.Mode + script := cfg.Script + mainCode := cleanPyKeywords(script.MainCode) + shortcutsCode := script.ShortcutsCode + + if mode != T.Script && len(shortcutsCode) == 0 { + return nil + } else if mode == T.Script && len(mainCode) == 0 { + return fmt.Errorf("initialized script module failure, can't find script code in the config file") + } + + content := + `# -*- coding: UTF-8 -*- + +from datetime import datetime as whatever + +class ClashTime: + def now(self): + return whatever.now() + + def unix(self): + return int(whatever.now().timestamp()) + + def unix_nano(self): + return int(round(whatever.now().timestamp() * 1000)) + +time = ClashTime() + +` + + var shouldInitPy bool + if mode == T.Script { + content += mainCode + "\n\n" + shouldInitPy = true + } + + for k, v := range shortcutsCode { + v = cleanPyKeywords(v) + v = strings.TrimSpace(v) + if len(v) == 0 { + return fmt.Errorf("initialized rule SCRIPT failure, shortcut [%s] code invalid syntax", k) + } + + content += "def " + strings.ToLower(k) + "(ctx, network, process_name, host, src_ip, src_port, dst_ip, dst_port):\n return " + v + "\n\n" + shouldInitPy = true + } + + if !shouldInitPy { + return nil + } + + err := os.WriteFile(C.Path.Script(), []byte(content), 0644) + if err != nil { + return fmt.Errorf("initialized script module failure, %s", err.Error()) + } + + if err = S.Py_Initialize(C.Path.ScriptDir()); err != nil { + return fmt.Errorf("initialized script module failure, %s", err.Error()) + } else { + log.Infoln("Start initial script module successful") + } + + return nil +} + +func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[string]C.Rule, error) { rules := []C.Rule{} + ruleProviders := map[string]C.Rule{} rulesConfig := cfg.Rule + mode := cfg.Mode + + providerNames := []string{} + isPyInit := S.Py_IsInitialized() // parse rules for idx, line := range rulesConfig { @@ -410,6 +501,10 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, error) { ruleName = strings.ToUpper(rule[0]) ) + if mode == T.Script && ruleName != "GEOSITE" { + continue + } + switch l := len(rule); { case l == 2: target = rule[1] @@ -427,11 +522,11 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, error) { target = rule[2] params = rule[3:] default: - return nil, fmt.Errorf("rules[%d] [%s] error: format invalid", idx, line) + return nil, nil, fmt.Errorf("rules[%d] [%s] error: format invalid", idx, line) } - if _, ok := proxies[target]; !ok { - return nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target) + if _, ok := proxies[target]; mode != T.Script && !ok { + return nil, nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target) } //rule = trimArr(rule) @@ -439,15 +534,34 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, error) { parsed, parseErr := R.ParseRule(ruleName, payload, target, params) if parseErr != nil { - return nil, fmt.Errorf("rules[%d] [%s] error: %s", idx, line, parseErr.Error()) + return nil, nil, fmt.Errorf("rules[%d] [%s] error: %s", idx, line, parseErr.Error()) } - rules = append(rules, parsed) + if isPyInit { + if ruleName == "GEOSITE" { + pvName := "geosite:" + strings.ToLower(payload) + providerNames = append(providerNames, pvName) + ruleProviders[pvName] = parsed + } + } + + if mode != T.Script { + rules = append(rules, parsed) + } } runtime.GC() - return rules, nil + if isPyInit { + err := S.NewClashPyContext(providerNames) + if err != nil { + return nil, nil, err + } else { + log.Infoln("Start initial script context successful") + } + } + + return rules, ruleProviders, nil } func parseHosts(cfg *RawConfig) (*trie.DomainTrie, error) { @@ -657,3 +771,16 @@ func parseAuthentication(rawRecords []string) []auth.AuthUser { } return users } + +func cleanPyKeywords(code string) string { + if len(code) == 0 { + return code + } + keywords := []string{"import", "print"} + + for _, kw := range keywords { + reg := regexp.MustCompile("(?m)[\r\n]+^.*" + kw + ".*$") + code = reg.ReplaceAllString(code, "") + } + return code +} diff --git a/config/initial.go b/config/initial.go index 79ae1878..36203d45 100644 --- a/config/initial.go +++ b/config/initial.go @@ -6,18 +6,19 @@ import ( "net/http" "os" + "github.com/Dreamacro/clash/component/mmdb" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" ) -func downloadGeoIP(path string) (err error) { - resp, err := http.Get("https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geoip.dat") +func downloadMMDB(path string) (err error) { + resp, err := http.Get("https://cdn.jsdelivr.net/gh/Loyalsoldier/geoip@release/Country.mmdb") if err != nil { return } defer resp.Body.Close() - f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return err } @@ -27,6 +28,45 @@ func downloadGeoIP(path string) (err error) { return err } +func initMMDB() error { + if _, err := os.Stat(C.Path.MMDB()); os.IsNotExist(err) { + log.Infoln("Can't find MMDB, start download") + if err := downloadMMDB(C.Path.MMDB()); err != nil { + return fmt.Errorf("can't download MMDB: %s", err.Error()) + } + } + + if !mmdb.Verify() { + log.Warnln("MMDB invalid, remove and download") + if err := os.Remove(C.Path.MMDB()); err != nil { + return fmt.Errorf("can't remove invalid MMDB: %s", err.Error()) + } + + if err := downloadMMDB(C.Path.MMDB()); err != nil { + return fmt.Errorf("can't download MMDB: %s", err.Error()) + } + } + + return nil +} + +//func downloadGeoIP(path string) (err error) { +// resp, err := http.Get("https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geoip.dat") +// if err != nil { +// return +// } +// defer resp.Body.Close() +// +// f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644) +// if err != nil { +// return err +// } +// defer f.Close() +// _, err = io.Copy(f, resp.Body) +// +// return err +//} + func downloadGeoSite(path string) (err error) { resp, err := http.Get("https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geosite.dat") if err != nil { @@ -44,17 +84,18 @@ func downloadGeoSite(path string) (err error) { return err } -func initGeoIP() error { - if _, err := os.Stat(C.Path.GeoIP()); os.IsNotExist(err) { - log.Infoln("Can't find GeoIP.dat, start download") - if err := downloadGeoIP(C.Path.GeoIP()); err != nil { - return fmt.Errorf("can't download GeoIP.dat: %s", err.Error()) - } - log.Infoln("Download GeoIP.dat finish") - } - - return nil -} +// +//func initGeoIP() error { +// if _, err := os.Stat(C.Path.GeoIP()); os.IsNotExist(err) { +// log.Infoln("Can't find GeoIP.dat, start download") +// if err := downloadGeoIP(C.Path.GeoIP()); err != nil { +// return fmt.Errorf("can't download GeoIP.dat: %s", err.Error()) +// } +// log.Infoln("Download GeoIP.dat finish") +// } +// +// return nil +//} func initGeoSite() error { if _, err := os.Stat(C.Path.GeoSite()); os.IsNotExist(err) { @@ -88,9 +129,14 @@ func Init(dir string) error { f.Close() } - // initial GeoIP - if err := initGeoIP(); err != nil { - return fmt.Errorf("can't initial GeoIP: %w", err) + //// initial GeoIP + //if err := initGeoIP(); err != nil { + // return fmt.Errorf("can't initial GeoIP: %w", err) + //} + + // initial mmdb + if err := initMMDB(); err != nil { + return fmt.Errorf("can't initial MMDB: %w", err) } // initial GeoSite diff --git a/constant/path.go b/constant/path.go index 348dcd34..54f06667 100644 --- a/constant/path.go +++ b/constant/path.go @@ -22,6 +22,7 @@ var Path = func() *path { type path struct { homeDir string configFile string + scriptDir string } // SetHomeDir is used to set the configuration path @@ -67,6 +68,23 @@ func (p *path) GeoSite() string { return P.Join(p.homeDir, "geosite.dat") } +func (p *path) ScriptDir() string { + if len(p.scriptDir) != 0 { + return p.scriptDir + } + if dir, err := os.MkdirTemp("", Name+"-"); err == nil { + p.scriptDir = dir + } else { + p.scriptDir = P.Join(os.TempDir(), Name) + os.MkdirAll(p.scriptDir, 0644) + } + return p.scriptDir +} + +func (p *path) Script() string { + return P.Join(p.ScriptDir(), "clash_script.py") +} + func (p *path) GetAssetLocation(file string) string { return P.Join(p.homeDir, file) } diff --git a/constant/rule.go b/constant/rule.go index f36a3b51..87ae2a32 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -1,5 +1,7 @@ package constant +import "net" + // Rule Type const ( Domain RuleType = iota @@ -12,6 +14,7 @@ const ( SrcPort DstPort Process + Script MATCH ) @@ -39,6 +42,8 @@ func (rt RuleType) String() string { return "DstPort" case Process: return "Process" + case Script: + return "Script" case MATCH: return "Match" default: @@ -54,3 +59,5 @@ type Rule interface { ShouldResolveIP() bool RuleExtra() *RuleExtra } + +var TunBroadcastAddr = net.IPv4(198, 18, 255, 255) diff --git a/dns/filters.go b/dns/filters.go index c7c56f8a..26dc684f 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -2,16 +2,13 @@ package dns import ( "net" + "strings" + "github.com/Dreamacro/clash/component/mmdb" "github.com/Dreamacro/clash/component/trie" - "github.com/Dreamacro/clash/log" - "github.com/Dreamacro/clash/rule/geodata" - "github.com/Dreamacro/clash/rule/geodata/router" - _ "github.com/Dreamacro/clash/rule/geodata/standard" + //_ "github.com/Dreamacro/clash/rule/geodata/standard" ) -var multiGeoIPMatcher *router.MultiGeoIPMatcher - type fallbackIPFilter interface { Match(net.IP) bool } @@ -21,49 +18,8 @@ type geoipFilter struct { } func (gf *geoipFilter) Match(ip net.IP) bool { - if multiGeoIPMatcher == nil { - countryCode := gf.code - countryCodePrivate := "private" - geoLoader, err := geodata.GetGeoDataLoader("standard") - if err != nil { - log.Errorln("[GeoIPFilter] GetGeoDataLoader error: %s", err.Error()) - return false - } - - recordsCN, err := geoLoader.LoadGeoIP(countryCode) - if err != nil { - log.Errorln("[GeoIPFilter] LoadGeoIP error: %s", err.Error()) - return false - } - - recordsPrivate, err := geoLoader.LoadGeoIP(countryCodePrivate) - if err != nil { - log.Errorln("[GeoIPFilter] LoadGeoIP error: %s", err.Error()) - return false - } - - geoips := []*router.GeoIP{ - { - CountryCode: countryCode, - Cidr: recordsCN, - ReverseMatch: false, - }, - { - CountryCode: countryCodePrivate, - Cidr: recordsPrivate, - ReverseMatch: false, - }, - } - - multiGeoIPMatcher, err = router.NewMultiGeoIPMatcher(geoips) - - if err != nil { - log.Errorln("[GeoIPFilter] NewMultiGeoIPMatcher error: %s", err.Error()) - return false - } - } - - return !multiGeoIPMatcher.ApplyIp(ip) + record, _ := mmdb.Instance().Country(ip) + return !strings.EqualFold(record.Country.IsoCode, gf.code) && !ip.IsPrivate() } type ipnetFilter struct { diff --git a/go.mod b/go.mod index d34a4e74..4224be32 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/insomniacslk/dhcp v0.0.0-20210827173440-b95caade3eac github.com/kr328/tun2socket v0.0.0-20210412191540-3d56c47e2d99 github.com/miekg/dns v1.1.43 + github.com/oschwald/geoip2-golang v1.5.0 github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.7.0 github.com/xtls/go v0.0.0-20201118062508-3632bf3b7499 @@ -24,13 +25,17 @@ require ( google.golang.org/protobuf v1.27.1 gopkg.in/yaml.v2 v2.4.0 gvisor.dev/gvisor v0.0.0-20210922003438-b39716d116fd + inet.af/netaddr v0.0.0-20210903134321-85fa6c94624e ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.0.1 // indirect + github.com/oschwald/maxminddb-golang v1.8.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/u-root/uio v0.0.0-20210528114334-82958018845c // indirect + go4.org/intern v0.0.0-20210108033219-3eb7198706b2 // indirect + go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 // indirect golang.org/x/text v0.3.6 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect diff --git a/go.sum b/go.sum index 197f76c3..1f45498d 100644 --- a/go.sum +++ b/go.sum @@ -186,6 +186,7 @@ github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/dvyukov/go-fuzz v0.0.0-20210103155950-6a8e9d1f2415/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw= github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= @@ -463,6 +464,10 @@ github.com/opencontainers/runtime-spec v1.0.2-0.20190207185410-29686dbc5559/go.m github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-tools v0.0.0-20181011054405-1d69bd0f9c39/go.mod h1:r3f7wjNzSs2extwzU3Y+6pKfobzPh+kKFJ3ofN+3nfs= github.com/opencontainers/selinux v1.8.0/go.mod h1:RScLhm78qiWa2gbVCcGkC7tCGdgk3ogry1nUQF8Evvo= +github.com/oschwald/geoip2-golang v1.5.0 h1:igg2yQIrrcRccB1ytFXqBfOHCjXWIoMv85lVJ1ONZzw= +github.com/oschwald/geoip2-golang v1.5.0/go.mod h1:xdvYt5xQzB8ORWFqPnqMwZpCpgNagttWdoZLlJQzg7s= +github.com/oschwald/maxminddb-golang v1.8.0 h1:Uh/DSnGoxsyp/KYbY1AuP0tYEwfs0sCph9p/UMXK/Hk= +github.com/oschwald/maxminddb-golang v1.8.0/go.mod h1:RXZtst0N6+FY/3qCNmZMBApR19cdQj43/NM9VkrNAis= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= @@ -596,6 +601,11 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go4.org/intern v0.0.0-20210108033219-3eb7198706b2 h1:VFTf+jjIgsldaz/Mr00VaCSswHJrI2hIjQygE/W4IMg= +go4.org/intern v0.0.0-20210108033219-3eb7198706b2/go.mod h1:vLqJ+12kCw61iCWsPto0EOHhBS+o4rO5VIucbc9g2Cc= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222175341-b30ae309168e/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 h1:1tk03FUNpulq2cuWpXZWj649rwJpk0d20rxWiopKRmc= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181009213950-7c1a557ab941/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -752,6 +762,7 @@ golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1014,6 +1025,8 @@ honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.1.1/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= +inet.af/netaddr v0.0.0-20210903134321-85fa6c94624e h1:tvgqez5ZQoBBiBAGNU/fmJy247yB/7++kcLOEoMYup0= +inet.af/netaddr v0.0.0-20210903134321-85fa6c94624e/go.mod h1:z0nx+Dh+7N7CC8V5ayHtHGpZpxLQZZxkIaaz6HN65Ls= k8s.io/api v0.16.13/go.mod h1:QWu8UWSTiuQZMMeYjwLs6ILu5O74qKSJ0c+4vrchDxs= k8s.io/apimachinery v0.16.13/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ= k8s.io/apimachinery v0.16.14-rc.0/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ= diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 25ad9bf3..137babaf 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -18,6 +18,7 @@ import ( "github.com/Dreamacro/clash/component/profile" "github.com/Dreamacro/clash/component/profile/cachefile" "github.com/Dreamacro/clash/component/resolver" + S "github.com/Dreamacro/clash/component/script" "github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/config" C "github.com/Dreamacro/clash/constant" @@ -68,6 +69,9 @@ func ParseWithPath(path string) (*config.Config, error) { // ParseWithBytes config with buffer func ParseWithBytes(buf []byte) (*config.Config, error) { + mux.Lock() + defer mux.Unlock() + return config.Parse(buf) } @@ -79,6 +83,7 @@ func ApplyConfig(cfg *config.Config, force bool) { updateUsers(cfg.Users) updateProxies(cfg.Proxies, cfg.Providers) updateRules(cfg.Rules) + updateRuleProviders(cfg.RuleProviders) updateHosts(cfg.Hosts) updateProfile(cfg) updateIPTables(cfg.DNS, cfg.General) @@ -179,6 +184,10 @@ func updateRules(rules []C.Rule) { tunnel.UpdateRules(rules) } +func updateRuleProviders(providers map[string]C.Rule) { + S.UpdateRuleProviders(providers) +} + func updateGeneral(general *config.General, force bool) { tunnel.SetMode(general.Mode) resolver.DisableIPv6 = !general.IPv6 @@ -321,4 +330,6 @@ func CleanUp() { if runtime.GOOS == "linux" { tproxy.CleanUpTProxyLinuxIPTables() } + + S.Py_Finalize() } diff --git a/listener/tun/ipstack/lwip/tcp.go b/listener/tun/ipstack/lwip/tcp.go index 0c668df6..f2771d26 100644 --- a/listener/tun/ipstack/lwip/tcp.go +++ b/listener/tun/ipstack/lwip/tcp.go @@ -34,27 +34,22 @@ func (h *tcpHandler) Handle(conn net.Conn, target *net.TCPAddr) error { return nil } - //if err := conn.SetDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil { - // _ = conn.Close() - // return nil - //} - src, _ := conn.LocalAddr().(*net.TCPAddr) dst, _ := conn.RemoteAddr().(*net.TCPAddr) - addrType := C.AtypIPv4 - if dst.IP.To4() == nil { - addrType = C.AtypIPv6 - } + //addrType := C.AtypIPv4 + //if dst.IP.To4() == nil { + // addrType = C.AtypIPv6 + //} metadata := &C.Metadata{ - NetWork: C.TCP, - Type: C.TUN, - SrcIP: src.IP, - DstIP: dst.IP, - SrcPort: strconv.Itoa(src.Port), - DstPort: strconv.Itoa(dst.Port), - AddrType: addrType, - Host: "", + NetWork: C.TCP, + Type: C.TUN, + SrcIP: src.IP, + DstIP: dst.IP, + SrcPort: strconv.Itoa(src.Port), + DstPort: strconv.Itoa(dst.Port), + //AddrType: addrType, + Host: "", } go func(conn net.Conn, metadata *C.Metadata) { diff --git a/main.go b/main.go index 79b48804..7691aded 100644 --- a/main.go +++ b/main.go @@ -45,7 +45,7 @@ func init() { func main() { if version { - fmt.Printf("Clash %s %s %s with %s %s\n", C.Version, runtime.GOOS, runtime.GOARCH, runtime.Version(), C.BuildTime) + fmt.Printf("Clash Plus Pro %s %s %s with %s %s\n", C.Version, runtime.GOOS, runtime.GOARCH, runtime.Version(), C.BuildTime) return } diff --git a/rule/geodata/memconservative/cache.go b/rule/geodata/memconservative/cache.go index 3a1ac7e6..d66438b4 100644 --- a/rule/geodata/memconservative/cache.go +++ b/rule/geodata/memconservative/cache.go @@ -2,7 +2,7 @@ package memconservative import ( "fmt" - "io/ioutil" + "os" "strings" C "github.com/Dreamacro/clash/constant" @@ -54,7 +54,7 @@ func (g GeoIPCache) Unmarshal(filename, code string) (*router.GeoIP, error) { case errFailedToReadBytes, errFailedToReadExpectedLenBytes, errInvalidGeodataFile, errInvalidGeodataVarintLength: log.Warnln("failed to decode geoip file: %s%s", filename, ", fallback to the original ReadFile method") - geoipBytes, err = ioutil.ReadFile(asset) + geoipBytes, err = os.ReadFile(asset) if err != nil { return nil, err } @@ -119,7 +119,7 @@ func (g GeoSiteCache) Unmarshal(filename, code string) (*router.GeoSite, error) case errFailedToReadBytes, errFailedToReadExpectedLenBytes, errInvalidGeodataFile, errInvalidGeodataVarintLength: log.Warnln("failed to decode geoip file: %s%s", filename, ", fallback to the original ReadFile method") - geositeBytes, err = ioutil.ReadFile(asset) + geositeBytes, err = os.ReadFile(asset) if err != nil { return nil, err } diff --git a/rule/geodata/router/condition_geoip.go b/rule/geodata/router/condition_geoip.go index 5a4bb5ca..f9341051 100644 --- a/rule/geodata/router/condition_geoip.go +++ b/rule/geodata/router/condition_geoip.go @@ -1,120 +1,43 @@ package router import ( - "encoding/binary" "fmt" "net" - "sort" + + "inet.af/netaddr" ) -// CIDRList is an alias of []*CIDR to provide sort.Interface. -type CIDRList []*CIDR - -// Len implements sort.Interface. -func (l *CIDRList) Len() int { - return len(*l) -} - -// Less implements sort.Interface. -func (l *CIDRList) Less(i int, j int) bool { - ci := (*l)[i] - cj := (*l)[j] - - if len(ci.Ip) < len(cj.Ip) { - return true - } - - if len(ci.Ip) > len(cj.Ip) { - return false - } - - for k := 0; k < len(ci.Ip); k++ { - if ci.Ip[k] < cj.Ip[k] { - return true - } - - if ci.Ip[k] > cj.Ip[k] { - return false - } - } - - return ci.Prefix < cj.Prefix -} - -// Swap implements sort.Interface. -func (l *CIDRList) Swap(i int, j int) { - (*l)[i], (*l)[j] = (*l)[j], (*l)[i] -} - -type ipv6 struct { - a uint64 - b uint64 -} - type GeoIPMatcher struct { countryCode string reverseMatch bool - ip4 []uint32 - prefix4 []uint8 - ip6 []ipv6 - prefix6 []uint8 -} - -func normalize4(ip uint32, prefix uint8) uint32 { - return (ip >> (32 - prefix)) << (32 - prefix) -} - -func normalize6(ip ipv6, prefix uint8) ipv6 { - if prefix <= 64 { - ip.a = (ip.a >> (64 - prefix)) << (64 - prefix) - ip.b = 0 - } else { - ip.b = (ip.b >> (128 - prefix)) << (128 - prefix) - } - return ip + ip4 *netaddr.IPSet + ip6 *netaddr.IPSet } func (m *GeoIPMatcher) Init(cidrs []*CIDR) error { - ip4Count := 0 - ip6Count := 0 - + var builder4, builder6 netaddr.IPSetBuilder for _, cidr := range cidrs { - ip := cidr.Ip - switch len(ip) { - case 4: - ip4Count++ - case 16: - ip6Count++ - default: - return fmt.Errorf("unexpect ip length: %d", len(ip)) + netaddrIP, ok := netaddr.FromStdIP(net.IP(cidr.GetIp())) + if !ok { + return fmt.Errorf("invalid IP address %v", cidr) + } + ipPrefix := netaddr.IPPrefixFrom(netaddrIP, uint8(cidr.GetPrefix())) + switch { + case netaddrIP.Is4(): + builder4.AddPrefix(ipPrefix) + case netaddrIP.Is6(): + builder6.AddPrefix(ipPrefix) } } - cidrList := CIDRList(cidrs) - sort.Sort(&cidrList) - - m.ip4 = make([]uint32, 0, ip4Count) - m.prefix4 = make([]uint8, 0, ip4Count) - m.ip6 = make([]ipv6, 0, ip6Count) - m.prefix6 = make([]uint8, 0, ip6Count) - - for _, cidr := range cidrs { - ip := cidr.Ip - prefix := uint8(cidr.Prefix) - switch len(ip) { - case 4: - m.ip4 = append(m.ip4, normalize4(binary.BigEndian.Uint32(ip), prefix)) - m.prefix4 = append(m.prefix4, prefix) - case 16: - ip6 := ipv6{ - a: binary.BigEndian.Uint64(ip[0:8]), - b: binary.BigEndian.Uint64(ip[8:16]), - } - ip6 = normalize6(ip6, prefix) - - m.ip6 = append(m.ip6, ip6) - m.prefix6 = append(m.prefix6, prefix) - } + var err error + m.ip4, err = builder4.IPSet() + if err != nil { + return err + } + m.ip6, err = builder6.IPSet() + if err != nil { + return err } return nil @@ -124,91 +47,35 @@ func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) { m.reverseMatch = isReverseMatch } -func (m *GeoIPMatcher) match4(ip uint32) bool { - if len(m.ip4) == 0 { +func (m *GeoIPMatcher) match4(ip net.IP) bool { + nip, ok := netaddr.FromStdIP(ip) + if !ok { return false } - - if ip < m.ip4[0] { - return false - } - - size := uint32(len(m.ip4)) - l := uint32(0) - r := size - for l < r { - x := ((l + r) >> 1) - if ip < m.ip4[x] { - r = x - continue - } - - nip := normalize4(ip, m.prefix4[x]) - if nip == m.ip4[x] { - return true - } - - l = x + 1 - } - - return l > 0 && normalize4(ip, m.prefix4[l-1]) == m.ip4[l-1] + return m.ip4.Contains(nip) } -func less6(a ipv6, b ipv6) bool { - return a.a < b.a || (a.a == b.a && a.b < b.b) -} - -func (m *GeoIPMatcher) match6(ip ipv6) bool { - if len(m.ip6) == 0 { +func (m *GeoIPMatcher) match6(ip net.IP) bool { + nip, ok := netaddr.FromStdIP(ip) + if !ok { return false } - - if less6(ip, m.ip6[0]) { - return false - } - - size := uint32(len(m.ip6)) - l := uint32(0) - r := size - for l < r { - x := (l + r) / 2 - if less6(ip, m.ip6[x]) { - r = x - continue - } - - if normalize6(ip, m.prefix6[x]) == m.ip6[x] { - return true - } - - l = x + 1 - } - - return l > 0 && normalize6(ip, m.prefix6[l-1]) == m.ip6[l-1] + return m.ip6.Contains(nip) } // Match returns true if the given ip is included by the GeoIP. func (m *GeoIPMatcher) Match(ip net.IP) bool { + isMatched := false switch len(ip) { - case 4: - if m.reverseMatch { - return !m.match4(binary.BigEndian.Uint32(ip)) - } - return m.match4(binary.BigEndian.Uint32(ip)) - case 16: - if m.reverseMatch { - return !m.match6(ipv6{ - a: binary.BigEndian.Uint64(ip[0:8]), - b: binary.BigEndian.Uint64(ip[8:16]), - }) - } - return m.match6(ipv6{ - a: binary.BigEndian.Uint64(ip[0:8]), - b: binary.BigEndian.Uint64(ip[8:16]), - }) - default: - return false + case net.IPv4len: + isMatched = m.match4(ip) + case net.IPv6len: + isMatched = m.match6(ip) } + if m.reverseMatch { + return !isMatched + } + return isMatched } // GeoIPMatcherContainer is a container for GeoIPMatchers. It keeps unique copies of GeoIPMatcher by country code. @@ -219,7 +86,7 @@ type GeoIPMatcherContainer struct { // Add adds a new GeoIP set into the container. // If the country code of GeoIP is not empty, GeoIPMatcherContainer will try to find an existing one, instead of adding a new one. func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) { - if len(geoip.CountryCode) > 0 { + if geoip.CountryCode != "" { for _, m := range c.matchers { if m.countryCode == geoip.CountryCode && m.reverseMatch == geoip.ReverseMatch { return m, nil @@ -234,7 +101,7 @@ func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) { if err := m.Init(geoip.Cidr); err != nil { return nil, err } - if len(geoip.CountryCode) > 0 { + if geoip.CountryCode != "" { c.matchers = append(c.matchers, m) } return m, nil diff --git a/rule/geodata/standard/standard.go b/rule/geodata/standard/standard.go index 21e437a3..190b5bdc 100644 --- a/rule/geodata/standard/standard.go +++ b/rule/geodata/standard/standard.go @@ -2,7 +2,7 @@ package standard import ( "fmt" - "io/ioutil" + "io" "os" "strings" @@ -19,7 +19,7 @@ func ReadFile(path string) ([]byte, error) { } defer reader.Close() - return ioutil.ReadAll(reader) + return io.ReadAll(reader) } func ReadAsset(file string) ([]byte, error) { diff --git a/rule/geoip.go b/rule/geoip.go index f7c5f430..1169bb58 100644 --- a/rule/geoip.go +++ b/rule/geoip.go @@ -1,22 +1,21 @@ package rules import ( - "fmt" "strings" + "github.com/Dreamacro/clash/component/mmdb" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/log" - "github.com/Dreamacro/clash/rule/geodata" - "github.com/Dreamacro/clash/rule/geodata/router" - _ "github.com/Dreamacro/clash/rule/geodata/standard" + //"github.com/Dreamacro/clash/rule/geodata" + //"github.com/Dreamacro/clash/rule/geodata/router" + //_ "github.com/Dreamacro/clash/rule/geodata/standard" ) type GEOIP struct { - country string - adapter string - noResolveIP bool - ruleExtra *C.RuleExtra - geoIPMatcher *router.GeoIPMatcher + country string + adapter string + noResolveIP bool + ruleExtra *C.RuleExtra + //geoIPMatcher *router.GeoIPMatcher } func (g *GEOIP) RuleType() C.RuleType { @@ -29,10 +28,11 @@ func (g *GEOIP) Match(metadata *C.Metadata) bool { return false } - if strings.EqualFold(g.country, "LAN") { + if strings.EqualFold(g.country, "LAN") || C.TunBroadcastAddr.Equal(ip) { return ip.IsPrivate() } - return g.geoIPMatcher.Match(ip) + record, _ := mmdb.Instance().Country(ip) + return strings.EqualFold(record.Country.IsoCode, g.country) } func (g *GEOIP) Adapter() string { @@ -51,39 +51,43 @@ func (g *GEOIP) RuleExtra() *C.RuleExtra { return g.ruleExtra } +func (g *GEOIP) GetCountry() string { + return g.country +} + func NewGEOIP(country string, adapter string, noResolveIP bool, ruleExtra *C.RuleExtra) (*GEOIP, error) { - geoLoaderName := "standard" - //geoLoaderName := "memconservative" - geoLoader, err := geodata.GetGeoDataLoader(geoLoaderName) - if err != nil { - return nil, fmt.Errorf("[GeoIP] %s", err.Error()) - } - - records, err := geoLoader.LoadGeoIP(strings.ReplaceAll(country, "!", "")) - if err != nil { - return nil, fmt.Errorf("[GeoIP] %s", err.Error()) - } - - geoIP := &router.GeoIP{ - CountryCode: country, - Cidr: records, - ReverseMatch: strings.Contains(country, "!"), - } - - geoIPMatcher, err := router.NewGeoIPMatcher(geoIP) - - if err != nil { - return nil, fmt.Errorf("[GeoIP] %s", err.Error()) - } - - log.Infoln("Start initial GeoIP rule %s => %s, records: %d", country, adapter, len(records)) + //geoLoaderName := "standard" + ////geoLoaderName := "memconservative" + //geoLoader, err := geodata.GetGeoDataLoader(geoLoaderName) + //if err != nil { + // return nil, fmt.Errorf("load GeoIP data error, %s", err.Error()) + //} + // + //records, err := geoLoader.LoadGeoIP(strings.ReplaceAll(country, "!", "")) + //if err != nil { + // return nil, fmt.Errorf("load GeoIP data error, %s", err.Error()) + //} + // + //geoIP := &router.GeoIP{ + // CountryCode: country, + // Cidr: records, + // ReverseMatch: strings.Contains(country, "!"), + //} + // + //geoIPMatcher, err := router.NewGeoIPMatcher(geoIP) + // + //if err != nil { + // return nil, fmt.Errorf("load GeoIP data error, %s", err.Error()) + //} + // + //log.Infoln("Start initial GeoIP rule %s => %s, records: %d, reverse match: %v", country, adapter, len(records), geoIP.ReverseMatch) geoip := &GEOIP{ - country: country, - adapter: adapter, - noResolveIP: noResolveIP, - ruleExtra: ruleExtra, - geoIPMatcher: geoIPMatcher, + country: country, + adapter: adapter, + noResolveIP: noResolveIP, + ruleExtra: ruleExtra, + //geoIPMatcher: geoIPMatcher, } return geoip, nil diff --git a/rule/geosite.go b/rule/geosite.go index 9351b01e..d80027b7 100644 --- a/rule/geosite.go +++ b/rule/geosite.go @@ -52,12 +52,12 @@ func NewGEOSITE(country string, adapter string, ruleExtra *C.RuleExtra) (*GEOSIT //geoLoaderName := "memconservative" geoLoader, err := geodata.GetGeoDataLoader(geoLoaderName) if err != nil { - return nil, fmt.Errorf("[GeoSite] %s", err.Error()) + return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) } domains, err := geoLoader.LoadGeoSite(country) if err != nil { - return nil, fmt.Errorf("[GeoSite] %s", err.Error()) + return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) } //linear: linear algorithm @@ -66,7 +66,7 @@ func NewGEOSITE(country string, adapter string, ruleExtra *C.RuleExtra) (*GEOSIT //mph:minimal perfect hash algorithm matcher, err := router.NewMphMatcherGroup(domains) if err != nil { - return nil, fmt.Errorf("[GeoSite] %s", err.Error()) + return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) } log.Infoln("Start initial GeoSite rule %s => %s, records: %d", country, adapter, len(domains)) diff --git a/rule/parser.go b/rule/parser.go index fa73c2ae..60bffefb 100644 --- a/rule/parser.go +++ b/rule/parser.go @@ -40,6 +40,8 @@ func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { parsed, parseErr = NewPort(payload, target, false, ruleExtra) case "PROCESS-NAME": parsed, parseErr = NewProcess(payload, target, ruleExtra) + case "SCRIPT": + parsed, parseErr = NewScript(payload, target) case "MATCH": parsed = NewMatch(target, ruleExtra) default: diff --git a/rule/script.go b/rule/script.go new file mode 100644 index 00000000..e0546951 --- /dev/null +++ b/rule/script.go @@ -0,0 +1,72 @@ +package rules + +import ( + "fmt" + S "github.com/Dreamacro/clash/component/script" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + "runtime" + "strings" +) + +type Script struct { + shortcut string + adapter string + shortcutFunction *S.PyObject +} + +func (s *Script) RuleType() C.RuleType { + return C.Script +} + +func (s *Script) Match(metadata *C.Metadata) bool { + rs, err := S.CallPyShortcut(s.shortcutFunction, metadata) + if err != nil { + log.Errorln("[Script] match rule error: %s", err.Error()) + return false + } + + return rs +} + +func (s *Script) Adapter() string { + return s.adapter +} + +func (s *Script) Payload() string { + return s.shortcut +} + +func (s *Script) ShouldResolveIP() bool { + return false +} + +func (s *Script) RuleExtra() *C.RuleExtra { + return nil +} + +func NewScript(shortcut string, adapter string) (*Script, error) { + shortcut = strings.ToLower(shortcut) + if !S.Py_IsInitialized() { + return nil, fmt.Errorf("load script shortcut [%s] failure, can't find any shortcuts in the config file", shortcut) + } + + shortcutFunction, err := S.LoadShortcutFunction(shortcut) + if err != nil { + return nil, fmt.Errorf("can't find script shortcut [%s] in the config file", shortcut) + } + + obj := &Script{ + shortcut: shortcut, + adapter: adapter, + shortcutFunction: shortcutFunction, + } + + runtime.SetFinalizer(obj, func(s *Script) { + s.shortcutFunction.Clear() + }) + + log.Infoln("Start initial script shortcut rule %s => %s", shortcut, adapter) + + return obj, nil +} diff --git a/transport/vless/conn.go b/transport/vless/conn.go index 36e8918e..e6e6e34c 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "github.com/gofrs/uuid" @@ -87,7 +86,7 @@ func (vc *Conn) recvResponse() error { length := int64(buf[0]) if length != 0 { // addon data length > 0 - io.CopyN(ioutil.Discard, vc.Conn, length) // just discard + io.CopyN(io.Discard, vc.Conn, length) // just discard } return nil diff --git a/tunnel/mode.go b/tunnel/mode.go index 6e07a060..72d6c81c 100644 --- a/tunnel/mode.go +++ b/tunnel/mode.go @@ -13,6 +13,7 @@ var ( ModeMapping = map[string]TunnelMode{ Global.String(): Global, Rule.String(): Rule, + Script.String(): Script, Direct.String(): Direct, } ) @@ -20,6 +21,7 @@ var ( const ( Global TunnelMode = iota Rule + Script Direct ) @@ -63,6 +65,8 @@ func (m TunnelMode) String() string { return "global" case Rule: return "rule" + case Script: + return "script" case Direct: return "direct" default: diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 294b36a4..a947f50c 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -10,6 +10,7 @@ import ( "github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/component/nat" "github.com/Dreamacro/clash/component/resolver" + S "github.com/Dreamacro/clash/component/script" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/constant/provider" "github.com/Dreamacro/clash/context" @@ -34,8 +35,6 @@ var ( udpTimeout = 60 * time.Second preProcessCacheFinder, _ = R.NewProcess("", "", nil) - - tunBroadcastAddr = net.IPv4(198, 18, 255, 255) ) func init() { @@ -143,7 +142,7 @@ func preHandleMetadata(metadata *C.Metadata) error { // redir-host should lookup the hosts metadata.DstIP = node.Data.(net.IP) } - } else if resolver.IsFakeIP(metadata.DstIP) && !tunBroadcastAddr.Equal(metadata.DstIP) { + } else if resolver.IsFakeIP(metadata.DstIP) && !C.TunBroadcastAddr.Equal(metadata.DstIP) { return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) } } @@ -157,6 +156,8 @@ func resolveMetadata(ctx C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, r proxy = proxies["DIRECT"] case Global: proxy = proxies["GLOBAL"] + case Script: + proxy, err = matchScript(metadata) // Rule default: proxy, rule, err = match(metadata) @@ -235,7 +236,9 @@ func handleUDPConn(packet *inbound.PacketAdapter) { switch true { case rule != nil: - log.Infoln("[UDP] %s(%s) --> %s:%s match %s(%s) using %s", metadata.SourceAddress(), metadata.Process, metadata.RemoteAddress(), metadata.DstPort, rule.RuleType().String(), rule.Payload(), rawPc.Chains().String()) + log.Infoln("[UDP] %s(%s) --> %s match %s(%s) using %s", metadata.SourceAddress(), metadata.Process, metadata.RemoteAddress(), rule.RuleType().String(), rule.Payload(), rawPc.Chains().String()) + case mode == Script: + log.Infoln("[UDP] %s --> %s using SCRIPT %s", metadata.SourceAddress(), metadata.RemoteAddress(), rawPc.Chains().String()) case mode == Global: log.Infoln("[UDP] %s(%s) --> %s using GLOBAL", metadata.SourceAddress(), metadata.Process, metadata.RemoteAddress()) case mode == Direct: @@ -285,7 +288,9 @@ func handleTCPConn(ctx C.ConnContext) { switch true { case rule != nil: - log.Infoln("[TCP] %s(%s) --> %s:%s match %s(%s) using %s", metadata.SourceAddress(), metadata.Process, metadata.RemoteAddress(), metadata.DstPort, rule.RuleType().String(), rule.Payload(), remoteConn.Chains().String()) + log.Infoln("[TCP] %s(%s) --> %s match %s(%s) using %s", metadata.SourceAddress(), metadata.Process, metadata.RemoteAddress(), rule.RuleType().String(), rule.Payload(), remoteConn.Chains().String()) + case mode == Script: + log.Infoln("[TCP] %s(%s) --> %s using SCRIPT %s", metadata.SourceAddress(), metadata.Process, metadata.RemoteAddress(), remoteConn.Chains().String()) case mode == Global: log.Infoln("[TCP] %s(%s) --> %s using GLOBAL", metadata.SourceAddress(), metadata.Process, metadata.RemoteAddress()) case mode == Direct: @@ -357,3 +362,28 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { //return proxies["DIRECT"], nil, nil return proxies["REJECT"], nil, nil } + +func matchScript(metadata *C.Metadata) (C.Proxy, error) { + configMux.RLock() + defer configMux.RUnlock() + + if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { + ip := node.Data.(net.IP) + metadata.DstIP = ip + } + + // preset process name and cache it + preProcessCacheFinder.Match(metadata) + + adapter, err := S.CallPyMainFunction(metadata) + + if err != nil { + return nil, err + } + + if _, ok := proxies[adapter]; !ok { + return nil, fmt.Errorf("proxy [%s] not found by script", adapter) + } + + return proxies[adapter], nil +}