chore: do websocket client upgrade directly instead of gobwas/ws
This commit is contained in:
parent
ee3038d5e4
commit
665ba7f9f1
1 changed files with 81 additions and 84 deletions
|
@ -3,6 +3,7 @@ package vmess
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
|
@ -19,6 +20,7 @@ import (
|
|||
"github.com/Dreamacro/clash/common/buf"
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
tlsC "github.com/Dreamacro/clash/component/tls"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
|
@ -317,35 +319,35 @@ func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Co
|
|||
}
|
||||
|
||||
func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
|
||||
dialer := ws.Dialer{
|
||||
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return conn, nil
|
||||
},
|
||||
TLSConfig: c.TLSConfig,
|
||||
u, err := url.Parse(c.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
|
||||
}
|
||||
|
||||
scheme := "ws"
|
||||
if c.TLS {
|
||||
scheme = "wss"
|
||||
if len(c.ClientFingerprint) != 0 {
|
||||
if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists {
|
||||
utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint)
|
||||
|
||||
if err := utlsConn.BuildWebsocketHandshakeState(); err != nil {
|
||||
if err = utlsConn.BuildWebsocketHandshakeState(); err != nil {
|
||||
return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
|
||||
}
|
||||
conn = utlsConn
|
||||
}
|
||||
} else {
|
||||
conn = tls.Client(conn, c.TLSConfig)
|
||||
}
|
||||
|
||||
dialer.TLSClient = func(conn net.Conn, hostname string) net.Conn {
|
||||
return utlsConn
|
||||
}
|
||||
if tlsConn, ok := conn.(interface {
|
||||
HandshakeContext(ctx context.Context) error
|
||||
}); ok {
|
||||
if err = tlsConn.HandshakeContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
u, err := url.Parse(c.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
|
||||
}
|
||||
|
||||
uri := url.URL{
|
||||
Scheme: scheme,
|
||||
Host: net.JoinHostPort(c.Host, c.Port),
|
||||
|
@ -353,56 +355,36 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig,
|
|||
RawQuery: u.RawQuery,
|
||||
}
|
||||
|
||||
if c.V2rayHttpUpgrade {
|
||||
if c.TLS {
|
||||
if dialer.TLSClient != nil {
|
||||
conn = dialer.TLSClient(conn, uri.Host)
|
||||
} else {
|
||||
conn = tls.Client(conn, dialer.TLSConfig)
|
||||
}
|
||||
if tlsConn, ok := conn.(interface {
|
||||
HandshakeContext(ctx context.Context) error
|
||||
}); ok {
|
||||
if err = tlsConn.HandshakeContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
request := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &uri,
|
||||
Header: c.Headers.Clone(),
|
||||
Host: c.Host,
|
||||
}
|
||||
request.Header.Set("Connection", "Upgrade")
|
||||
request.Header.Set("Upgrade", "websocket")
|
||||
if host := request.Header.Get("Host"); host != "" {
|
||||
request.Header.Del("Host")
|
||||
request.Host = host
|
||||
}
|
||||
err = request.Write(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bufferedConn := N.NewBufferedConn(conn)
|
||||
response, err := http.ReadResponse(bufferedConn.Reader(), request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response.StatusCode != 101 ||
|
||||
!strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
|
||||
!strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
|
||||
return nil, fmt.Errorf("unexpected status: %s", response.Status)
|
||||
}
|
||||
return bufferedConn, nil
|
||||
request := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &uri,
|
||||
Header: c.Headers.Clone(),
|
||||
Host: c.Host,
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("User-Agent", "Go-http-client/1.1") // match golang's net/http
|
||||
if c.Headers != nil {
|
||||
for k := range c.Headers {
|
||||
headers.Add(k, c.Headers.Get(k))
|
||||
request.Header.Set("Connection", "Upgrade")
|
||||
request.Header.Set("Upgrade", "websocket")
|
||||
|
||||
if host := request.Header.Get("Host"); host != "" {
|
||||
// For client requests, Host optionally overrides the Host
|
||||
// header to send. If empty, the Request.Write method uses
|
||||
// the value of URL.Host. Host may contain an international
|
||||
// domain name.
|
||||
request.Host = host
|
||||
}
|
||||
request.Header.Del("Host")
|
||||
|
||||
var nonce string
|
||||
if !c.V2rayHttpUpgrade {
|
||||
const nonceKeySize = 16
|
||||
// NOTE: bts does not escape.
|
||||
bts := make([]byte, nonceKeySize)
|
||||
if _, err = fastrand.Read(bts); err != nil {
|
||||
return nil, fmt.Errorf("rand read error: %w", err)
|
||||
}
|
||||
nonce = base64.StdEncoding.EncodeToString(bts)
|
||||
request.Header.Set("Sec-WebSocket-Version", "13")
|
||||
request.Header.Set("Sec-WebSocket-Key", nonce)
|
||||
}
|
||||
|
||||
if earlyData != nil {
|
||||
|
@ -410,36 +392,51 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig,
|
|||
if c.EarlyDataHeaderName == "" {
|
||||
uri.Path += earlyDataString
|
||||
} else {
|
||||
headers.Set(c.EarlyDataHeaderName, earlyDataString)
|
||||
request.Header.Set(c.EarlyDataHeaderName, earlyDataString)
|
||||
}
|
||||
}
|
||||
|
||||
// gobwas/ws will check server's response "Sec-Websocket-Protocol" so must add Protocols to ws.Dialer
|
||||
// if not will cause ws.ErrHandshakeBadSubProtocol
|
||||
if secProtocol := headers.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 {
|
||||
// gobwas/ws will set "Sec-Websocket-Protocol" according dialer.Protocols
|
||||
// to avoid send repeatedly don't set it to headers
|
||||
dialer.Protocols = []string{secProtocol}
|
||||
if ctx.Done() != nil {
|
||||
done := N.SetupContextForConn(ctx, conn)
|
||||
defer done(&err)
|
||||
}
|
||||
headers.Del("Sec-WebSocket-Protocol")
|
||||
|
||||
// gobwas/ws send "Host" directly in Upgrade() by `httpWriteHeader(bw, headerHost, u.Host)`
|
||||
// if headers has "Host" will send repeatedly
|
||||
if host := headers.Get("Host"); host != "" {
|
||||
uri.Host = host
|
||||
}
|
||||
headers.Del("Host")
|
||||
|
||||
dialer.Header = ws.HandshakeHeaderHTTP(headers)
|
||||
|
||||
conn, reader, _, err := dialer.Dial(ctx, uri.String())
|
||||
err = request.Write(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s error: %w", uri.Host, err)
|
||||
return nil, err
|
||||
}
|
||||
bufferedConn := N.NewBufferedConn(conn)
|
||||
response, err := http.ReadResponse(bufferedConn.Reader(), request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response.StatusCode != http.StatusSwitchingProtocols ||
|
||||
!strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
|
||||
!strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
|
||||
return nil, fmt.Errorf("unexpected status: %s", response.Status)
|
||||
}
|
||||
|
||||
// some bytes which could be written by the peer right after response and be caught by us during buffered read,
|
||||
// so we need warp Conn with bio.Reader
|
||||
conn = N.WarpConnWithBioReader(conn, reader)
|
||||
if c.V2rayHttpUpgrade {
|
||||
return bufferedConn, nil
|
||||
}
|
||||
|
||||
if log.Level() == log.DEBUG { // we might not check this for performance
|
||||
secAccept := response.Header.Get("Sec-Websocket-Accept")
|
||||
const acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
|
||||
if lenSecAccept := len(secAccept); lenSecAccept != acceptSize {
|
||||
return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept)
|
||||
}
|
||||
|
||||
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
|
||||
p := make([]byte, nonceSize+len(magic))
|
||||
copy(p[:nonceSize], nonce)
|
||||
copy(p[nonceSize:], magic)
|
||||
sum := sha1.Sum(p)
|
||||
if accept := base64.StdEncoding.EncodeToString(sum[:]); accept != secAccept {
|
||||
return nil, errors.New("unexpected Sec-Websocket-Accept")
|
||||
}
|
||||
}
|
||||
|
||||
conn = newWebsocketConn(conn, ws.StateClientSide)
|
||||
// websocketConn can't correct handle ReadDeadline
|
||||
|
|
Loading…
Reference in a new issue