From c675e82fbf82edc987f2d512d97319d7c6e62df4 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 6 Oct 2023 17:44:36 +0800 Subject: [PATCH] chore: migrate from gorilla/websocket to gobwas/ws --- component/tls/utls.go | 9 +- go.mod | 4 +- go.sum | 9 +- hub/route/connections.go | 9 +- hub/route/server.go | 36 +++-- transport/vmess/websocket.go | 263 ++++++++++++++++++++--------------- 6 files changed, 187 insertions(+), 143 deletions(-) diff --git a/component/tls/utls.go b/component/tls/utls.go index 7ea2ad06..e3d101dc 100644 --- a/component/tls/utls.go +++ b/component/tls/utls.go @@ -99,10 +99,9 @@ func copyConfig(c *tls.Config) *utls.Config { } } -// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send -// http/1.1 in its ALPN. +// BuildWebsocketHandshakeState it will only send http/1.1 in its ALPN. // Copy from https://github.com/XTLS/Xray-core/blob/main/transport/internet/tls/tls.go -func (c *UConn) WebsocketHandshake() error { +func (c *UConn) BuildWebsocketHandshakeState() error { // Build the handshake state. This will apply every variable of the TLS of the // fingerprint in the UConn if err := c.BuildHandshakeState(); err != nil { @@ -120,11 +119,11 @@ func (c *UConn) WebsocketHandshake() error { if !hasALPNExtension { // Append extension if doesn't exists c.Extensions = append(c.Extensions, &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}}) } - // Rebuild the client hello and do the handshake + // Rebuild the client hello if err := c.BuildHandshakeState(); err != nil { return err } - return c.Handshake() + return nil } func SetGlobalUtlsClient(Client string) { diff --git a/go.mod b/go.mod index abcf0a2c..e14ccda1 100644 --- a/go.mod +++ b/go.mod @@ -11,8 +11,8 @@ require ( github.com/go-chi/chi/v5 v5.0.10 github.com/go-chi/cors v1.2.1 github.com/go-chi/render v1.0.3 + github.com/gobwas/ws v1.3.0 github.com/gofrs/uuid/v5 v5.0.0 - github.com/gorilla/websocket v1.5.0 github.com/insomniacslk/dhcp v0.0.0-20230908212754-65c27093e38a github.com/jpillora/backoff v1.0.0 github.com/klauspost/cpuid/v2 v2.2.5 @@ -69,6 +69,8 @@ require ( github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect diff --git a/go.sum b/go.sum index cd1a44a2..0c6b9e50 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,12 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= +github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M= github.com/gofrs/uuid/v5 v5.0.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -63,8 +69,6 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/tink/go v1.6.1 h1:t7JHqO8Ath2w2ig5vjwQYJzhGEZymedQc90lQXUBa4I= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= @@ -237,6 +241,7 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= diff --git a/hub/route/connections.go b/hub/route/connections.go index b123ecae..67d5afa3 100644 --- a/hub/route/connections.go +++ b/hub/route/connections.go @@ -11,7 +11,8 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/render" - "github.com/gorilla/websocket" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" ) func connectionRouter() http.Handler { @@ -23,13 +24,13 @@ func connectionRouter() http.Handler { } func getConnections(w http.ResponseWriter, r *http.Request) { - if !websocket.IsWebSocketUpgrade(r) { + if !(r.Header.Get("Upgrade") == "websocket") { snapshot := statistic.DefaultManager.Snapshot() render.JSON(w, r, snapshot) return } - conn, err := upgrader.Upgrade(w, r, nil) + conn, _, _, err := ws.UpgradeHTTP(r, w) if err != nil { return } @@ -55,7 +56,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) { return err } - return conn.WriteMessage(websocket.TextMessage, buf.Bytes()) + return wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpText, buf.Bytes()) } if err := sendSnapshot(); err != nil { diff --git a/hub/route/server.go b/hub/route/server.go index aa2d03b8..93afd989 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -5,6 +5,7 @@ import ( "crypto/subtle" "crypto/tls" "encoding/json" + "net" "net/http" "runtime/debug" "strings" @@ -21,7 +22,8 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/go-chi/render" - "github.com/gorilla/websocket" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" ) var ( @@ -29,12 +31,6 @@ var ( serverAddr = "" uiPath = "" - - upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, - } ) type Traffic struct { @@ -166,7 +162,7 @@ func authentication(next http.Handler) http.Handler { } // Browser websocket not support custom header - if websocket.IsWebSocketUpgrade(r) && r.URL.Query().Get("token") != "" { + if r.Header.Get("Upgrade") == "websocket" && r.URL.Query().Get("token") != "" { token := r.URL.Query().Get("token") if !safeEuqal(token, serverSecret) { render.Status(r, http.StatusUnauthorized) @@ -197,10 +193,10 @@ func hello(w http.ResponseWriter, r *http.Request) { } func traffic(w http.ResponseWriter, r *http.Request) { - var wsConn *websocket.Conn - if websocket.IsWebSocketUpgrade(r) { + var wsConn net.Conn + if r.Header.Get("Upgrade") == "websocket" { var err error - wsConn, err = upgrader.Upgrade(w, r, nil) + wsConn, _, _, err = ws.UpgradeHTTP(r, w) if err != nil { return } @@ -230,7 +226,7 @@ func traffic(w http.ResponseWriter, r *http.Request) { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) + err = wsutil.WriteMessage(wsConn, ws.StateServerSide, ws.OpText, buf.Bytes()) } if err != nil { @@ -240,10 +236,10 @@ func traffic(w http.ResponseWriter, r *http.Request) { } func memory(w http.ResponseWriter, r *http.Request) { - var wsConn *websocket.Conn - if websocket.IsWebSocketUpgrade(r) { + var wsConn net.Conn + if r.Header.Get("Upgrade") == "websocket" { var err error - wsConn, err = upgrader.Upgrade(w, r, nil) + wsConn, _, _, err = ws.UpgradeHTTP(r, w) if err != nil { return } @@ -280,7 +276,7 @@ func memory(w http.ResponseWriter, r *http.Request) { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) + err = wsutil.WriteMessage(wsConn, ws.StateServerSide, ws.OpText, buf.Bytes()) } if err != nil { @@ -323,10 +319,10 @@ func getLogs(w http.ResponseWriter, r *http.Request) { return } - var wsConn *websocket.Conn - if websocket.IsWebSocketUpgrade(r) { + var wsConn net.Conn + if r.Header.Get("Upgrade") == "websocket" { var err error - wsConn, err = upgrader.Upgrade(w, r, nil) + wsConn, _, _, err = ws.UpgradeHTTP(r, w) if err != nil { return } @@ -385,7 +381,7 @@ func getLogs(w http.ResponseWriter, r *http.Request) { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) + err = wsutil.WriteMessage(wsConn, ws.StateServerSide, ws.OpText, buf.Bytes()) } if err != nil { diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index a4ce99a9..2b26eb1a 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -1,6 +1,7 @@ package vmess import ( + "bufio" "bytes" "context" "crypto/tls" @@ -14,27 +15,24 @@ import ( "net/url" "strconv" "strings" - "sync" "time" "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" tlsC "github.com/Dreamacro/clash/component/tls" - "github.com/gorilla/websocket" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" "github.com/zhangyunhao116/fastrand" ) type websocketConn struct { - conn *websocket.Conn - reader io.Reader - remoteAddr net.Addr + net.Conn + state ws.State + reader *wsutil.Reader + controlHandler wsutil.FrameHandlerFunc rawWriter N.ExtendedWriter - - // https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency - rMux sync.Mutex - wMux sync.Mutex } type websocketWithEarlyDataConn struct { @@ -61,32 +59,48 @@ type WebsocketConfig struct { } // Read implements net.Conn.Read() -func (wsc *websocketConn) Read(b []byte) (int, error) { - wsc.rMux.Lock() - defer wsc.rMux.Unlock() +// modify from gobwas/ws/wsutil.readData +func (wsc *websocketConn) Read(b []byte) (n int, err error) { + var header ws.Header for { - reader, err := wsc.getReader() - if err != nil { - return 0, err + n, err = wsc.reader.Read(b) + // in gobwas/ws: "The error is io.EOF only if all of message bytes were read." + // but maybe next frame still have data, so drop it + if errors.Is(err, io.EOF) { + err = nil } - - nBytes, err := reader.Read(b) - if err == io.EOF { - wsc.reader = nil + if !errors.Is(err, wsutil.ErrNoFrameAdvance) { + return + } + header, err = wsc.reader.NextFrame() + if err != nil { + return + } + if header.OpCode.IsControl() { + err = wsc.controlHandler(header, wsc.reader) + if err != nil { + return + } + continue + } + if header.OpCode&(ws.OpBinary|ws.OpText) == 0 { + err = wsc.reader.Discard() + if err != nil { + return + } continue } - return nBytes, err } } // Write implements io.Writer. -func (wsc *websocketConn) Write(b []byte) (int, error) { - wsc.wMux.Lock() - defer wsc.wMux.Unlock() - if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { - return 0, err +func (wsc *websocketConn) Write(b []byte) (n int, err error) { + err = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpBinary, b) + if err != nil { + return } - return len(b), nil + n = len(b) + return } func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error { @@ -108,7 +122,7 @@ func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error { header := buffer.ExtendHeader(headerLen) _ = header[2] // bounds check hint to compiler - header[0] = websocket.BinaryMessage | 1<<7 + header[0] = byte(ws.OpBinary) | 0x80 header[1] = 1 << 7 if dataLen < 126 { @@ -121,12 +135,12 @@ func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error { binary.BigEndian.PutUint64(header[2:], uint64(dataLen)) } - maskKey := fastrand.Uint32() - binary.LittleEndian.PutUint32(header[1+payloadBitLength:], maskKey) - N.MaskWebSocket(maskKey, data) + if wsc.state.ClientSide() { + maskKey := fastrand.Uint32() + binary.LittleEndian.PutUint32(header[1+payloadBitLength:], maskKey) + N.MaskWebSocket(maskKey, data) + } - wsc.wMux.Lock() - defer wsc.wMux.Unlock() return wsc.rawWriter.WriteBuffer(buffer) } @@ -135,59 +149,16 @@ func (wsc *websocketConn) FrontHeadroom() int { } func (wsc *websocketConn) Upstream() any { - return wsc.conn.UnderlyingConn() + return wsc.Conn } func (wsc *websocketConn) Close() error { - var e []string - if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { - e = append(e, err.Error()) - } - if err := wsc.conn.Close(); err != nil { - e = append(e, err.Error()) - } - if len(e) > 0 { - return fmt.Errorf("failed to close connection: %s", strings.Join(e, ",")) - } + _ = wsc.Conn.SetWriteDeadline(time.Now().Add(time.Second * 5)) + _ = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, "")) + _ = wsc.Conn.Close() return nil } -func (wsc *websocketConn) getReader() (io.Reader, error) { - if wsc.reader != nil { - return wsc.reader, nil - } - - _, reader, err := wsc.conn.NextReader() - if err != nil { - return nil, err - } - wsc.reader = reader - return reader, nil -} - -func (wsc *websocketConn) LocalAddr() net.Addr { - return wsc.conn.LocalAddr() -} - -func (wsc *websocketConn) RemoteAddr() net.Addr { - return wsc.remoteAddr -} - -func (wsc *websocketConn) SetDeadline(t time.Time) error { - if err := wsc.SetReadDeadline(t); err != nil { - return err - } - return wsc.SetWriteDeadline(t) -} - -func (wsc *websocketConn) SetReadDeadline(t time.Time) error { - return wsc.conn.SetReadDeadline(t) -} - -func (wsc *websocketConn) SetWriteDeadline(t time.Time) error { - return wsc.conn.SetWriteDeadline(t) -} - func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { base64DataBuf := &bytes.Buffer{} base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf) @@ -341,29 +312,25 @@ func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Co } func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { - - dialer := &websocket.Dialer{ - NetDial: func(network, addr string) (net.Conn, error) { + dialer := ws.Dialer{ + NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { return conn, nil }, - ReadBufferSize: 4 * 1024, - WriteBufferSize: 4 * 1024, - HandshakeTimeout: time.Second * 8, + TLSConfig: c.TLSConfig, } - scheme := "ws" if c.TLS { scheme = "wss" - dialer.TLSClientConfig = c.TLSConfig if len(c.ClientFingerprint) != 0 { if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists { - dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (net.Conn, error) { - utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint) + utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint) - if err := utlsConn.(*tlsC.UConn).WebsocketHandshake(); err != nil { - return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) - } - return utlsConn, nil + if err := utlsConn.(*tlsC.UConn).BuildWebsocketHandshakeState(); err != nil { + return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) + } + + dialer.TLSClient = func(conn net.Conn, hostname string) net.Conn { + return utlsConn } } } @@ -381,38 +348,47 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, RawQuery: u.RawQuery, } - headers := http.Header{} + headers := http.Header{"User-Agent": []string{"Go-http-client/1.1"}} // match golang's net/http if c.Headers != nil { - for k := range c.Headers { - headers.Add(k, c.Headers.Get(k)) + cHeaders := c.Headers + // gobwas/ws send "Host" directly in Upgrade() by `httpWriteHeader(bw, headerHost, u.Host)` + // if headers has "Host" will send repeatedly + if host := cHeaders.Get("Host"); host != "" { + cHeaders.Del("Host") + uri.Host = host + } + for k := range cHeaders { + headers.Add(k, cHeaders.Get(k)) } } if earlyData != nil { + earlyDataString := earlyData.String() if c.EarlyDataHeaderName == "" { - uri.Path += earlyData.String() + uri.Path += earlyDataString } else { - headers.Set(c.EarlyDataHeaderName, earlyData.String()) + // gobwas/ws will check server's response "Sec-Websocket-Protocol" so must add Protocols to ws.Dialer + // if not will cause ws.ErrHandshakeBadSubProtocol + if c.EarlyDataHeaderName == "Sec-WebSocket-Protocol" { + // gobwas/ws will set "Sec-Websocket-Protocol" according dialer.Protocols + // to avoid send repeatedly don't set it to headers + dialer.Protocols = []string{earlyDataString} + } else { + headers.Set(c.EarlyDataHeaderName, earlyDataString) + } + } } - wsConn, resp, err := dialer.DialContext(ctx, uri.String(), headers) + dialer.Header = ws.HandshakeHeaderHTTP(headers) + + conn, reader, _, err := dialer.Dial(ctx, uri.String()) if err != nil { - reason := err - if resp != nil { - reason = errors.New(resp.Status) - } - return nil, fmt.Errorf("dial %s error: %w", uri.Host, reason) + return nil, fmt.Errorf("dial %s error: %w", uri.Host, err) } - conn = &websocketConn{ - conn: wsConn, - rawWriter: N.NewExtendedWriter(wsConn.UnderlyingConn()), - remoteAddr: conn.RemoteAddr(), - } + conn = newWebsocketConn(conn, reader, ws.StateClientSide) // websocketConn can't correct handle ReadDeadline - // gorilla/websocket will cache the os.ErrDeadlineExceeded from conn.Read() - // it will cause read fail and event panic in *websocket.Conn.NextReader() // so call N.NewDeadlineConn to add a safe wrapper return N.NewDeadlineConn(conn), nil } @@ -436,3 +412,68 @@ func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig) return streamWebsocketConn(ctx, conn, c, nil) } + +func newWebsocketConn(conn net.Conn, br *bufio.Reader, state ws.State) *websocketConn { + controlHandler := wsutil.ControlFrameHandler(conn, state) + var reader io.Reader + if br != nil && br.Buffered() > 0 { + reader = br + } else { + reader = conn + } + return &websocketConn{ + Conn: conn, + state: state, + reader: &wsutil.Reader{ + Source: reader, + State: state, + SkipHeaderCheck: true, + CheckUTF8: false, + OnIntermediate: controlHandler, + }, + controlHandler: controlHandler, + rawWriter: N.NewExtendedWriter(conn), + } +} + +var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") + +func decodeEd(s string) ([]byte, error) { + return base64.RawURLEncoding.DecodeString(replacer.Replace(s)) +} + +func decodeXray0rtt(requestHeader http.Header) ([]byte, http.Header) { + var edBuf []byte + responseHeader := http.Header{} + // read inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws + if secProtocol := requestHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 { + if buf, err := decodeEd(secProtocol); err == nil { // sure could base64 decode + edBuf = buf + } + } + return edBuf, responseHeader +} + +func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Conn, error) { + edBuf, responseHeader := decodeXray0rtt(r.Header) + wsConn, rw, _, err := ws.HTTPUpgrader{ + Header: responseHeader, + }.Upgrade(r, w) + if err != nil { + return nil, err + } + conn := newWebsocketConn(wsConn, rw.Reader, ws.StateServerSide) + if len(edBuf) > 0 { + return &websocketWithReaderConn{conn, io.MultiReader(bytes.NewReader(edBuf), conn)}, nil + } + return conn, nil +} + +type websocketWithReaderConn struct { + *websocketConn + reader io.Reader +} + +func (ws *websocketWithReaderConn) Read(b []byte) (n int, err error) { + return ws.reader.Read(b) +}