diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 211db5e0..09254c58 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -11,6 +11,7 @@ import ( func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { opt := &option{ interfaceName: DefaultInterface.Load(), + routingMark: int(DefaultRoutingMark.Load()), } for _, o := range DefaultOptions { @@ -58,6 +59,7 @@ func DialContext(ctx context.Context, network, address string, options ...Option func ListenPacket(ctx context.Context, network, address string, options ...Option) (net.PacketConn, error) { cfg := &option{ interfaceName: DefaultInterface.Load(), + routingMark: int(DefaultRoutingMark.Load()), } for _, o := range DefaultOptions { diff --git a/component/script/build_local.go b/component/script/build_local.go new file mode 100644 index 00000000..b3e5640e --- /dev/null +++ b/component/script/build_local.go @@ -0,0 +1,9 @@ +//go:build build_local +// +build build_local + +package script + +/* +#cgo pkg-config: python3-embed +*/ +import "C" diff --git a/component/script/build_xgo.go b/component/script/build_xgo.go new file mode 100644 index 00000000..4a066c90 --- /dev/null +++ b/component/script/build_xgo.go @@ -0,0 +1,25 @@ +//go:build !build_local && cgo +// +build !build_local,cgo + +package script + +/* +#cgo linux,amd64 pkg-config: python3-embed + +#cgo darwin,amd64 CFLAGS: -I/build/python/python-3.9.7-darwin-amd64/include/python3.9 +#cgo darwin,arm64 CFLAGS: -I/build/python/python-3.9.7-darwin-arm64/include/python3.9 +#cgo windows,amd64 CFLAGS: -I/build/python/python-3.9.7-windows-amd64/include -DMS_WIN64 +#cgo windows,386 CFLAGS: -I/build/python/python-3.9.7-windows-386/include +//#cgo linux,amd64 CFLAGS: -I/home/runner/work/clash/clash/bin/python/python-3.9.7-linux-amd64/include/python3.9 +//#cgo linux,arm64 CFLAGS: -I/build/python/python-3.9.7-linux-arm64/include/python3.9 +//#cgo linux,386 CFLAGS: -I/build/python/python-3.9.7-linux-386/include/python3.9 + +#cgo darwin,amd64 LDFLAGS: -L/build/python/python-3.9.7-darwin-amd64/lib -lpython3.9 -ldl -framework CoreFoundation +#cgo darwin,arm64 LDFLAGS: -L/build/python/python-3.9.7-darwin-arm64/lib -lpython3.9 -ldl -framework CoreFoundation +#cgo windows,amd64 LDFLAGS: -L/build/python/python-3.9.7-windows-amd64/lib -lpython39 -lpthread -lm +#cgo windows,386 LDFLAGS: -L/build/python/python-3.9.7-windows-386/lib -lpython39 -lpthread -lm +//#cgo linux,amd64 LDFLAGS: -L/home/runner/work/clash/clash/bin/python/python-3.9.7-linux-amd64/lib -lpython3.9 -lpthread -ldl -lutil -lm +//#cgo linux,arm64 LDFLAGS: -L/build/python/python-3.9.7-linux-arm64/lib -lpython3.9 -lpthread -ldl -lutil -lm +//#cgo linux,386 LDFLAGS: -L/build/python/python-3.9.7-linux-386/lib -lpython3.9 -lpthread -ldl -lutil -lm +*/ +import "C" diff --git a/component/script/clash_module.c b/component/script/clash_module.c new file mode 100644 index 00000000..f15f0191 --- /dev/null +++ b/component/script/clash_module.c @@ -0,0 +1,735 @@ +#define PY_SSIZE_T_CLEAN + +#include "clash_module.h" +#include + +PyObject *clash_module; +PyObject *main_fn; +PyObject *clash_context; + +// init_python +void init_python(const char *program, const char *path) { + +// Py_NoSiteFlag = 1; +// Py_FrozenFlag = 1; +// Py_IgnoreEnvironmentFlag = 1; +// Py_IsolatedFlag = 1; + + append_inittab(); + + wchar_t *programName = Py_DecodeLocale(program, NULL); + if (programName != NULL) { + Py_SetProgramName(programName); + PyMem_RawFree(programName); + } + +// wchar_t *newPath = Py_DecodeLocale(path, NULL); +// if (newPath != NULL) { +// Py_SetPath(newPath); +// PyMem_RawFree(newPath); +// } + +// Py_Initialize(); + Py_InitializeEx(0); + + char *pathPrefix = "import sys; sys.path.append('"; + char *pathSuffix = "')"; + char *newPath = (char *) malloc(strlen(pathPrefix) + strlen(path) + strlen(pathSuffix)); + sprintf(newPath, "%s%s%s", pathPrefix, path, pathSuffix); + + PyRun_SimpleString(newPath); + free(newPath); + + /* Optionally import the module; alternatively, + import can be deferred until the embedded script + imports it. */ + clash_module = PyImport_ImportModule("clash"); +} + +// Load function, same as "import module_name.func_name as obj" in Python +// Returns the function object or NULL if not found +PyObject *load_func(const char *module_name, char *func_name) { + // Import the module + PyObject *py_mod_name = PyUnicode_FromString(module_name); + if (py_mod_name == NULL) { + return NULL; + } + + PyObject *module = PyImport_Import(py_mod_name); + Py_DECREF(py_mod_name); + if (module == NULL) { + return NULL; + } + + // Get function, same as "getattr(module, func_name)" in Python + PyObject *func = PyObject_GetAttrString(module, func_name); + Py_DECREF(module); + return func; +} + +// Return last error as char *, NULL if there was no error +const char *py_last_error() { + PyObject *err = PyErr_Occurred(); + if (err == NULL) { + return NULL; + } + + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + + if (value == NULL) { + return NULL; + } + + PyObject *str = PyObject_Str(value); + const char *utf8 = PyUnicode_AsUTF8(str); + Py_DECREF(str); + PyErr_Clear(); + return utf8; +} + +void py_clear(PyObject *obj) { + Py_CLEAR(obj); +} + +void load_main_func() { + main_fn = load_func(CLASH_SCRIPT_MODULE_NAME, "main"); +} + +/** callback function, that call go function by python3 script. **/ + +resolve_ip_callback resolve_ip_callback_fn; + +geoip_callback geoip_callback_fn; + +rule_provider_callback rule_provider_callback_fn; + +log_callback log_callback_fn; + +void +set_resolve_ip_callback(resolve_ip_callback cb) +{ + resolve_ip_callback_fn = cb; +} + +void +set_geoip_callback(geoip_callback cb) +{ + geoip_callback_fn = cb; +} + +void +set_rule_provider_callback(rule_provider_callback cb) +{ + rule_provider_callback_fn = cb; +} + +void +set_log_callback(log_callback cb) +{ + log_callback_fn = cb; +} + +/** end callback function **/ + +/* --------------------------------------------------------------------- */ + +/* RuleProvider objects */ + +typedef struct { + PyObject_HEAD + PyObject *name; /* rule provider name */ +} RuleProviderObject; + +static int +RuleProvider_traverse(RuleProviderObject *self, visitproc visit, void *arg) +{ + Py_VISIT(self->name); + return 0; +} + +static int +RuleProvider_clear(RuleProviderObject *self) +{ + Py_CLEAR(self->name); + return 0; +} + +static void +RuleProvider_dealloc(RuleProviderObject *self) +{ + PyObject_GC_UnTrack(self); + RuleProvider_clear(self); + Py_TYPE(self)->tp_free((PyObject *) self); +} + +static PyObject * +RuleProvider_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + RuleProviderObject *self; + self = (RuleProviderObject *) type->tp_alloc(type, 0); + if (self != NULL) { + self->name = PyUnicode_FromString(""); + if (self->name == NULL) { + Py_DECREF(self); + return NULL; + } + } + return (PyObject *) self; +} + +static int +RuleProvider_init(RuleProviderObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"name", NULL}; + PyObject *name = NULL, *tmp; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|Us", kwlist, &name)) + return -1; + + if (name) { + tmp = self->name; + Py_INCREF(name); + self->name = name; + Py_DECREF(tmp); + } + return 0; +} + +//static PyMemberDef RuleProvider_members[] = { +// {"adapter_type", T_STRING, offsetof(RuleProviderObject, adapter_type), 0, +// "adapter type"}, +// {NULL} /* Sentinel */ +//}; + +static PyObject * +RuleProvider_getname(RuleProviderObject *self, void *closure) +{ + Py_INCREF(self->name); + return self->name; +} + +static int +RuleProvider_setname(RuleProviderObject *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, "Cannot delete the name attribute"); + return -1; + } + if (!PyUnicode_Check(value)) { + PyErr_SetString(PyExc_TypeError, + "The name attribute value must be a string"); + return -1; + } + Py_INCREF(value); + Py_CLEAR(self->name); + self->name = value; + return 0; +} + +static PyGetSetDef RuleProvider_getsetters[] = { + {"name", (getter) RuleProvider_getname, (setter) RuleProvider_setname, + "name", NULL}, + {NULL} /* Sentinel */ +}; + +static PyObject * +RuleProvider_name(RuleProviderObject *self, PyObject *Py_UNUSED(ignored)) +{ + Py_INCREF(self->name); + return self->name; +} + +static PyObject * +RuleProvider_match(RuleProviderObject *self, PyObject *args) +{ + PyObject *result; + PyObject *tmp; + const char *provider_name; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &tmp)) //Format "O","O!","O&": Borrowed reference. + return NULL; + + if (tmp == NULL) + Py_RETURN_FALSE; + + Py_INCREF(tmp); +// PyObject *py_src_port = PyDict_GetItemString(tmp, "src_port"); //Return value: Borrowed reference. +// PyObject *py_dst_port = PyDict_GetItemString(tmp, "dst_port"); //Return value: Borrowed reference. +// Py_INCREF(py_src_port); +// Py_INCREF(py_dst_port); +// char *c_src_port = (char *) malloc(PyLong_AsSize_t(py_src_port)); +// char *c_dst_port = (char *) malloc(PyLong_AsSize_t(py_dst_port)); +// sprintf(c_src_port, "%ld", PyLong_AsLong(py_src_port)); +// sprintf(c_dst_port, "%ld", PyLong_AsLong(py_dst_port)); + + struct Metadata metadata = { + .type = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "type")), // PyDict_GetItemString() Return value: Borrowed reference. + .network = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "network")), + .process_name = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "process_name")), + .host = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "host")), + .src_ip = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "src_ip")), + .src_port = (unsigned short)PyLong_AsUnsignedLong(PyDict_GetItemString(tmp, "src_port")), + .dst_ip = PyUnicode_AsUTF8(PyDict_GetItemString(tmp, "dst_ip")), + .dst_port = (unsigned short)PyLong_AsUnsignedLong(PyDict_GetItemString(tmp, "dst_port")) + }; + +// Py_DECREF(py_src_port); +// Py_DECREF(py_dst_port); + + Py_INCREF(self->name); + provider_name = PyUnicode_AsUTF8(self->name); + Py_DECREF(self->name); + Py_DECREF(tmp); + + int rs = rule_provider_callback_fn(provider_name, &metadata); + + result = (rs == 1) ? Py_True : Py_False; + Py_INCREF(result); + return result; +} + +static PyMethodDef RuleProvider_methods[] = { + {"name", (PyCFunction) RuleProvider_name, METH_NOARGS, + "Return the RuleProvider name" + }, + {"match", (PyCFunction) RuleProvider_match, METH_VARARGS, + "Match the rule by the RuleProvider, match(metadata) -> boolean" + }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject RuleProviderType = { + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "clash.RuleProvider", + .tp_doc = "Clash RuleProvider objects", + .tp_basicsize = sizeof(RuleProviderObject), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, + .tp_new = RuleProvider_new, + .tp_init = (initproc) RuleProvider_init, + .tp_dealloc = (destructor) RuleProvider_dealloc, + .tp_traverse = (traverseproc) RuleProvider_traverse, + .tp_clear = (inquiry) RuleProvider_clear, +// .tp_members = RuleProvider_members, + .tp_methods = RuleProvider_methods, + .tp_getset = RuleProvider_getsetters, +}; + +/* end RuleProvider objects */ +/* --------------------------------------------------------------------- */ + +/* Context objects */ + +typedef struct { + PyObject_HEAD + PyObject *rule_providers; /* Dict */ +} ContextObject; + +static int +Context_traverse(ContextObject *self, visitproc visit, void *arg) +{ + Py_VISIT(self->rule_providers); + return 0; +} + +static int +Context_clear(ContextObject *self) +{ + Py_CLEAR(self->rule_providers); + return 0; +} + +static void +Context_dealloc(ContextObject *self) +{ + PyObject_GC_UnTrack(self); + Context_clear(self); + Py_TYPE(self)->tp_free((PyObject *) self); +} + +static PyObject * +Context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + ContextObject *self; + self = (ContextObject *) type->tp_alloc(type, 0); + if (self != NULL) { + self->rule_providers = PyDict_New(); + if (self->rule_providers == NULL) { + Py_DECREF(self); + return NULL; + } + } + return (PyObject *) self; +} + +static int +Context_init(ContextObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"rule_providers", NULL}; + PyObject *rule_providers = NULL, *tmp; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O", kwlist, + &rule_providers)) + return -1; + + if (rule_providers) { + tmp = self->rule_providers; + Py_INCREF(rule_providers); + self->rule_providers = rule_providers; + Py_DECREF(tmp); + } + return 0; +} + +static PyObject * +Context_getrule_providers(ContextObject *self, void *closure) +{ + Py_INCREF(self->rule_providers); + return self->rule_providers; +} + +static int +Context_setrule_providers(ContextObject *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, "Cannot delete the rule_providers attribute"); + return -1; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, + "The rule_providers attribute value must be a dict"); + return -1; + } + Py_INCREF(value); + Py_CLEAR(self->rule_providers); + self->rule_providers = value; + return 0; +} + +static PyGetSetDef Context_getsetters[] = { + {"rule_providers", (getter) Context_getrule_providers, (setter) Context_setrule_providers, + "rule_providers", NULL}, + {NULL} /* Sentinel */ +}; + +static PyObject * +Context_resolve_ip(PyObject *self, PyObject *args) +{ + const char *host; + const char *ip; + + if (!PyArg_ParseTuple(args, "s", &host)) + return NULL; + + if (host == NULL) + return PyUnicode_FromString(""); + + ip = resolve_ip_callback_fn(host); + + return PyUnicode_FromString(ip); +} + +static PyObject * +Context_geoip(PyObject *self, PyObject *args) +{ + const char *ip; + const char *countryCode; + + if (!PyArg_ParseTuple(args, "s", &ip)) + return NULL; + + if (ip == NULL) + return PyUnicode_FromString(""); + + countryCode = geoip_callback_fn(ip); + + return PyUnicode_FromString(countryCode); +} + +static PyObject * +Context_log(PyObject *self, PyObject *args) +{ + const char *msg; + + if (!PyArg_ParseTuple(args, "s", &msg)) + return NULL; + + log_callback_fn(msg); + + Py_RETURN_NONE; +} + +static PyMethodDef Context_methods[] = { + {"resolve_ip", (PyCFunction) Context_resolve_ip, METH_VARARGS, + "resolve_ip(host) -> string" + }, + {"geoip", (PyCFunction) Context_geoip, METH_VARARGS, + "geoip(ip) -> string" + }, + {"log", (PyCFunction) Context_log, METH_VARARGS, + "log(msg) -> void" + }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject ContextType = { + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "clash.Context", + .tp_doc = "Clash Context objects", + .tp_basicsize = sizeof(ContextObject), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, + .tp_new = Context_new, + .tp_init = (initproc) Context_init, + .tp_dealloc = (destructor) Context_dealloc, + .tp_traverse = (traverseproc) Context_traverse, + .tp_clear = (inquiry) Context_clear, + .tp_methods = Context_methods, + .tp_getset = Context_getsetters, +}; + +static PyModuleDef clashmodule = { + PyModuleDef_HEAD_INIT, + .m_name = "clash", + .m_doc = "Clash module that creates an extension module for python3.", + .m_size = -1, +}; + +PyMODINIT_FUNC +PyInit_clash(void) +{ + PyObject *m; + + m = PyModule_Create(&clashmodule); + if (m == NULL) + return NULL; + + if (PyType_Ready(&RuleProviderType) < 0) + return NULL; + + Py_INCREF(&RuleProviderType); + if (PyModule_AddObject(m, "RuleProvider", (PyObject *) &RuleProviderType) < 0) { + Py_DECREF(&RuleProviderType); + Py_DECREF(m); + return NULL; + } + + if (PyType_Ready(&ContextType) < 0) + return NULL; + + Py_INCREF(&ContextType); + if (PyModule_AddObject(m, "Context", (PyObject *) &ContextType) < 0) { + Py_DECREF(&ContextType); + Py_DECREF(m); + return NULL; + } + + return m; +} + +/* end Context objects */ + +/* --------------------------------------------------------------------- */ + +void +append_inittab() +{ + /* Add a built-in module, before Py_Initialize */ + PyImport_AppendInittab("clash", PyInit_clash); +} + +int new_clash_py_context(const char *provider_name_arr[], int size) { + PyObject *dict = PyDict_New(); //Return value: New reference. + if (dict == NULL) { + PyErr_SetString(PyExc_TypeError, + "PyDict_New failure"); + return 0; + } + + for (int i = 0; i < size; i++) { + PyObject *rule_provider = RuleProvider_new(&RuleProviderType, NULL, NULL); + if (rule_provider == NULL) { + Py_DECREF(dict); + PyErr_SetString(PyExc_TypeError, + "RuleProvider_new failure"); + return 0; + } + + RuleProviderObject *providerObj = (RuleProviderObject *) rule_provider; + + PyObject *py_name = PyUnicode_FromString(provider_name_arr[i]); //Return value: New reference. + RuleProvider_setname(providerObj, py_name, NULL); + Py_DECREF(py_name); + + PyDict_SetItemString(dict, provider_name_arr[i], rule_provider); //Parameter value: New reference. + Py_DECREF(rule_provider); + } + + clash_context = Context_new(&ContextType, NULL, NULL); + + if (clash_context == NULL) { + Py_DECREF(dict); + PyErr_SetString(PyExc_TypeError, + "Context_new failure"); + return 0; + } + + Context_setrule_providers((ContextObject *) clash_context, dict, NULL); + Py_DECREF(dict); + return 1; +} + +const char *call_main( + const char *type, + const char *network, + const char *process_name, + const char *host, + const char *src_ip, + unsigned short src_port, + const char *dst_ip, + unsigned short dst_port) { + + PyObject *metadataDict; + PyObject *tupleArgs; + PyObject *result; + + metadataDict = PyDict_New(); //Return value: New reference. + + if (metadataDict == NULL) { + PyErr_SetString(PyExc_TypeError, + "PyDict_New failure"); + return "-1"; + } + + PyObject *p_type = PyUnicode_FromString(type); //Return value: New reference. + PyObject *p_network = PyUnicode_FromString(network); //Return value: New reference. + PyObject *p_process_name = PyUnicode_FromString(process_name); //Return value: New reference. + PyObject *p_host = PyUnicode_FromString(host); //Return value: New reference. + PyObject *p_src_ip = PyUnicode_FromString(src_ip); //Return value: New reference. + PyObject *p_src_port = PyLong_FromUnsignedLong((unsigned long)src_port); //Return value: New reference. + PyObject *p_dst_ip = PyUnicode_FromString(dst_ip); //Return value: New reference. + PyObject *p_dst_port = PyLong_FromUnsignedLong((unsigned long)dst_port); //Return value: New reference. + + PyDict_SetItemString(metadataDict, "type", p_type); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "network", p_network); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "process_name", p_process_name); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "host", p_host); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "src_ip", p_src_ip); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "src_port", p_src_port); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "dst_ip", p_dst_ip); //Parameter value: New reference. + PyDict_SetItemString(metadataDict, "dst_port", p_dst_port); //Parameter value: New reference. + + Py_DECREF(p_type); + Py_DECREF(p_network); + Py_DECREF(p_process_name); + Py_DECREF(p_host); + Py_DECREF(p_src_ip); + Py_DECREF(p_src_port); + Py_DECREF(p_dst_ip); + Py_DECREF(p_dst_port); + + tupleArgs = PyTuple_New(2); //Return value: New reference. + if (tupleArgs == NULL) { + Py_DECREF(metadataDict); + PyErr_SetString(PyExc_TypeError, + "PyTuple_New failure"); + return "-1"; + } + + Py_INCREF(clash_context); + PyTuple_SetItem(tupleArgs, 0, clash_context); //clash_context Parameter value: Stolen reference. + PyTuple_SetItem(tupleArgs, 1, metadataDict); //metadataDict Parameter value: Stolen reference. + + Py_INCREF(main_fn); + result = PyObject_CallObject(main_fn, tupleArgs); //Return value: New reference. + Py_DECREF(main_fn); + Py_DECREF(tupleArgs); + + if (result == NULL) { + return "-1"; + } + + if (!PyUnicode_Check(result)) { + Py_DECREF(result); + PyErr_SetString(PyExc_TypeError, + "script main function return value must be a string"); + return "-1"; + } + + const char *adapter = PyUnicode_AsUTF8(result); + + Py_DECREF(result); + + return adapter; +} + +int call_shortcut(PyObject *shortcut_fn, + const char *type, + const char *network, + const char *process_name, + const char *host, + const char *src_ip, + unsigned short src_port, + const char *dst_ip, + unsigned short dst_port) { + + PyObject *args; + PyObject *result; + + args = Py_BuildValue("{s:O, s:s, s:s, s:s, s:s, s:H, s:s, s:H}", + "ctx", clash_context, + "network", network, + "process_name", process_name, + "host", host, + "src_ip", src_ip, + "src_port", src_port, + "dst_ip", dst_ip, + "dst_port", dst_port); //Return value: New reference. + + if (args == NULL) { + PyErr_SetString(PyExc_TypeError, + "Py_BuildValue failure"); + return -1; + } + + PyObject *tupleArgs = PyTuple_New(0); //Return value: New reference. + + Py_INCREF(clash_context); + Py_INCREF(shortcut_fn); + result = PyObject_Call(shortcut_fn, tupleArgs, args); //Return value: New reference. + Py_DECREF(shortcut_fn); + Py_DECREF(clash_context); + Py_DECREF(tupleArgs); + Py_DECREF(args); + + if (result == NULL) { + return -1; + } + + if (!PyBool_Check(result)) { + Py_DECREF(result); + PyErr_SetString(PyExc_TypeError, + "script shortcut return value must be as boolean"); + return -1; + } + + int rs = (result == Py_True) ? 1 : 0; + + Py_DECREF(result); + + return rs; +} + +void finalize_Python() { + Py_CLEAR(main_fn); + Py_CLEAR(clash_context); + Py_CLEAR(clash_module); + Py_FinalizeEx(); + +// clash_module = NULL; +// main_fn = NULL; +// clash_context = NULL; +} + +/* --------------------------------------------------------------------- */ \ No newline at end of file diff --git a/component/script/clash_module.go b/component/script/clash_module.go new file mode 100644 index 00000000..5fccb4c1 --- /dev/null +++ b/component/script/clash_module.go @@ -0,0 +1,337 @@ +package script + +/* +#include "clash_module.h" + +extern const char *resolveIPCallbackFn(const char *host); + +void +go_set_resolve_ip_callback() { + set_resolve_ip_callback(resolveIPCallbackFn); +} + +extern const char *geoipCallbackFn(const char *ip); + +void +go_set_geoip_callback() { + set_geoip_callback(geoipCallbackFn); +} + +extern const int ruleProviderCallbackFn(const char *provider_name, struct Metadata *metadata); + +void +go_set_rule_provider_callback() { + set_rule_provider_callback(ruleProviderCallbackFn); +} + +extern void logCallbackFn(const char *msg); + +void +go_set_log_callback() { + set_log_callback(logCallbackFn); +} +*/ +import "C" +import ( + "errors" + "fmt" + "os" + "runtime" + "strconv" + "strings" + "sync" + "syscall" + "unsafe" + + "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +const ClashScriptModuleName = C.CLASH_SCRIPT_MODULE_NAME + +var lock sync.Mutex + +type PyObject C.PyObject + +func togo(cobject *C.PyObject) *PyObject { + return (*PyObject)(cobject) +} + +func toc(object *PyObject) *C.PyObject { + return (*C.PyObject)(object) +} + +func (pyObject *PyObject) IncRef() { + C.Py_IncRef(toc(pyObject)) +} + +func (pyObject *PyObject) DecRef() { + C.Py_DecRef(toc(pyObject)) +} + +func (pyObject *PyObject) Clear() { + C.py_clear(toc(pyObject)) +} + +// Py_Initialize initialize Python3 +func Py_Initialize(program string, path string) error { + lock.Lock() + defer lock.Unlock() + + if C.Py_IsInitialized() != 0 { + if pyThreadState != nil { + PyEval_RestoreThread(pyThreadState) + } + C.finalize_Python() + } + + path = strings.ReplaceAll(path, "\\", "/") + cPath := C.CString(path) + //defer C.free(unsafe.Pointer(cPath)) + + C.init_python(C.CString(program), cPath) + err := PyLastError() + + if err != nil { + if C.Py_IsInitialized() != 0 { + C.finalize_Python() + _ = os.RemoveAll(constant.Path.ScriptDir()) + } + return err + } else if C.Py_IsInitialized() == 0 { + err = errors.New("initialized script module failure") + return err + } + + initPython3Callback() + return nil +} + +func Py_IsInitialized() bool { + lock.Lock() + defer lock.Unlock() + + return C.Py_IsInitialized() != 0 +} + +func Py_Finalize() { + lock.Lock() + defer lock.Unlock() + + if C.Py_IsInitialized() != 0 { + if pyThreadState != nil { + PyEval_RestoreThread(pyThreadState) + } + C.finalize_Python() + _ = os.RemoveAll(constant.Path.ScriptDir()) + log.Warnln("Clash clean up script mode.") + } +} + +//Py_GetVersion get +func Py_GetVersion() string { + cversion := C.Py_GetVersion() + return strings.Split(C.GoString(cversion), "\n")[0] +} + +// loadPyFunc loads a Python function by module and function name +func loadPyFunc(moduleName, funcName string) (*C.PyObject, error) { + // Convert names to C char* + cMod := C.CString(moduleName) + cFunc := C.CString(funcName) + + // Free memory allocated by C.CString + defer func() { + C.free(unsafe.Pointer(cMod)) + C.free(unsafe.Pointer(cFunc)) + }() + + fnc := C.load_func(cMod, cFunc) + if fnc == nil { + return nil, PyLastError() + } + + return fnc, nil +} + +//PyLastError python last error +func PyLastError() error { + cp := C.py_last_error() + if cp == nil { + return nil + } + + return errors.New(C.GoString(cp)) +} + +func LoadShortcutFunction(shortcut string) (*PyObject, error) { + fnc, err := loadPyFunc(ClashScriptModuleName, shortcut) + if err != nil { + return nil, err + } + return togo(fnc), nil +} + +func LoadMainFunction() error { + C.load_main_func() + err := PyLastError() + if err != nil { + return err + } + return nil +} + +//CallPyMainFunction call python script main function +//return the proxy adapter name. +func CallPyMainFunction(mtd *constant.Metadata) (string, error) { + _type := C.CString(mtd.Type.String()) + network := C.CString(mtd.NetWork.String()) + processName := C.CString(mtd.Process) + host := C.CString(mtd.Host) + + srcPortGo, _ := strconv.ParseUint(mtd.SrcPort, 10, 16) + dstPortGo, _ := strconv.ParseUint(mtd.DstPort, 10, 16) + srcPort := C.ushort(srcPortGo) + dstPort := C.ushort(dstPortGo) + + dstIpGo := "" + srcIpGo := "" + if mtd.SrcIP != nil { + srcIpGo = mtd.SrcIP.String() + } + if mtd.DstIP != nil { + dstIpGo = mtd.DstIP.String() + } + srcIp := C.CString(srcIpGo) + dstIp := C.CString(dstIpGo) + + defer func() { + C.free(unsafe.Pointer(_type)) + C.free(unsafe.Pointer(network)) + C.free(unsafe.Pointer(processName)) + C.free(unsafe.Pointer(host)) + C.free(unsafe.Pointer(srcIp)) + C.free(unsafe.Pointer(dstIp)) + }() + + runtime.LockOSThread() + gilState := PyGILState_Ensure() + defer PyGILState_Release(gilState) + + cRs := C.call_main(_type, network, processName, host, srcIp, srcPort, dstIp, dstPort) + + rs := C.GoString(cRs) + if rs == "-1" { + err := PyLastError() + if err != nil { + log.Errorln("[Script] script code error: %s", err.Error()) + killSelf() + return "", fmt.Errorf("script code error: %w", err) + } else { + return "", fmt.Errorf("script code error, result: %v", rs) + } + } + + return rs, nil +} + +//CallPyShortcut call python script shortcuts function +//param: shortcut name +//return the match result. +func CallPyShortcut(fn *PyObject, mtd *constant.Metadata) (bool, error) { + _type := C.CString(mtd.Type.String()) + network := C.CString(mtd.NetWork.String()) + processName := C.CString(mtd.Process) + host := C.CString(mtd.Host) + + srcPortGo, _ := strconv.ParseUint(mtd.SrcPort, 10, 16) + dstPortGo, _ := strconv.ParseUint(mtd.DstPort, 10, 16) + srcPort := C.ushort(srcPortGo) + dstPort := C.ushort(dstPortGo) + + dstIpGo := "" + srcIpGo := "" + if mtd.SrcIP != nil { + srcIpGo = mtd.SrcIP.String() + } + if mtd.DstIP != nil { + dstIpGo = mtd.DstIP.String() + } + srcIp := C.CString(srcIpGo) + dstIp := C.CString(dstIpGo) + + defer func() { + C.free(unsafe.Pointer(_type)) + C.free(unsafe.Pointer(network)) + C.free(unsafe.Pointer(processName)) + C.free(unsafe.Pointer(host)) + C.free(unsafe.Pointer(srcIp)) + C.free(unsafe.Pointer(dstIp)) + }() + + runtime.LockOSThread() + gilState := PyGILState_Ensure() + defer PyGILState_Release(gilState) + + cRs := C.call_shortcut(toc(fn), _type, network, processName, host, srcIp, srcPort, dstIp, dstPort) + + rs := int(cRs) + if rs == -1 { + err := PyLastError() + if err != nil { + log.Errorln("[Script] script shortcut code error: %s", err.Error()) + killSelf() + return false, fmt.Errorf("script shortcut code error: %w", err) + } else { + return false, fmt.Errorf("script shortcut code error: result: %d", rs) + } + } + + if rs == 1 { + return true, nil + } else { + return false, nil + } +} + +func initPython3Callback() { + C.go_set_resolve_ip_callback() + C.go_set_geoip_callback() + C.go_set_rule_provider_callback() + C.go_set_log_callback() +} + +//NewClashPyContext new clash context for python +func NewClashPyContext(ruleProvidersName []string) error { + length := len(ruleProvidersName) + cStringArr := make([]*C.char, length) + for i, v := range ruleProvidersName { + cStringArr[i] = C.CString(v) + defer C.free(unsafe.Pointer(cStringArr[i])) + } + + cArrPointer := unsafe.Pointer(nil) + if length > 0 { + cArrPointer = unsafe.Pointer(&cStringArr[0]) + } + + rs := C.new_clash_py_context((**C.char)(cArrPointer), C.int(length)) + + if int(rs) == 0 { + err := PyLastError() + return fmt.Errorf("new script module context failure: %s", err.Error()) + } + + return nil +} + +func killSelf() { + p, err := os.FindProcess(os.Getpid()) + + if err != nil { + os.Exit(int(syscall.SIGINT)) + return + } + + _ = p.Signal(syscall.SIGINT) +} diff --git a/component/script/clash_module.h b/component/script/clash_module.h new file mode 100644 index 00000000..3e03e16d --- /dev/null +++ b/component/script/clash_module.h @@ -0,0 +1,62 @@ +#ifndef CLASH_CALLBACK_MODULE_H__ +#define CLASH_CALLBACK_MODULE_H__ + +#include + +#define CLASH_SCRIPT_MODULE_NAME "clash_script" + +struct Metadata { + const char *type; /* type socks5/http */ + const char *network; /* network tcp/udp */ + const char *process_name; + const char *host; + const char *src_ip; + unsigned short src_port; + const char *dst_ip; + unsigned short dst_port; +}; + +/** callback function, that call go function by python3 script. **/ +typedef const char *(*resolve_ip_callback)(const char *host); +typedef const char *(*geoip_callback)(const char *ip); +typedef const int (*rule_provider_callback)(const char *provider_name, struct Metadata *metadata); +typedef void (*log_callback)(const char *msg); + +void set_resolve_ip_callback(resolve_ip_callback cb); +void set_geoip_callback(geoip_callback cb); +void set_rule_provider_callback(rule_provider_callback cb); +void set_log_callback(log_callback cb); +/*---------------------------------------------------------------*/ + +void append_inittab(); +void init_python(const char *program, const char *path); +void load_main_func(); +void finalize_Python(); +void py_clear(PyObject *obj); +const char *py_last_error(); + +PyObject *load_func(const char *module_name, char *func_name); + +int new_clash_py_context(const char *provider_name_arr[], int size); + +const char *call_main( + const char *type, + const char *network, + const char *process_name, + const char *host, + const char *src_ip, + unsigned short src_port, + const char *dst_ip, + unsigned short dst_port); + +int call_shortcut(PyObject *shortcut_fn, + const char *type, + const char *network, + const char *process_name, + const char *host, + const char *src_ip, + unsigned short src_port, + const char *dst_ip, + unsigned short dst_port); + +#endif // CLASH_CALLBACK_MODULE_H__ \ No newline at end of file diff --git a/component/script/clash_module_export.go b/component/script/clash_module_export.go new file mode 100644 index 00000000..b3f5355b --- /dev/null +++ b/component/script/clash_module_export.go @@ -0,0 +1,136 @@ +package script + +/* +#include "clash_module.h" +*/ +import "C" +import ( + "net" + "strconv" + "strings" + "unsafe" + + "github.com/Dreamacro/clash/component/mmdb" + "github.com/Dreamacro/clash/component/resolver" + "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +var ( + ruleProviders = map[string]constant.Rule{} + pyThreadState *PyThreadState +) + +func UpdateRuleProviders(rpd map[string]constant.Rule) { + ruleProviders = rpd + if Py_IsInitialized() { + pyThreadState = PyEval_SaveThread() + } +} + +//export resolveIPCallbackFn +func resolveIPCallbackFn(cHost *C.char) *C.char { + host := C.GoString(cHost) + if len(host) == 0 { + cip := C.CString("") + defer C.free(unsafe.Pointer(cip)) + return cip + } + if ip, err := resolver.ResolveIP(host); err == nil { + cip := C.CString(ip.String()) + defer C.free(unsafe.Pointer(cip)) + return cip + } else { + log.Errorln("[Script] resolve ip error: %s", err.Error()) + cip := C.CString("") + defer C.free(unsafe.Pointer(cip)) + return cip + } +} + +//export geoipCallbackFn +func geoipCallbackFn(cIP *C.char) *C.char { + dstIP := net.ParseIP(C.GoString(cIP)) + + if dstIP == nil { + emptyC := C.CString("") + defer C.free(unsafe.Pointer(emptyC)) + + return emptyC + } + + if dstIP.IsPrivate() || constant.TunBroadcastAddr.Equal(dstIP) { + lanC := C.CString("LAN") + defer C.free(unsafe.Pointer(lanC)) + + return lanC + } + + record, _ := mmdb.Instance().Country(dstIP) + + rc := C.CString(strings.ToUpper(record.Country.IsoCode)) + defer C.free(unsafe.Pointer(rc)) + + return rc +} + +//export ruleProviderCallbackFn +func ruleProviderCallbackFn(cProviderName *C.char, cMetadata *C.struct_Metadata) C.int { + //_type := C.GoString(cMetadata._type) + //network := C.GoString(cMetadata.network) + processName := C.GoString(cMetadata.process_name) + host := C.GoString(cMetadata.host) + srcIp := C.GoString(cMetadata.src_ip) + srcPort := strconv.Itoa(int(cMetadata.src_port)) + dstIp := C.GoString(cMetadata.dst_ip) + dstPort := strconv.Itoa(int(cMetadata.dst_port)) + + dst := net.ParseIP(dstIp) + addrType := constant.AtypDomainName + + if dst != nil { + if dst.To4() != nil { + addrType = constant.AtypIPv4 + } else { + addrType = constant.AtypIPv6 + } + } + + metadata := &constant.Metadata{ + Process: processName, + SrcIP: net.ParseIP(srcIp), + DstIP: dst, + SrcPort: srcPort, + DstPort: dstPort, + AddrType: addrType, + Host: host, + } + + providerName := C.GoString(cProviderName) + + rule, ok := ruleProviders[providerName] + if !ok { + log.Warnln("[Script] rule provider [%s] not found", providerName) + return C.int(0) + } + + if strings.HasPrefix(providerName, "geosite:") { + if len(host) == 0 { + return C.int(0) + } + metadata.AddrType = constant.AtypDomainName + } + + rs := rule.Match(metadata) + + if rs { + return C.int(1) + } + return C.int(0) +} + +//export logCallbackFn +func logCallbackFn(msg *C.char) { + + log.Infoln(C.GoString(msg)) +} diff --git a/component/script/thread.go b/component/script/thread.go new file mode 100644 index 00000000..8b7735c0 --- /dev/null +++ b/component/script/thread.go @@ -0,0 +1,52 @@ +package script + +/* +#include "Python.h" +*/ +import "C" + +//PyThreadState : https://docs.python.org/3/c-api/init.html#c.PyThreadState +type PyThreadState C.PyThreadState + +//PyGILState is an opaque “handle” to the thread state when PyGILState_Ensure() was called, and must be passed to PyGILState_Release() to ensure Python is left in the same state +type PyGILState C.PyGILState_STATE + +//PyEval_SaveThread : https://docs.python.org/3/c-api/init.html#c.PyEval_SaveThread +func PyEval_SaveThread() *PyThreadState { + return (*PyThreadState)(C.PyEval_SaveThread()) +} + +//PyEval_RestoreThread : https://docs.python.org/3/c-api/init.html#c.PyEval_RestoreThread +func PyEval_RestoreThread(tstate *PyThreadState) { + C.PyEval_RestoreThread((*C.PyThreadState)(tstate)) +} + +//PyThreadState_Get : https://docs.python.org/3/c-api/init.html#c.PyThreadState_Get +func PyThreadState_Get() *PyThreadState { + return (*PyThreadState)(C.PyThreadState_Get()) +} + +//PyThreadState_Swap : https://docs.python.org/3/c-api/init.html#c.PyThreadState_Swap +func PyThreadState_Swap(tstate *PyThreadState) *PyThreadState { + return (*PyThreadState)(C.PyThreadState_Swap((*C.PyThreadState)(tstate))) +} + +//PyGILState_Ensure : https://docs.python.org/3/c-api/init.html#c.PyGILState_Ensure +func PyGILState_Ensure() PyGILState { + return PyGILState(C.PyGILState_Ensure()) +} + +//PyGILState_Release : https://docs.python.org/3/c-api/init.html#c.PyGILState_Release +func PyGILState_Release(state PyGILState) { + C.PyGILState_Release(C.PyGILState_STATE(state)) +} + +//PyGILState_GetThisThreadState : https://docs.python.org/3/c-api/init.html#c.PyGILState_GetThisThreadState +func PyGILState_GetThisThreadState() *PyThreadState { + return (*PyThreadState)(C.PyGILState_GetThisThreadState()) +} + +//PyGILState_Check : https://docs.python.org/3/c-api/init.html#c.PyGILState_Check +func PyGILState_Check() bool { + return C.PyGILState_Check() == 1 +} diff --git a/constant/rule_extra.go b/constant/rule_extra.go index 119b42ca..9df17418 100644 --- a/constant/rule_extra.go +++ b/constant/rule_extra.go @@ -2,6 +2,7 @@ package constant import ( "net" + "strings" "github.com/Dreamacro/clash/component/geodata/router" ) @@ -9,8 +10,9 @@ import ( var TunBroadcastAddr = net.IPv4(198, 18, 255, 255) type RuleExtra struct { - Network NetWork - SourceIPs []*net.IPNet + Network NetWork + SourceIPs []*net.IPNet + ProcessNames []string } func (re *RuleExtra) NotMatchNetwork(network NetWork) bool { @@ -30,6 +32,19 @@ func (re *RuleExtra) NotMatchSourceIP(srcIP net.IP) bool { return true } +func (re *RuleExtra) NotMatchProcessName(processName string) bool { + if re.ProcessNames == nil { + return false + } + + for _, pn := range re.ProcessNames { + if strings.EqualFold(pn, processName) { + return false + } + } + return true +} + type RuleGeoSite interface { GetDomainMatcher() *router.DomainMatcher } diff --git a/go.mod b/go.mod index 64f5b329..fdd14c9e 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,8 @@ require ( github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/onsi/ginkgo v1.16.4 // indirect + github.com/google/btree v1.0.1 // indirect + github.com/kr/pretty v0.2.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/u-root/uio v0.0.0-20210528114334-82958018845c // indirect golang.org/x/mod v0.4.2 // indirect diff --git a/hub/route/configs.go b/hub/route/configs.go index 99337dd9..6ec86963 100644 --- a/hub/route/configs.go +++ b/hub/route/configs.go @@ -79,15 +79,7 @@ func patchConfigs(w http.ResponseWriter, r *http.Request) { P.ReCreateMixed(pointerOrDefault(general.MixedPort, ports.MixedPort), tcpIn, udpIn) if general.Tun != nil { - err := P.ReCreateTun(*general.Tun, tcpIn, udpIn) - if err == nil { - log.Infoln("Recreate tun success.") - } else { - log.Errorln("Recreate tun failed: %s", err.Error()) - render.Status(r, http.StatusBadRequest) - render.JSON(w, r, newError(err.Error())) - return - } + P.ReCreateTun(*general.Tun, tcpIn, udpIn) } if general.Mode != nil { diff --git a/listener/tun/dev/dev_windows.go b/listener/tun/dev/dev_windows.go index 2919b80d..f387c60c 100644 --- a/listener/tun/dev/dev_windows.go +++ b/listener/tun/dev/dev_windows.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "os" - "sync/atomic" "time" _ "unsafe" diff --git a/listener/tun/dev/wintun/session_windows.go b/listener/tun/dev/wintun/session_windows.go index 11081158..f023baf7 100644 --- a/listener/tun/dev/wintun/session_windows.go +++ b/listener/tun/dev/wintun/session_windows.go @@ -86,7 +86,5 @@ func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err er } func (session Session) SendPacket(packet []byte) { - if packet != nil && len(packet) > 0 { - syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) - } + syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) } diff --git a/listener/tun/ipstack/gvisor/tun.go b/listener/tun/ipstack/gvisor/tun.go index b19e6fa7..77cd780f 100644 --- a/listener/tun/ipstack/gvisor/tun.go +++ b/listener/tun/ipstack/gvisor/tun.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "strconv" "strings" "sync" @@ -47,14 +48,13 @@ type gvisorAdapter struct { writeHandle *channel.NotificationHandle } -// NewAdapter GvisorAdapter create GvisorAdapter +// GvisorAdapter create GvisorAdapter func NewAdapter(device dev.TunDevice, conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) { ipstack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, }) - adapter := &gvisorAdapter{ device: device, ipstack: ipstack, udpIn: udpIn, @@ -85,10 +85,13 @@ func NewAdapter(device dev.TunDevice, conf config.Tun, tcpIn chan<- C.ConnContex // maximum number of half-open tcp connection set to 1024 // receive buffer size set to 20k tcpFwd := tcp.NewForwarder(ipstack, pool.RelayBufferSize, 1024, func(r *tcp.ForwarderRequest) { + src := net.JoinHostPort(r.ID().RemoteAddress.String(), strconv.Itoa((int)(r.ID().RemotePort))) + dst := net.JoinHostPort(r.ID().LocalAddress.String(), strconv.Itoa((int)(r.ID().LocalPort))) + log.Debugln("Get TCP Syn %v -> %s in ipstack", src, dst) var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { - log.Warnln("Can't create TCP Endpoint in ipstack: %v", err) + log.Warnln("Can't create TCP Endpoint(%s -> %s) in ipstack: %v", src, dst, err) r.Complete(true) return } @@ -113,7 +116,7 @@ func NewAdapter(device dev.TunDevice, conf config.Tun, tcpIn chan<- C.ConnContex ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, adapter.udpHandlePacket) if resolver.DefaultResolver != nil { - err = adapter.ReCreateDNSServer(resolver.DefaultResolver.(*dns.Resolver), resolver.DefaultHostMapper.(*dns.ResolverEnhancer), conf.DnsHijack) + err = adapter.ReCreateDNSServer(resolver.DefaultResolver.(*dns.Resolver), resolver.DefaultHostMapper.(*dns.ResolverEnhancer), conf.DNSListen) if err != nil { return nil, err } @@ -190,6 +193,9 @@ func (t *gvisorAdapter) AsLinkEndpoint() (result stack.LinkEndpoint, err error) for !t.device.IsClose() { packet := make([]byte, mtu) n, err := t.device.Read(packet) + if n == 0 { + continue + } if err != nil && !t.device.IsClose() { log.Errorln("can not read from tun: %v", err) continue @@ -202,9 +208,12 @@ func (t *gvisorAdapter) AsLinkEndpoint() (result stack.LinkEndpoint, err error) p = header.IPv6ProtocolNumber } if linkEP.IsAttached() { - linkEP.InjectInbound(p, stack.NewPacketBuffer(stack.PacketBufferOptions{ + packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(packet[:n]).ToVectorisedView(), - })) + }) + + linkEP.InjectInbound(p, packetBuffer) + packetBuffer.DecRef() } else { log.Debugln("received packet from tun when %s is not attached to any dispatcher.", t.device.Name()) } @@ -222,14 +231,14 @@ func (t *gvisorAdapter) AsLinkEndpoint() (result stack.LinkEndpoint, err error) // WriteNotify implements channel.Notification.WriteNotify. func (t *gvisorAdapter) WriteNotify() { - packet, ok := t.linkCache.Read() - if ok { + packetBuffer := t.linkCache.Read() + if packetBuffer != nil { var vv buffer.VectorisedView // Append upper headers. - vv.AppendView(packet.Pkt.NetworkHeader().View()) - vv.AppendView(packet.Pkt.TransportHeader().View()) + vv.AppendView(packetBuffer.NetworkHeader().View()) + vv.AppendView(packetBuffer.TransportHeader().View()) // Append data payload. - vv.Append(packet.Pkt.Data().ExtractVV()) + vv.Append(packetBuffer.Data().ExtractVV()) _, err := t.device.Write(vv.ToView()) if err != nil && !t.device.IsClose() { diff --git a/listener/tun/ipstack/lwip/dns.go b/listener/tun/ipstack/lwip/dns.go new file mode 100644 index 00000000..6a314c08 --- /dev/null +++ b/listener/tun/ipstack/lwip/dns.go @@ -0,0 +1,87 @@ +package lwip + +import ( + "encoding/binary" + "io" + "net" + "time" + + "github.com/Dreamacro/clash/component/resolver" + D "github.com/Dreamacro/clash/listener/tun/ipstack/commons" + "github.com/Dreamacro/clash/log" + "github.com/yaling888/go-lwip" +) + +const defaultDnsReadTimeout = time.Second * 8 + +func shouldHijackDns(dnsIP net.IP, targetIp net.IP, targetPort int) bool { + if targetPort != 53 { + return false + } + + return dnsIP.Equal(net.IPv4zero) || dnsIP.Equal(targetIp) +} + +func hijackUDPDns(conn golwip.UDPConn, pkt []byte, addr *net.UDPAddr) { + go func() { + defer func(conn golwip.UDPConn) { + _ = conn.Close() + }(conn) + + answer, err := D.RelayDnsPacket(pkt) + if err != nil { + return + } + _, _ = conn.WriteFrom(answer, addr) + }() +} + +func hijackTCPDns(conn net.Conn) { + go func() { + defer func(conn net.Conn) { + _ = conn.Close() + }(conn) + + if err := conn.SetDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil { + return + } + + for { + var length uint16 + if binary.Read(conn, binary.BigEndian, &length) != nil { + return + } + + data := make([]byte, length) + + _, err := io.ReadFull(conn, data) + if err != nil { + return + } + + rb, err := D.RelayDnsPacket(data) + if err != nil { + continue + } + + if binary.Write(conn, binary.BigEndian, uint16(len(rb))) != nil { + return + } + + if _, err = conn.Write(rb); err != nil { + return + } + } + }() +} + +type dnsHandler struct{} + +func newDnsHandler() golwip.DnsHandler { + return &dnsHandler{} +} + +func (d dnsHandler) ResolveIP(host string) (net.IP, error) { + log.Debugln("[TUN] lwip resolve ip for host: %s", host) + return resolver.ResolveIP(host) +} diff --git a/listener/tun/ipstack/lwip/tcp.go b/listener/tun/ipstack/lwip/tcp.go new file mode 100644 index 00000000..c62a6beb --- /dev/null +++ b/listener/tun/ipstack/lwip/tcp.go @@ -0,0 +1,61 @@ +package lwip + +import ( + "net" + "strconv" + + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/context" + "github.com/Dreamacro/clash/log" + "github.com/yaling888/go-lwip" +) + +type tcpHandler struct { + dnsIP net.IP + tcpIn chan<- C.ConnContext +} + +func newTCPHandler(dnsIP net.IP, tcpIn chan<- C.ConnContext) golwip.TCPConnHandler { + return &tcpHandler{dnsIP, tcpIn} +} + +func (h *tcpHandler) Handle(conn net.Conn, target *net.TCPAddr) error { + if shouldHijackDns(h.dnsIP, target.IP, target.Port) { + hijackTCPDns(conn) + log.Debugln("[TUN] hijack dns tcp: %s:%d", target.IP.String(), target.Port) + return nil + } + + if conn.RemoteAddr() == nil { + _ = conn.Close() + return nil + } + + src, _ := conn.LocalAddr().(*net.TCPAddr) + dst, _ := conn.RemoteAddr().(*net.TCPAddr) + + addrType := C.AtypIPv4 + if dst.IP.To4() == nil { + addrType = C.AtypIPv6 + } + + metadata := &C.Metadata{ + NetWork: C.TCP, + Type: C.TUN, + SrcIP: src.IP, + DstIP: dst.IP, + SrcPort: strconv.Itoa(src.Port), + DstPort: strconv.Itoa(dst.Port), + AddrType: addrType, + Host: "", + } + + go func(conn net.Conn, metadata *C.Metadata) { + //if c, ok := conn.(*net.TCPConn); ok { + // c.SetKeepAlive(true) + //} + h.tcpIn <- context.NewConnContext(conn, metadata) + }(conn, metadata) + + return nil +} diff --git a/listener/tun/ipstack/lwip/tun.go b/listener/tun/ipstack/lwip/tun.go new file mode 100644 index 00000000..9037be6e --- /dev/null +++ b/listener/tun/ipstack/lwip/tun.go @@ -0,0 +1,121 @@ +package lwip + +import ( + "io" + "net" + "sync" + + "github.com/Dreamacro/clash/adapter/inbound" + "github.com/Dreamacro/clash/common/pool" + "github.com/Dreamacro/clash/config" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/listener/tun/dev" + "github.com/Dreamacro/clash/listener/tun/ipstack" + "github.com/Dreamacro/clash/log" + "github.com/yaling888/go-lwip" +) + +type lwipAdapter struct { + device dev.TunDevice + lwipStack golwip.LWIPStack + lock sync.Mutex + mtu int + stackName string + dnsListen string + autoRoute bool +} + +func NewAdapter(device dev.TunDevice, conf config.Tun, mtu int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) { + adapter := &lwipAdapter{ + device: device, + mtu: mtu, + stackName: conf.Stack, + dnsListen: conf.DNSListen, + autoRoute: conf.AutoRoute, + } + + adapter.lock.Lock() + defer adapter.lock.Unlock() + + dnsHost, _, err := net.SplitHostPort(conf.DNSListen) + if err != nil { + return nil, err + } + + dnsIP := net.ParseIP(dnsHost) + + // Register output function, write packets from lwip stack to tun device + golwip.RegisterOutputFn(func(data []byte) (int, error) { + return device.Write(data) + }) + + // Set custom buffer pool + golwip.SetPoolAllocator(newLWIPPool()) + + // Setup TCP/IP stack. + lwipStack, err := golwip.NewLWIPStack(mtu) + if err != nil { + return nil, err + } + adapter.lwipStack = lwipStack + + golwip.RegisterDnsHandler(newDnsHandler()) + golwip.RegisterTCPConnHandler(newTCPHandler(dnsIP, tcpIn)) + golwip.RegisterUDPConnHandler(newUDPHandler(dnsIP, udpIn)) + + // Copy packets from tun device to lwip stack, it's the loop. + go func(lwipStack golwip.LWIPStack, device dev.TunDevice, mtu int) { + _, err := io.CopyBuffer(lwipStack.(io.Writer), device, make([]byte, mtu)) + if err != nil { + log.Debugln("copying data failed: %v", err) + } + }(lwipStack, device, mtu) + + return adapter, nil +} + +func (l *lwipAdapter) Stack() string { + return l.stackName +} + +func (l *lwipAdapter) AutoRoute() bool { + return l.autoRoute +} + +func (l *lwipAdapter) DNSListen() string { + return l.dnsListen +} + +func (l *lwipAdapter) Close() { + l.lock.Lock() + defer l.lock.Unlock() + + l.stopLocked() +} + +func (l *lwipAdapter) stopLocked() { + if l.lwipStack != nil { + _ = l.lwipStack.Close() + } + + if l.device != nil { + _ = l.device.Close() + } + + l.lwipStack = nil + l.device = nil +} + +type lwipPool struct{} + +func (p lwipPool) Get(size int) []byte { + return pool.Get(size) +} + +func (p lwipPool) Put(buf []byte) error { + return pool.Put(buf) +} + +func newLWIPPool() golwip.LWIPPool { + return &lwipPool{} +} diff --git a/listener/tun/ipstack/lwip/udp.go b/listener/tun/ipstack/lwip/udp.go new file mode 100644 index 00000000..747796bf --- /dev/null +++ b/listener/tun/ipstack/lwip/udp.go @@ -0,0 +1,74 @@ +package lwip + +import ( + "io" + "net" + + "github.com/Dreamacro/clash/adapter/inbound" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + "github.com/Dreamacro/clash/transport/socks5" + "github.com/yaling888/go-lwip" +) + +type udpPacket struct { + source *net.UDPAddr + payload []byte + sender golwip.UDPConn +} + +func (u *udpPacket) Data() []byte { + return u.payload +} + +func (u *udpPacket) WriteBack(b []byte, addr net.Addr) (n int, err error) { + _, ok := addr.(*net.UDPAddr) + if !ok { + return 0, io.ErrClosedPipe + } + + return u.sender.WriteFrom(b, u.source) +} + +func (u *udpPacket) Drop() { +} + +func (u *udpPacket) LocalAddr() net.Addr { + return u.source +} + +type udpHandler struct { + dnsIP net.IP + udpIn chan<- *inbound.PacketAdapter +} + +func newUDPHandler(dnsIP net.IP, udpIn chan<- *inbound.PacketAdapter) golwip.UDPConnHandler { + return &udpHandler{dnsIP, udpIn} +} + +func (h *udpHandler) Connect(golwip.UDPConn, *net.UDPAddr) error { + return nil +} + +func (h *udpHandler) ReceiveTo(conn golwip.UDPConn, data []byte, addr *net.UDPAddr) error { + if shouldHijackDns(h.dnsIP, addr.IP, addr.Port) { + hijackUDPDns(conn, data, addr) + log.Debugln("[TUN] hijack dns udp: %s:%d", addr.IP.String(), addr.Port) + return nil + } + + packet := &udpPacket{ + source: conn.LocalAddr(), + payload: data, + sender: conn, + } + + go func(addr *net.UDPAddr, packet *udpPacket) { + select { + case h.udpIn <- inbound.NewPacket(socks5.ParseAddrToSocksAddr(addr), packet, C.TUN): + default: + } + }(addr, packet) + + return nil +} diff --git a/rule/common/process.go b/rule/common/process.go index 47731443..9481e1fc 100644 --- a/rule/common/process.go +++ b/rule/common/process.go @@ -30,7 +30,7 @@ func (ps *Process) Match(metadata *C.Metadata) bool { // ignore match in proxy type "tproxy" //if metadata.Type == C.TPROXY || !C.AutoIptables { - if C.AutoIptables == "Enable" { + if metadata.Type == C.TPROXY || C.AutoIptables == "Enable" { return false } diff --git a/rule/geosite.go b/rule/geosite.go new file mode 100644 index 00000000..875320bd --- /dev/null +++ b/rule/geosite.go @@ -0,0 +1,70 @@ +package rules + +import ( + "fmt" + + "github.com/Dreamacro/clash/component/geodata" + "github.com/Dreamacro/clash/component/geodata/router" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + + _ "github.com/Dreamacro/clash/component/geodata/standard" +) + +type GEOSITE struct { + country string + adapter string + ruleExtra *C.RuleExtra + matcher *router.DomainMatcher +} + +func (gs *GEOSITE) RuleType() C.RuleType { + return C.GEOSITE +} + +func (gs *GEOSITE) Match(metadata *C.Metadata) bool { + if metadata.AddrType != C.AtypDomainName { + return false + } + + domain := metadata.Host + return gs.matcher.ApplyDomain(domain) +} + +func (gs *GEOSITE) Adapter() string { + return gs.adapter +} + +func (gs *GEOSITE) Payload() string { + return gs.country +} + +func (gs *GEOSITE) ShouldResolveIP() bool { + return false +} + +func (gs *GEOSITE) RuleExtra() *C.RuleExtra { + return gs.ruleExtra +} + +func (gs *GEOSITE) GetDomainMatcher() *router.DomainMatcher { + return gs.matcher +} + +func NewGEOSITE(country string, adapter string, ruleExtra *C.RuleExtra) (*GEOSITE, error) { + matcher, recordsCount, err := geodata.LoadGeoSiteMatcher(country) + if err != nil { + return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) + } + + log.Infoln("Start initial GeoSite rule %s => %s, records: %d", country, adapter, recordsCount) + + geoSite := &GEOSITE{ + country: country, + adapter: adapter, + ruleExtra: ruleExtra, + matcher: matcher, + } + + return geoSite, nil +} diff --git a/rule/parser.go b/rule/parser.go index 700c51fd..88c8b375 100644 --- a/rule/parser.go +++ b/rule/parser.go @@ -17,6 +17,7 @@ func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { ruleExtra := &C.RuleExtra{ Network: RC.FindNetwork(params), SourceIPs: RC.FindSourceIPs(params), + ProcessNames: RC.FindProcessName(params), } switch tp { diff --git a/rule/port.go b/rule/port.go new file mode 100644 index 00000000..e978e28d --- /dev/null +++ b/rule/port.go @@ -0,0 +1,125 @@ +package rules + +import ( + "fmt" + "strconv" + "strings" + + C "github.com/Dreamacro/clash/constant" +) + +type portReal struct { + portStart int + portEnd int +} + +type Port struct { + adapter string + port string + isSource bool + portList []portReal + ruleExtra *C.RuleExtra +} + +func (p *Port) RuleType() C.RuleType { + if p.isSource { + return C.SrcPort + } + return C.DstPort +} + +func (p *Port) Match(metadata *C.Metadata) bool { + if p.isSource { + return p.matchPortReal(metadata.SrcPort) + } + return p.matchPortReal(metadata.DstPort) +} + +func (p *Port) Adapter() string { + return p.adapter +} + +func (p *Port) Payload() string { + return p.port +} + +func (p *Port) ShouldResolveIP() bool { + return false +} + +func (p *Port) RuleExtra() *C.RuleExtra { + return p.ruleExtra +} + +func (p *Port) matchPortReal(portRef string) bool { + port, err := strconv.Atoi(portRef) + if err != nil { + return false + } + + var rs bool + for _, pr := range p.portList { + if pr.portEnd == -1 { + rs = port == pr.portStart + } else { + rs = port >= pr.portStart && port <= pr.portEnd + } + if rs { + return true + } + } + return false +} + +func NewPort(port string, adapter string, isSource bool, ruleExtra *C.RuleExtra) (*Port, error) { + ports := strings.Split(port, "/") + if len(ports) > 28 { + return nil, fmt.Errorf("%s, too many ports to use, maximum support 28 ports", errPayload.Error()) + } + + var portList []portReal + for _, p := range ports { + if p == "" { + continue + } + + subPorts := strings.Split(p, "-") + subPortsLen := len(subPorts) + if subPortsLen > 2 { + return nil, errPayload + } + + portStart, err := strconv.Atoi(strings.Trim(subPorts[0], "[ ]")) + if err != nil || portStart < 0 || portStart > 65535 { + return nil, errPayload + } + + if subPortsLen == 1 { + portList = append(portList, portReal{portStart, -1}) + } else if subPortsLen == 2 { + portEnd, err1 := strconv.Atoi(strings.Trim(subPorts[1], "[ ]")) + if err1 != nil || portEnd < 0 || portEnd > 65535 { + return nil, errPayload + } + + shouldReverse := portStart > portEnd + if shouldReverse { + portList = append(portList, portReal{portEnd, portStart}) + } else { + portList = append(portList, portReal{portStart, portEnd}) + } + } + } + + if len(portList) == 0 { + return nil, errPayload + } + + return &Port{ + adapter: adapter, + port: port, + isSource: isSource, + portList: portList, + ruleExtra: ruleExtra, + }, nil +} diff --git a/rule/script.go b/rule/script.go new file mode 100644 index 00000000..b1ccb5fb --- /dev/null +++ b/rule/script.go @@ -0,0 +1,73 @@ +package rules + +import ( + "fmt" + "runtime" + "strings" + + S "github.com/Dreamacro/clash/component/script" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" +) + +type Script struct { + shortcut string + adapter string + shortcutFunction *S.PyObject +} + +func (s *Script) RuleType() C.RuleType { + return C.Script +} + +func (s *Script) Match(metadata *C.Metadata) bool { + rs, err := S.CallPyShortcut(s.shortcutFunction, metadata) + if err != nil { + log.Errorln("[Script] match rule error: %s", err.Error()) + return false + } + + return rs +} + +func (s *Script) Adapter() string { + return s.adapter +} + +func (s *Script) Payload() string { + return s.shortcut +} + +func (s *Script) ShouldResolveIP() bool { + return false +} + +func (s *Script) RuleExtra() *C.RuleExtra { + return nil +} + +func NewScript(shortcut string, adapter string) (*Script, error) { + shortcut = strings.ToLower(shortcut) + if !S.Py_IsInitialized() { + return nil, fmt.Errorf("load script shortcut [%s] failure, can't find any shortcuts in the config file", shortcut) + } + + shortcutFunction, err := S.LoadShortcutFunction(shortcut) + if err != nil { + return nil, fmt.Errorf("can't find script shortcut [%s] in the config file", shortcut) + } + + obj := &Script{ + shortcut: shortcut, + adapter: adapter, + shortcutFunction: shortcutFunction, + } + + runtime.SetFinalizer(obj, func(s *Script) { + s.shortcutFunction.Clear() + }) + + log.Infoln("Start initial script shortcut rule %s => %s", shortcut, adapter) + + return obj, nil +} diff --git a/transport/vless/xtls.go b/transport/vless/xtls.go index 69035aa0..6e2a4d46 100644 --- a/transport/vless/xtls.go +++ b/transport/vless/xtls.go @@ -1,8 +1,10 @@ package vless import ( + "context" "net" + C "github.com/Dreamacro/clash/constant" xtls "github.com/xtls/go" ) @@ -20,6 +22,10 @@ func StreamXTLSConn(conn net.Conn, cfg *XTLSConfig) (net.Conn, error) { } xtlsConn := xtls.Client(conn, xtlsConfig) - err := xtlsConn.Handshake() + + // fix tls handshake not timeout + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + err := xtlsConn.HandshakeContext(ctx) return xtlsConn, err }