diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 6d679e29..83f5e3c2 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -353,17 +353,11 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, RawQuery: u.RawQuery, } - headers := http.Header{"User-Agent": []string{"Go-http-client/1.1"}} // match golang's net/http + headers := http.Header{} + headers.Set("User-Agent", "Go-http-client/1.1") // match golang's net/http if c.Headers != nil { - cHeaders := c.Headers - // gobwas/ws send "Host" directly in Upgrade() by `httpWriteHeader(bw, headerHost, u.Host)` - // if headers has "Host" will send repeatedly - if host := cHeaders.Get("Host"); host != "" { - cHeaders.Del("Host") - uri.Host = host - } - for k := range cHeaders { - headers.Add(k, cHeaders.Get(k)) + for k := range c.Headers { + headers.Add(k, c.Headers.Get(k)) } } @@ -372,19 +366,26 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, if c.EarlyDataHeaderName == "" { uri.Path += earlyDataString } else { - // gobwas/ws will check server's response "Sec-Websocket-Protocol" so must add Protocols to ws.Dialer - // if not will cause ws.ErrHandshakeBadSubProtocol - if c.EarlyDataHeaderName == "Sec-WebSocket-Protocol" { - // gobwas/ws will set "Sec-Websocket-Protocol" according dialer.Protocols - // to avoid send repeatedly don't set it to headers - dialer.Protocols = []string{earlyDataString} - } else { - headers.Set(c.EarlyDataHeaderName, earlyDataString) - } - + headers.Set(c.EarlyDataHeaderName, earlyDataString) } } + // gobwas/ws will check server's response "Sec-Websocket-Protocol" so must add Protocols to ws.Dialer + // if not will cause ws.ErrHandshakeBadSubProtocol + if secProtocol := headers.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 { + // gobwas/ws will set "Sec-Websocket-Protocol" according dialer.Protocols + // to avoid send repeatedly don't set it to headers + headers.Del("Sec-WebSocket-Protocol") + dialer.Protocols = []string{secProtocol} + } + + // gobwas/ws send "Host" directly in Upgrade() by `httpWriteHeader(bw, headerHost, u.Host)` + // if headers has "Host" will send repeatedly + if host := headers.Get("Host"); host != "" { + headers.Del("Host") + uri.Host = host + } + dialer.Header = ws.HandshakeHeaderHTTP(headers) conn, reader, _, err := dialer.Dial(ctx, uri.String()) @@ -447,26 +448,23 @@ func decodeEd(s string) ([]byte, error) { return base64.RawURLEncoding.DecodeString(replacer.Replace(s)) } -func decodeXray0rtt(requestHeader http.Header) ([]byte, http.Header) { - var edBuf []byte - responseHeader := http.Header{} +func decodeXray0rtt(requestHeader http.Header) []byte { // read inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws if secProtocol := requestHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 { - if buf, err := decodeEd(secProtocol); err == nil { // sure could base64 decode - edBuf = buf + if edBuf, err := decodeEd(secProtocol); err == nil { // sure could base64 decode + return edBuf } } - return edBuf, responseHeader + return nil } func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Conn, error) { - edBuf, responseHeader := decodeXray0rtt(r.Header) - wsConn, rw, _, err := ws.HTTPUpgrader{Header: responseHeader}.Upgrade(r, w) + wsConn, rw, _, err := ws.UpgradeHTTP(r, w) if err != nil { return nil, err } conn := newWebsocketConn(wsConn, rw.Reader, ws.StateServerSide) - if len(edBuf) > 0 { + if edBuf := decodeXray0rtt(r.Header); len(edBuf) > 0 { return N.NewDeadlineConn(&websocketWithReaderConn{conn, io.MultiReader(bytes.NewReader(edBuf), conn)}), nil } return N.NewDeadlineConn(conn), nil