From b8b3c9ef9fcbd7df567bddda991f00f2bff2e938 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sat, 19 Nov 2022 10:31:50 +0800 Subject: [PATCH] fix: DoH/DoQ doesn't use context --- dns/doh.go | 63 ++++++++++++++++++++++++++---------------------------- dns/doq.go | 14 +++++------- 2 files changed, 35 insertions(+), 42 deletions(-) diff --git a/dns/doh.go b/dns/doh.go index f34246d5..84135b72 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -118,7 +118,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D. // Check if there was already an active client before sending the request. // We'll only attempt to re-connect if there was one. - client, isCached, err := doh.getClient() + client, isCached, err := doh.getClient(ctx) if err != nil { return nil, fmt.Errorf("failed to init http client: %w", err) } @@ -132,7 +132,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D. // the case when the connection was closed (due to inactivity for example) // AND the server refuses to open a 0-RTT connection. for i := 0; isCached && doh.shouldRetry(err) && i < 2; i++ { - client, err = doh.resetClient(err) + client, err = doh.resetClient(ctx, err) if err != nil { return nil, fmt.Errorf("failed to reset http client: %w", err) } @@ -142,7 +142,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D. if err != nil { // If the request failed anyway, make sure we don't use this client. - _, resErr := doh.resetClient(err) + _, resErr := doh.resetClient(ctx, err) return nil, fmt.Errorf("err:%v,resErr:%v", err, resErr) } @@ -183,7 +183,6 @@ func (doh *dnsOverHTTPS) closeClient(client *http.Client) (err error) { // exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient. func (doh *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *D.Msg) (resp *D.Msg, err error) { resp, err = doh.exchangeHTTPSClient(ctx, client, req) - return resp, err } @@ -207,23 +206,24 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient( method = http3.MethodGet0RTT } - doh.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)) - httpReq, err := http.NewRequestWithContext(ctx, method, doh.url.String(), nil) + url := doh.url + url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)) + httpReq, err := http.NewRequestWithContext(ctx, method, url.String(), nil) if err != nil { - return nil, fmt.Errorf("creating http request to %s: %w", doh.url, err) + return nil, fmt.Errorf("creating http request to %s: %w", url, err) } httpReq.Header.Set("Accept", "application/dns-message") httpReq.Header.Set("User-Agent", "") httpResp, err := client.Do(httpReq) if err != nil { - return nil, fmt.Errorf("requesting %s: %w", doh.url, err) + return nil, fmt.Errorf("requesting %s: %w", url, err) } defer httpResp.Body.Close() body, err := io.ReadAll(httpResp.Body) if err != nil { - return nil, fmt.Errorf("reading %s: %w", doh.url, err) + return nil, fmt.Errorf("reading %s: %w", url, err) } if httpResp.StatusCode != http.StatusOK { @@ -232,7 +232,7 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient( "expected status %d, got %d from %s", http.StatusOK, httpResp.StatusCode, - doh.url, + url, ) } @@ -241,7 +241,7 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient( if err != nil { return nil, fmt.Errorf( "unpacking response from %s: body is %s: %w", - doh.url, + url, body, err, ) @@ -281,7 +281,7 @@ func (doh *dnsOverHTTPS) shouldRetry(err error) (ok bool) { // resetClient triggers re-creation of the *http.Client that is used by this // upstream. This method accepts the error that caused resetting client as // depending on the error we may also reset the QUIC config. -func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) { +func (doh *dnsOverHTTPS) resetClient(ctx context.Context, resetErr error) (client *http.Client, err error) { doh.clientMu.Lock() defer doh.clientMu.Unlock() @@ -299,7 +299,7 @@ func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err e } log.Debugln("re-creating the http client due to %v", resetErr) - doh.client, err = doh.createClient() + doh.client, err = doh.createClient(ctx) return doh.client, err } @@ -325,7 +325,7 @@ func (doh *dnsOverHTTPS) resetQUICConfig() { // getClient gets or lazily initializes an HTTP client (and transport) that will // be used for this DoH resolver. -func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) { +func (doh *dnsOverHTTPS) getClient(ctx context.Context) (c *http.Client, isCached bool, err error) { startTime := time.Now() doh.clientMu.Lock() @@ -342,7 +342,7 @@ func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) } log.Debugln("creating a new http client") - doh.client, err = doh.createClient() + doh.client, err = doh.createClient(ctx) return doh.client, false, err } @@ -351,8 +351,8 @@ func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) // will depend on whether HTTP3 is allowed and provided by this upstream. Note, // that we'll attempt to establish a QUIC connection when creating the client in // order to check whether HTTP3 is supported. -func (doh *dnsOverHTTPS) createClient() (*http.Client, error) { - transport, err := doh.createTransport() +func (doh *dnsOverHTTPS) createClient(ctx context.Context) (*http.Client, error) { + transport, err := doh.createTransport(ctx) if err != nil { return nil, fmt.Errorf("initializing http transport: %w", err) } @@ -374,7 +374,7 @@ func (doh *dnsOverHTTPS) createClient() (*http.Client, error) { // that this function will first attempt to establish a QUIC connection (if // HTTP3 is enabled in the upstream options). If this attempt is successful, // it returns an HTTP3 transport, otherwise it returns the H1/H2 transport. -func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { +func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripper, err error) { tlsConfig := tlsC.GetGlobalFingerprintTLCConfig( &tls.Config{ InsecureSkipVerify: false, @@ -390,7 +390,7 @@ func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { // First, we attempt to create an HTTP3 transport. If the probe QUIC // connection is established successfully, we'll be using HTTP3 for this // upstream. - transportH3, err := doh.createTransportH3(tlsConfig, dialContext) + transportH3, err := doh.createTransportH3(ctx, tlsConfig, dialContext) if err == nil { log.Debugln("using HTTP/3 for this upstream: QUIC was faster") return transportH3, nil @@ -483,6 +483,7 @@ func (h *http3Transport) Close() (err error) { // in parallel (one for TLS, the other one for QUIC) and if QUIC is faster it // will create the *http3.RoundTripper instance. func (doh *dnsOverHTTPS) createTransportH3( + ctx context.Context, tlsConfig *tls.Config, dialContext dialHandler, ) (roundTripper http.RoundTripper, err error) { @@ -490,7 +491,7 @@ func (doh *dnsOverHTTPS) createTransportH3( return nil, errors.New("HTTP3 support is not enabled") } - addr, err := doh.probeH3(tlsConfig, dialContext) + addr, err := doh.probeH3(ctx, tlsConfig, dialContext) if err != nil { return nil, err } @@ -552,13 +553,14 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls. // upstream. If the test is successful it will return the address that we // should use to establish the QUIC connections. func (doh *dnsOverHTTPS) probeH3( + ctx context.Context, tlsConfig *tls.Config, dialContext dialHandler, ) (addr string, err error) { // We're using bootstrapped address instead of what's passed to the function // it does not create an actual connection, but it helps us determine // what IP is actually reachable (when there are v4/v6 addresses). - rawConn, err := dialContext(context.Background(), "udp", doh.url.Host) + rawConn, err := dialContext(ctx, "udp", doh.url.Host) if err != nil { return "", fmt.Errorf("failed to dial: %w", err) } @@ -597,8 +599,8 @@ func (doh *dnsOverHTTPS) probeH3( // Run probeQUIC and probeTLS in parallel and see which one is faster. chQuic := make(chan error, 1) chTLS := make(chan error, 1) - go doh.probeQUIC(addr, probeTLSCfg, chQuic) - go doh.probeTLS(dialContext, probeTLSCfg, chTLS) + go doh.probeQUIC(ctx, addr, probeTLSCfg, chQuic) + go doh.probeTLS(ctx, dialContext, probeTLSCfg, chTLS) select { case quicErr := <-chQuic: @@ -622,13 +624,8 @@ func (doh *dnsOverHTTPS) probeH3( // probeQUIC attempts to establish a QUIC connection to the specified address. // We run probeQUIC and probeTLS in parallel and see which one is faster. -func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) { +func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *tls.Config, ch chan error) { startTime := time.Now() - - timeout := DefaultTimeout - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) - defer cancel() - conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig()) if err != nil { ch <- fmt.Errorf("opening QUIC connection to %s: %w", doh.Address(), err) @@ -646,10 +643,10 @@ func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan e // probeTLS attempts to establish a TLS connection to the specified address. We // run probeQUIC and probeTLS in parallel and see which one is faster. -func (doh *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) { +func (doh *dnsOverHTTPS) probeTLS(ctx context.Context, dialContext dialHandler, tlsConfig *tls.Config, ch chan error) { startTime := time.Now() - conn, err := doh.tlsDial(dialContext, "tcp", tlsConfig) + conn, err := doh.tlsDial(ctx, dialContext, "tcp", tlsConfig) if err != nil { ch <- fmt.Errorf("opening TLS connection: %w", err) return @@ -705,10 +702,10 @@ func isHTTP3(client *http.Client) (ok bool) { // tlsDial is basically the same as tls.DialWithDialer, but we will call our own // dialContext function to get connection. -func (doh *dnsOverHTTPS) tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) { +func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) { // We're using bootstrapped address instead of what's passed // to the function. - rawConn, err := dialContext(context.Background(), network, doh.url.Host) + rawConn, err := dialContext(ctx, network, doh.url.Host) if err != nil { return nil, err } diff --git a/dns/doq.go b/dns/doq.go index 316417ef..7823a403 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -157,7 +157,7 @@ func (doq *dnsOverQUIC) Close() (err error) { // through it and return the response it got from the server. func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) { var conn quic.Connection - conn, err = doq.getConnection(true) + conn, err = doq.getConnection(ctx,true) if err != nil { return nil, err } @@ -225,7 +225,7 @@ func (doq *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { // argument controls whether we should try to use the existing cached // connection. If it is false, we will forcibly create a new connection and // close the existing one if needed. -func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { +func (doq *dnsOverQUIC) getConnection(ctx context.Context,useCached bool) (quic.Connection, error) { var conn quic.Connection doq.connMu.RLock() conn = doq.conn @@ -244,7 +244,7 @@ func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { defer doq.connMu.Unlock() var err error - conn, err = doq.openConnection() + conn, err = doq.openConnection(ctx) if err != nil { return nil, err } @@ -292,7 +292,7 @@ func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (q // We can get here if the old QUIC connection is not valid anymore. We // should try to re-create the connection again in this case. - newConn, err := doq.getConnection(false) + newConn, err := doq.getConnection(ctx,false) if err != nil { return nil, err } @@ -301,7 +301,7 @@ func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (q } // openConnection opens a new QUIC connection. -func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { +func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connection, err error) { tlsConfig := tlsC.GetGlobalFingerprintTLCConfig( &tls.Config{ InsecureSkipVerify: false, @@ -313,14 +313,12 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { // we're using bootstrapped address instead of what's passed to the function // it does not create an actual connection, but it helps us determine // what IP is actually reachable (when there're v4/v6 addresses). - ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout) rawConn, err := getDialHandler(doq.r, doq.proxyAdapter)(ctx, "udp", doq.addr) if err != nil { return nil, fmt.Errorf("failed to open a QUIC connection: %w", err) } // It's never actually used _ = rawConn.Close() - cancel() var addr string udpConn, ok := rawConn.(*net.UDPConn) if !ok { @@ -365,8 +363,6 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { udp = wrapConn } - ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) - defer cancel() host, _, err := net.SplitHostPort(doq.addr) if err != nil { return nil, err