From c71a4619b82f8ccaef22e18fcd058899ef2565e8 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 18 May 2023 13:15:08 +0800 Subject: [PATCH] chore: more context passing in outbounds --- adapter/outbound/base.go | 4 ++-- adapter/outbound/http.go | 8 +++----- adapter/outbound/shadowsocks.go | 11 ++--------- adapter/outbound/shadowsocksr.go | 6 +++--- adapter/outbound/snell.go | 6 +++--- adapter/outbound/socks5.go | 8 +++----- adapter/outbound/trojan.go | 16 ++++++++-------- adapter/outbound/vless.go | 20 ++++++++++---------- adapter/outbound/vmess.go | 16 ++++++++-------- adapter/outbound/wireguard.go | 4 ++-- constant/adapters.go | 4 ++-- transport/trojan/trojan.go | 6 +++--- transport/v2ray-plugin/websocket.go | 5 +++-- transport/vless/xtls.go | 6 +----- transport/vmess/tls.go | 11 +---------- transport/vmess/websocket.go | 22 +++++++++++----------- 16 files changed, 65 insertions(+), 88 deletions(-) diff --git a/adapter/outbound/base.go b/adapter/outbound/base.go index c5901d7f..f2ce56c9 100644 --- a/adapter/outbound/base.go +++ b/adapter/outbound/base.go @@ -45,8 +45,8 @@ func (b *Base) Type() C.AdapterType { return b.tp } -// StreamConn implements C.ProxyAdapter -func (b *Base) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +// StreamConnContext implements C.ProxyAdapter +func (b *Base) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { return c, C.ErrNotSupport } diff --git a/adapter/outbound/http.go b/adapter/outbound/http.go index 82fff5bb..78735b2d 100644 --- a/adapter/outbound/http.go +++ b/adapter/outbound/http.go @@ -40,12 +40,10 @@ type HttpOption struct { Headers map[string]string `proxy:"headers,omitempty"` } -// StreamConn implements C.ProxyAdapter -func (h *Http) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +// StreamConnContext implements C.ProxyAdapter +func (h *Http) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { if h.tlsConfig != nil { cc := tls.Client(c, h.tlsConfig) - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() err := cc.HandshakeContext(ctx) c = cc if err != nil { @@ -82,7 +80,7 @@ func (h *Http) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad safeConnClose(c, err) }(c) - c, err = h.StreamConn(c, metadata) + c, err = h.StreamConnContext(ctx, c, metadata) if err != nil { return nil, err } diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index 02e975ef..e6c74592 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -84,14 +84,7 @@ type restlsOption struct { RestlsScript string `obfs:"restls-script,omitempty"` } -// StreamConn implements C.ProxyAdapter -func (ss *ShadowSocks) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { - // fix tls handshake not timeout - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() - return ss.StreamConnContext(ctx, c, metadata) -} - +// StreamConnContext implements C.ProxyAdapter func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { useEarly := false switch ss.obfsMode { @@ -102,7 +95,7 @@ func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metada c = obfs.NewHTTPObfs(c, ss.obfsOption.Host, port) case "websocket": var err error - c, err = v2rayObfs.NewV2rayObfs(c, ss.v2rayOption) + c, err = v2rayObfs.NewV2rayObfs(ctx, c, ss.v2rayOption) if err != nil { return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) } diff --git a/adapter/outbound/shadowsocksr.go b/adapter/outbound/shadowsocksr.go index 2b94ab0c..d33d6586 100644 --- a/adapter/outbound/shadowsocksr.go +++ b/adapter/outbound/shadowsocksr.go @@ -38,8 +38,8 @@ type ShadowSocksROption struct { UDP bool `proxy:"udp,omitempty"` } -// StreamConn implements C.ProxyAdapter -func (ssr *ShadowSocksR) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +// StreamConnContext implements C.ProxyAdapter +func (ssr *ShadowSocksR) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { c = ssr.obfs.StreamConn(c) c = ssr.cipher.StreamConn(c) var ( @@ -83,7 +83,7 @@ func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dia safeConnClose(c, err) }(c) - c, err = ssr.StreamConn(c, metadata) + c, err = ssr.StreamConnContext(ctx, c, metadata) return NewConn(c, ssr), err } diff --git a/adapter/outbound/snell.go b/adapter/outbound/snell.go index 1ec0a430..fc1f4eb3 100644 --- a/adapter/outbound/snell.go +++ b/adapter/outbound/snell.go @@ -52,8 +52,8 @@ func streamConn(c net.Conn, option streamOption) *snell.Snell { return snell.StreamConn(c, option.psk, option.version) } -// StreamConn implements C.ProxyAdapter -func (s *Snell) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +// StreamConnContext implements C.ProxyAdapter +func (s *Snell) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption}) if metadata.NetWork == C.UDP { err := snell.WriteUDPHeader(c, s.version) @@ -101,7 +101,7 @@ func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta safeConnClose(c, err) }(c) - c, err = s.StreamConn(c, metadata) + c, err = s.StreamConnContext(ctx, c, metadata) return NewConn(c, s), err } diff --git a/adapter/outbound/socks5.go b/adapter/outbound/socks5.go index 26f5733b..9af4d0fc 100644 --- a/adapter/outbound/socks5.go +++ b/adapter/outbound/socks5.go @@ -39,12 +39,10 @@ type Socks5Option struct { Fingerprint string `proxy:"fingerprint,omitempty"` } -// StreamConn implements C.ProxyAdapter -func (ss *Socks5) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +// StreamConnContext implements C.ProxyAdapter +func (ss *Socks5) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { if ss.tls { cc := tls.Client(c, ss.tlsConfig) - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() err := cc.HandshakeContext(ctx) c = cc if err != nil { @@ -88,7 +86,7 @@ func (ss *Socks5) DialContextWithDialer(ctx context.Context, dialer C.Dialer, me safeConnClose(c, err) }(c) - c, err = ss.StreamConn(c, metadata) + c, err = ss.StreamConnContext(ctx, c, metadata) if err != nil { return nil, err } diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index b9bbc33f..81fb1ceb 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -50,7 +50,7 @@ type TrojanOption struct { ClientFingerprint string `proxy:"client-fingerprint,omitempty"` } -func (t *Trojan) plainStream(c net.Conn) (net.Conn, error) { +func (t *Trojan) plainStream(ctx context.Context, c net.Conn) (net.Conn, error) { if t.option.Network == "ws" { host, port, _ := net.SplitHostPort(t.addr) wsOpts := &trojan.WebsocketOption{ @@ -71,14 +71,14 @@ func (t *Trojan) plainStream(c net.Conn) (net.Conn, error) { wsOpts.Headers = header } - return t.instance.StreamWebsocketConn(c, wsOpts) + return t.instance.StreamWebsocketConn(ctx, c, wsOpts) } - return t.instance.StreamConn(c) + return t.instance.StreamConn(ctx, c) } -// StreamConn implements C.ProxyAdapter -func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +// StreamConnContext implements C.ProxyAdapter +func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { var err error if tlsC.HaveGlobalFingerprint() && len(t.option.ClientFingerprint) == 0 { @@ -88,7 +88,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) if t.transport != nil { c, err = gun.StreamGunWithConn(c, t.gunTLSConfig, t.gunConfig, t.realityConfig) } else { - c, err = t.plainStream(c) + c, err = t.plainStream(ctx, c) } if err != nil { @@ -151,7 +151,7 @@ func (t *Trojan) DialContextWithDialer(ctx context.Context, dialer C.Dialer, met safeConnClose(c, err) }(c) - c, err = t.StreamConn(c, metadata) + c, err = t.StreamConnContext(ctx, c, metadata) if err != nil { return nil, err } @@ -199,7 +199,7 @@ func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, me safeConnClose(c, err) }(c) tcpKeepAlive(c) - c, err = t.plainStream(c) + c, err = t.plainStream(ctx, c) if err != nil { return nil, fmt.Errorf("%s connect error: %w", t.addr, err) } diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 048350e9..e3aff5fb 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -75,7 +75,7 @@ type VlessOption struct { ClientFingerprint string `proxy:"client-fingerprint,omitempty"` } -func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { var err error if tlsC.HaveGlobalFingerprint() && len(v.option.ClientFingerprint) == 0 { @@ -129,10 +129,10 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { convert.SetUserAgent(wsOpts.Headers) } } - c, err = vmess.StreamWebsocketConn(c, wsOpts) + c, err = vmess.StreamWebsocketConn(ctx, c, wsOpts) case "http": // readability first, so just copy default TLS logic - c, err = v.streamTLSOrXTLSConn(c, false) + c, err = v.streamTLSOrXTLSConn(ctx, c, false) if err != nil { return nil, err } @@ -147,7 +147,7 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { c = vmess.StreamHTTPConn(c, httpOpts) case "h2": - c, err = v.streamTLSOrXTLSConn(c, true) + c, err = v.streamTLSOrXTLSConn(ctx, c, true) if err != nil { return nil, err } @@ -163,7 +163,7 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { default: // default tcp network // handle TLS And XTLS - c, err = v.streamTLSOrXTLSConn(c, false) + c, err = v.streamTLSOrXTLSConn(ctx, c, false) } if err != nil { @@ -201,7 +201,7 @@ func (v *Vless) streamConn(c net.Conn, metadata *C.Metadata) (conn net.Conn, err return } -func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) { +func (v *Vless) streamTLSOrXTLSConn(ctx context.Context, conn net.Conn, isH2 bool) (net.Conn, error) { host, _, _ := net.SplitHostPort(v.addr) if v.isLegacyXTLSEnabled() && !isH2 { @@ -215,7 +215,7 @@ func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) xtlsOpts.Host = v.option.ServerName } - return vless.StreamXTLSConn(conn, &xtlsOpts) + return vless.StreamXTLSConn(ctx, conn, &xtlsOpts) } else if v.option.TLS { tlsOpts := vmess.TLSConfig{ @@ -234,7 +234,7 @@ func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) tlsOpts.Host = v.option.ServerName } - return vmess.StreamTLSConn(conn, &tlsOpts) + return vmess.StreamTLSConn(ctx, conn, &tlsOpts) } return conn, nil @@ -283,7 +283,7 @@ func (v *Vless) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta safeConnClose(c, err) }(c) - c, err = v.StreamConn(c, metadata) + c, err = v.StreamConnContext(ctx, c, metadata) if err != nil { return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) } @@ -348,7 +348,7 @@ func (v *Vless) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met safeConnClose(c, err) }(c) - c, err = v.StreamConn(c, metadata) + c, err = v.StreamConnContext(ctx, c, metadata) if err != nil { return nil, fmt.Errorf("new vless client error: %v", err) } diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index c0063b3e..058ce49d 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -89,8 +89,8 @@ type WSOptions struct { EarlyDataHeaderName string `proxy:"early-data-header-name,omitempty"` } -// StreamConn implements C.ProxyAdapter -func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +// StreamConnContext implements C.ProxyAdapter +func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { var err error if tlsC.HaveGlobalFingerprint() && (len(v.option.ClientFingerprint) == 0) { @@ -138,7 +138,7 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { wsOpts.TLSConfig.ServerName = host } } - c, err = clashVMess.StreamWebsocketConn(c, wsOpts) + c, err = clashVMess.StreamWebsocketConn(ctx, c, wsOpts) case "http": // readability first, so just copy default TLS logic if v.option.TLS { @@ -153,7 +153,7 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { if v.option.ServerName != "" { tlsOpts.Host = v.option.ServerName } - c, err = clashVMess.StreamTLSConn(c, tlsOpts) + c, err = clashVMess.StreamTLSConn(ctx, c, tlsOpts) if err != nil { return nil, err } @@ -182,7 +182,7 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { tlsOpts.Host = v.option.ServerName } - c, err = clashVMess.StreamTLSConn(c, &tlsOpts) + c, err = clashVMess.StreamTLSConn(ctx, c, &tlsOpts) if err != nil { return nil, err } @@ -210,7 +210,7 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { tlsOpts.Host = v.option.ServerName } - c, err = clashVMess.StreamTLSConn(c, tlsOpts) + c, err = clashVMess.StreamTLSConn(ctx, c, tlsOpts) } } @@ -294,7 +294,7 @@ func (v *Vmess) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta safeConnClose(c, err) }(c) - c, err = v.StreamConn(c, metadata) + c, err = v.StreamConnContext(ctx, c, metadata) return NewConn(c, v), err } @@ -355,7 +355,7 @@ func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met safeConnClose(c, err) }(c) - c, err = v.StreamConn(c, metadata) + c, err = v.StreamConnContext(ctx, c, metadata) if err != nil { return nil, fmt.Errorf("new vmess client error: %v", err) } diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index 67cd9092..38b5aa02 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -499,9 +499,9 @@ func (r *refProxyAdapter) MarshalJSON() ([]byte, error) { return nil, C.ErrNotSupport } -func (r *refProxyAdapter) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +func (r *refProxyAdapter) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { if r.proxyAdapter != nil { - return r.proxyAdapter.StreamConn(c, metadata) + return r.proxyAdapter.StreamConnContext(ctx, c, metadata) } return nil, C.ErrNotSupport } diff --git a/constant/adapters.go b/constant/adapters.go index 73877dec..12579685 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -106,11 +106,11 @@ type ProxyAdapter interface { // // Examples: // conn, _ := net.DialContext(context.Background(), "tcp", "host:port") - // conn, _ = adapter.StreamConn(conn, metadata) + // conn, _ = adapter.StreamConnContext(context.Background(), conn, metadata) // // It returns a C.Conn with protocol which start with // a new session (if any) - StreamConn(c net.Conn, metadata *Metadata) (net.Conn, error) + StreamConnContext(ctx context.Context, c net.Conn, metadata *Metadata) (net.Conn, error) // DialContext return a C.Conn with protocol which // contains multiplexing-related reuse logic (if any) diff --git a/transport/trojan/trojan.go b/transport/trojan/trojan.go index 8eae8237..8b4146c6 100644 --- a/transport/trojan/trojan.go +++ b/transport/trojan/trojan.go @@ -70,7 +70,7 @@ type Trojan struct { hexPassword []byte } -func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) { +func (t *Trojan) StreamConn(ctx context.Context, conn net.Conn) (net.Conn, error) { alpn := defaultALPN if len(t.option.ALPN) != 0 { alpn = t.option.ALPN @@ -149,7 +149,7 @@ func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) { } } -func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) { +func (t *Trojan) StreamWebsocketConn(ctx context.Context, conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) { alpn := defaultWebsocketALPN if len(t.option.ALPN) != 0 { alpn = t.option.ALPN @@ -162,7 +162,7 @@ func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) ServerName: t.option.ServerName, } - return vmess.StreamWebsocketConn(conn, &vmess.WebsocketConfig{ + return vmess.StreamWebsocketConn(ctx, conn, &vmess.WebsocketConfig{ Host: wsOptions.Host, Port: wsOptions.Port, Path: wsOptions.Path, diff --git a/transport/v2ray-plugin/websocket.go b/transport/v2ray-plugin/websocket.go index 7c2c8a88..25483670 100644 --- a/transport/v2ray-plugin/websocket.go +++ b/transport/v2ray-plugin/websocket.go @@ -1,6 +1,7 @@ package obfs import ( + "context" "crypto/tls" "net" "net/http" @@ -22,7 +23,7 @@ type Option struct { } // NewV2rayObfs return a HTTPObfs -func NewV2rayObfs(conn net.Conn, option *Option) (net.Conn, error) { +func NewV2rayObfs(ctx context.Context, conn net.Conn, option *Option) (net.Conn, error) { header := http.Header{} for k, v := range option.Headers { header.Add(k, v) @@ -57,7 +58,7 @@ func NewV2rayObfs(conn net.Conn, option *Option) (net.Conn, error) { } var err error - conn, err = vmess.StreamWebsocketConn(conn, config) + conn, err = vmess.StreamWebsocketConn(ctx, conn, config) if err != nil { return nil, err } diff --git a/transport/vless/xtls.go b/transport/vless/xtls.go index 3a319568..09929fc3 100644 --- a/transport/vless/xtls.go +++ b/transport/vless/xtls.go @@ -6,7 +6,6 @@ import ( "net" tlsC "github.com/Dreamacro/clash/component/tls" - C "github.com/Dreamacro/clash/constant" xtls "github.com/xtls/go" ) @@ -21,7 +20,7 @@ type XTLSConfig struct { NextProtos []string } -func StreamXTLSConn(conn net.Conn, cfg *XTLSConfig) (net.Conn, error) { +func StreamXTLSConn(ctx context.Context, conn net.Conn, cfg *XTLSConfig) (net.Conn, error) { xtlsConfig := &xtls.Config{ ServerName: cfg.Host, InsecureSkipVerify: cfg.SkipCertVerify, @@ -38,9 +37,6 @@ func StreamXTLSConn(conn net.Conn, cfg *XTLSConfig) (net.Conn, error) { xtlsConn := xtls.Client(conn, xtlsConfig) - // fix xtls handshake not timeout - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() err := xtlsConn.HandshakeContext(ctx) return xtlsConn, err } diff --git a/transport/vmess/tls.go b/transport/vmess/tls.go index f020d273..54813029 100644 --- a/transport/vmess/tls.go +++ b/transport/vmess/tls.go @@ -7,7 +7,6 @@ import ( "net" tlsC "github.com/Dreamacro/clash/component/tls" - C "github.com/Dreamacro/clash/constant" ) type TLSConfig struct { @@ -19,7 +18,7 @@ type TLSConfig struct { Reality *tlsC.RealityConfig } -func StreamTLSConn(conn net.Conn, cfg *TLSConfig) (net.Conn, error) { +func StreamTLSConn(ctx context.Context, conn net.Conn, cfg *TLSConfig) (net.Conn, error) { tlsConfig := &tls.Config{ ServerName: cfg.Host, InsecureSkipVerify: cfg.SkipCertVerify, @@ -39,15 +38,10 @@ func StreamTLSConn(conn net.Conn, cfg *TLSConfig) (net.Conn, error) { if cfg.Reality == nil { utlsConn, valid := GetUTLSConn(conn, cfg.ClientFingerprint, tlsConfig) if valid { - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() - err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx) return utlsConn, err } } else { - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() return tlsC.GetRealityConn(ctx, conn, cfg.ClientFingerprint, tlsConfig, cfg.Reality) } } @@ -57,9 +51,6 @@ func StreamTLSConn(conn net.Conn, cfg *TLSConfig) (net.Conn, error) { tlsConn := tls.Client(conn, tlsConfig) - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() - err := tlsConn.HandshakeContext(ctx) return tlsConn, err } diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index e7335d84..a4ce99a9 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -194,17 +194,17 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { earlyDataBuf := bytes.NewBuffer(earlyData) if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil { - return errors.New("failed to encode early data: " + err.Error()) + return fmt.Errorf("failed to encode early data: %w", err) } if errc := base64EarlyDataEncoder.Close(); errc != nil { - return errors.New("failed to encode early data tail: " + errc.Error()) + return fmt.Errorf("failed to encode early data tail: %w", errc) } var err error - if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, base64DataBuf); err != nil { + if wsedc.Conn, err = streamWebsocketConn(wsedc.ctx, wsedc.underlay, wsedc.config, base64DataBuf); err != nil { wsedc.Close() - return errors.New("failed to dial WebSocket: " + err.Error()) + return fmt.Errorf("failed to dial WebSocket: %w", err) } wsedc.dialed <- true @@ -340,7 +340,7 @@ func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Co return N.NewDeadlineConn(conn), nil } -func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { +func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { dialer := &websocket.Dialer{ NetDial: func(network, addr string) (net.Conn, error) { @@ -396,13 +396,13 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf } } - wsConn, resp, err := dialer.Dial(uri.String(), headers) + wsConn, resp, err := dialer.DialContext(ctx, uri.String(), headers) if err != nil { - reason := err.Error() + reason := err if resp != nil { - reason = resp.Status + reason = errors.New(resp.Status) } - return nil, fmt.Errorf("dial %s error: %s", uri.Host, reason) + return nil, fmt.Errorf("dial %s error: %w", uri.Host, reason) } conn = &websocketConn{ @@ -417,7 +417,7 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf return N.NewDeadlineConn(conn), nil } -func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { +func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig) (net.Conn, error) { if u, err := url.Parse(c.Path); err == nil { if q := u.Query(); q.Get("ed") != "" { if ed, err := strconv.Atoi(q.Get("ed")); err == nil { @@ -434,5 +434,5 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { return streamWebsocketWithEarlyDataConn(conn, c) } - return streamWebsocketConn(conn, c, nil) + return streamWebsocketConn(ctx, conn, c, nil) }