diff --git a/config/config.go b/config/config.go index 8386c0c6..f912103c 100644 --- a/config/config.go +++ b/config/config.go @@ -78,7 +78,7 @@ type DNS struct { EnhancedMode C.DNSMode `yaml:"enhanced-mode"` DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` FakeIPRange *fakeip.Pool - Hosts *trie.DomainTrie + Hosts *trie.DomainTrie[netip.Addr] NameServerPolicy map[string]dns.NameServer ProxyServerNameserver []dns.NameServer } @@ -113,12 +113,6 @@ type Tun struct { AutoRoute bool `yaml:"auto-route" json:"auto-route"` } -// Script config -type Script struct { - MainCode string `yaml:"code" json:"code"` - ShortcutsCode map[string]string `yaml:"shortcuts" json:"shortcuts"` -} - // IPTables config type IPTables struct { Enable bool `yaml:"enable" json:"enable"` @@ -142,7 +136,7 @@ type Config struct { IPTables *IPTables DNS *DNS Experimental *Experimental - Hosts *trie.DomainTrie + Hosts *trie.DomainTrie[netip.Addr] Profile *Profile Rules []C.Rule Users []auth.AuthUser @@ -564,7 +558,7 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[strin params = rule[l:] } - if _, ok := proxies[target]; mode != T.Script && !ok { + if _, ok := proxies[target]; !ok { return nil, nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target) } @@ -581,9 +575,7 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[strin return nil, nil, fmt.Errorf("rules[%d] [%s] error: %s", idx, line, parseErr.Error()) } - if mode != T.Script { - rules = append(rules, parsed) - } + rules = append(rules, parsed) } runtime.GC() @@ -591,18 +583,18 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[strin return rules, ruleProviders, nil } -func parseHosts(cfg *RawConfig) (*trie.DomainTrie, error) { - tree := trie.New() +func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) { + tree := trie.New[netip.Addr]() // add default hosts - if err := tree.Insert("localhost", net.IP{127, 0, 0, 1}); err != nil { + if err := tree.Insert("localhost", netip.AddrFrom4([4]byte{127, 0, 0, 1})); err != nil { log.Errorln("insert localhost to host error: %s", err.Error()) } if len(cfg.Hosts) != 0 { for domain, ipStr := range cfg.Hosts { - ip := net.ParseIP(ipStr) - if ip == nil { + ip, err := netip.ParseAddr(ipStr) + if err != nil { return nil, fmt.Errorf("%s is not a valid IP", ipStr) } _ = tree.Insert(domain, ip) @@ -750,7 +742,7 @@ func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]*router.DomainM return sites, nil } -func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, error) { +func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.Rule) (*DNS, error) { cfg := rawCfg.DNS if cfg.Enable && len(cfg.NameServer) == 0 { return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty") @@ -806,10 +798,10 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, return nil, err } - var host *trie.DomainTrie + var host *trie.DomainTrie[bool] // fake ip skip host filter if len(cfg.FakeIPFilter) != 0 { - host = trie.New() + host = trie.New[bool]() for _, domain := range cfg.FakeIPFilter { _ = host.Insert(domain, true) } @@ -817,7 +809,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, if len(dnsCfg.Fallback) != 0 { if host == nil { - host = trie.New() + host = trie.New[bool]() } for _, fb := range dnsCfg.Fallback { if net.ParseIP(fb.Addr) != nil { diff --git a/dns/middleware.go b/dns/middleware.go index 7259df66..eeb40c1d 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -21,7 +21,7 @@ type ( middleware func(next handler) handler ) -func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware { +func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[string, string]) middleware { return func(next handler) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] @@ -30,7 +30,9 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware { return next(ctx, r) } - record := hosts.Search(strings.TrimRight(q.Name, ".")) + host := strings.TrimRight(q.Name, ".") + + record := hosts.Search(host) if record == nil { return next(ctx, r) } @@ -40,13 +42,13 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware { if ip.Is4() && q.Qtype == D.TypeA { rr := &D.A{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10} rr.A = ip.AsSlice() msg.Answer = []D.RR{rr} } else if ip.Is6() && q.Qtype == D.TypeAAAA { rr := &D.AAAA{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10} rr.AAAA = ip.AsSlice() msg.Answer = []D.RR{rr} @@ -54,6 +56,10 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware { return next(ctx, r) } + if mapping != nil { + mapping.SetWithExpire(ip.Unmap().String(), host, time.Now().Add(time.Second*10)) + } + ctx.SetType(context.DNSTypeHost) msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true @@ -177,7 +183,7 @@ func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { middlewares := []middleware{} if resolver.hosts != nil { - middlewares = append(middlewares, withHosts(resolver.hosts)) + middlewares = append(middlewares, withHosts(resolver.hosts, mapper.mapping)) } if mapper.mode == C.DNSFakeIP { diff --git a/listener/http/utils.go b/listener/http/utils.go index 74b12005..bcee60f0 100644 --- a/listener/http/utils.go +++ b/listener/http/utils.go @@ -40,7 +40,7 @@ func removeExtraHTTPHostPort(req *http.Request) { host = req.URL.Host } - if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" { + if pHost, port, err := net.SplitHostPort(host); err == nil && (port == "80" || port == "443") { host = pHost } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index c3623957..0072d4f7 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -152,7 +152,7 @@ func preHandleMetadata(metadata *C.Metadata) error { metadata.DNSMode = C.DNSFakeIP } else if node := resolver.DefaultHosts.Search(host); node != nil { // redir-host should lookup the hosts - metadata.DstIP = node.Data.(net.IP) + metadata.DstIP = node.Data.AsSlice() } } else if resolver.IsFakeIP(metadata.DstIP) { return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) @@ -175,7 +175,7 @@ func preHandleMetadata(metadata *C.Metadata) error { return nil } -func resolveMetadata(ctx C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) { +func resolveMetadata(_ C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) { switch mode { case Direct: proxy = proxies["DIRECT"] @@ -211,7 +211,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) { handle := func() bool { pc := natTable.Get(key) if pc != nil { - handleUDPToRemote(packet, pc, metadata) + _ = handleUDPToRemote(packet, pc, metadata) return true } return false @@ -247,7 +247,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) { ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout) defer cancel() - rawPc, err := proxy.ListenPacketContext(ctx, metadata) + rawPc, err := proxy.ListenPacketContext(ctx, metadata.Pure()) if err != nil { if rule == nil { log.Warnln("[UDP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error()) @@ -284,7 +284,9 @@ func handleUDPConn(packet *inbound.PacketAdapter) { } func handleTCPConn(connCtx C.ConnContext) { - defer connCtx.Conn().Close() + defer func(conn net.Conn) { + _ = conn.Close() + }(connCtx.Conn()) metadata := connCtx.Metadata() if !metadata.Valid() { @@ -307,9 +309,7 @@ func handleTCPConn(connCtx C.ConnContext) { return } - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) - defer cancel() - remoteConn, err := proxy.DialContext(ctx, metadata) + remoteConn, err := proxy.DialContext(ctx, metadata.Pure()) if err != nil { if rule == nil { log.Warnln("[TCP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error()) @@ -319,7 +319,9 @@ func handleTCPConn(connCtx C.ConnContext) { return } remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule) - defer remoteConn.Close() + defer func(remoteConn C.Conn) { + _ = remoteConn.Close() + }(remoteConn) switch true { case rule != nil: @@ -352,8 +354,7 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { var resolved bool if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { - ip := node.Data.(net.IP) - metadata.DstIP = ip + metadata.DstIP = node.Data.AsSlice() resolved = true }