From 8c3557e96be1acbccd318584c48b587b59b3e7fb Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 3 Nov 2023 13:58:53 +0800 Subject: [PATCH] chore: support v2ray http upgrade server too --- common/net/cached.go | 49 +++++++++++++++++++++ transport/vmess/websocket.go | 82 ++++++++++++++++++++++++------------ 2 files changed, 105 insertions(+), 26 deletions(-) create mode 100644 common/net/cached.go diff --git a/common/net/cached.go b/common/net/cached.go new file mode 100644 index 00000000..3b7da44c --- /dev/null +++ b/common/net/cached.go @@ -0,0 +1,49 @@ +package net + +import ( + "net" + + "github.com/Dreamacro/clash/common/buf" +) + +var _ ExtendedConn = (*CachedConn)(nil) + +type CachedConn struct { + ExtendedConn + data []byte +} + +func NewCachedConn(c net.Conn, data []byte) *CachedConn { + return &CachedConn{NewExtendedConn(c), data} +} + +func (c *CachedConn) Read(b []byte) (n int, err error) { + if len(c.data) > 0 { + n = copy(b, c.data) + c.data = c.data[n:] + return + } + return c.ExtendedConn.Read(b) +} + +func (c *CachedConn) ReadCached() *buf.Buffer { // call in sing/common/bufio.Copy + if len(c.data) > 0 { + return buf.As(c.data) + } + return nil +} + +func (c *CachedConn) Upstream() any { + return c.ExtendedConn +} + +func (c *CachedConn) ReaderReplaceable() bool { + if len(c.data) > 0 { + return false + } + return true +} + +func (c *CachedConn) WriterReplaceable() bool { + return true +} diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 0c2a3a16..9f09185b 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -1,6 +1,7 @@ package vmess import ( + "bufio" "bytes" "context" "crypto/sha1" @@ -382,7 +383,7 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, } request.Header.Del("Host") - var nonce string + var secKey string if !c.V2rayHttpUpgrade { const nonceKeySize = 16 // NOTE: bts does not escape. @@ -390,9 +391,9 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, if _, err = fastrand.Read(bts); err != nil { return nil, fmt.Errorf("rand read error: %w", err) } - nonce = base64.StdEncoding.EncodeToString(bts) + secKey = base64.StdEncoding.EncodeToString(bts) request.Header.Set("Sec-WebSocket-Version", "13") - request.Header.Set("Sec-WebSocket-Key", nonce) + request.Header.Set("Sec-WebSocket-Key", secKey) } if earlyData != nil { @@ -434,14 +435,7 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, if lenSecAccept := len(secAccept); lenSecAccept != acceptSize { return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept) } - - const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) - p := make([]byte, nonceSize+len(magic)) - copy(p[:nonceSize], nonce) - copy(p[nonceSize:], magic) - sum := sha1.Sum(p) - if accept := base64.StdEncoding.EncodeToString(sum[:]); accept != secAccept { + if getSecAccept(secKey) != secAccept { return nil, errors.New("unexpected Sec-Websocket-Accept") } } @@ -452,6 +446,16 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, return N.NewDeadlineConn(conn), nil } +func getSecAccept(secKey string) string { + const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) + p := make([]byte, nonceSize+len(magic)) + copy(p[:nonceSize], secKey) + copy(p[nonceSize:], magic) + sum := sha1.Sum(p) + return base64.StdEncoding.EncodeToString(sum[:]) +} + func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig) (net.Conn, error) { if u, err := url.Parse(c.Path); err == nil { if q := u.Query(); q.Get("ed") != "" { @@ -505,27 +509,53 @@ func decodeXray0rtt(requestHeader http.Header) []byte { return nil } +func IsWebSocketUpgrade(r *http.Request) bool { + return r.Header.Get("Upgrade") == "websocket" +} + +func IsV2rayHttpUpdate(r *http.Request) bool { + return IsWebSocketUpgrade(r) && r.Header.Get("Sec-WebSocket-Key") == "" +} + func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Conn, error) { - wsConn, rw, _, err := ws.UpgradeHTTP(r, w) + var conn net.Conn + var rw *bufio.ReadWriter + var err error + isRaw := IsV2rayHttpUpdate(r) + w.Header().Set("Connection", "upgrade") + w.Header().Set("Upgrade", "websocket") + if !isRaw { + w.Header().Set("Sec-Websocket-Accept", getSecAccept(r.Header.Get("Sec-WebSocket-Key"))) + } + w.WriteHeader(http.StatusSwitchingProtocols) + if flusher, isFlusher := w.(interface{ FlushError() error }); isFlusher { + err = flusher.FlushError() + if err != nil { + return nil, fmt.Errorf("flush response: %w", err) + } + } + hijacker, canHijack := w.(http.Hijacker) + if !canHijack { + return nil, errors.New("invalid connection, maybe HTTP/2") + } + conn, rw, err = hijacker.Hijack() if err != nil { - return nil, err + return nil, fmt.Errorf("hijack failed: %w", err) } - // gobwas/ws will flush rw.Writer, so we only need warp rw.Reader - wsConn = N.WarpConnWithBioReader(wsConn, rw.Reader) + // rw.Writer was flushed, so we only need warp rw.Reader + conn = N.WarpConnWithBioReader(conn, rw.Reader) + + if !isRaw { + conn = newWebsocketConn(conn, ws.StateServerSide) + // websocketConn can't correct handle ReadDeadline + // so call N.NewDeadlineConn to add a safe wrapper + conn = N.NewDeadlineConn(conn) + } - conn := newWebsocketConn(wsConn, ws.StateServerSide) if edBuf := decodeXray0rtt(r.Header); len(edBuf) > 0 { - return N.NewDeadlineConn(&websocketWithReaderConn{conn, io.MultiReader(bytes.NewReader(edBuf), conn)}), nil + conn = N.NewCachedConn(conn, edBuf) } - return N.NewDeadlineConn(conn), nil -} -type websocketWithReaderConn struct { - *websocketConn - reader io.Reader -} - -func (ws *websocketWithReaderConn) Read(b []byte) (n int, err error) { - return ws.reader.Read(b) + return conn, nil }