Merge remote-tracking branch 'tun/with-tun' into Alpha

This commit is contained in:
Meta 2022-03-28 10:51:59 +08:00
commit 64a5fd02da
24 changed files with 241 additions and 180 deletions

View file

@ -14,6 +14,7 @@ type Direct struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
opts = append(opts, dialer.WithDirect())
c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -24,6 +25,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
opts = append(opts, dialer.WithDirect())
pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...) pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -67,7 +67,7 @@ func resolveUDPAddr(network, address string) (*net.UDPAddr, error) {
return nil, err return nil, err
} }
ip, err := resolver.ResolveIP(host) ip, err := resolver.ResolveProxyServerHost(host)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -31,7 +31,13 @@ type HealthCheck struct {
func (hc *HealthCheck) process() { func (hc *HealthCheck) process() {
ticker := time.NewTicker(time.Duration(hc.interval) * time.Second) ticker := time.NewTicker(time.Duration(hc.interval) * time.Second)
go hc.check() go func() {
t := time.NewTicker(30 * time.Second)
<-t.C
t.Stop()
hc.check()
}()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:

View file

@ -32,14 +32,14 @@ func DialContext(ctx context.Context, network, address string, options ...Option
var ip net.IP var ip net.IP
switch network { switch network {
case "tcp4", "udp4": case "tcp4", "udp4":
if opt.interfaceName != "" { if !opt.direct {
ip, err = resolver.ResolveIPv4WithMain(host) ip, err = resolver.ResolveIPv4ProxyServerHost(host)
} else { } else {
ip, err = resolver.ResolveIPv4(host) ip, err = resolver.ResolveIPv4(host)
} }
default: default:
if opt.interfaceName != "" { if !opt.direct {
ip, err = resolver.ResolveIPv6WithMain(host) ip, err = resolver.ResolveIPv6ProxyServerHost(host)
} else { } else {
ip, err = resolver.ResolveIPv6(host) ip, err = resolver.ResolveIPv6(host)
} }
@ -121,7 +121,7 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
results := make(chan dialResult) results := make(chan dialResult)
var primary, fallback dialResult var primary, fallback dialResult
startRacer := func(ctx context.Context, network, host string, ipv6 bool) { startRacer := func(ctx context.Context, network, host string, direct bool, ipv6 bool) {
result := dialResult{ipv6: ipv6, done: true} result := dialResult{ipv6: ipv6, done: true}
defer func() { defer func() {
select { select {
@ -135,14 +135,14 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
var ip net.IP var ip net.IP
if ipv6 { if ipv6 {
if opt.interfaceName != "" { if !direct {
ip, result.error = resolver.ResolveIPv6WithMain(host) ip, result.error = resolver.ResolveIPv6ProxyServerHost(host)
} else { } else {
ip, result.error = resolver.ResolveIPv6(host) ip, result.error = resolver.ResolveIPv6(host)
} }
} else { } else {
if opt.interfaceName != "" { if !direct {
ip, result.error = resolver.ResolveIPv4WithMain(host) ip, result.error = resolver.ResolveIPv4ProxyServerHost(host)
} else { } else {
ip, result.error = resolver.ResolveIPv4(host) ip, result.error = resolver.ResolveIPv4(host)
} }
@ -155,8 +155,8 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
result.Conn, result.error = dialContext(ctx, network, ip, port, opt) result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
} }
go startRacer(ctx, network+"4", host, false) go startRacer(ctx, network+"4", host, opt.direct, false)
go startRacer(ctx, network+"6", host, true) go startRacer(ctx, network+"6", host, opt.direct, true)
for res := range results { for res := range results {
if res.error == nil { if res.error == nil {

View file

@ -12,6 +12,7 @@ type option struct {
interfaceName string interfaceName string
addrReuse bool addrReuse bool
routingMark int routingMark int
direct bool
} }
type Option func(opt *option) type Option func(opt *option)
@ -33,3 +34,9 @@ func WithRoutingMark(mark int) Option {
opt.routingMark = mark opt.routingMark = mark
} }
} }
func WithDirect() Option {
return func(opt *option) {
opt.direct = true
}
}

View file

@ -15,8 +15,8 @@ var (
// DefaultResolver aim to resolve ip // DefaultResolver aim to resolve ip
DefaultResolver Resolver DefaultResolver Resolver
// MainResolver resolve ip with main domain server // ProxyServerHostResolver resolve ip to proxies server host
MainResolver Resolver ProxyServerHostResolver Resolver
// DisableIPv6 means don't resolve ipv6 host // DisableIPv6 means don't resolve ipv6 host
// default value is true // default value is true
@ -46,10 +46,6 @@ func ResolveIPv4(host string) (net.IP, error) {
return ResolveIPv4WithResolver(host, DefaultResolver) return ResolveIPv4WithResolver(host, DefaultResolver)
} }
func ResolveIPv4WithMain(host string) (net.IP, error) {
return ResolveIPv4WithResolver(host, MainResolver)
}
func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data.(net.IP).To4(); ip != nil { if ip := node.Data.(net.IP).To4(); ip != nil {
@ -69,6 +65,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) {
return r.ResolveIPv4(host) return r.ResolveIPv4(host)
} }
if DefaultResolver == nil {
ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout)
defer cancel() defer cancel()
ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host) ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host)
@ -81,15 +78,14 @@ func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) {
return ipAddrs[rand.Intn(len(ipAddrs))], nil return ipAddrs[rand.Intn(len(ipAddrs))], nil
} }
return nil, ErrIPNotFound
}
// ResolveIPv6 with a host, return ipv6 // ResolveIPv6 with a host, return ipv6
func ResolveIPv6(host string) (net.IP, error) { func ResolveIPv6(host string) (net.IP, error) {
return ResolveIPv6WithResolver(host, DefaultResolver) return ResolveIPv6WithResolver(host, DefaultResolver)
} }
func ResolveIPv6WithMain(host string) (net.IP, error) {
return ResolveIPv6WithResolver(host, MainResolver)
}
func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) {
if DisableIPv6 { if DisableIPv6 {
return nil, ErrIPv6Disabled return nil, ErrIPv6Disabled
@ -113,6 +109,7 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) {
return r.ResolveIPv6(host) return r.ResolveIPv6(host)
} }
if DefaultResolver == nil {
ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout)
defer cancel() defer cancel()
ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host) ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host)
@ -125,6 +122,9 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) {
return ipAddrs[rand.Intn(len(ipAddrs))], nil return ipAddrs[rand.Intn(len(ipAddrs))], nil
} }
return nil, ErrIPNotFound
}
// ResolveIPWithResolver same as ResolveIP, but with a resolver // ResolveIPWithResolver same as ResolveIP, but with a resolver
func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) { func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
@ -145,6 +145,7 @@ func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) {
return ip, nil return ip, nil
} }
if DefaultResolver == nil {
ipAddr, err := net.ResolveIPAddr("ip", host) ipAddr, err := net.ResolveIPAddr("ip", host)
if err != nil { if err != nil {
return nil, err return nil, err
@ -153,12 +154,34 @@ func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) {
return ipAddr.IP, nil return ipAddr.IP, nil
} }
return nil, ErrIPNotFound
}
// ResolveIP with a host, return ip // ResolveIP with a host, return ip
func ResolveIP(host string) (net.IP, error) { func ResolveIP(host string) (net.IP, error) {
return ResolveIPWithResolver(host, DefaultResolver) return ResolveIPWithResolver(host, DefaultResolver)
} }
// ResolveIPWithMainResolver with a host, use main resolver, return ip // ResolveIPv4ProxyServerHost proxies server host only
func ResolveIPWithMainResolver(host string) (net.IP, error) { func ResolveIPv4ProxyServerHost(host string) (net.IP, error) {
return ResolveIPWithResolver(host, MainResolver) if ProxyServerHostResolver != nil {
return ResolveIPv4WithResolver(host, ProxyServerHostResolver)
}
return ResolveIPv4(host)
}
// ResolveIPv6ProxyServerHost proxies server host only
func ResolveIPv6ProxyServerHost(host string) (net.IP, error) {
if ProxyServerHostResolver != nil {
return ResolveIPv6WithResolver(host, ProxyServerHostResolver)
}
return ResolveIPv6(host)
}
// ResolveProxyServerHost proxies server host only
func ResolveProxyServerHost(host string) (net.IP, error) {
if ProxyServerHostResolver != nil {
return ResolveIPWithResolver(host, ProxyServerHostResolver)
}
return ResolveIP(host)
} }

