chore: do websocket client upgrade directly instead of gobwas/ws

This commit is contained in:
wwqgtxx 2023-11-03 11:02:19 +08:00
parent ee3038d5e4
commit 665ba7f9f1

View file

@ -3,6 +3,7 @@ package vmess
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/sha1"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
@ -19,6 +20,7 @@ import (
"github.com/Dreamacro/clash/common/buf" "github.com/Dreamacro/clash/common/buf"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
tlsC "github.com/Dreamacro/clash/component/tls" tlsC "github.com/Dreamacro/clash/component/tls"
"github.com/Dreamacro/clash/log"
"github.com/gobwas/ws" "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
@ -317,33 +319,33 @@ func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Co
} }
func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
dialer := ws.Dialer{ u, err := url.Parse(c.Path)
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { if err != nil {
return conn, nil return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
},
TLSConfig: c.TLSConfig,
} }
scheme := "ws" scheme := "ws"
if c.TLS { if c.TLS {
scheme = "wss" scheme = "wss"
if len(c.ClientFingerprint) != 0 { if len(c.ClientFingerprint) != 0 {
if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists { if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists {
utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint) utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint)
if err = utlsConn.BuildWebsocketHandshakeState(); err != nil {
if err := utlsConn.BuildWebsocketHandshakeState(); err != nil {
return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
} }
conn = utlsConn
dialer.TLSClient = func(conn net.Conn, hostname string) net.Conn {
return utlsConn
}
}
} }
} else {
conn = tls.Client(conn, c.TLSConfig)
} }
u, err := url.Parse(c.Path) if tlsConn, ok := conn.(interface {
if err != nil { HandshakeContext(ctx context.Context) error
return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) }); ok {
if err = tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}
}
} }
uri := url.URL{ uri := url.URL{
@ -353,33 +355,52 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig,
RawQuery: u.RawQuery, RawQuery: u.RawQuery,
} }
if c.V2rayHttpUpgrade {
if c.TLS {
if dialer.TLSClient != nil {
conn = dialer.TLSClient(conn, uri.Host)
} else {
conn = tls.Client(conn, dialer.TLSConfig)
}
if tlsConn, ok := conn.(interface {
HandshakeContext(ctx context.Context) error
}); ok {
if err = tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}
}
}
request := &http.Request{ request := &http.Request{
Method: http.MethodGet, Method: http.MethodGet,
URL: &uri, URL: &uri,
Header: c.Headers.Clone(), Header: c.Headers.Clone(),
Host: c.Host, Host: c.Host,
} }
request.Header.Set("Connection", "Upgrade") request.Header.Set("Connection", "Upgrade")
request.Header.Set("Upgrade", "websocket") request.Header.Set("Upgrade", "websocket")
if host := request.Header.Get("Host"); host != "" { if host := request.Header.Get("Host"); host != "" {
request.Header.Del("Host") // For client requests, Host optionally overrides the Host
// header to send. If empty, the Request.Write method uses
// the value of URL.Host. Host may contain an international
// domain name.
request.Host = host request.Host = host
} }
request.Header.Del("Host")
var nonce string
if !c.V2rayHttpUpgrade {
const nonceKeySize = 16
// NOTE: bts does not escape.
bts := make([]byte, nonceKeySize)
if _, err = fastrand.Read(bts); err != nil {
return nil, fmt.Errorf("rand read error: %w", err)
}
nonce = base64.StdEncoding.EncodeToString(bts)
request.Header.Set("Sec-WebSocket-Version", "13")
request.Header.Set("Sec-WebSocket-Key", nonce)
}
if earlyData != nil {
earlyDataString := earlyData.String()
if c.EarlyDataHeaderName == "" {
uri.Path += earlyDataString
} else {
request.Header.Set(c.EarlyDataHeaderName, earlyDataString)
}
}
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, conn)
defer done(&err)
}
err = request.Write(conn) err = request.Write(conn)
if err != nil { if err != nil {
return nil, err return nil, err
@ -389,58 +410,34 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig,
if err != nil { if err != nil {
return nil, err return nil, err
} }
if response.StatusCode != 101 || if response.StatusCode != http.StatusSwitchingProtocols ||
!strings.EqualFold(response.Header.Get("Connection"), "upgrade") || !strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
!strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
return nil, fmt.Errorf("unexpected status: %s", response.Status) return nil, fmt.Errorf("unexpected status: %s", response.Status)
} }
if c.V2rayHttpUpgrade {
return bufferedConn, nil return bufferedConn, nil
} }
headers := http.Header{} if log.Level() == log.DEBUG { // we might not check this for performance
headers.Set("User-Agent", "Go-http-client/1.1") // match golang's net/http secAccept := response.Header.Get("Sec-Websocket-Accept")
if c.Headers != nil { const acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
for k := range c.Headers { if lenSecAccept := len(secAccept); lenSecAccept != acceptSize {
headers.Add(k, c.Headers.Get(k)) return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept)
}
} }
if earlyData != nil { const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
earlyDataString := earlyData.String() const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
if c.EarlyDataHeaderName == "" { p := make([]byte, nonceSize+len(magic))
uri.Path += earlyDataString copy(p[:nonceSize], nonce)
} else { copy(p[nonceSize:], magic)
headers.Set(c.EarlyDataHeaderName, earlyDataString) sum := sha1.Sum(p)
if accept := base64.StdEncoding.EncodeToString(sum[:]); accept != secAccept {
return nil, errors.New("unexpected Sec-Websocket-Accept")
} }
} }
// gobwas/ws will check server's response "Sec-Websocket-Protocol" so must add Protocols to ws.Dialer
// if not will cause ws.ErrHandshakeBadSubProtocol
if secProtocol := headers.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 {
// gobwas/ws will set "Sec-Websocket-Protocol" according dialer.Protocols
// to avoid send repeatedly don't set it to headers
dialer.Protocols = []string{secProtocol}
}
headers.Del("Sec-WebSocket-Protocol")
// gobwas/ws send "Host" directly in Upgrade() by `httpWriteHeader(bw, headerHost, u.Host)`
// if headers has "Host" will send repeatedly
if host := headers.Get("Host"); host != "" {
uri.Host = host
}
headers.Del("Host")
dialer.Header = ws.HandshakeHeaderHTTP(headers)
conn, reader, _, err := dialer.Dial(ctx, uri.String())
if err != nil {
return nil, fmt.Errorf("dial %s error: %w", uri.Host, err)
}
// some bytes which could be written by the peer right after response and be caught by us during buffered read,
// so we need warp Conn with bio.Reader
conn = N.WarpConnWithBioReader(conn, reader)
conn = newWebsocketConn(conn, ws.StateClientSide) conn = newWebsocketConn(conn, ws.StateClientSide)
// websocketConn can't correct handle ReadDeadline // websocketConn can't correct handle ReadDeadline
// so call N.NewDeadlineConn to add a safe wrapper // so call N.NewDeadlineConn to add a safe wrapper