fix: DoH/DoQ doesn't use context

This commit is contained in:
Skyxim 2022-11-19 10:31:50 +08:00
parent f00dc69bb6
commit b8b3c9ef9f
2 changed files with 35 additions and 42 deletions

View file

@ -118,7 +118,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
// Check if there was already an active client before sending the request. // Check if there was already an active client before sending the request.
// We'll only attempt to re-connect if there was one. // We'll only attempt to re-connect if there was one.
client, isCached, err := doh.getClient() client, isCached, err := doh.getClient(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to init http client: %w", err) return nil, fmt.Errorf("failed to init http client: %w", err)
} }
@ -132,7 +132,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
// the case when the connection was closed (due to inactivity for example) // the case when the connection was closed (due to inactivity for example)
// AND the server refuses to open a 0-RTT connection. // AND the server refuses to open a 0-RTT connection.
for i := 0; isCached && doh.shouldRetry(err) && i < 2; i++ { for i := 0; isCached && doh.shouldRetry(err) && i < 2; i++ {
client, err = doh.resetClient(err) client, err = doh.resetClient(ctx, err)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to reset http client: %w", err) return nil, fmt.Errorf("failed to reset http client: %w", err)
} }
@ -142,7 +142,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
if err != nil { if err != nil {
// If the request failed anyway, make sure we don't use this client. // If the request failed anyway, make sure we don't use this client.
_, resErr := doh.resetClient(err) _, resErr := doh.resetClient(ctx, err)
return nil, fmt.Errorf("err:%v,resErr:%v", err, resErr) return nil, fmt.Errorf("err:%v,resErr:%v", err, resErr)
} }
@ -183,7 +183,6 @@ func (doh *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient. // exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient.
func (doh *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *D.Msg) (resp *D.Msg, err error) { func (doh *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *D.Msg) (resp *D.Msg, err error) {
resp, err = doh.exchangeHTTPSClient(ctx, client, req) resp, err = doh.exchangeHTTPSClient(ctx, client, req)
return resp, err return resp, err
} }
@ -207,23 +206,24 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
method = http3.MethodGet0RTT method = http3.MethodGet0RTT
} }
doh.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)) url := doh.url
httpReq, err := http.NewRequestWithContext(ctx, method, doh.url.String(), nil) url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf))
httpReq, err := http.NewRequestWithContext(ctx, method, url.String(), nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating http request to %s: %w", doh.url, err) return nil, fmt.Errorf("creating http request to %s: %w", url, err)
} }
httpReq.Header.Set("Accept", "application/dns-message") httpReq.Header.Set("Accept", "application/dns-message")
httpReq.Header.Set("User-Agent", "") httpReq.Header.Set("User-Agent", "")
httpResp, err := client.Do(httpReq) httpResp, err := client.Do(httpReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("requesting %s: %w", doh.url, err) return nil, fmt.Errorf("requesting %s: %w", url, err)
} }
defer httpResp.Body.Close() defer httpResp.Body.Close()
body, err := io.ReadAll(httpResp.Body) body, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("reading %s: %w", doh.url, err) return nil, fmt.Errorf("reading %s: %w", url, err)
} }
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
@ -232,7 +232,7 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
"expected status %d, got %d from %s", "expected status %d, got %d from %s",
http.StatusOK, http.StatusOK,
httpResp.StatusCode, httpResp.StatusCode,
doh.url, url,
) )
} }
@ -241,7 +241,7 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"unpacking response from %s: body is %s: %w", "unpacking response from %s: body is %s: %w",
doh.url, url,
body, body,
err, err,
) )
@ -281,7 +281,7 @@ func (doh *dnsOverHTTPS) shouldRetry(err error) (ok bool) {
// resetClient triggers re-creation of the *http.Client that is used by this // resetClient triggers re-creation of the *http.Client that is used by this
// upstream. This method accepts the error that caused resetting client as // upstream. This method accepts the error that caused resetting client as
// depending on the error we may also reset the QUIC config. // depending on the error we may also reset the QUIC config.
func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) { func (doh *dnsOverHTTPS) resetClient(ctx context.Context, resetErr error) (client *http.Client, err error) {
doh.clientMu.Lock() doh.clientMu.Lock()
defer doh.clientMu.Unlock() defer doh.clientMu.Unlock()
@ -299,7 +299,7 @@ func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err e
} }
log.Debugln("re-creating the http client due to %v", resetErr) log.Debugln("re-creating the http client due to %v", resetErr)
doh.client, err = doh.createClient() doh.client, err = doh.createClient(ctx)
return doh.client, err return doh.client, err
} }
@ -325,7 +325,7 @@ func (doh *dnsOverHTTPS) resetQUICConfig() {
// getClient gets or lazily initializes an HTTP client (and transport) that will // getClient gets or lazily initializes an HTTP client (and transport) that will
// be used for this DoH resolver. // be used for this DoH resolver.
func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) { func (doh *dnsOverHTTPS) getClient(ctx context.Context) (c *http.Client, isCached bool, err error) {
startTime := time.Now() startTime := time.Now()
doh.clientMu.Lock() doh.clientMu.Lock()
@ -342,7 +342,7 @@ func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error)
} }
log.Debugln("creating a new http client") log.Debugln("creating a new http client")
doh.client, err = doh.createClient() doh.client, err = doh.createClient(ctx)
return doh.client, false, err return doh.client, false, err
} }
@ -351,8 +351,8 @@ func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error)
// will depend on whether HTTP3 is allowed and provided by this upstream. Note, // will depend on whether HTTP3 is allowed and provided by this upstream. Note,
// that we'll attempt to establish a QUIC connection when creating the client in // that we'll attempt to establish a QUIC connection when creating the client in
// order to check whether HTTP3 is supported. // order to check whether HTTP3 is supported.
func (doh *dnsOverHTTPS) createClient() (*http.Client, error) { func (doh *dnsOverHTTPS) createClient(ctx context.Context) (*http.Client, error) {
transport, err := doh.createTransport() transport, err := doh.createTransport(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("initializing http transport: %w", err) return nil, fmt.Errorf("initializing http transport: %w", err)
} }
@ -374,7 +374,7 @@ func (doh *dnsOverHTTPS) createClient() (*http.Client, error) {
// that this function will first attempt to establish a QUIC connection (if // that this function will first attempt to establish a QUIC connection (if
// HTTP3 is enabled in the upstream options). If this attempt is successful, // HTTP3 is enabled in the upstream options). If this attempt is successful,
// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport. // it returns an HTTP3 transport, otherwise it returns the H1/H2 transport.
func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripper, err error) {
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig( tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(
&tls.Config{ &tls.Config{
InsecureSkipVerify: false, InsecureSkipVerify: false,
@ -390,7 +390,7 @@ func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
// First, we attempt to create an HTTP3 transport. If the probe QUIC // First, we attempt to create an HTTP3 transport. If the probe QUIC
// connection is established successfully, we'll be using HTTP3 for this // connection is established successfully, we'll be using HTTP3 for this
// upstream. // upstream.
transportH3, err := doh.createTransportH3(tlsConfig, dialContext) transportH3, err := doh.createTransportH3(ctx, tlsConfig, dialContext)
if err == nil { if err == nil {
log.Debugln("using HTTP/3 for this upstream: QUIC was faster") log.Debugln("using HTTP/3 for this upstream: QUIC was faster")
return transportH3, nil return transportH3, nil
@ -483,6 +483,7 @@ func (h *http3Transport) Close() (err error) {
// in parallel (one for TLS, the other one for QUIC) and if QUIC is faster it // in parallel (one for TLS, the other one for QUIC) and if QUIC is faster it
// will create the *http3.RoundTripper instance. // will create the *http3.RoundTripper instance.
func (doh *dnsOverHTTPS) createTransportH3( func (doh *dnsOverHTTPS) createTransportH3(
ctx context.Context,
tlsConfig *tls.Config, tlsConfig *tls.Config,
dialContext dialHandler, dialContext dialHandler,
) (roundTripper http.RoundTripper, err error) { ) (roundTripper http.RoundTripper, err error) {
@ -490,7 +491,7 @@ func (doh *dnsOverHTTPS) createTransportH3(
return nil, errors.New("HTTP3 support is not enabled") return nil, errors.New("HTTP3 support is not enabled")
} }
addr, err := doh.probeH3(tlsConfig, dialContext) addr, err := doh.probeH3(ctx, tlsConfig, dialContext)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -552,13 +553,14 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.
// upstream. If the test is successful it will return the address that we // upstream. If the test is successful it will return the address that we
// should use to establish the QUIC connections. // should use to establish the QUIC connections.
func (doh *dnsOverHTTPS) probeH3( func (doh *dnsOverHTTPS) probeH3(
ctx context.Context,
tlsConfig *tls.Config, tlsConfig *tls.Config,
dialContext dialHandler, dialContext dialHandler,
) (addr string, err error) { ) (addr string, err error) {
// We're using bootstrapped address instead of what's passed to the function // We're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine // it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there are v4/v6 addresses). // what IP is actually reachable (when there are v4/v6 addresses).
rawConn, err := dialContext(context.Background(), "udp", doh.url.Host) rawConn, err := dialContext(ctx, "udp", doh.url.Host)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to dial: %w", err) return "", fmt.Errorf("failed to dial: %w", err)
} }
@ -597,8 +599,8 @@ func (doh *dnsOverHTTPS) probeH3(
// Run probeQUIC and probeTLS in parallel and see which one is faster. // Run probeQUIC and probeTLS in parallel and see which one is faster.
chQuic := make(chan error, 1) chQuic := make(chan error, 1)
chTLS := make(chan error, 1) chTLS := make(chan error, 1)
go doh.probeQUIC(addr, probeTLSCfg, chQuic) go doh.probeQUIC(ctx, addr, probeTLSCfg, chQuic)
go doh.probeTLS(dialContext, probeTLSCfg, chTLS) go doh.probeTLS(ctx, dialContext, probeTLSCfg, chTLS)
select { select {
case quicErr := <-chQuic: case quicErr := <-chQuic:
@ -622,13 +624,8 @@ func (doh *dnsOverHTTPS) probeH3(
// probeQUIC attempts to establish a QUIC connection to the specified address. // probeQUIC attempts to establish a QUIC connection to the specified address.
// We run probeQUIC and probeTLS in parallel and see which one is faster. // We run probeQUIC and probeTLS in parallel and see which one is faster.
func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) { func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now() startTime := time.Now()
timeout := DefaultTimeout
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout))
defer cancel()
conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig()) conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig())
if err != nil { if err != nil {
ch <- fmt.Errorf("opening QUIC connection to %s: %w", doh.Address(), err) ch <- fmt.Errorf("opening QUIC connection to %s: %w", doh.Address(), err)
@ -646,10 +643,10 @@ func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan e
// probeTLS attempts to establish a TLS connection to the specified address. We // probeTLS attempts to establish a TLS connection to the specified address. We
// run probeQUIC and probeTLS in parallel and see which one is faster. // run probeQUIC and probeTLS in parallel and see which one is faster.
func (doh *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) { func (doh *dnsOverHTTPS) probeTLS(ctx context.Context, dialContext dialHandler, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now() startTime := time.Now()
conn, err := doh.tlsDial(dialContext, "tcp", tlsConfig) conn, err := doh.tlsDial(ctx, dialContext, "tcp", tlsConfig)
if err != nil { if err != nil {
ch <- fmt.Errorf("opening TLS connection: %w", err) ch <- fmt.Errorf("opening TLS connection: %w", err)
return return
@ -705,10 +702,10 @@ func isHTTP3(client *http.Client) (ok bool) {
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own // tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection. // dialContext function to get connection.
func (doh *dnsOverHTTPS) tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) { func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
// We're using bootstrapped address instead of what's passed // We're using bootstrapped address instead of what's passed
// to the function. // to the function.
rawConn, err := dialContext(context.Background(), network, doh.url.Host) rawConn, err := dialContext(ctx, network, doh.url.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -157,7 +157,7 @@ func (doq *dnsOverQUIC) Close() (err error) {
// through it and return the response it got from the server. // through it and return the response it got from the server.
func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) { func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) {
var conn quic.Connection var conn quic.Connection
conn, err = doq.getConnection(true) conn, err = doq.getConnection(ctx,true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -225,7 +225,7 @@ func (doq *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
// argument controls whether we should try to use the existing cached // argument controls whether we should try to use the existing cached
// connection. If it is false, we will forcibly create a new connection and // connection. If it is false, we will forcibly create a new connection and
// close the existing one if needed. // close the existing one if needed.
func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { func (doq *dnsOverQUIC) getConnection(ctx context.Context,useCached bool) (quic.Connection, error) {
var conn quic.Connection var conn quic.Connection
doq.connMu.RLock() doq.connMu.RLock()
conn = doq.conn conn = doq.conn
@ -244,7 +244,7 @@ func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) {
defer doq.connMu.Unlock() defer doq.connMu.Unlock()
var err error var err error
conn, err = doq.openConnection() conn, err = doq.openConnection(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -292,7 +292,7 @@ func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (q
// We can get here if the old QUIC connection is not valid anymore. We // We can get here if the old QUIC connection is not valid anymore. We
// should try to re-create the connection again in this case. // should try to re-create the connection again in this case.
newConn, err := doq.getConnection(false) newConn, err := doq.getConnection(ctx,false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -301,7 +301,7 @@ func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (q
} }
// openConnection opens a new QUIC connection. // openConnection opens a new QUIC connection.
func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connection, err error) {
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig( tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(
&tls.Config{ &tls.Config{
InsecureSkipVerify: false, InsecureSkipVerify: false,
@ -313,14 +313,12 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
// we're using bootstrapped address instead of what's passed to the function // we're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine // it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there're v4/v6 addresses). // what IP is actually reachable (when there're v4/v6 addresses).
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
rawConn, err := getDialHandler(doq.r, doq.proxyAdapter)(ctx, "udp", doq.addr) rawConn, err := getDialHandler(doq.r, doq.proxyAdapter)(ctx, "udp", doq.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open a QUIC connection: %w", err) return nil, fmt.Errorf("failed to open a QUIC connection: %w", err)
} }
// It's never actually used // It's never actually used
_ = rawConn.Close() _ = rawConn.Close()
cancel()
var addr string var addr string
udpConn, ok := rawConn.(*net.UDPConn) udpConn, ok := rawConn.(*net.UDPConn)
if !ok { if !ok {
@ -365,8 +363,6 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
udp = wrapConn udp = wrapConn
} }
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
host, _, err := net.SplitHostPort(doq.addr) host, _, err := net.SplitHostPort(doq.addr)
if err != nil { if err != nil {
return nil, err return nil, err