View file

@ -81,6 +81,7 @@ type DNS struct {
FakeIPRange *fakeip.Pool FakeIPRange *fakeip.Pool
Hosts *trie.DomainTrie Hosts *trie.DomainTrie
NameServerPolicy map[string]dns.NameServer NameServerPolicy map[string]dns.NameServer
ProxyServerNameserver []dns.NameServer
} }
// FallbackFilter config // FallbackFilter config
@ -158,6 +159,7 @@ type RawDNS struct {
FakeIPFilter []string `yaml:"fake-ip-filter"` FakeIPFilter []string `yaml:"fake-ip-filter"`
DefaultNameserver []string `yaml:"default-nameserver"` DefaultNameserver []string `yaml:"default-nameserver"`
NameServerPolicy map[string]string `yaml:"nameserver-policy"` NameServerPolicy map[string]string `yaml:"nameserver-policy"`
ProxyServerNameserver []string `yaml:"proxy-server-nameserver"`
} }
type RawFallbackFilter struct { type RawFallbackFilter struct {
@ -805,6 +807,10 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS,
return nil, err return nil, err
} }
if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver); err != nil {
return nil, err
}
if len(cfg.DefaultNameserver) == 0 { if len(cfg.DefaultNameserver) == 0 {
return nil, errors.New("default nameserver should have at least one nameserver") return nil, errors.New("default nameserver should have at least one nameserver")
} }

