Feature: add vmess WebSocket early data (#1505)

Co-authored-by: ShinyGwyn <79344143+ShinyGwyn@users.noreply.github.com>
This commit is contained in:
gVisor bot 2021-08-22 00:25:29 +08:00
parent 5046f3beab
commit d99920a3e6
3 changed files with 181 additions and 18 deletions

View file

@ -43,6 +43,7 @@ type VmessOption struct {
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"` HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"` HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
WSOpts WSOptions `proxy:"ws-opts,omitempty"`
WSPath string `proxy:"ws-path,omitempty"` WSPath string `proxy:"ws-path,omitempty"`
WSHeaders map[string]string `proxy:"ws-headers,omitempty"` WSHeaders map[string]string `proxy:"ws-headers,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
@ -64,19 +65,35 @@ type GrpcOptions struct {
GrpcServiceName string `proxy:"grpc-service-name,omitempty"` GrpcServiceName string `proxy:"grpc-service-name,omitempty"`
} }
type WSOptions struct {
Path string `proxy:"path,omitempty"`
Headers map[string]string `proxy:"headers,omitempty"`
MaxEarlyData int `proxy:"max-early-data,omitempty"`
EarlyDataHeaderName string `proxy:"early-data-header-name,omitempty"`
}
// StreamConn implements C.ProxyAdapter // StreamConn implements C.ProxyAdapter
func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error var err error
switch v.option.Network { switch v.option.Network {
case "ws": case "ws":
host, port, _ := net.SplitHostPort(v.addr) if v.option.WSOpts.Path == "" {
wsOpts := &vmess.WebsocketConfig{ v.option.WSOpts.Path = v.option.WSPath
Host: host, }
Port: port, if len(v.option.WSOpts.Headers) == 0 {
Path: v.option.WSPath, v.option.WSOpts.Headers = v.option.WSHeaders
} }
if len(v.option.WSHeaders) != 0 { host, port, _ := net.SplitHostPort(v.addr)
wsOpts := &vmess.WebsocketConfig{
Host: host,
Port: port,
Path: v.option.WSOpts.Path,
MaxEarlyData: v.option.WSOpts.MaxEarlyData,
EarlyDataHeaderName: v.option.WSOpts.EarlyDataHeaderName,
}
if len(v.option.WSOpts.Headers) != 0 {
header := http.Header{} header := http.Header{}
for key, value := range v.option.WSHeaders { for key, value := range v.option.WSHeaders {
header.Add(key, value) header.Add(key, value)

View file

@ -59,12 +59,12 @@ func (vc *Conn) Read(b []byte) (int, error) {
func (vc *Conn) sendRequest() error { func (vc *Conn) sendRequest() error {
timestamp := time.Now() timestamp := time.Now()
mbuf := &bytes.Buffer{}
if !vc.isAead { if !vc.isAead {
h := hmac.New(md5.New, vc.id.UUID.Bytes()) h := hmac.New(md5.New, vc.id.UUID.Bytes())
binary.Write(h, binary.BigEndian, uint64(timestamp.Unix())) binary.Write(h, binary.BigEndian, uint64(timestamp.Unix()))
if _, err := vc.Conn.Write(h.Sum(nil)); err != nil { mbuf.Write(h.Sum(nil))
return err
}
} }
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
@ -110,7 +110,8 @@ func (vc *Conn) sendRequest() error {
stream := cipher.NewCFBEncrypter(block, hashTimestamp(timestamp)) stream := cipher.NewCFBEncrypter(block, hashTimestamp(timestamp))
stream.XORKeyStream(buf.Bytes(), buf.Bytes()) stream.XORKeyStream(buf.Bytes(), buf.Bytes())
_, err = vc.Conn.Write(buf.Bytes()) mbuf.Write(buf.Bytes())
_, err = vc.Conn.Write(mbuf.Bytes())
return err return err
} }

View file

@ -1,7 +1,11 @@
package vmess package vmess
import ( import (
"bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -23,15 +27,26 @@ type websocketConn struct {
rMux sync.Mutex rMux sync.Mutex
wMux sync.Mutex wMux sync.Mutex
} }
type websocketWithEarlyDataConn struct {
net.Conn
underlay net.Conn
closed bool
dialed chan bool
cancel context.CancelFunc
ctx context.Context
config *WebsocketConfig
}
type WebsocketConfig struct { type WebsocketConfig struct {
Host string Host string
Port string Port string
Path string Path string
Headers http.Header Headers http.Header
TLS bool TLS bool
SkipCertVerify bool SkipCertVerify bool
ServerName string ServerName string
MaxEarlyData int
EarlyDataHeaderName string
} }
// Read implements net.Conn.Read() // Read implements net.Conn.Read()
@ -113,7 +128,121 @@ func (wsc *websocketConn) SetWriteDeadline(t time.Time) error {
return wsc.conn.SetWriteDeadline(t) return wsc.conn.SetWriteDeadline(t)
} }
func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
earlyDataBuf := bytes.NewBuffer(nil)
base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
earlydata := bytes.NewReader(earlyData)
limitedEarlyDatareader := io.LimitReader(earlydata, int64(wsedc.config.MaxEarlyData))
n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
if encerr != nil {
return errors.New("failed to encode early data: " + encerr.Error())
}
if errc := base64EarlyDataEncoder.Close(); errc != nil {
return errors.New("failed to encode early data tail: " + errc.Error())
}
var err error
if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, earlyDataBuf); err != nil {
wsedc.Close()
return errors.New("failed to dial WebSocket: " + err.Error())
}
wsedc.dialed <- true
if n != int64(len(earlyData)) {
_, err = wsedc.Conn.Write(earlyData[n:])
}
return err
}
func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(b); err != nil {
return 0, err
}
return len(b), nil
}
return wsedc.Conn.Write(b)
}
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
}
if wsedc.Conn == nil {
select {
case <-wsedc.ctx.Done():
return 0, io.ErrUnexpectedEOF
case <-wsedc.dialed:
}
}
return wsedc.Conn.Read(b)
}
func (wsedc *websocketWithEarlyDataConn) Close() error {
wsedc.closed = true
wsedc.cancel()
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.Close()
}
func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
if wsedc.Conn == nil {
return wsedc.underlay.LocalAddr()
}
return wsedc.Conn.LocalAddr()
}
func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
if wsedc.Conn == nil {
return wsedc.underlay.RemoteAddr()
}
return wsedc.Conn.RemoteAddr()
}
func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
if err := wsedc.SetReadDeadline(t); err != nil {
return err
}
return wsedc.SetWriteDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.SetReadDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.SetWriteDeadline(t)
}
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
ctx, cancel := context.WithCancel(context.Background())
conn = &websocketWithEarlyDataConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
underlay: conn,
config: c,
}
return conn, nil
}
func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) { NetDial: func(network, addr string) (net.Conn, error) {
return conn, nil return conn, nil
@ -152,6 +281,14 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
} }
} }
if earlyData != nil {
if c.EarlyDataHeaderName == "" {
uri.Path += earlyData.String()
} else {
headers.Set(c.EarlyDataHeaderName, earlyData.String())
}
}
wsConn, resp, err := dialer.Dial(uri.String(), headers) wsConn, resp, err := dialer.Dial(uri.String(), headers)
if err != nil { if err != nil {
reason := err.Error() reason := err.Error()
@ -166,3 +303,11 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
remoteAddr: conn.RemoteAddr(), remoteAddr: conn.RemoteAddr(),
}, nil }, nil
} }
func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
if c.MaxEarlyData > 0 {
return streamWebsocketWithEarlyDataConn(conn, c)
}
return streamWebsocketConn(conn, c, nil)
}