From d99920a3e69041f8c8d83440576317a2739b7868 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Sun, 22 Aug 2021 00:25:29 +0800 Subject: [PATCH] Feature: add vmess WebSocket early data (#1505) Co-authored-by: ShinyGwyn <79344143+ShinyGwyn@users.noreply.github.com> --- adapter/outbound/vmess.go | 29 +++++-- transport/vmess/conn.go | 9 +- transport/vmess/websocket.go | 161 +++++++++++++++++++++++++++++++++-- 3 files changed, 181 insertions(+), 18 deletions(-) diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 5ee4abbc..445b1ef4 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -43,6 +43,7 @@ type VmessOption struct { HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"` HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"` GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` + WSOpts WSOptions `proxy:"ws-opts,omitempty"` WSPath string `proxy:"ws-path,omitempty"` WSHeaders map[string]string `proxy:"ws-headers,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` @@ -64,19 +65,35 @@ type GrpcOptions struct { GrpcServiceName string `proxy:"grpc-service-name,omitempty"` } +type WSOptions struct { + Path string `proxy:"path,omitempty"` + Headers map[string]string `proxy:"headers,omitempty"` + MaxEarlyData int `proxy:"max-early-data,omitempty"` + EarlyDataHeaderName string `proxy:"early-data-header-name,omitempty"` +} + // StreamConn implements C.ProxyAdapter func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { var err error switch v.option.Network { case "ws": - host, port, _ := net.SplitHostPort(v.addr) - wsOpts := &vmess.WebsocketConfig{ - Host: host, - Port: port, - Path: v.option.WSPath, + if v.option.WSOpts.Path == "" { + v.option.WSOpts.Path = v.option.WSPath + } + if len(v.option.WSOpts.Headers) == 0 { + v.option.WSOpts.Headers = v.option.WSHeaders } - if len(v.option.WSHeaders) != 0 { + host, port, _ := net.SplitHostPort(v.addr) + wsOpts := &vmess.WebsocketConfig{ + Host: host, + Port: port, + Path: v.option.WSOpts.Path, + MaxEarlyData: v.option.WSOpts.MaxEarlyData, + EarlyDataHeaderName: v.option.WSOpts.EarlyDataHeaderName, + } + + if len(v.option.WSOpts.Headers) != 0 { header := http.Header{} for key, value := range v.option.WSHeaders { header.Add(key, value) diff --git a/transport/vmess/conn.go b/transport/vmess/conn.go index e6e57be6..cc3155ee 100644 --- a/transport/vmess/conn.go +++ b/transport/vmess/conn.go @@ -59,12 +59,12 @@ func (vc *Conn) Read(b []byte) (int, error) { func (vc *Conn) sendRequest() error { timestamp := time.Now() + mbuf := &bytes.Buffer{} + if !vc.isAead { h := hmac.New(md5.New, vc.id.UUID.Bytes()) binary.Write(h, binary.BigEndian, uint64(timestamp.Unix())) - if _, err := vc.Conn.Write(h.Sum(nil)); err != nil { - return err - } + mbuf.Write(h.Sum(nil)) } buf := &bytes.Buffer{} @@ -110,7 +110,8 @@ func (vc *Conn) sendRequest() error { stream := cipher.NewCFBEncrypter(block, hashTimestamp(timestamp)) stream.XORKeyStream(buf.Bytes(), buf.Bytes()) - _, err = vc.Conn.Write(buf.Bytes()) + mbuf.Write(buf.Bytes()) + _, err = vc.Conn.Write(mbuf.Bytes()) return err } diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 6ed353e7..956b6d95 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -1,7 +1,11 @@ package vmess import ( + "bytes" + "context" "crypto/tls" + "encoding/base64" + "errors" "fmt" "io" "net" @@ -23,15 +27,26 @@ type websocketConn struct { rMux sync.Mutex wMux sync.Mutex } +type websocketWithEarlyDataConn struct { + net.Conn + underlay net.Conn + closed bool + dialed chan bool + cancel context.CancelFunc + ctx context.Context + config *WebsocketConfig +} type WebsocketConfig struct { - Host string - Port string - Path string - Headers http.Header - TLS bool - SkipCertVerify bool - ServerName string + Host string + Port string + Path string + Headers http.Header + TLS bool + SkipCertVerify bool + ServerName string + MaxEarlyData int + EarlyDataHeaderName string } // Read implements net.Conn.Read() @@ -113,7 +128,121 @@ func (wsc *websocketConn) SetWriteDeadline(t time.Time) error { return wsc.conn.SetWriteDeadline(t) } -func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { +func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { + earlyDataBuf := bytes.NewBuffer(nil) + base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf) + + earlydata := bytes.NewReader(earlyData) + limitedEarlyDatareader := io.LimitReader(earlydata, int64(wsedc.config.MaxEarlyData)) + n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader) + if encerr != nil { + return errors.New("failed to encode early data: " + encerr.Error()) + } + + if errc := base64EarlyDataEncoder.Close(); errc != nil { + return errors.New("failed to encode early data tail: " + errc.Error()) + } + + var err error + if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, earlyDataBuf); err != nil { + wsedc.Close() + return errors.New("failed to dial WebSocket: " + err.Error()) + } + + wsedc.dialed <- true + + if n != int64(len(earlyData)) { + _, err = wsedc.Conn.Write(earlyData[n:]) + } + + return err +} + +func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) { + if wsedc.closed { + return 0, io.ErrClosedPipe + } + if wsedc.Conn == nil { + if err := wsedc.Dial(b); err != nil { + return 0, err + } + return len(b), nil + } + + return wsedc.Conn.Write(b) +} + +func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { + if wsedc.closed { + return 0, io.ErrClosedPipe + } + if wsedc.Conn == nil { + select { + case <-wsedc.ctx.Done(): + return 0, io.ErrUnexpectedEOF + case <-wsedc.dialed: + } + } + return wsedc.Conn.Read(b) +} + +func (wsedc *websocketWithEarlyDataConn) Close() error { + wsedc.closed = true + wsedc.cancel() + if wsedc.Conn == nil { + return nil + } + return wsedc.Conn.Close() +} + +func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr { + if wsedc.Conn == nil { + return wsedc.underlay.LocalAddr() + } + return wsedc.Conn.LocalAddr() +} + +func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr { + if wsedc.Conn == nil { + return wsedc.underlay.RemoteAddr() + } + return wsedc.Conn.RemoteAddr() +} + +func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error { + if err := wsedc.SetReadDeadline(t); err != nil { + return err + } + return wsedc.SetWriteDeadline(t) +} + +func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error { + if wsedc.Conn == nil { + return nil + } + return wsedc.Conn.SetReadDeadline(t) +} + +func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error { + if wsedc.Conn == nil { + return nil + } + return wsedc.Conn.SetWriteDeadline(t) +} + +func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { + ctx, cancel := context.WithCancel(context.Background()) + conn = &websocketWithEarlyDataConn{ + dialed: make(chan bool, 1), + cancel: cancel, + ctx: ctx, + underlay: conn, + config: c, + } + return conn, nil +} + +func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { dialer := &websocket.Dialer{ NetDial: func(network, addr string) (net.Conn, error) { return conn, nil @@ -152,6 +281,14 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { } } + if earlyData != nil { + if c.EarlyDataHeaderName == "" { + uri.Path += earlyData.String() + } else { + headers.Set(c.EarlyDataHeaderName, earlyData.String()) + } + } + wsConn, resp, err := dialer.Dial(uri.String(), headers) if err != nil { reason := err.Error() @@ -166,3 +303,11 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { remoteAddr: conn.RemoteAddr(), }, nil } + +func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { + if c.MaxEarlyData > 0 { + return streamWebsocketWithEarlyDataConn(conn, c) + } + + return streamWebsocketConn(conn, c, nil) +}