fix: DoH/DoQ doesn't use context
This commit is contained in:
parent
f00dc69bb6
commit
b8b3c9ef9f
2 changed files with 35 additions and 42 deletions
63
dns/doh.go
63
dns/doh.go
|
@ -118,7 +118,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
|
|||
|
||||
// Check if there was already an active client before sending the request.
|
||||
// We'll only attempt to re-connect if there was one.
|
||||
client, isCached, err := doh.getClient()
|
||||
client, isCached, err := doh.getClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init http client: %w", err)
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
|
|||
// the case when the connection was closed (due to inactivity for example)
|
||||
// AND the server refuses to open a 0-RTT connection.
|
||||
for i := 0; isCached && doh.shouldRetry(err) && i < 2; i++ {
|
||||
client, err = doh.resetClient(err)
|
||||
client, err = doh.resetClient(ctx, err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to reset http client: %w", err)
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
|
|||
|
||||
if err != nil {
|
||||
// If the request failed anyway, make sure we don't use this client.
|
||||
_, resErr := doh.resetClient(err)
|
||||
_, resErr := doh.resetClient(ctx, err)
|
||||
|
||||
return nil, fmt.Errorf("err:%v,resErr:%v", err, resErr)
|
||||
}
|
||||
|
@ -183,7 +183,6 @@ func (doh *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
|
|||
// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient.
|
||||
func (doh *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *D.Msg) (resp *D.Msg, err error) {
|
||||
resp, err = doh.exchangeHTTPSClient(ctx, client, req)
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
|
@ -207,23 +206,24 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
|
|||
method = http3.MethodGet0RTT
|
||||
}
|
||||
|
||||
doh.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, method, doh.url.String(), nil)
|
||||
url := doh.url
|
||||
url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, method, url.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating http request to %s: %w", doh.url, err)
|
||||
return nil, fmt.Errorf("creating http request to %s: %w", url, err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Accept", "application/dns-message")
|
||||
httpReq.Header.Set("User-Agent", "")
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("requesting %s: %w", doh.url, err)
|
||||
return nil, fmt.Errorf("requesting %s: %w", url, err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading %s: %w", doh.url, err)
|
||||
return nil, fmt.Errorf("reading %s: %w", url, err)
|
||||
}
|
||||
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
|
@ -232,7 +232,7 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
|
|||
"expected status %d, got %d from %s",
|
||||
http.StatusOK,
|
||||
httpResp.StatusCode,
|
||||
doh.url,
|
||||
url,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -241,7 +241,7 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"unpacking response from %s: body is %s: %w",
|
||||
doh.url,
|
||||
url,
|
||||
body,
|
||||
err,
|
||||
)
|
||||
|
@ -281,7 +281,7 @@ func (doh *dnsOverHTTPS) shouldRetry(err error) (ok bool) {
|
|||
// resetClient triggers re-creation of the *http.Client that is used by this
|
||||
// upstream. This method accepts the error that caused resetting client as
|
||||
// depending on the error we may also reset the QUIC config.
|
||||
func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) {
|
||||
func (doh *dnsOverHTTPS) resetClient(ctx context.Context, resetErr error) (client *http.Client, err error) {
|
||||
doh.clientMu.Lock()
|
||||
defer doh.clientMu.Unlock()
|
||||
|
||||
|
@ -299,7 +299,7 @@ func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err e
|
|||
}
|
||||
|
||||
log.Debugln("re-creating the http client due to %v", resetErr)
|
||||
doh.client, err = doh.createClient()
|
||||
doh.client, err = doh.createClient(ctx)
|
||||
|
||||
return doh.client, err
|
||||
}
|
||||
|
@ -325,7 +325,7 @@ func (doh *dnsOverHTTPS) resetQUICConfig() {
|
|||
|
||||
// getClient gets or lazily initializes an HTTP client (and transport) that will
|
||||
// be used for this DoH resolver.
|
||||
func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) {
|
||||
func (doh *dnsOverHTTPS) getClient(ctx context.Context) (c *http.Client, isCached bool, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
doh.clientMu.Lock()
|
||||
|
@ -342,7 +342,7 @@ func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error)
|
|||
}
|
||||
|
||||
log.Debugln("creating a new http client")
|
||||
doh.client, err = doh.createClient()
|
||||
doh.client, err = doh.createClient(ctx)
|
||||
|
||||
return doh.client, false, err
|
||||
}
|
||||
|
@ -351,8 +351,8 @@ func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error)
|
|||
// will depend on whether HTTP3 is allowed and provided by this upstream. Note,
|
||||
// that we'll attempt to establish a QUIC connection when creating the client in
|
||||
// order to check whether HTTP3 is supported.
|
||||
func (doh *dnsOverHTTPS) createClient() (*http.Client, error) {
|
||||
transport, err := doh.createTransport()
|
||||
func (doh *dnsOverHTTPS) createClient(ctx context.Context) (*http.Client, error) {
|
||||
transport, err := doh.createTransport(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing http transport: %w", err)
|
||||
}
|
||||
|
@ -374,7 +374,7 @@ func (doh *dnsOverHTTPS) createClient() (*http.Client, error) {
|
|||
// that this function will first attempt to establish a QUIC connection (if
|
||||
// HTTP3 is enabled in the upstream options). If this attempt is successful,
|
||||
// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport.
|
||||
func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
|
||||
func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripper, err error) {
|
||||
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(
|
||||
&tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
|
@ -390,7 +390,7 @@ func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
|
|||
// First, we attempt to create an HTTP3 transport. If the probe QUIC
|
||||
// connection is established successfully, we'll be using HTTP3 for this
|
||||
// upstream.
|
||||
transportH3, err := doh.createTransportH3(tlsConfig, dialContext)
|
||||
transportH3, err := doh.createTransportH3(ctx, tlsConfig, dialContext)
|
||||
if err == nil {
|
||||
log.Debugln("using HTTP/3 for this upstream: QUIC was faster")
|
||||
return transportH3, nil
|
||||
|
@ -483,6 +483,7 @@ func (h *http3Transport) Close() (err error) {
|
|||
// in parallel (one for TLS, the other one for QUIC) and if QUIC is faster it
|
||||
// will create the *http3.RoundTripper instance.
|
||||
func (doh *dnsOverHTTPS) createTransportH3(
|
||||
ctx context.Context,
|
||||
tlsConfig *tls.Config,
|
||||
dialContext dialHandler,
|
||||
) (roundTripper http.RoundTripper, err error) {
|
||||
|
@ -490,7 +491,7 @@ func (doh *dnsOverHTTPS) createTransportH3(
|
|||
return nil, errors.New("HTTP3 support is not enabled")
|
||||
}
|
||||
|
||||
addr, err := doh.probeH3(tlsConfig, dialContext)
|
||||
addr, err := doh.probeH3(ctx, tlsConfig, dialContext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -552,13 +553,14 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.
|
|||
// upstream. If the test is successful it will return the address that we
|
||||
// should use to establish the QUIC connections.
|
||||
func (doh *dnsOverHTTPS) probeH3(
|
||||
ctx context.Context,
|
||||
tlsConfig *tls.Config,
|
||||
dialContext dialHandler,
|
||||
) (addr string, err error) {
|
||||
// We're using bootstrapped address instead of what's passed to the function
|
||||
// it does not create an actual connection, but it helps us determine
|
||||
// what IP is actually reachable (when there are v4/v6 addresses).
|
||||
rawConn, err := dialContext(context.Background(), "udp", doh.url.Host)
|
||||
rawConn, err := dialContext(ctx, "udp", doh.url.Host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to dial: %w", err)
|
||||
}
|
||||
|
@ -597,8 +599,8 @@ func (doh *dnsOverHTTPS) probeH3(
|
|||
// Run probeQUIC and probeTLS in parallel and see which one is faster.
|
||||
chQuic := make(chan error, 1)
|
||||
chTLS := make(chan error, 1)
|
||||
go doh.probeQUIC(addr, probeTLSCfg, chQuic)
|
||||
go doh.probeTLS(dialContext, probeTLSCfg, chTLS)
|
||||
go doh.probeQUIC(ctx, addr, probeTLSCfg, chQuic)
|
||||
go doh.probeTLS(ctx, dialContext, probeTLSCfg, chTLS)
|
||||
|
||||
select {
|
||||
case quicErr := <-chQuic:
|
||||
|
@ -622,13 +624,8 @@ func (doh *dnsOverHTTPS) probeH3(
|
|||
|
||||
// probeQUIC attempts to establish a QUIC connection to the specified address.
|
||||
// We run probeQUIC and probeTLS in parallel and see which one is faster.
|
||||
func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) {
|
||||
func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *tls.Config, ch chan error) {
|
||||
startTime := time.Now()
|
||||
|
||||
timeout := DefaultTimeout
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout))
|
||||
defer cancel()
|
||||
|
||||
conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig())
|
||||
if err != nil {
|
||||
ch <- fmt.Errorf("opening QUIC connection to %s: %w", doh.Address(), err)
|
||||
|
@ -646,10 +643,10 @@ func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan e
|
|||
|
||||
// probeTLS attempts to establish a TLS connection to the specified address. We
|
||||
// run probeQUIC and probeTLS in parallel and see which one is faster.
|
||||
func (doh *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) {
|
||||
func (doh *dnsOverHTTPS) probeTLS(ctx context.Context, dialContext dialHandler, tlsConfig *tls.Config, ch chan error) {
|
||||
startTime := time.Now()
|
||||
|
||||
conn, err := doh.tlsDial(dialContext, "tcp", tlsConfig)
|
||||
conn, err := doh.tlsDial(ctx, dialContext, "tcp", tlsConfig)
|
||||
if err != nil {
|
||||
ch <- fmt.Errorf("opening TLS connection: %w", err)
|
||||
return
|
||||
|
@ -705,10 +702,10 @@ func isHTTP3(client *http.Client) (ok bool) {
|
|||
|
||||
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
|
||||
// dialContext function to get connection.
|
||||
func (doh *dnsOverHTTPS) tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
|
||||
func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
|
||||
// We're using bootstrapped address instead of what's passed
|
||||
// to the function.
|
||||
rawConn, err := dialContext(context.Background(), network, doh.url.Host)
|
||||
rawConn, err := dialContext(ctx, network, doh.url.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
14
dns/doq.go
14
dns/doq.go
|
@ -157,7 +157,7 @@ func (doq *dnsOverQUIC) Close() (err error) {
|
|||
// through it and return the response it got from the server.
|
||||
func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) {
|
||||
var conn quic.Connection
|
||||
conn, err = doq.getConnection(true)
|
||||
conn, err = doq.getConnection(ctx,true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -225,7 +225,7 @@ func (doq *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
|
|||
// argument controls whether we should try to use the existing cached
|
||||
// connection. If it is false, we will forcibly create a new connection and
|
||||
// close the existing one if needed.
|
||||
func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) {
|
||||
func (doq *dnsOverQUIC) getConnection(ctx context.Context,useCached bool) (quic.Connection, error) {
|
||||
var conn quic.Connection
|
||||
doq.connMu.RLock()
|
||||
conn = doq.conn
|
||||
|
@ -244,7 +244,7 @@ func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) {
|
|||
defer doq.connMu.Unlock()
|
||||
|
||||
var err error
|
||||
conn, err = doq.openConnection()
|
||||
conn, err = doq.openConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -292,7 +292,7 @@ func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (q
|
|||
|
||||
// We can get here if the old QUIC connection is not valid anymore. We
|
||||
// should try to re-create the connection again in this case.
|
||||
newConn, err := doq.getConnection(false)
|
||||
newConn, err := doq.getConnection(ctx,false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -301,7 +301,7 @@ func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (q
|
|||
}
|
||||
|
||||
// openConnection opens a new QUIC connection.
|
||||
func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
|
||||
func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connection, err error) {
|
||||
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(
|
||||
&tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
|
@ -313,14 +313,12 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
|
|||
// we're using bootstrapped address instead of what's passed to the function
|
||||
// it does not create an actual connection, but it helps us determine
|
||||
// what IP is actually reachable (when there're v4/v6 addresses).
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
rawConn, err := getDialHandler(doq.r, doq.proxyAdapter)(ctx, "udp", doq.addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open a QUIC connection: %w", err)
|
||||
}
|
||||
// It's never actually used
|
||||
_ = rawConn.Close()
|
||||
cancel()
|
||||
var addr string
|
||||
udpConn, ok := rawConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
|
@ -365,8 +363,6 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
|
|||
udp = wrapConn
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
host, _, err := net.SplitHostPort(doq.addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
Loading…
Reference in a new issue