From db4fb69b1043670e1b790c2af62e94e7f3677803 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 6 Jul 2022 20:53:34 +0800 Subject: [PATCH] refactor: h3 for doh --- config/config.go | 2 + dns/doh.go | 97 +++++++++++++++++++++++++++++------------------- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/config/config.go b/config/config.go index dc4ba2b4..d0d36c32 100644 --- a/config/config.go +++ b/config/config.go @@ -157,6 +157,7 @@ type Config struct { type RawDNS struct { Enable bool `yaml:"enable"` + PreferH3 bool `yaml:"prefer-h3"` IPv6 bool `yaml:"ipv6"` UseHosts bool `yaml:"use-hosts"` NameServer []string `yaml:"nameserver"` @@ -767,6 +768,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R dnsCfg := &DNS{ Enable: cfg.Enable, Listen: cfg.Listen, + PreferH3: cfg.PreferH3, IPv6: cfg.IPv6, EnhancedMode: cfg.EnhancedMode, FallbackFilter: FallbackFilter{ diff --git a/dns/doh.go b/dns/doh.go index 81af438c..c79e7ba2 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -11,6 +11,7 @@ import ( D "github.com/miekg/dns" "go.uber.org/atomic" "io" + "io/ioutil" "net" "net/http" "strconv" @@ -22,12 +23,8 @@ const ( ) type dohClient struct { - url string - proxyAdapter string - transport *http.Transport - h3Transport *http3.RoundTripper - supportH3 *atomic.Bool - firstTest *atomic.Bool + url string + transport http.RoundTripper } func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) { @@ -70,30 +67,7 @@ func (dc *dohClient) newRequest(m *D.Msg) (*http.Request, error) { } func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { - if dc.supportH3.Load() { - msg, err = dc.doRequestWithTransport(req, dc.h3Transport) - if err != nil { - if dc.firstTest.CAS(true, false) { - dc.supportH3.Store(false) - _ = dc.h3Transport.Close() - dc.h3Transport = nil - } - } else { - if dc.firstTest.CAS(true, false) { - dc.supportH3.Store(true) - dc.transport.CloseIdleConnections() - dc.transport = nil - } - } - } else { - msg, err = dc.doRequestWithTransport(req, dc.transport) - } - - return -} - -func (dc *dohClient) doRequestWithTransport(req *http.Request, transport http.RoundTripper) (*D.Msg, error) { - client := &http.Client{Transport: transport} + client := &http.Client{Transport: dc.transport} resp, err := client.Do(req) if err != nil { if err != nil { @@ -107,16 +81,28 @@ func (dc *dohClient) doRequestWithTransport(req *http.Request, transport http.Ro if err != nil { return nil, err } - msg := &D.Msg{} + msg = &D.Msg{} err = msg.Unpack(buf) return msg, err } func newDoHClient(url string, r *Resolver, preferH3 bool, proxyAdapter string) *dohClient { return &dohClient{ - url: url, - proxyAdapter: proxyAdapter, - transport: &http.Transport{ + url: url, + transport: newDohTransport(r, preferH3, proxyAdapter), + } +} + +type dohTransport struct { + *http.Transport + h3 *http3.RoundTripper + preferH3 bool + canUseH3 atomic.Bool +} + +func newDohTransport(r *Resolver, preferH3 bool, proxyAdapter string) *dohTransport { + dohT := &dohTransport{ + Transport: &http.Transport{ ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) @@ -136,8 +122,12 @@ func newDoHClient(url string, r *Resolver, preferH3 bool, proxyAdapter string) * } }, }, + preferH3: preferH3, + } - h3Transport: &http3.RoundTripper{ + dohT.canUseH3.Store(preferH3) + if preferH3 { + dohT.h3 = &http3.RoundTripper{ Dial: func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { host, port, err := net.SplitHostPort(addr) if err != nil { @@ -168,8 +158,39 @@ func newDoHClient(url string, r *Resolver, preferH3 bool, proxyAdapter string) * } } }, - }, - supportH3: atomic.NewBool(preferH3), - firstTest: atomic.NewBool(true), + } } + + return dohT +} + +func (doh *dohTransport) RoundTrip(req *http.Request) (*http.Response, error) { + var resp *http.Response + var err error + var bodyBytes []byte + if req.Body != nil { + bodyBytes, err = ioutil.ReadAll(req.Body) + } + + req.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes)) + if doh.preferH3 && doh.canUseH3.Load() { + resp, err = doh.h3.RoundTrip(req) + if err == nil { + return resp, err + } else { + doh.canUseH3.Store(false) + req.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes)) + } + } + + resp, err = doh.Transport.RoundTrip(req) + if err != nil { + if doh.preferH3 { + doh.canUseH3.Store(true) + } + + return resp, err + } + + return resp, err }