From 1607d3253fbe2216f5e2614d5128863b4eea5b2e Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Tue, 11 Dec 2018 00:25:05 +0800 Subject: [PATCH] Feature: add websocket headers support in vmess --- README.md | 6 +-- adapters/outbound/vmess.go | 42 +++++++++++---------- common/structure/structure.go | 71 +++++++++++++++++++++++++++++++++++ component/vmess/vmess.go | 22 ++++++----- component/vmess/websocket.go | 11 +++++- 5 files changed, 118 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 60917f46..509209ae 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@
-

A rule based tunnel in Go.

+

A rule-based tunnel in Go.

@@ -126,8 +126,8 @@ Proxy: - { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, tls: true } # with tls and skip-cert-verify - { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, tls: true, skip-cert-verify: true } -# with ws -- { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, network: ws, ws-path: /path } +# with ws-path and ws-headers +- { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, network: ws, ws-path: /path, ws-headers: { Host: v2ray.com } } # with ws + tls - { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, network: ws, ws-path: /path, tls: true } diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index 9894bf80..dc3eb2d2 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -32,16 +32,17 @@ type Vmess struct { } type VmessOption struct { - Name string `proxy:"name"` - Server string `proxy:"server"` - Port int `proxy:"port"` - UUID string `proxy:"uuid"` - AlterID int `proxy:"alterId"` - Cipher string `proxy:"cipher"` - TLS bool `proxy:"tls,omitempty"` - Network string `proxy:"network,omitempty"` - WSPath string `proxy:"ws-path,omitempty"` - SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` + Name string `proxy:"name"` + Server string `proxy:"server"` + Port int `proxy:"port"` + UUID string `proxy:"uuid"` + AlterID int `proxy:"alterId"` + Cipher string `proxy:"cipher"` + TLS bool `proxy:"tls,omitempty"` + Network string `proxy:"network,omitempty"` + WSPath string `proxy:"ws-path,omitempty"` + WSHeaders map[string]string `proxy:"ws-headers,omitempty"` + SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` } func (v *Vmess) Name() string { @@ -71,16 +72,17 @@ func (v *Vmess) MarshalJSON() ([]byte, error) { func NewVmess(option VmessOption) (*Vmess, error) { security := strings.ToLower(option.Cipher) client, err := vmess.NewClient(vmess.Config{ - UUID: option.UUID, - AlterID: uint16(option.AlterID), - Security: security, - TLS: option.TLS, - HostName: option.Server, - Port: strconv.Itoa(option.Port), - NetWork: option.Network, - WebSocketPath: option.WSPath, - SkipCertVerify: option.SkipCertVerify, - SessionCacahe: getClientSessionCache(), + UUID: option.UUID, + AlterID: uint16(option.AlterID), + Security: security, + TLS: option.TLS, + HostName: option.Server, + Port: strconv.Itoa(option.Port), + NetWork: option.Network, + WebSocketPath: option.WSPath, + WebSocketHeaders: option.WSHeaders, + SkipCertVerify: option.SkipCertVerify, + SessionCacahe: getClientSessionCache(), }) if err != nil { return nil, err diff --git a/common/structure/structure.go b/common/structure/structure.go index ac824a07..600f264d 100644 --- a/common/structure/structure.go +++ b/common/structure/structure.go @@ -1,5 +1,7 @@ package structure +// references: https://github.com/mitchellh/mapstructure + import ( "fmt" "reflect" @@ -70,6 +72,8 @@ func (d *Decoder) decode(name string, data interface{}, val reflect.Value) error return d.decodeBool(name, data, val) case reflect.Slice: return d.decodeSlice(name, data, val) + case reflect.Map: + return d.decodeMap(name, data, val) default: return fmt.Errorf("type %s not support", val.Kind().String()) } @@ -158,3 +162,70 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) val.Set(valSlice) return nil } + +func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error { + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + valMap := val + + if valMap.IsNil() { + mapType := reflect.MapOf(valKeyType, valElemType) + valMap = reflect.MakeMap(mapType) + } + + dataVal := reflect.Indirect(reflect.ValueOf(data)) + if dataVal.Kind() != reflect.Map { + return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) + } + + return d.decodeMapFromMap(name, dataVal, val, valMap) +} + +func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { + valType := val.Type() + valKeyType := valType.Key() + valElemType := valType.Elem() + + errors := make([]string, 0) + + if dataVal.Len() == 0 { + if dataVal.IsNil() { + if !val.IsNil() { + val.Set(dataVal) + } + } else { + val.Set(valMap) + } + + return nil + } + + for _, k := range dataVal.MapKeys() { + fieldName := fmt.Sprintf("%s[%s]", name, k) + + currentKey := reflect.Indirect(reflect.New(valKeyType)) + if err := d.decode(fieldName, k.Interface(), currentKey); err != nil { + errors = append(errors, err.Error()) + continue + } + + v := dataVal.MapIndex(k).Interface() + currentVal := reflect.Indirect(reflect.New(valElemType)) + if err := d.decode(fieldName, v, currentVal); err != nil { + errors = append(errors, err.Error()) + continue + } + + valMap.SetMapIndex(currentKey, currentVal) + } + + val.Set(valMap) + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, ",")) + } + + return nil +} diff --git a/component/vmess/vmess.go b/component/vmess/vmess.go index 8273ea87..604b45f6 100644 --- a/component/vmess/vmess.go +++ b/component/vmess/vmess.go @@ -75,16 +75,17 @@ type Client struct { // Config of vmess type Config struct { - UUID string - AlterID uint16 - Security string - TLS bool - HostName string - Port string - NetWork string - WebSocketPath string - SkipCertVerify bool - SessionCacahe tls.ClientSessionCache + UUID string + AlterID uint16 + Security string + TLS bool + HostName string + Port string + NetWork string + WebSocketPath string + WebSocketHeaders map[string]string + SkipCertVerify bool + SessionCacahe tls.ClientSessionCache } // New return a Conn with net.Conn and DstAddr @@ -149,6 +150,7 @@ func NewClient(config Config) (*Client, error) { wsConfig = &websocketConfig{ host: host, path: config.WebSocketPath, + headers: config.WebSocketHeaders, tls: config.TLS, tlsConfig: tlsConfig, } diff --git a/component/vmess/websocket.go b/component/vmess/websocket.go index 7bef5e55..fc50b813 100644 --- a/component/vmess/websocket.go +++ b/component/vmess/websocket.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "net/http" "net/url" "strings" "time" @@ -21,6 +22,7 @@ type websocketConn struct { type websocketConfig struct { host string path string + headers map[string]string tls bool tlsConfig *tls.Config } @@ -127,7 +129,14 @@ func newWebsocketConn(conn net.Conn, c *websocketConfig) (net.Conn, error) { Path: c.path, } - wsConn, resp, err := dialer.Dial(uri.String(), nil) + headers := http.Header{} + if c.headers != nil { + for k, v := range c.headers { + headers.Set(k, v) + } + } + + wsConn, resp, err := dialer.Dial(uri.String(), headers) if err != nil { var reason string if resp != nil {