Feature: add vmess WebSocket early data (#1505)

Co-authored-by: ShinyGwyn <79344143+ShinyGwyn@users.noreply.github.com>
This commit is contained in:
gVisor bot 2021-08-22 00:25:29 +08:00
parent 5046f3beab
commit d99920a3e6
3 changed files with 181 additions and 18 deletions

View file

@ -43,6 +43,7 @@ type VmessOption struct {
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"` HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"` HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
WSOpts WSOptions `proxy:"ws-opts,omitempty"`
WSPath string `proxy:"ws-path,omitempty"` WSPath string `proxy:"ws-path,omitempty"`
WSHeaders map[string]string `proxy:"ws-headers,omitempty"` WSHeaders map[string]string `proxy:"ws-headers,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
@ -64,19 +65,35 @@ type GrpcOptions struct {
GrpcServiceName string `proxy:"grpc-service-name,omitempty"` GrpcServiceName string `proxy:"grpc-service-name,omitempty"`
} }
type WSOptions struct {
Path string `proxy:"path,omitempty"`
Headers map[string]string `proxy:"headers,omitempty"`
MaxEarlyData int `proxy:"max-early-data,omitempty"`
EarlyDataHeaderName string `proxy:"early-data-header-name,omitempty"`
}
// StreamConn implements C.ProxyAdapter // StreamConn implements C.ProxyAdapter
func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error var err error
switch v.option.Network { switch v.option.Network {
case "ws": case "ws":
if v.option.WSOpts.Path == "" {
v.option.WSOpts.Path = v.option.WSPath
}
if len(v.option.WSOpts.Headers) == 0 {
v.option.WSOpts.Headers = v.option.WSHeaders
}
host, port, _ := net.SplitHostPort(v.addr) host, port, _ := net.SplitHostPort(v.addr)
wsOpts := &vmess.WebsocketConfig{ wsOpts := &vmess.WebsocketConfig{
Host: host, Host: host,
Port: port, Port: port,
Path: v.option.WSPath, Path: v.option.WSOpts.Path,
MaxEarlyData: v.option.WSOpts.MaxEarlyData,
EarlyDataHeaderName: v.option.WSOpts.EarlyDataHeaderName,
} }
if len(v.option.WSHeaders) != 0 { if len(v.option.WSOpts.Headers) != 0 {
header := http.Header{} header := http.Header{}
for key, value := range v.option.WSHeaders { for key, value := range v.option.WSHeaders {
header.Add(key, value) header.Add(key, value)

View file

@ -59,12 +59,12 @@ func (vc *Conn) Read(b []byte) (int, error) {
func (vc *Conn) sendRequest() error { func (vc *Conn) sendRequest() error {
timestamp := time.Now() timestamp := time.Now()
mbuf := &bytes.Buffer{}
if !vc.isAead { if !vc.isAead {
h := hmac.New(md5.New, vc.id.UUID.Bytes()) h := hmac.New(md5.New, vc.id.UUID.Bytes())
binary.Write(h, binary.BigEndian, uint64(timestamp.Unix())) binary.Write(h, binary.BigEndian, uint64(timestamp.Unix()))
if _, err := vc.Conn.Write(h.Sum(nil)); err != nil { mbuf.Write(h.Sum(nil))
return err
}
} }
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
@ -110,7 +110,8 @@ func (vc *Conn) sendRequest() error {
stream := cipher.NewCFBEncrypter(block, hashTimestamp(timestamp)) stream := cipher.NewCFBEncrypter(block, hashTimestamp(timestamp))
stream.XORKeyStream(buf.Bytes(), buf.Bytes()) stream.XORKeyStream(buf.Bytes(), buf.Bytes())
_, err = vc.Conn.Write(buf.Bytes()) mbuf.Write(buf.Bytes())
_, err = vc.Conn.Write(mbuf.Bytes())
return err return err
} }

View file

@ -1,7 +1,11 @@
package vmess package vmess
import ( import (
"bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -23,6 +27,15 @@ type websocketConn struct {
rMux sync.Mutex rMux sync.Mutex
wMux sync.Mutex wMux sync.Mutex
} }
type websocketWithEarlyDataConn struct {
net.Conn
underlay net.Conn
closed bool
dialed chan bool
cancel context.CancelFunc
ctx context.Context
config *WebsocketConfig
}
type WebsocketConfig struct { type WebsocketConfig struct {
Host string Host string
@ -32,6 +45,8 @@ type WebsocketConfig struct {
TLS bool TLS bool
SkipCertVerify bool SkipCertVerify bool
ServerName string ServerName string
MaxEarlyData int
EarlyDataHeaderName string
} }
// Read implements net.Conn.Read() // Read implements net.Conn.Read()
@ -113,7 +128,121 @@ func (wsc *websocketConn) SetWriteDeadline(t time.Time) error {
return wsc.conn.SetWriteDeadline(t) return wsc.conn.SetWriteDeadline(t)
} }
func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
earlyDataBuf := bytes.NewBuffer(nil)
base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
earlydata := bytes.NewReader(earlyData)
limitedEarlyDatareader := io.LimitReader(earlydata, int64(wsedc.config.MaxEarlyData))
n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
if encerr != nil {
return errors.New("failed to encode early data: " + encerr.Error())
}
if errc := base64EarlyDataEncoder.Close(); errc != nil {
return errors.New("failed to encode early data tail: " + errc.Error())
}
var err error
if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, earlyDataBuf); err != nil {
wsedc.Close()
return errors.New("failed to dial WebSocket: " + err.Error())
}
wsedc.dialed <- true
if n != int64(len(earlyData)) {
_, err = wsedc.Conn.Write(earlyData[n:])
}
return err
}
func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(b); err != nil {
return 0, err
}
return len(b), nil
}
return wsedc.Conn.Write(b)
}
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
}
if wsedc.Conn == nil {
select {
case <-wsedc.ctx.Done():
return 0, io.ErrUnexpectedEOF
case <-wsedc.dialed:
}
}
return wsedc.Conn.Read(b)
}
func (wsedc *websocketWithEarlyDataConn) Close() error {
wsedc.closed = true
wsedc.cancel()
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.Close()
}
func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
if wsedc.Conn == nil {
return wsedc.underlay.LocalAddr()
}
return wsedc.Conn.LocalAddr()
}
func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
if wsedc.Conn == nil {
return wsedc.underlay.RemoteAddr()
}
return wsedc.Conn.RemoteAddr()
}
func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
if err := wsedc.SetReadDeadline(t); err != nil {
return err
}
return wsedc.SetWriteDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.SetReadDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.SetWriteDeadline(t)
}
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
ctx, cancel := context.WithCancel(context.Background())
conn = &websocketWithEarlyDataConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
underlay: conn,
config: c,
}
return conn, nil
}
func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) { NetDial: func(network, addr string) (net.Conn, error) {
return conn, nil return conn, nil
@ -152,6 +281,14 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
} }
} }
if earlyData != nil {
if c.EarlyDataHeaderName == "" {
uri.Path += earlyData.String()
} else {
headers.Set(c.EarlyDataHeaderName, earlyData.String())
}
}
wsConn, resp, err := dialer.Dial(uri.String(), headers) wsConn, resp, err := dialer.Dial(uri.String(), headers)
if err != nil { if err != nil {
reason := err.Error() reason := err.Error()
@ -166,3 +303,11 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
}, nil }, nil
} }
func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
if c.MaxEarlyData > 0 {
return streamWebsocketWithEarlyDataConn(conn, c)
}
return streamWebsocketConn(conn, c, nil)
}