fix: DoH/DoQ doesn't use context

This commit is contained in:
Skyxim 2022-11-19 10:31:50 +08:00
parent f00dc69bb6
commit b8b3c9ef9f
2 changed files with 35 additions and 42 deletions

View file

@ -118,7 +118,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
// Check if there was already an active client before sending the request.
// We'll only attempt to re-connect if there was one.
client, isCached, err := doh.getClient()
client, isCached, err := doh.getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to init http client: %w", err)
}
@ -132,7 +132,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
// the case when the connection was closed (due to inactivity for example)
// AND the server refuses to open a 0-RTT connection.
for i := 0; isCached && doh.shouldRetry(err) && i < 2; i++ {
client, err = doh.resetClient(err)
client, err = doh.resetClient(ctx, err)
if err != nil {
return nil, fmt.Errorf("failed to reset http client: %w", err)
}
@ -142,7 +142,7 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.
if err != nil {
// If the request failed anyway, make sure we don't use this client.
_, resErr := doh.resetClient(err)
_, resErr := doh.resetClient(ctx, err)
return nil, fmt.Errorf("err:%v,resErr:%v", err, resErr)
}
@ -183,7 +183,6 @@ func (doh *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient.
func (doh *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *D.Msg) (resp *D.Msg, err error) {
resp, err = doh.exchangeHTTPSClient(ctx, client, req)
return resp, err
}
@ -207,23 +206,24 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
method = http3.MethodGet0RTT
}
doh.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf))
httpReq, err := http.NewRequestWithContext(ctx, method, doh.url.String(), nil)
url := doh.url
url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf))
httpReq, err := http.NewRequestWithContext(ctx, method, url.String(), nil)
if err != nil {
return nil, fmt.Errorf("creating http request to %s: %w", doh.url, err)
return nil, fmt.Errorf("creating http request to %s: %w", url, err)
}
httpReq.Header.Set("Accept", "application/dns-message")
httpReq.Header.Set("User-Agent", "")
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("requesting %s: %w", doh.url, err)
return nil, fmt.Errorf("requesting %s: %w", url, err)
}
defer httpResp.Body.Close()
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, fmt.Errorf("reading %s: %w", doh.url, err)
return nil, fmt.Errorf("reading %s: %w", url, err)
}
if httpResp.StatusCode != http.StatusOK {
@ -232,7 +232,7 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
"expected status %d, got %d from %s",
http.StatusOK,
httpResp.StatusCode,
doh.url,
url,
)
}
@ -241,7 +241,7 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient(
if err != nil {
return nil, fmt.Errorf(
"unpacking response from %s: body is %s: %w",
doh.url,
url,
body,
err,
)
@ -281,7 +281,7 @@ func (doh *dnsOverHTTPS) shouldRetry(err error) (ok bool) {
// resetClient triggers re-creation of the *http.Client that is used by this
// upstream. This method accepts the error that caused resetting client as
// depending on the error we may also reset the QUIC config.
func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) {
func (doh *dnsOverHTTPS) resetClient(ctx context.Context, resetErr error) (client *http.Client, err error) {
doh.clientMu.Lock()
defer doh.clientMu.Unlock()
@ -299,7 +299,7 @@ func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err e
}
log.Debugln("re-creating the http client due to %v", resetErr)
doh.client, err = doh.createClient()
doh.client, err = doh.createClient(ctx)
return doh.client, err
}
@ -325,7 +325,7 @@ func (doh *dnsOverHTTPS) resetQUICConfig() {
// getClient gets or lazily initializes an HTTP client (and transport) that will
// be used for this DoH resolver.
func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) {
func (doh *dnsOverHTTPS) getClient(ctx context.Context) (c *http.Client, isCached bool, err error) {
startTime := time.Now()
doh.clientMu.Lock()
@ -342,7 +342,7 @@ func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error)
}
log.Debugln("creating a new http client")
doh.client, err = doh.createClient()
doh.client, err = doh.createClient(ctx)
return doh.client, false, err
}
@ -351,8 +351,8 @@ func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error)
// will depend on whether HTTP3 is allowed and provided by this upstream. Note,
// that we'll attempt to establish a QUIC connection when creating the client in
// order to check whether HTTP3 is supported.
func (doh *dnsOverHTTPS) createClient() (*http.Client, error) {
transport, err := doh.createTransport()
func (doh *dnsOverHTTPS) createClient(ctx context.Context) (*http.Client, error) {
transport, err := doh.createTransport(ctx)
if err != nil {
return nil, fmt.Errorf("initializing http transport: %w", err)
}
@ -374,7 +374,7 @@ func (doh *dnsOverHTTPS) createClient() (*http.Client, error) {
// that this function will first attempt to establish a QUIC connection (if
// HTTP3 is enabled in the upstream options). If this attempt is successful,
// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport.
func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripper, err error) {
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(
&tls.Config{
InsecureSkipVerify: false,
@ -390,7 +390,7 @@ func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
// First, we attempt to create an HTTP3 transport. If the probe QUIC
// connection is established successfully, we'll be using HTTP3 for this
// upstream.
transportH3, err := doh.createTransportH3(tlsConfig, dialContext)
transportH3, err := doh.createTransportH3(ctx, tlsConfig, dialContext)
if err == nil {
log.Debugln("using HTTP/3 for this upstream: QUIC was faster")
return transportH3, nil
@ -483,6 +483,7 @@ func (h *http3Transport) Close() (err error) {
// in parallel (one for TLS, the other one for QUIC) and if QUIC is faster it
// will create the *http3.RoundTripper instance.
func (doh *dnsOverHTTPS) createTransportH3(
ctx context.Context,
tlsConfig *tls.Config,
dialContext dialHandler,
) (roundTripper http.RoundTripper, err error) {
@ -490,7 +491,7 @@ func (doh *dnsOverHTTPS) createTransportH3(
return nil, errors.New("HTTP3 support is not enabled")
}
addr, err := doh.probeH3(tlsConfig, dialContext)
addr, err := doh.probeH3(ctx, tlsConfig, dialContext)
if err != nil {
return nil, err
}
@ -552,13 +553,14 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.
// upstream. If the test is successful it will return the address that we
// should use to establish the QUIC connections.
func (doh *dnsOverHTTPS) probeH3(
ctx context.Context,
tlsConfig *tls.Config,
dialContext dialHandler,
) (addr string, err error) {
// We're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there are v4/v6 addresses).
rawConn, err := dialContext(context.Background(), "udp", doh.url.Host)
rawConn, err := dialContext(ctx, "udp", doh.url.Host)
if err != nil {
return "", fmt.Errorf("failed to dial: %w", err)
}
@ -597,8 +599,8 @@ func (doh *dnsOverHTTPS) probeH3(
// Run probeQUIC and probeTLS in parallel and see which one is faster.
chQuic := make(chan error, 1)
chTLS := make(chan error, 1)
go doh.probeQUIC(addr, probeTLSCfg, chQuic)
go doh.probeTLS(dialContext, probeTLSCfg, chTLS)
go doh.probeQUIC(ctx, addr, probeTLSCfg, chQuic)
go doh.probeTLS(ctx, dialContext, probeTLSCfg, chTLS)
select {
case quicErr := <-chQuic:
@ -622,13 +624,8 @@ func (doh *dnsOverHTTPS) probeH3(
// probeQUIC attempts to establish a QUIC connection to the specified address.
// We run probeQUIC and probeTLS in parallel and see which one is faster.
func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) {
func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
timeout := DefaultTimeout
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout))
defer cancel()
conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig())
if err != nil {
ch <- fmt.Errorf("opening QUIC connection to %s: %w", doh.Address(), err)
@ -646,10 +643,10 @@ func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan e
// probeTLS attempts to establish a TLS connection to the specified address. We
// run probeQUIC and probeTLS in parallel and see which one is faster.
func (doh *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) {
func (doh *dnsOverHTTPS) probeTLS(ctx context.Context, dialContext dialHandler, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
conn, err := doh.tlsDial(dialContext, "tcp", tlsConfig)
conn, err := doh.tlsDial(ctx, dialContext, "tcp", tlsConfig)
if err != nil {
ch <- fmt.Errorf("opening TLS connection: %w", err)
return
@ -705,10 +702,10 @@ func isHTTP3(client *http.Client) (ok bool) {
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func (doh *dnsOverHTTPS) tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
// We're using bootstrapped address instead of what's passed
// to the function.
rawConn, err := dialContext(context.Background(), network, doh.url.Host)
rawConn, err := dialContext(ctx, network, doh.url.Host)
if err != nil {
return nil, err
}

View file

@ -157,7 +157,7 @@ func (doq *dnsOverQUIC) Close() (err error) {
// through it and return the response it got from the server.
func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) {
var conn quic.Connection
conn, err = doq.getConnection(true)
conn, err = doq.getConnection(ctx,true)
if err != nil {
return nil, err
}
@ -225,7 +225,7 @@ func (doq *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
// argument controls whether we should try to use the existing cached
// connection. If it is false, we will forcibly create a new connection and
// close the existing one if needed.
func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) {
func (doq *dnsOverQUIC) getConnection(ctx context.Context,useCached bool) (quic.Connection, error) {
var conn quic.Connection
doq.connMu.RLock()
conn = doq.conn
@ -244,7 +244,7 @@ func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) {
defer doq.connMu.Unlock()
var err error
conn, err = doq.openConnection()
conn, err = doq.openConnection(ctx)
if err != nil {
return nil, err
}
@ -292,7 +292,7 @@ func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (q
// We can get here if the old QUIC connection is not valid anymore. We
// should try to re-create the connection again in this case.
newConn, err := doq.getConnection(false)
newConn, err := doq.getConnection(ctx,false)
if err != nil {
return nil, err
}
@ -301,7 +301,7 @@ func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (q
}
// openConnection opens a new QUIC connection.
func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
func (doq *dnsOverQUIC) openConnection(ctx context.Context) (conn quic.Connection, err error) {
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(
&tls.Config{
InsecureSkipVerify: false,
@ -313,14 +313,12 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
// we're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there're v4/v6 addresses).
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
rawConn, err := getDialHandler(doq.r, doq.proxyAdapter)(ctx, "udp", doq.addr)
if err != nil {
return nil, fmt.Errorf("failed to open a QUIC connection: %w", err)
}
// It's never actually used
_ = rawConn.Close()
cancel()
var addr string
udpConn, ok := rawConn.(*net.UDPConn)
if !ok {
@ -365,8 +363,6 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
udp = wrapConn
}
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
host, _, err := net.SplitHostPort(doq.addr)
if err != nil {
return nil, err