diff --git a/dns/doh.go b/dns/doh.go index 157ab67b..8c38e2f4 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -83,8 +83,9 @@ func newDoHClient(urlString string, r *Resolver, preferH3 bool, params map[strin } doh := &dnsOverHTTPS{ - url: u, - r: r, + url: u, + r: r, + proxyAdapter: proxyAdapter, quicConfig: &quic.Config{ KeepAlivePeriod: QUICKeepAlivePeriod, TokenStore: newQUICTokenStore(), @@ -98,8 +99,8 @@ func newDoHClient(urlString string, r *Resolver, preferH3 bool, params map[strin } // Address implements the Upstream interface for *dnsOverHTTPS. -func (p *dnsOverHTTPS) Address() string { return p.url.String() } -func (p *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { +func (doh *dnsOverHTTPS) Address() string { return doh.url.String() } +func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { // Quote from https://www.rfc-editor.org/rfc/rfc8484.html: // In order to maximize HTTP cache friendliness, DoH clients using media // formats that include the ID field from the DNS message header, such @@ -117,31 +118,31 @@ func (p *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Ms // Check if there was already an active client before sending the request. // We'll only attempt to re-connect if there was one. - client, isCached, err := p.getClient() + client, isCached, err := doh.getClient() if err != nil { return nil, fmt.Errorf("failed to init http client: %w", err) } // Make the first attempt to send the DNS query. - msg, err = p.exchangeHTTPS(ctx, client, m) + msg, err = doh.exchangeHTTPS(ctx, client, m) // Make up to 2 attempts to re-create the HTTP client and send the request // again. There are several cases (mostly, with QUIC) where this workaround // is necessary to make HTTP client usable. We need to make 2 attempts in // the case when the connection was closed (due to inactivity for example) // AND the server refuses to open a 0-RTT connection. - for i := 0; isCached && p.shouldRetry(err) && i < 2; i++ { - client, err = p.resetClient(err) + for i := 0; isCached && doh.shouldRetry(err) && i < 2; i++ { + client, err = doh.resetClient(err) if err != nil { return nil, fmt.Errorf("failed to reset http client: %w", err) } - msg, err = p.exchangeHTTPS(ctx, client, m) + msg, err = doh.exchangeHTTPS(ctx, client, m) } if err != nil { // If the request failed anyway, make sure we don't use this client. - _, resErr := p.resetClient(err) + _, resErr := doh.resetClient(err) return nil, fmt.Errorf("err:%v,resErr:%v", err, resErr) } @@ -150,28 +151,28 @@ func (p *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Ms } // Exchange implements the Upstream interface for *dnsOverHTTPS. -func (p *dnsOverHTTPS) Exchange(m *D.Msg) (*D.Msg, error) { - return p.ExchangeContext(context.Background(), m) +func (doh *dnsOverHTTPS) Exchange(m *D.Msg) (*D.Msg, error) { + return doh.ExchangeContext(context.Background(), m) } // Close implements the Upstream interface for *dnsOverHTTPS. -func (p *dnsOverHTTPS) Close() (err error) { - p.clientMu.Lock() - defer p.clientMu.Unlock() +func (doh *dnsOverHTTPS) Close() (err error) { + doh.clientMu.Lock() + defer doh.clientMu.Unlock() - runtime.SetFinalizer(p, nil) + runtime.SetFinalizer(doh, nil) - if p.client == nil { + if doh.client == nil { return nil } - return p.closeClient(p.client) + return doh.closeClient(doh.client) } // closeClient cleans up resources used by client if necessary. Note, that at // this point it should only be done for HTTP/3 as it may leak due to keep-alive // connections. -func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) { +func (doh *dnsOverHTTPS) closeClient(client *http.Client) (err error) { if isHTTP3(client) { return client.Transport.(io.Closer).Close() } @@ -180,15 +181,15 @@ func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) { } // exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient. -func (p *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *D.Msg) (resp *D.Msg, err error) { - resp, err = p.exchangeHTTPSClient(ctx, client, req) +func (doh *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *D.Msg) (resp *D.Msg, err error) { + resp, err = doh.exchangeHTTPSClient(ctx, client, req) return resp, err } // exchangeHTTPSClient sends the DNS query to a DoH resolver using the specified // http.Client instance. -func (p *dnsOverHTTPS) exchangeHTTPSClient( +func (doh *dnsOverHTTPS) exchangeHTTPSClient( ctx context.Context, client *http.Client, req *D.Msg, @@ -206,10 +207,10 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient( method = http3.MethodGet0RTT } - p.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)) - httpReq, err := http.NewRequest(method, p.url.String(), nil) + doh.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)) + httpReq, err := http.NewRequest(method, doh.url.String(), nil) if err != nil { - return nil, fmt.Errorf("creating http request to %s: %w", p.url, err) + return nil, fmt.Errorf("creating http request to %s: %w", doh.url, err) } httpReq.Header.Set("Accept", "application/dns-message") @@ -217,13 +218,13 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient( _ = httpReq.WithContext(ctx) httpResp, err := client.Do(httpReq) if err != nil { - return nil, fmt.Errorf("requesting %s: %w", p.url, err) + return nil, fmt.Errorf("requesting %s: %w", doh.url, err) } defer httpResp.Body.Close() body, err := io.ReadAll(httpResp.Body) if err != nil { - return nil, fmt.Errorf("reading %s: %w", p.url, err) + return nil, fmt.Errorf("reading %s: %w", doh.url, err) } if httpResp.StatusCode != http.StatusOK { @@ -232,7 +233,7 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient( "expected status %d, got %d from %s", http.StatusOK, httpResp.StatusCode, - p.url, + doh.url, ) } @@ -241,7 +242,7 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient( if err != nil { return nil, fmt.Errorf( "unpacking response from %s: body is %s: %w", - p.url, + doh.url, body, err, ) @@ -256,7 +257,7 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient( // shouldRetry checks what error we have received and returns true if we should // re-create the HTTP client and retry the request. -func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) { +func (doh *dnsOverHTTPS) shouldRetry(err error) (ok bool) { if err == nil { return false } @@ -281,57 +282,57 @@ func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) { // resetClient triggers re-creation of the *http.Client that is used by this // upstream. This method accepts the error that caused resetting client as // depending on the error we may also reset the QUIC config. -func (p *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) { - p.clientMu.Lock() - defer p.clientMu.Unlock() +func (doh *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) { + doh.clientMu.Lock() + defer doh.clientMu.Unlock() if errors.Is(resetErr, quic.Err0RTTRejected) { // Reset the TokenStore only if 0-RTT was rejected. - p.resetQUICConfig() + doh.resetQUICConfig() } - oldClient := p.client + oldClient := doh.client if oldClient != nil { - closeErr := p.closeClient(oldClient) + closeErr := doh.closeClient(oldClient) if closeErr != nil { log.Warnln("warning: failed to close the old http client: %v", closeErr) } } log.Debugln("re-creating the http client due to %v", resetErr) - p.client, err = p.createClient() + doh.client, err = doh.createClient() - return p.client, err + return doh.client, err } // getQUICConfig returns the QUIC config in a thread-safe manner. Note, that // this method returns a pointer, it is forbidden to change its properties. -func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) { - p.quicConfigGuard.Lock() - defer p.quicConfigGuard.Unlock() +func (doh *dnsOverHTTPS) getQUICConfig() (c *quic.Config) { + doh.quicConfigGuard.Lock() + defer doh.quicConfigGuard.Unlock() - return p.quicConfig + return doh.quicConfig } // resetQUICConfig Re-create the token store to make sure we're not trying to // use invalid for 0-RTT. -func (p *dnsOverHTTPS) resetQUICConfig() { - p.quicConfigGuard.Lock() - defer p.quicConfigGuard.Unlock() +func (doh *dnsOverHTTPS) resetQUICConfig() { + doh.quicConfigGuard.Lock() + defer doh.quicConfigGuard.Unlock() - p.quicConfig = p.quicConfig.Clone() - p.quicConfig.TokenStore = newQUICTokenStore() + doh.quicConfig = doh.quicConfig.Clone() + doh.quicConfig.TokenStore = newQUICTokenStore() } // getClient gets or lazily initializes an HTTP client (and transport) that will // be used for this DoH resolver. -func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) { +func (doh *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) { startTime := time.Now() - p.clientMu.Lock() - defer p.clientMu.Unlock() - if p.client != nil { - return p.client, true, nil + doh.clientMu.Lock() + defer doh.clientMu.Unlock() + if doh.client != nil { + return doh.client, true, nil } // Timeout can be exceeded while waiting for the lock. This happens quite @@ -342,17 +343,17 @@ func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) { } log.Debugln("creating a new http client") - p.client, err = p.createClient() + doh.client, err = doh.createClient() - return p.client, false, err + return doh.client, false, err } // createClient creates a new *http.Client instance. The HTTP protocol version // will depend on whether HTTP3 is allowed and provided by this upstream. Note, // that we'll attempt to establish a QUIC connection when creating the client in // order to check whether HTTP3 is supported. -func (p *dnsOverHTTPS) createClient() (*http.Client, error) { - transport, err := p.createTransport() +func (doh *dnsOverHTTPS) createClient() (*http.Client, error) { + transport, err := doh.createTransport() if err != nil { return nil, fmt.Errorf("initializing http transport: %w", err) } @@ -363,9 +364,9 @@ func (p *dnsOverHTTPS) createClient() (*http.Client, error) { Jar: nil, } - p.client = client + doh.client = client - return p.client, nil + return doh.client, nil } // createTransport initializes an HTTP transport that will be used specifically @@ -374,7 +375,7 @@ func (p *dnsOverHTTPS) createClient() (*http.Client, error) { // that this function will first attempt to establish a QUIC connection (if // HTTP3 is enabled in the upstream options). If this attempt is successful, // it returns an HTTP3 transport, otherwise it returns the H1/H2 transport. -func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { +func (doh *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { tlsConfig := tlsC.GetGlobalFingerprintTLCConfig( &tls.Config{ InsecureSkipVerify: false, @@ -382,15 +383,15 @@ func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { SessionTicketsDisabled: false, }) var nextProtos []string - for _, v := range p.httpVersions { + for _, v := range doh.httpVersions { nextProtos = append(nextProtos, string(v)) } tlsConfig.NextProtos = nextProtos - dialContext := getDialHandler(p.r, p.proxyAdapter) + dialContext := getDialHandler(doh.r, doh.proxyAdapter) // First, we attempt to create an HTTP3 transport. If the probe QUIC // connection is established successfully, we'll be using HTTP3 for this // upstream. - transportH3, err := p.createTransportH3(tlsConfig, dialContext) + transportH3, err := doh.createTransportH3(tlsConfig, dialContext) if err == nil { log.Debugln("using HTTP/3 for this upstream: QUIC was faster") return transportH3, nil @@ -398,7 +399,7 @@ func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { log.Debugln("using HTTP/2 for this upstream: %v", err) - if !p.supportsHTTP() { + if !doh.supportsHTTP() { return nil, errors.New("HTTP1/1 and HTTP2 are not supported by this upstream") } @@ -551,14 +552,14 @@ func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls. // probeH3 runs a test to check whether QUIC is faster than TLS for this // upstream. If the test is successful it will return the address that we // should use to establish the QUIC connections. -func (p *dnsOverHTTPS) probeH3( +func (doh *dnsOverHTTPS) probeH3( tlsConfig *tls.Config, dialContext dialHandler, ) (addr string, err error) { // We're using bootstrapped address instead of what's passed to the function // it does not create an actual connection, but it helps us determine // what IP is actually reachable (when there are v4/v6 addresses). - rawConn, err := dialContext(context.Background(), "udp", p.url.Host) + rawConn, err := dialContext(context.Background(), "udp", doh.url.Host) if err != nil { return "", fmt.Errorf("failed to dial: %w", err) } @@ -567,13 +568,17 @@ func (p *dnsOverHTTPS) probeH3( udpConn, ok := rawConn.(*net.UDPConn) if !ok { - return "", fmt.Errorf("not a UDP connection to %s", p.Address()) + if packetConn, ok := rawConn.(*wrapPacketConn); !ok { + return "", fmt.Errorf("not a UDP connection to %s", doh.Address()) + } else { + addr = packetConn.RemoteAddr().String() + } + } else { + addr = udpConn.RemoteAddr().String() } - addr = udpConn.RemoteAddr().String() - // Avoid spending time on probing if this upstream only supports HTTP/3. - if p.supportsH3() && !p.supportsHTTP() { + if doh.supportsH3() && !doh.supportsHTTP() { return addr, nil } @@ -593,8 +598,8 @@ func (p *dnsOverHTTPS) probeH3( // Run probeQUIC and probeTLS in parallel and see which one is faster. chQuic := make(chan error, 1) chTLS := make(chan error, 1) - go p.probeQUIC(addr, probeTLSCfg, chQuic) - go p.probeTLS(dialContext, probeTLSCfg, chTLS) + go doh.probeQUIC(addr, probeTLSCfg, chQuic) + go doh.probeTLS(dialContext, probeTLSCfg, chTLS) select { case quicErr := <-chQuic: @@ -618,16 +623,16 @@ func (p *dnsOverHTTPS) probeH3( // probeQUIC attempts to establish a QUIC connection to the specified address. // We run probeQUIC and probeTLS in parallel and see which one is faster. -func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) { +func (doh *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) { startTime := time.Now() timeout := DefaultTimeout ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) defer cancel() - conn, err := p.dialQuic(ctx, addr, tlsConfig, p.getQUICConfig()) + conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig()) if err != nil { - ch <- fmt.Errorf("opening QUIC connection to %s: %w", p.Address(), err) + ch <- fmt.Errorf("opening QUIC connection to %s: %w", doh.Address(), err) return } @@ -642,10 +647,10 @@ func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan err // probeTLS attempts to establish a TLS connection to the specified address. We // run probeQUIC and probeTLS in parallel and see which one is faster. -func (p *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) { +func (doh *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) { startTime := time.Now() - conn, err := p.tlsDial(dialContext, "tcp", tlsConfig) + conn, err := doh.tlsDial(dialContext, "tcp", tlsConfig) if err != nil { ch <- fmt.Errorf("opening TLS connection: %w", err) return @@ -661,8 +666,8 @@ func (p *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, } // supportsH3 returns true if HTTP/3 is supported by this upstream. -func (p *dnsOverHTTPS) supportsH3() (ok bool) { - for _, v := range p.supportedHTTPVersions() { +func (doh *dnsOverHTTPS) supportsH3() (ok bool) { + for _, v := range doh.supportedHTTPVersions() { if v == C.HTTPVersion3 { return true } @@ -672,8 +677,8 @@ func (p *dnsOverHTTPS) supportsH3() (ok bool) { } // supportsHTTP returns true if HTTP/1.1 or HTTP2 is supported by this upstream. -func (p *dnsOverHTTPS) supportsHTTP() (ok bool) { - for _, v := range p.supportedHTTPVersions() { +func (doh *dnsOverHTTPS) supportsHTTP() (ok bool) { + for _, v := range doh.supportedHTTPVersions() { if v == C.HTTPVersion11 || v == C.HTTPVersion2 { return true } @@ -683,8 +688,8 @@ func (p *dnsOverHTTPS) supportsHTTP() (ok bool) { } // supportedHTTPVersions returns the list of supported HTTP versions. -func (p *dnsOverHTTPS) supportedHTTPVersions() (v []C.HTTPVersion) { - v = p.httpVersions +func (doh *dnsOverHTTPS) supportedHTTPVersions() (v []C.HTTPVersion) { + v = doh.httpVersions if v == nil { v = DefaultHTTPVersions } diff --git a/dns/doq.go b/dns/doq.go index 07cceff6..316417ef 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -88,9 +88,9 @@ func newDoQ(resolver *Resolver, addr string, adapter string) (dnsClient, error) } // Address implements the Upstream interface for *dnsOverQUIC. -func (p *dnsOverQUIC) Address() string { return p.addr } +func (doq *dnsOverQUIC) Address() string { return doq.addr } -func (p *dnsOverQUIC) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { +func (doq *dnsOverQUIC) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { // When sending queries over a QUIC connection, the DNS Message ID MUST be // set to zero. id := m.Id @@ -105,49 +105,49 @@ func (p *dnsOverQUIC) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg // Check if there was already an active conn before sending the request. // We'll only attempt to re-connect if there was one. - hasConnection := p.hasConnection() + hasConnection := doq.hasConnection() // Make the first attempt to send the DNS query. - msg, err = p.exchangeQUIC(ctx, m) + msg, err = doq.exchangeQUIC(ctx, m) // Make up to 2 attempts to re-open the QUIC connection and send the request // again. There are several cases where this workaround is necessary to // make DoQ usable. We need to make 2 attempts in the case when the // connection was closed (due to inactivity for example) AND the server // refuses to open a 0-RTT connection. - for i := 0; hasConnection && p.shouldRetry(err) && i < 2; i++ { + for i := 0; hasConnection && doq.shouldRetry(err) && i < 2; i++ { log.Debugln("re-creating the QUIC connection and retrying due to %v", err) // Close the active connection to make sure we'll try to re-connect. - p.closeConnWithError(err) + doq.closeConnWithError(err) // Retry sending the request. - msg, err = p.exchangeQUIC(ctx, m) + msg, err = doq.exchangeQUIC(ctx, m) } if err != nil { // If we're unable to exchange messages, make sure the connection is // closed and signal about an internal error. - p.closeConnWithError(err) + doq.closeConnWithError(err) } return msg, err } // Exchange implements the Upstream interface for *dnsOverQUIC. -func (p *dnsOverQUIC) Exchange(m *D.Msg) (msg *D.Msg, err error) { - return p.ExchangeContext(context.Background(), m) +func (doq *dnsOverQUIC) Exchange(m *D.Msg) (msg *D.Msg, err error) { + return doq.ExchangeContext(context.Background(), m) } // Close implements the Upstream interface for *dnsOverQUIC. -func (p *dnsOverQUIC) Close() (err error) { - p.connMu.Lock() - defer p.connMu.Unlock() +func (doq *dnsOverQUIC) Close() (err error) { + doq.connMu.Lock() + defer doq.connMu.Unlock() - runtime.SetFinalizer(p, nil) + runtime.SetFinalizer(doq, nil) - if p.conn != nil { - err = p.conn.CloseWithError(QUICCodeNoError, "") + if doq.conn != nil { + err = doq.conn.CloseWithError(QUICCodeNoError, "") } return err @@ -155,9 +155,9 @@ func (p *dnsOverQUIC) Close() (err error) { // exchangeQUIC attempts to open a QUIC connection, send the DNS message // through it and return the response it got from the server. -func (p *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) { +func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) { var conn quic.Connection - conn, err = p.getConnection(true) + conn, err = doq.getConnection(true) if err != nil { return nil, err } @@ -169,7 +169,7 @@ func (p *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg } var stream quic.Stream - stream, err = p.openStream(ctx, conn) + stream, err = doq.openStream(ctx, conn) if err != nil { return nil, err } @@ -185,7 +185,7 @@ func (p *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg // write-direction of the stream, but does not prevent reading from it. _ = stream.Close() - return p.readMsg(stream) + return doq.readMsg(stream) } // AddPrefix adds a 2-byte prefix with the DNS message length. @@ -199,17 +199,17 @@ func AddPrefix(b []byte) (m []byte) { // shouldRetry checks what error we received and decides whether it is required // to re-open the connection and retry sending the request. -func (p *dnsOverQUIC) shouldRetry(err error) (ok bool) { +func (doq *dnsOverQUIC) shouldRetry(err error) (ok bool) { return isQUICRetryError(err) } // getBytesPool returns (creates if needed) a pool we store byte buffers in. -func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { - p.bytesPoolGuard.Lock() - defer p.bytesPoolGuard.Unlock() +func (doq *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { + doq.bytesPoolGuard.Lock() + defer doq.bytesPoolGuard.Unlock() - if p.bytesPool == nil { - p.bytesPool = &sync.Pool{ + if doq.bytesPool == nil { + doq.bytesPool = &sync.Pool{ New: func() interface{} { b := make([]byte, MaxMsgSize) @@ -218,19 +218,19 @@ func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { } } - return p.bytesPool + return doq.bytesPool } // getConnection opens or returns an existing quic.Connection. useCached // argument controls whether we should try to use the existing cached // connection. If it is false, we will forcibly create a new connection and // close the existing one if needed. -func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { +func (doq *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { var conn quic.Connection - p.connMu.RLock() - conn = p.conn + doq.connMu.RLock() + conn = doq.conn if conn != nil && useCached { - p.connMu.RUnlock() + doq.connMu.RUnlock() return conn, nil } @@ -238,50 +238,50 @@ func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { // we're recreating the connection, let's create a new one. _ = conn.CloseWithError(QUICCodeNoError, "") } - p.connMu.RUnlock() + doq.connMu.RUnlock() - p.connMu.Lock() - defer p.connMu.Unlock() + doq.connMu.Lock() + defer doq.connMu.Unlock() var err error - conn, err = p.openConnection() + conn, err = doq.openConnection() if err != nil { return nil, err } - p.conn = conn + doq.conn = conn return conn, nil } // hasConnection returns true if there's an active QUIC connection. -func (p *dnsOverQUIC) hasConnection() (ok bool) { - p.connMu.Lock() - defer p.connMu.Unlock() +func (doq *dnsOverQUIC) hasConnection() (ok bool) { + doq.connMu.Lock() + defer doq.connMu.Unlock() - return p.conn != nil + return doq.conn != nil } // getQUICConfig returns the QUIC config in a thread-safe manner. Note, that // this method returns a pointer, it is forbidden to change its properties. -func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) { - p.quicConfigGuard.Lock() - defer p.quicConfigGuard.Unlock() +func (doq *dnsOverQUIC) getQUICConfig() (c *quic.Config) { + doq.quicConfigGuard.Lock() + defer doq.quicConfigGuard.Unlock() - return p.quicConfig + return doq.quicConfig } // resetQUICConfig re-creates the tokens store as we may need to use a new one // if we failed to connect. -func (p *dnsOverQUIC) resetQUICConfig() { - p.quicConfigGuard.Lock() - defer p.quicConfigGuard.Unlock() +func (doq *dnsOverQUIC) resetQUICConfig() { + doq.quicConfigGuard.Lock() + defer doq.quicConfigGuard.Unlock() - p.quicConfig = p.quicConfig.Clone() - p.quicConfig.TokenStore = newQUICTokenStore() + doq.quicConfig = doq.quicConfig.Clone() + doq.quicConfig.TokenStore = newQUICTokenStore() } // openStream opens a new QUIC stream for the specified connection. -func (p *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (quic.Stream, error) { +func (doq *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (quic.Stream, error) { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -292,7 +292,7 @@ func (p *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (qui // We can get here if the old QUIC connection is not valid anymore. We // should try to re-create the connection again in this case. - newConn, err := p.getConnection(false) + newConn, err := doq.getConnection(false) if err != nil { return nil, err } @@ -321,14 +321,18 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { // It's never actually used _ = rawConn.Close() cancel() - + var addr string udpConn, ok := rawConn.(*net.UDPConn) if !ok { - return nil, fmt.Errorf("failed to open connection to %s", doq.addr) + if packetConn, ok := rawConn.(*wrapPacketConn); !ok { + return nil, fmt.Errorf("failed to open connection to %s", doq.addr) + } else { + addr = packetConn.RemoteAddr().String() + } + } else { + addr = udpConn.RemoteAddr().String() } - addr := udpConn.RemoteAddr().String() - ip, port, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -379,11 +383,11 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { // closeConnWithError closes the active connection with error to make sure that // new queries were processed in another connection. We can do that in the case // of a fatal error. -func (p *dnsOverQUIC) closeConnWithError(err error) { - p.connMu.Lock() - defer p.connMu.Unlock() +func (doq *dnsOverQUIC) closeConnWithError(err error) { + doq.connMu.Lock() + defer doq.connMu.Unlock() - if p.conn == nil { + if doq.conn == nil { // Do nothing, there's no active conn anyways. return } @@ -395,19 +399,19 @@ func (p *dnsOverQUIC) closeConnWithError(err error) { if errors.Is(err, quic.Err0RTTRejected) { // Reset the TokenStore only if 0-RTT was rejected. - p.resetQUICConfig() + doq.resetQUICConfig() } - err = p.conn.CloseWithError(code, "") + err = doq.conn.CloseWithError(code, "") if err != nil { log.Errorln("failed to close the conn: %v", err) } - p.conn = nil + doq.conn = nil } // readMsg reads the incoming DNS message from the QUIC stream. -func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *D.Msg, err error) { - pool := p.getBytesPool() +func (doq *dnsOverQUIC) readMsg(stream quic.Stream) (m *D.Msg, err error) { + pool := doq.getBytesPool() bufPtr := pool.Get().(*[]byte) defer pool.Put(bufPtr) @@ -415,7 +419,7 @@ func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *D.Msg, err error) { respBuf := *bufPtr n, err := stream.Read(respBuf) if err != nil && n == 0 { - return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err) + return nil, fmt.Errorf("reading response from %s: %w", doq.Address(), err) } // All DNS messages (queries and responses) sent over DoQ connections MUST @@ -426,7 +430,7 @@ func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *D.Msg, err error) { m = new(D.Msg) err = m.Unpack(respBuf[2:]) if err != nil { - return nil, fmt.Errorf("unpacking response from %s: %w", p.Address(), err) + return nil, fmt.Errorf("unpacking response from %s: %w", doq.Address(), err) } return m, nil @@ -512,7 +516,7 @@ func getDialHandler(r *Resolver, proxyAdapter string) dialHandler { if len(proxyAdapter) == 0 { return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port), dialer.WithDirect()) } else { - return dialContextExtra(ctx, proxyAdapter, network, ip.Unmap(), port, dialer.WithDirect()) + return dialContextExtra(ctx, proxyAdapter, network, ip.Unmap(), port) } } }