From 1607d3253fbe2216f5e2614d5128863b4eea5b2e Mon Sep 17 00:00:00 2001
From: Dreamacro <305009791@qq.com>
Date: Tue, 11 Dec 2018 00:25:05 +0800
Subject: [PATCH] Feature: add websocket headers support in vmess
---
README.md | 6 +--
adapters/outbound/vmess.go | 42 +++++++++++----------
common/structure/structure.go | 71 +++++++++++++++++++++++++++++++++++
component/vmess/vmess.go | 22 ++++++-----
component/vmess/websocket.go | 11 +++++-
5 files changed, 118 insertions(+), 34 deletions(-)
diff --git a/README.md b/README.md
index 60917f46..509209ae 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
-
A rule based tunnel in Go.
+A rule-based tunnel in Go.
@@ -126,8 +126,8 @@ Proxy:
- { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, tls: true }
# with tls and skip-cert-verify
- { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, tls: true, skip-cert-verify: true }
-# with ws
-- { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, network: ws, ws-path: /path }
+# with ws-path and ws-headers
+- { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, network: ws, ws-path: /path, ws-headers: { Host: v2ray.com } }
# with ws + tls
- { name: "vmess", type: vmess, server: server, port: 443, uuid: uuid, alterId: 32, cipher: auto, network: ws, ws-path: /path, tls: true }
diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go
index 9894bf80..dc3eb2d2 100644
--- a/adapters/outbound/vmess.go
+++ b/adapters/outbound/vmess.go
@@ -32,16 +32,17 @@ type Vmess struct {
}
type VmessOption struct {
- Name string `proxy:"name"`
- Server string `proxy:"server"`
- Port int `proxy:"port"`
- UUID string `proxy:"uuid"`
- AlterID int `proxy:"alterId"`
- Cipher string `proxy:"cipher"`
- TLS bool `proxy:"tls,omitempty"`
- Network string `proxy:"network,omitempty"`
- WSPath string `proxy:"ws-path,omitempty"`
- SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
+ Name string `proxy:"name"`
+ Server string `proxy:"server"`
+ Port int `proxy:"port"`
+ UUID string `proxy:"uuid"`
+ AlterID int `proxy:"alterId"`
+ Cipher string `proxy:"cipher"`
+ TLS bool `proxy:"tls,omitempty"`
+ Network string `proxy:"network,omitempty"`
+ WSPath string `proxy:"ws-path,omitempty"`
+ WSHeaders map[string]string `proxy:"ws-headers,omitempty"`
+ SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
}
func (v *Vmess) Name() string {
@@ -71,16 +72,17 @@ func (v *Vmess) MarshalJSON() ([]byte, error) {
func NewVmess(option VmessOption) (*Vmess, error) {
security := strings.ToLower(option.Cipher)
client, err := vmess.NewClient(vmess.Config{
- UUID: option.UUID,
- AlterID: uint16(option.AlterID),
- Security: security,
- TLS: option.TLS,
- HostName: option.Server,
- Port: strconv.Itoa(option.Port),
- NetWork: option.Network,
- WebSocketPath: option.WSPath,
- SkipCertVerify: option.SkipCertVerify,
- SessionCacahe: getClientSessionCache(),
+ UUID: option.UUID,
+ AlterID: uint16(option.AlterID),
+ Security: security,
+ TLS: option.TLS,
+ HostName: option.Server,
+ Port: strconv.Itoa(option.Port),
+ NetWork: option.Network,
+ WebSocketPath: option.WSPath,
+ WebSocketHeaders: option.WSHeaders,
+ SkipCertVerify: option.SkipCertVerify,
+ SessionCacahe: getClientSessionCache(),
})
if err != nil {
return nil, err
diff --git a/common/structure/structure.go b/common/structure/structure.go
index ac824a07..600f264d 100644
--- a/common/structure/structure.go
+++ b/common/structure/structure.go
@@ -1,5 +1,7 @@
package structure
+// references: https://github.com/mitchellh/mapstructure
+
import (
"fmt"
"reflect"
@@ -70,6 +72,8 @@ func (d *Decoder) decode(name string, data interface{}, val reflect.Value) error
return d.decodeBool(name, data, val)
case reflect.Slice:
return d.decodeSlice(name, data, val)
+ case reflect.Map:
+ return d.decodeMap(name, data, val)
default:
return fmt.Errorf("type %s not support", val.Kind().String())
}
@@ -158,3 +162,70 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)
val.Set(valSlice)
return nil
}
+
+func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error {
+ valType := val.Type()
+ valKeyType := valType.Key()
+ valElemType := valType.Elem()
+
+ valMap := val
+
+ if valMap.IsNil() {
+ mapType := reflect.MapOf(valKeyType, valElemType)
+ valMap = reflect.MakeMap(mapType)
+ }
+
+ dataVal := reflect.Indirect(reflect.ValueOf(data))
+ if dataVal.Kind() != reflect.Map {
+ return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
+ }
+
+ return d.decodeMapFromMap(name, dataVal, val, valMap)
+}
+
+func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
+ valType := val.Type()
+ valKeyType := valType.Key()
+ valElemType := valType.Elem()
+
+ errors := make([]string, 0)
+
+ if dataVal.Len() == 0 {
+ if dataVal.IsNil() {
+ if !val.IsNil() {
+ val.Set(dataVal)
+ }
+ } else {
+ val.Set(valMap)
+ }
+
+ return nil
+ }
+
+ for _, k := range dataVal.MapKeys() {
+ fieldName := fmt.Sprintf("%s[%s]", name, k)
+
+ currentKey := reflect.Indirect(reflect.New(valKeyType))
+ if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
+ errors = append(errors, err.Error())
+ continue
+ }
+
+ v := dataVal.MapIndex(k).Interface()
+ currentVal := reflect.Indirect(reflect.New(valElemType))
+ if err := d.decode(fieldName, v, currentVal); err != nil {
+ errors = append(errors, err.Error())
+ continue
+ }
+
+ valMap.SetMapIndex(currentKey, currentVal)
+ }
+
+ val.Set(valMap)
+
+ if len(errors) > 0 {
+ return fmt.Errorf(strings.Join(errors, ","))
+ }
+
+ return nil
+}
diff --git a/component/vmess/vmess.go b/component/vmess/vmess.go
index 8273ea87..604b45f6 100644
--- a/component/vmess/vmess.go
+++ b/component/vmess/vmess.go
@@ -75,16 +75,17 @@ type Client struct {
// Config of vmess
type Config struct {
- UUID string
- AlterID uint16
- Security string
- TLS bool
- HostName string
- Port string
- NetWork string
- WebSocketPath string
- SkipCertVerify bool
- SessionCacahe tls.ClientSessionCache
+ UUID string
+ AlterID uint16
+ Security string
+ TLS bool
+ HostName string
+ Port string
+ NetWork string
+ WebSocketPath string
+ WebSocketHeaders map[string]string
+ SkipCertVerify bool
+ SessionCacahe tls.ClientSessionCache
}
// New return a Conn with net.Conn and DstAddr
@@ -149,6 +150,7 @@ func NewClient(config Config) (*Client, error) {
wsConfig = &websocketConfig{
host: host,
path: config.WebSocketPath,
+ headers: config.WebSocketHeaders,
tls: config.TLS,
tlsConfig: tlsConfig,
}
diff --git a/component/vmess/websocket.go b/component/vmess/websocket.go
index 7bef5e55..fc50b813 100644
--- a/component/vmess/websocket.go
+++ b/component/vmess/websocket.go
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net"
+ "net/http"
"net/url"
"strings"
"time"
@@ -21,6 +22,7 @@ type websocketConn struct {
type websocketConfig struct {
host string
path string
+ headers map[string]string
tls bool
tlsConfig *tls.Config
}
@@ -127,7 +129,14 @@ func newWebsocketConn(conn net.Conn, c *websocketConfig) (net.Conn, error) {
Path: c.path,
}
- wsConn, resp, err := dialer.Dial(uri.String(), nil)
+ headers := http.Header{}
+ if c.headers != nil {
+ for k, v := range c.headers {
+ headers.Set(k, v)
+ }
+ }
+
+ wsConn, resp, err := dialer.Dial(uri.String(), headers)
if err != nil {
var reason string
if resp != nil {