refactor: h3 for doh

This commit is contained in:
Skyxim 2022-07-06 20:53:34 +08:00
parent baee951657
commit 0c91a4e0f3
2 changed files with 61 additions and 38 deletions

View file

@ -157,6 +157,7 @@ type Config struct {
type RawDNS struct { type RawDNS struct {
Enable bool `yaml:"enable"` Enable bool `yaml:"enable"`
PreferH3 bool `yaml:"prefer-h3"`
IPv6 bool `yaml:"ipv6"` IPv6 bool `yaml:"ipv6"`
UseHosts bool `yaml:"use-hosts"` UseHosts bool `yaml:"use-hosts"`
NameServer []string `yaml:"nameserver"` NameServer []string `yaml:"nameserver"`
@ -767,6 +768,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
dnsCfg := &DNS{ dnsCfg := &DNS{
Enable: cfg.Enable, Enable: cfg.Enable,
Listen: cfg.Listen, Listen: cfg.Listen,
PreferH3: cfg.PreferH3,
IPv6: cfg.IPv6, IPv6: cfg.IPv6,
EnhancedMode: cfg.EnhancedMode, EnhancedMode: cfg.EnhancedMode,
FallbackFilter: FallbackFilter{ FallbackFilter: FallbackFilter{

View file

@ -11,6 +11,7 @@ import (
D "github.com/miekg/dns" D "github.com/miekg/dns"
"go.uber.org/atomic" "go.uber.org/atomic"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
@ -22,12 +23,8 @@ const (
) )
type dohClient struct { type dohClient struct {
url string url string
proxyAdapter string transport http.RoundTripper
transport *http.Transport
h3Transport *http3.RoundTripper
supportH3 *atomic.Bool
firstTest *atomic.Bool
} }
func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) { func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) {
@ -70,30 +67,7 @@ func (dc *dohClient) newRequest(m *D.Msg) (*http.Request, error) {
} }
func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) {
if dc.supportH3.Load() { client := &http.Client{Transport: dc.transport}
msg, err = dc.doRequestWithTransport(req, dc.h3Transport)
if err != nil {
if dc.firstTest.CAS(true, false) {
dc.supportH3.Store(false)
_ = dc.h3Transport.Close()
dc.h3Transport = nil
}
} else {
if dc.firstTest.CAS(true, false) {
dc.supportH3.Store(true)
dc.transport.CloseIdleConnections()
dc.transport = nil
}
}
} else {
msg, err = dc.doRequestWithTransport(req, dc.transport)
}
return
}
func (dc *dohClient) doRequestWithTransport(req *http.Request, transport http.RoundTripper) (*D.Msg, error) {
client := &http.Client{Transport: transport}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
if err != nil { if err != nil {
@ -107,16 +81,28 @@ func (dc *dohClient) doRequestWithTransport(req *http.Request, transport http.Ro
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg := &D.Msg{} msg = &D.Msg{}
err = msg.Unpack(buf) err = msg.Unpack(buf)
return msg, err return msg, err
} }
func newDoHClient(url string, r *Resolver, preferH3 bool, proxyAdapter string) *dohClient { func newDoHClient(url string, r *Resolver, preferH3 bool, proxyAdapter string) *dohClient {
return &dohClient{ return &dohClient{
url: url, url: url,
proxyAdapter: proxyAdapter, transport: newDohTransport(r, preferH3, proxyAdapter),
transport: &http.Transport{ }
}
type dohTransport struct {
*http.Transport
h3 *http3.RoundTripper
preferH3 bool
canUseH3 atomic.Bool
}
func newDohTransport(r *Resolver, preferH3 bool, proxyAdapter string) *dohTransport {
dohT := &dohTransport{
Transport: &http.Transport{
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
@ -136,8 +122,12 @@ func newDoHClient(url string, r *Resolver, preferH3 bool, proxyAdapter string) *
} }
}, },
}, },
preferH3: preferH3,
}
h3Transport: &http3.RoundTripper{ dohT.canUseH3.Store(preferH3)
if preferH3 {
dohT.h3 = &http3.RoundTripper{
Dial: func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { Dial: func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
@ -168,8 +158,39 @@ func newDoHClient(url string, r *Resolver, preferH3 bool, proxyAdapter string) *
} }
} }
}, },
}, }
supportH3: atomic.NewBool(preferH3),
firstTest: atomic.NewBool(true),
} }
return dohT
}
func (doh *dohTransport) RoundTrip(req *http.Request) (*http.Response, error) {
var resp *http.Response
var err error
var bodyBytes []byte
if req.Body != nil {
bodyBytes, err = ioutil.ReadAll(req.Body)
}
req.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes))
if doh.preferH3 && doh.canUseH3.Load() {
resp, err = doh.h3.RoundTrip(req)
if err == nil {
return resp, err
} else {
doh.canUseH3.Store(false)
req.Body = ioutil.NopCloser(bytes.NewReader(bodyBytes))
}
}
resp, err = doh.Transport.RoundTrip(req)
if err != nil {
if doh.preferH3 {
doh.canUseH3.Store(true)
}
return resp, err
}
return resp, err
} }