From 665ba7f9f179a69e4da357ef2eae186fa32cdbd8 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 3 Nov 2023 11:02:19 +0800 Subject: [PATCH] chore: do websocket client upgrade directly instead of gobwas/ws --- transport/vmess/websocket.go | 165 +++++++++++++++++------------------ 1 file changed, 81 insertions(+), 84 deletions(-) diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index ebafefa4..60353d5a 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -3,6 +3,7 @@ package vmess import ( "bytes" "context" + "crypto/sha1" "crypto/tls" "encoding/base64" "encoding/binary" @@ -19,6 +20,7 @@ import ( "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" tlsC "github.com/Dreamacro/clash/component/tls" + "github.com/Dreamacro/clash/log" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" @@ -317,35 +319,35 @@ func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Co } func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { - dialer := ws.Dialer{ - NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { - return conn, nil - }, - TLSConfig: c.TLSConfig, + u, err := url.Parse(c.Path) + if err != nil { + return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) } + scheme := "ws" if c.TLS { scheme = "wss" if len(c.ClientFingerprint) != 0 { if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists { utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint) - - if err := utlsConn.BuildWebsocketHandshakeState(); err != nil { + if err = utlsConn.BuildWebsocketHandshakeState(); err != nil { return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) } + conn = utlsConn + } + } else { + conn = tls.Client(conn, c.TLSConfig) + } - dialer.TLSClient = func(conn net.Conn, hostname string) net.Conn { - return utlsConn - } + if tlsConn, ok := conn.(interface { + HandshakeContext(ctx context.Context) error + }); ok { + if err = tlsConn.HandshakeContext(ctx); err != nil { + return nil, err } } } - u, err := url.Parse(c.Path) - if err != nil { - return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) - } - uri := url.URL{ Scheme: scheme, Host: net.JoinHostPort(c.Host, c.Port), @@ -353,56 +355,36 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, RawQuery: u.RawQuery, } - if c.V2rayHttpUpgrade { - if c.TLS { - if dialer.TLSClient != nil { - conn = dialer.TLSClient(conn, uri.Host) - } else { - conn = tls.Client(conn, dialer.TLSConfig) - } - if tlsConn, ok := conn.(interface { - HandshakeContext(ctx context.Context) error - }); ok { - if err = tlsConn.HandshakeContext(ctx); err != nil { - return nil, err - } - } - } - request := &http.Request{ - Method: http.MethodGet, - URL: &uri, - Header: c.Headers.Clone(), - Host: c.Host, - } - request.Header.Set("Connection", "Upgrade") - request.Header.Set("Upgrade", "websocket") - if host := request.Header.Get("Host"); host != "" { - request.Header.Del("Host") - request.Host = host - } - err = request.Write(conn) - if err != nil { - return nil, err - } - bufferedConn := N.NewBufferedConn(conn) - response, err := http.ReadResponse(bufferedConn.Reader(), request) - if err != nil { - return nil, err - } - if response.StatusCode != 101 || - !strings.EqualFold(response.Header.Get("Connection"), "upgrade") || - !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { - return nil, fmt.Errorf("unexpected status: %s", response.Status) - } - return bufferedConn, nil + request := &http.Request{ + Method: http.MethodGet, + URL: &uri, + Header: c.Headers.Clone(), + Host: c.Host, } - headers := http.Header{} - headers.Set("User-Agent", "Go-http-client/1.1") // match golang's net/http - if c.Headers != nil { - for k := range c.Headers { - headers.Add(k, c.Headers.Get(k)) + request.Header.Set("Connection", "Upgrade") + request.Header.Set("Upgrade", "websocket") + + if host := request.Header.Get("Host"); host != "" { + // For client requests, Host optionally overrides the Host + // header to send. If empty, the Request.Write method uses + // the value of URL.Host. Host may contain an international + // domain name. + request.Host = host + } + request.Header.Del("Host") + + var nonce string + if !c.V2rayHttpUpgrade { + const nonceKeySize = 16 + // NOTE: bts does not escape. + bts := make([]byte, nonceKeySize) + if _, err = fastrand.Read(bts); err != nil { + return nil, fmt.Errorf("rand read error: %w", err) } + nonce = base64.StdEncoding.EncodeToString(bts) + request.Header.Set("Sec-WebSocket-Version", "13") + request.Header.Set("Sec-WebSocket-Key", nonce) } if earlyData != nil { @@ -410,36 +392,51 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, if c.EarlyDataHeaderName == "" { uri.Path += earlyDataString } else { - headers.Set(c.EarlyDataHeaderName, earlyDataString) + request.Header.Set(c.EarlyDataHeaderName, earlyDataString) } } - // gobwas/ws will check server's response "Sec-Websocket-Protocol" so must add Protocols to ws.Dialer - // if not will cause ws.ErrHandshakeBadSubProtocol - if secProtocol := headers.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 { - // gobwas/ws will set "Sec-Websocket-Protocol" according dialer.Protocols - // to avoid send repeatedly don't set it to headers - dialer.Protocols = []string{secProtocol} + if ctx.Done() != nil { + done := N.SetupContextForConn(ctx, conn) + defer done(&err) } - headers.Del("Sec-WebSocket-Protocol") - // gobwas/ws send "Host" directly in Upgrade() by `httpWriteHeader(bw, headerHost, u.Host)` - // if headers has "Host" will send repeatedly - if host := headers.Get("Host"); host != "" { - uri.Host = host - } - headers.Del("Host") - - dialer.Header = ws.HandshakeHeaderHTTP(headers) - - conn, reader, _, err := dialer.Dial(ctx, uri.String()) + err = request.Write(conn) if err != nil { - return nil, fmt.Errorf("dial %s error: %w", uri.Host, err) + return nil, err + } + bufferedConn := N.NewBufferedConn(conn) + response, err := http.ReadResponse(bufferedConn.Reader(), request) + if err != nil { + return nil, err + } + if response.StatusCode != http.StatusSwitchingProtocols || + !strings.EqualFold(response.Header.Get("Connection"), "upgrade") || + !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { + return nil, fmt.Errorf("unexpected status: %s", response.Status) } - // some bytes which could be written by the peer right after response and be caught by us during buffered read, - // so we need warp Conn with bio.Reader - conn = N.WarpConnWithBioReader(conn, reader) + if c.V2rayHttpUpgrade { + return bufferedConn, nil + } + + if log.Level() == log.DEBUG { // we might not check this for performance + secAccept := response.Header.Get("Sec-Websocket-Accept") + const acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size) + if lenSecAccept := len(secAccept); lenSecAccept != acceptSize { + return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept) + } + + const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) + p := make([]byte, nonceSize+len(magic)) + copy(p[:nonceSize], nonce) + copy(p[nonceSize:], magic) + sum := sha1.Sum(p) + if accept := base64.StdEncoding.EncodeToString(sum[:]); accept != secAccept { + return nil, errors.New("unexpected Sec-Websocket-Accept") + } + } conn = newWebsocketConn(conn, ws.StateClientSide) // websocketConn can't correct handle ReadDeadline