View file

@ -41,6 +41,7 @@ type Resolver struct {
group singleflight.Group group singleflight.Group
lruCache *cache.LruCache lruCache *cache.LruCache
policy *trie.DomainTrie policy *trie.DomainTrie
proxyServer []dnsClient
} }
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
@ -300,6 +301,11 @@ func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D
return ch return ch
} }
// HasProxyServer has proxy server dns client
func (r *Resolver) HasProxyServer() bool {
return len(r.main) > 0
}
type NameServer struct { type NameServer struct {
Net string Net string
Addr string Addr string
@ -318,6 +324,7 @@ type FallbackFilter struct {
type Config struct { type Config struct {
Main, Fallback []NameServer Main, Fallback []NameServer
Default []NameServer Default []NameServer
ProxyServer []NameServer
IPv6 bool IPv6 bool
EnhancedMode C.DNSMode EnhancedMode C.DNSMode
FallbackFilter FallbackFilter FallbackFilter FallbackFilter
@ -343,6 +350,10 @@ func NewResolver(config Config) *Resolver {
r.fallback = transform(config.Fallback, defaultResolver) r.fallback = transform(config.Fallback, defaultResolver)
} }
if len(config.ProxyServer) != 0 {
r.proxyServer = transform(config.ProxyServer, defaultResolver)
}
if len(config.Policy) != 0 { if len(config.Policy) != 0 {
r.policy = trie.New() r.policy = trie.New()
for domain, nameserver := range config.Policy { for domain, nameserver := range config.Policy {
@ -376,10 +387,10 @@ func NewResolver(config Config) *Resolver {
return r return r
} }
func NewMainResolver(old *Resolver) *Resolver { func NewProxyServerHostResolver(old *Resolver) *Resolver {
r := &Resolver{ r := &Resolver{
ipv6: old.ipv6, ipv6: old.ipv6,
main: old.main, main: old.proxyServer,
lruCache: old.lruCache, lruCache: old.lruCache,
hosts: old.hosts, hosts: old.hosts,
policy: old.policy, policy: old.policy,

View file

@ -133,10 +133,11 @@ func updateDNS(c *config.DNS, t *config.Tun) {
}, },
Default: c.DefaultNameserver, Default: c.DefaultNameserver,
Policy: c.NameServerPolicy, Policy: c.NameServerPolicy,
ProxyServer: c.ProxyServerNameserver,
} }
r := dns.NewResolver(cfg) r := dns.NewResolver(cfg)
mr := dns.NewMainResolver(r) pr := dns.NewProxyServerHostResolver(r)
m := dns.NewEnhancer(cfg) m := dns.NewEnhancer(cfg)
// reuse cache of old host mapper // reuse cache of old host mapper
@ -145,9 +146,12 @@ func updateDNS(c *config.DNS, t *config.Tun) {
} }
resolver.DefaultResolver = r resolver.DefaultResolver = r
resolver.MainResolver = mr
resolver.DefaultHostMapper = m resolver.DefaultHostMapper = m
if pr.HasProxyServer() {
resolver.ProxyServerHostResolver = pr
}
if t.Enable { if t.Enable {
resolver.DefaultLocalServer = dns.NewLocalServer(r, m) resolver.DefaultLocalServer = dns.NewLocalServer(r, m)
} }
@ -157,9 +161,9 @@ func updateDNS(c *config.DNS, t *config.Tun) {
} else { } else {
if !t.Enable { if !t.Enable {
resolver.DefaultResolver = nil resolver.DefaultResolver = nil
resolver.MainResolver = nil
resolver.DefaultHostMapper = nil resolver.DefaultHostMapper = nil
resolver.DefaultLocalServer = nil resolver.DefaultLocalServer = nil
resolver.ProxyServerHostResolver = nil
} }
dns.ReCreateServer("", nil, nil) dns.ReCreateServer("", nil, nil)
} }
@ -365,7 +369,9 @@ func updateIPTables(cfg *config.Config) {
log.Infoln("[IPTABLES] Setting iptables completed") log.Infoln("[IPTABLES] Setting iptables completed")
} }
func Cleanup() { func Shutdown() {
P.Cleanup() P.Cleanup()
tproxy.CleanupTProxyIPTables() tproxy.CleanupTProxyIPTables()
log.Warnln("Clash shutting down")
} }

View file

@ -48,7 +48,3 @@ func Parse(options ...Option) error {
executor.ApplyConfig(cfg, true) executor.ApplyConfig(cfg, true)
return nil return nil
} }
func Cleanup() {
executor.Cleanup()
}

View file

@ -29,4 +29,7 @@ type Device interface {
// UseIOBased work for other ip stack // UseIOBased work for other ip stack
UseIOBased() error UseIOBased() error
// Wait waits for the device to close.
Wait()
} }

View file

@ -103,7 +103,7 @@ func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) {
case header.IPv6Version: case header.IPv6Version:
e.InjectInbound(header.IPv6ProtocolNumber, pkt) e.InjectInbound(header.IPv6ProtocolNumber, pkt)
} }
pkt.DecRef() /* release */ pkt.DecRef()
} }
} }

View file

@ -106,6 +106,9 @@ func (t *TUN) Write(packet []byte) (int, error) {
} }
func (t *TUN) Close() error { func (t *TUN) Close() error {
if t.Endpoint != nil {
t.Endpoint.Close()
}
return t.nt.Close() return t.nt.Close()
} }

View file

@ -3,6 +3,12 @@ package adapter
// Handler is a TCP/UDP connection handler that implements // Handler is a TCP/UDP connection handler that implements
// HandleTCPConn and HandleUDPConn methods. // HandleTCPConn and HandleUDPConn methods.
type Handler interface { type Handler interface {
HandleTCPConn(TCPConn) HandleTCP(TCPConn)
HandleUDPConn(UDPConn) HandleUDP(UDPConn)
} }
// TCPHandleFunc handles incoming TCP connection.
type TCPHandleFunc func(TCPConn)
// UDPHandleFunc handles incoming UDP connection.
type UDPHandleFunc func(UDPConn)

View file

@ -24,7 +24,7 @@ type GVHandler struct {
UDPIn chan<- *inbound.PacketAdapter UDPIn chan<- *inbound.PacketAdapter
} }
func (gh *GVHandler) HandleTCPConn(tunConn adapter.TCPConn) { func (gh *GVHandler) HandleTCP(tunConn adapter.TCPConn) {
id := tunConn.ID() id := tunConn.ID()
rAddr := &net.UDPAddr{ rAddr := &net.UDPAddr{
@ -77,7 +77,7 @@ func (gh *GVHandler) HandleTCPConn(tunConn adapter.TCPConn) {
gh.TCPIn <- inbound.NewSocket(socks5.ParseAddrToSocksAddr(rAddr), tunConn, C.TUN) gh.TCPIn <- inbound.NewSocket(socks5.ParseAddrToSocksAddr(rAddr), tunConn, C.TUN)
} }
func (gh *GVHandler) HandleUDPConn(tunConn adapter.UDPConn) { func (gh *GVHandler) HandleUDP(tunConn adapter.UDPConn) {
id := tunConn.ID() id := tunConn.ID()
rAddr := &net.UDPAddr{ rAddr := &net.UDPAddr{

View file

@ -3,14 +3,13 @@ package gvisor
import ( import (
"fmt" "fmt"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
) )
const ( const (
// defaultNICID is the ID of default NIC used by DefaultStack.
defaultNICID tcpip.NICID = 0x01
// nicPromiscuousModeEnabled is the value used by stack to enable // nicPromiscuousModeEnabled is the value used by stack to enable
// or disable NIC's promiscuous mode. // or disable NIC's promiscuous mode.
nicPromiscuousModeEnabled = true nicPromiscuousModeEnabled = true
@ -21,9 +20,9 @@ const (
) )
// withCreatingNIC creates NIC for stack. // withCreatingNIC creates NIC for stack.
func withCreatingNIC(ep stack.LinkEndpoint) Option { func withCreatingNIC(nicID tcpip.NICID, ep stack.LinkEndpoint) option.Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
if err := s.CreateNICWithOptions(s.nicID, ep, if err := s.CreateNICWithOptions(nicID, ep,
stack.NICOptions{ stack.NICOptions{
Disabled: false, Disabled: false,
// If no queueing discipline was specified // If no queueing discipline was specified
@ -37,21 +36,21 @@ func withCreatingNIC(ep stack.LinkEndpoint) Option {
} }
} }
// withPromiscuousMode sets promiscuous mode in the given NIC. // withPromiscuousMode sets promiscuous mode in the given NICs.
func withPromiscuousMode(v bool) Option { func withPromiscuousMode(nicID tcpip.NICID, v bool) option.Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
if err := s.SetPromiscuousMode(s.nicID, v); err != nil { if err := s.SetPromiscuousMode(nicID, v); err != nil {
return fmt.Errorf("set promiscuous mode: %s", err) return fmt.Errorf("set promiscuous mode: %s", err)
} }
return nil return nil
} }
} }
// withSpoofing sets address spoofing in the given NIC, allowing // withSpoofing sets address spoofing in the given NICs, allowing
// endpoints to bind to any address in the NIC. // endpoints to bind to any address in the NIC.
func withSpoofing(v bool) Option { func withSpoofing(nicID tcpip.NICID, v bool) option.Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
if err := s.SetSpoofing(s.nicID, v); err != nil { if err := s.SetSpoofing(nicID, v); err != nil {
return fmt.Errorf("set spoofing: %s", err) return fmt.Errorf("set spoofing: %s", err)
} }
return nil return nil

View file

@ -1,4 +1,4 @@
package gvisor package option
import ( import (
"fmt" "fmt"
@ -7,6 +7,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
) )
@ -56,11 +57,11 @@ const (
tcpDefaultBufferSize = 212 << 10 // 212 KiB tcpDefaultBufferSize = 212 << 10 // 212 KiB
) )
type Option func(*gvStack) error type Option func(*stack.Stack) error
// WithDefault sets all default values for stack. // WithDefault sets all default values for stack.
func WithDefault() Option { func WithDefault() Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
opts := []Option{ opts := []Option{
WithDefaultTTL(defaultTimeToLive), WithDefaultTTL(defaultTimeToLive),
WithForwarding(ipForwardingEnabled), WithForwarding(ipForwardingEnabled),
@ -110,7 +111,7 @@ func WithDefault() Option {
// WithDefaultTTL sets the default TTL used by stack. // WithDefaultTTL sets the default TTL used by stack.
func WithDefaultTTL(ttl uint8) Option { func WithDefaultTTL(ttl uint8) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
opt := tcpip.DefaultTTLOption(ttl) opt := tcpip.DefaultTTLOption(ttl)
if err := s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil { if err := s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set ipv4 default TTL: %s", err) return fmt.Errorf("set ipv4 default TTL: %s", err)
@ -124,7 +125,7 @@ func WithDefaultTTL(ttl uint8) Option {
// WithForwarding sets packet forwarding between NICs for IPv4 & IPv6. // WithForwarding sets packet forwarding between NICs for IPv4 & IPv6.
func WithForwarding(v bool) Option { func WithForwarding(v bool) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, v); err != nil { if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, v); err != nil {
return fmt.Errorf("set ipv4 forwarding: %s", err) return fmt.Errorf("set ipv4 forwarding: %s", err)
} }
@ -138,7 +139,7 @@ func WithForwarding(v bool) Option {
// WithICMPBurst sets the number of ICMP messages that can be sent // WithICMPBurst sets the number of ICMP messages that can be sent
// in a single burst. // in a single burst.
func WithICMPBurst(burst int) Option { func WithICMPBurst(burst int) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
s.SetICMPBurst(burst) s.SetICMPBurst(burst)
return nil return nil
} }
@ -147,7 +148,7 @@ func WithICMPBurst(burst int) Option {
// WithICMPLimit sets the maximum number of ICMP messages permitted // WithICMPLimit sets the maximum number of ICMP messages permitted
// by rate limiter. // by rate limiter.
func WithICMPLimit(limit rate.Limit) Option { func WithICMPLimit(limit rate.Limit) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
s.SetICMPLimit(limit) s.SetICMPLimit(limit)
return nil return nil
} }
@ -155,7 +156,7 @@ func WithICMPLimit(limit rate.Limit) Option {
// WithTCPBufferSizeRange sets the receive and send buffer size range for TCP. // WithTCPBufferSizeRange sets the receive and send buffer size range for TCP.
func WithTCPBufferSizeRange(a, b, c int) Option { func WithTCPBufferSizeRange(a, b, c int) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c} rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil {
return fmt.Errorf("set TCP receive buffer size range: %s", err) return fmt.Errorf("set TCP receive buffer size range: %s", err)
@ -170,7 +171,7 @@ func WithTCPBufferSizeRange(a, b, c int) Option {
// WithTCPCongestionControl sets the current congestion control algorithm. // WithTCPCongestionControl sets the current congestion control algorithm.
func WithTCPCongestionControl(cc string) Option { func WithTCPCongestionControl(cc string) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
opt := tcpip.CongestionControlOption(cc) opt := tcpip.CongestionControlOption(cc)
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set TCP congestion control algorithm: %s", err) return fmt.Errorf("set TCP congestion control algorithm: %s", err)
@ -181,7 +182,7 @@ func WithTCPCongestionControl(cc string) Option {
// WithTCPDelay enables or disables Nagle's algorithm in TCP. // WithTCPDelay enables or disables Nagle's algorithm in TCP.
func WithTCPDelay(v bool) Option { func WithTCPDelay(v bool) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
opt := tcpip.TCPDelayEnabled(v) opt := tcpip.TCPDelayEnabled(v)
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set TCP delay: %s", err) return fmt.Errorf("set TCP delay: %s", err)
@ -192,7 +193,7 @@ func WithTCPDelay(v bool) Option {
// WithTCPModerateReceiveBuffer sets receive buffer moderation for TCP. // WithTCPModerateReceiveBuffer sets receive buffer moderation for TCP.
func WithTCPModerateReceiveBuffer(v bool) Option { func WithTCPModerateReceiveBuffer(v bool) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
opt := tcpip.TCPModerateReceiveBufferOption(v) opt := tcpip.TCPModerateReceiveBufferOption(v)
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set TCP moderate receive buffer: %s", err) return fmt.Errorf("set TCP moderate receive buffer: %s", err)
@ -203,7 +204,7 @@ func WithTCPModerateReceiveBuffer(v bool) Option {
// WithTCPSACKEnabled sets the SACK option for TCP. // WithTCPSACKEnabled sets the SACK option for TCP.
func WithTCPSACKEnabled(v bool) Option { func WithTCPSACKEnabled(v bool) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
opt := tcpip.TCPSACKEnabled(v) opt := tcpip.TCPSACKEnabled(v)
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set TCP SACK: %s", err) return fmt.Errorf("set TCP SACK: %s", err)
@ -214,7 +215,7 @@ func WithTCPSACKEnabled(v bool) Option {
// WithTCPRecovery sets the recovery option for TCP. // WithTCPRecovery sets the recovery option for TCP.
func WithTCPRecovery(v tcpip.TCPRecovery) Option { func WithTCPRecovery(v tcpip.TCPRecovery) Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
return fmt.Errorf("set TCP Recovery: %s", err) return fmt.Errorf("set TCP Recovery: %s", err)
} }

View file

@ -1,22 +1,23 @@
package gvisor package gvisor
import ( import (
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
) )
func withICMPHandler() Option { func withRouteTable(nicID tcpip.NICID) option.Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
// Add default route table for IPv4 and IPv6.
// This will handle all incoming ICMP packets.
s.SetRouteTable([]tcpip.Route{ s.SetRouteTable([]tcpip.Route{
{ {
Destination: header.IPv4EmptySubnet, Destination: header.IPv4EmptySubnet,
NIC: s.nicID, NIC: nicID,
}, },
{ {
Destination: header.IPv6EmptySubnet, Destination: header.IPv6EmptySubnet,
NIC: s.nicID, NIC: nicID,
}, },
}) })
return nil return nil

View file

@ -5,6 +5,7 @@ import (
"github.com/Dreamacro/clash/listener/tun/device" "github.com/Dreamacro/clash/listener/tun/device"
"github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@ -18,23 +19,23 @@ import (
type gvStack struct { type gvStack struct {
*stack.Stack *stack.Stack
device device.Device device device.Device
handler adapter.Handler
nicID tcpip.NICID
} }
func (s *gvStack) Close() error { func (s *gvStack) Close() error {
var err error
if s.device != nil {
err = s.device.Close()
s.device.Wait()
}
if s.Stack != nil { if s.Stack != nil {
s.Stack.Close() s.Stack.Close()
s.Stack.Wait()
} }
if s.device != nil { return err
_ = s.device.Close()
}
return nil
} }
// New allocates a new *gvStack with given options. // New allocates a new *gvStack with given options.
func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack.Stack, error) { func New(device device.Device, handler adapter.Handler, opts ...option.Option) (ipstack.Stack, error) {
s := &gvStack{ s := &gvStack{
Stack: stack.New(stack.Options{ Stack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ NetworkProtocols: []stack.NetworkProtocolFactory{
@ -50,18 +51,14 @@ func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack
}), }),
device: device, device: device,
handler: handler,
nicID: defaultNICID,
} }
opts = append(opts, // Generate unique NIC id.
// Important: We must initiate transport protocol handlers nicID := tcpip.NICID(s.Stack.UniqueID())
// before creating NIC, otherwise NIC would dispatch packets
// to stack and cause race condition.
withICMPHandler(), withTCPHandler(), withUDPHandler(),
// Create stack NIC and then bind link endpoint. opts = append(opts,
withCreatingNIC(device.(stack.LinkEndpoint)), // Create stack NIC and then bind link endpoint to it.
withCreatingNIC(nicID, device),
// In the past we did s.AddAddressRange to assign 0.0.0.0/0 // In the past we did s.AddAddressRange to assign 0.0.0.0/0
// onto the interface. We need that to be able to terminate // onto the interface. We need that to be able to terminate
@ -70,27 +67,34 @@ func New(device device.Device, handler adapter.Handler, opts ...Option) (ipstack
// Promiscuous mode. https://github.com/google/gvisor/issues/3876 // Promiscuous mode. https://github.com/google/gvisor/issues/3876
// //
// Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go // Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go
withPromiscuousMode(nicPromiscuousModeEnabled), withPromiscuousMode(nicID, nicPromiscuousModeEnabled),
// Enable spoofing if a stack may send packets from unowned addresses. // Enable spoofing if a stack may send packets from unowned
// This change required changes to some netgophers since previously, // addresses. This change required changes to some netgophers
// promiscuous mode was enough to let the netstack respond to all // since previously, promiscuous mode was enough to let the
// incoming packets regardless of the packet's destination address. Now // netstack respond to all incoming packets regardless of the
// that a stack.Route is not held for each incoming packet, finding a route // packet's destination address. Now that a stack.Route is not
// may fail with local addresses we don't own but accepted packets for // held for each incoming packet, finding a route may fail with
// while in promiscuous mode. Since we also want to be able to send from // local addresses we don't own but accepted packets for while
// any address (in response the received promiscuous mode packets), we need // in promiscuous mode. Since we also want to be able to send
// to enable spoofing. // from any address (in response the received promiscuous mode
// packets), we need to enable spoofing.
// //
// Ref: https://github.com/google/gvisor/commit/8c0701462a84ff77e602f1626aec49479c308127 // Ref: https://github.com/google/gvisor/commit/8c0701462a84ff77e602f1626aec49479c308127
withSpoofing(nicSpoofingEnabled), withSpoofing(nicID, nicSpoofingEnabled),
// Add default route table for IPv4 and IPv6. This will handle
// all incoming ICMP packets.
withRouteTable(nicID),
// Initiate transport protocol (TCP/UDP) with given handler.
withTCPHandler(handler.HandleTCP), withUDPHandler(handler.HandleUDP),
) )
for _, opt := range opts { for _, opt := range opts {
if err := opt(s); err != nil { if err := opt(s.Stack); err != nil {
return nil, err return nil, err
} }
} }
return s, nil return s, nil
} }

View file

@ -4,6 +4,9 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@ -37,9 +40,9 @@ const (
tcpKeepaliveInterval = 30 * time.Second tcpKeepaliveInterval = 30 * time.Second
) )
func withTCPHandler() Option { func withTCPHandler(handle adapter.TCPHandleFunc) option.Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { tcpForwarder := tcp.NewForwarder(s, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq) ep, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
@ -55,7 +58,7 @@ func withTCPHandler() Option {
TCPConn: gonet.NewTCPConn(&wq, ep), TCPConn: gonet.NewTCPConn(&wq, ep),
id: r.ID(), id: r.ID(),
} }
s.handler.HandleTCPConn(conn) handle(conn)
}) })
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
return nil return nil

View file

@ -5,6 +5,7 @@ import (
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/adapter"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@ -12,9 +13,9 @@ import (
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
) )
func withUDPHandler() Option { func withUDPHandler(handle adapter.UDPHandleFunc) option.Option {
return func(s *gvStack) error { return func(s *stack.Stack) error {
udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) { udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
var wq waiter.Queue var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq) ep, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
@ -23,10 +24,10 @@ func withUDPHandler() Option {
} }
conn := &udpConn{ conn := &udpConn{
UDPConn: gonet.NewUDPConn(s.Stack, &wq, ep), UDPConn: gonet.NewUDPConn(s, &wq, ep),
id: r.ID(), id: r.ID(),
} }
s.handler.HandleUDPConn(conn) handle(conn)
}) })
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
return nil return nil

View file

@ -36,8 +36,6 @@ func (s sysStack) Close() error {
return nil return nil
} }
var ipv4LoopBack = netip.MustParsePrefix("127.0.0.0/8")
func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) {
var ( var (
gateway = tunAddress.Masked().Addr().Next() gateway = tunAddress.Masked().Addr().Next()
@ -71,12 +69,6 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
rAddrIp, _ := netip.AddrFromSlice(rAddr.IP) rAddrIp, _ := netip.AddrFromSlice(rAddr.IP)
rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port)) rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port))
if ipv4LoopBack.Contains(rAddrIp) {
conn.Close()
continue
}
if D.ShouldHijackDns(dnsAddr, rAddrPort) { if D.ShouldHijackDns(dnsAddr, rAddrPort) {
go func() { go func() {
log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String()) log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String())
@ -149,12 +141,6 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
rAddrIp, _ := netip.AddrFromSlice(rAddr.IP) rAddrIp, _ := netip.AddrFromSlice(rAddr.IP)
rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port)) rAddrPort := netip.AddrPortFrom(rAddrIp, uint16(rAddr.Port))
if ipv4LoopBack.Contains(rAddrIp) {
pool.Put(buf)
continue
}
if D.ShouldHijackDns(dnsAddr, rAddrPort) { if D.ShouldHijackDns(dnsAddr, rAddrPort) {
go func() { go func() {
defer pool.Put(buf) defer pool.Put(buf)

View file

@ -13,6 +13,7 @@ import (
"github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/listener/tun/ipstack"
"github.com/Dreamacro/clash/listener/tun/ipstack/commons" "github.com/Dreamacro/clash/listener/tun/ipstack/commons"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor" "github.com/Dreamacro/clash/listener/tun/ipstack/gvisor"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor/option"
"github.com/Dreamacro/clash/listener/tun/ipstack/system" "github.com/Dreamacro/clash/listener/tun/ipstack/system"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
"net/netip" "net/netip"
@ -72,7 +73,7 @@ func New(tunConf *config.Tun, dnsConf *config.DNS, tcpIn chan<- C.ConnContext, u
DNSAdds: tunConf.DNSHijack, DNSAdds: tunConf.DNSHijack,
TCPIn: tcpIn, UDPIn: udpIn, TCPIn: tcpIn, UDPIn: udpIn,
}, },
gvisor.WithDefault(), option.WithDefault(),
) )
if err != nil { if err != nil {

View file

@ -106,13 +106,9 @@ func main() {
log.Fatalln("Parse config error: %s", err.Error()) log.Fatalln("Parse config error: %s", err.Error())
} }
defer executor.Shutdown()
sigCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh <-sigCh
// cleanup
log.Warnln("Clash cleanup")
hub.Cleanup()
log.Warnln("Clash shutting down")
} }