chore: wireguard's reserved support base64 input

This commit is contained in:
wwqgtxx 2022-12-13 20:40:07 +08:00
parent 9711390c18
commit afb2364ca2
2 changed files with 58 additions and 15 deletions

View file

@ -41,19 +41,19 @@ type WireGuard struct {
type WireGuardOption struct { type WireGuardOption struct {
BasicOption BasicOption
Name string `proxy:"name"` Name string `proxy:"name"`
Server string `proxy:"server"` Server string `proxy:"server"`
Port int `proxy:"port"` Port int `proxy:"port"`
Ip string `proxy:"ip,omitempty"` Ip string `proxy:"ip,omitempty"`
Ipv6 string `proxy:"ipv6,omitempty"` Ipv6 string `proxy:"ipv6,omitempty"`
PrivateKey string `proxy:"private-key"` PrivateKey string `proxy:"private-key"`
PublicKey string `proxy:"public-key"` PublicKey string `proxy:"public-key"`
PreSharedKey string `proxy:"pre-shared-key,omitempty"` PreSharedKey string `proxy:"pre-shared-key,omitempty"`
Reserved []int `proxy:"reserved,omitempty"` Reserved []uint8 `proxy:"reserved,omitempty"`
Workers int `proxy:"workers,omitempty"` Workers int `proxy:"workers,omitempty"`
MTU int `proxy:"mtu,omitempty"` MTU int `proxy:"mtu,omitempty"`
UDP bool `proxy:"udp,omitempty"` UDP bool `proxy:"udp,omitempty"`
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"` PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
} }
type wgDialer struct { type wgDialer struct {

View file

@ -3,6 +3,7 @@ package structure
// references: https://github.com/mitchellh/mapstructure // references: https://github.com/mitchellh/mapstructure
import ( import (
"encoding/base64"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
@ -86,8 +87,10 @@ func (d *Decoder) Decode(src map[string]any, dst any) error {
func (d *Decoder) decode(name string, data any, val reflect.Value) error { func (d *Decoder) decode(name string, data any, val reflect.Value) error {
switch val.Kind() { switch val.Kind() {
case reflect.Int: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return d.decodeInt(name, data, val) return d.decodeInt(name, data, val)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return d.decodeUint(name, data, val)
case reflect.String: case reflect.String:
return d.decodeString(name, data, val) return d.decodeString(name, data, val)
case reflect.Bool: case reflect.Bool:
@ -109,8 +112,10 @@ func (d *Decoder) decodeInt(name string, data any, val reflect.Value) (err error
dataVal := reflect.ValueOf(data) dataVal := reflect.ValueOf(data)
kind := dataVal.Kind() kind := dataVal.Kind()
switch { switch {
case kind == reflect.Int: case kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || kind == reflect.Int32 || kind == reflect.Int64:
val.SetInt(dataVal.Int()) val.SetInt(dataVal.Int())
case (kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || kind == reflect.Uint64) && d.option.WeaklyTypedInput:
val.SetInt(int64(dataVal.Uint()))
case kind == reflect.Float64 && d.option.WeaklyTypedInput: case kind == reflect.Float64 && d.option.WeaklyTypedInput:
val.SetInt(int64(dataVal.Float())) val.SetInt(int64(dataVal.Float()))
case kind == reflect.String && d.option.WeaklyTypedInput: case kind == reflect.String && d.option.WeaklyTypedInput:
@ -130,6 +135,33 @@ func (d *Decoder) decodeInt(name string, data any, val reflect.Value) (err error
return err return err
} }
func (d *Decoder) decodeUint(name string, data any, val reflect.Value) (err error) {
dataVal := reflect.ValueOf(data)
kind := dataVal.Kind()
switch {
case kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || kind == reflect.Uint64:
val.SetUint(dataVal.Uint())
case (kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || kind == reflect.Int32 || kind == reflect.Int64) && d.option.WeaklyTypedInput:
val.SetUint(uint64(dataVal.Int()))
case kind == reflect.Float64 && d.option.WeaklyTypedInput:
val.SetUint(uint64(int64(dataVal.Float())))
case kind == reflect.String && d.option.WeaklyTypedInput:
var i uint64
i, err = strconv.ParseUint(dataVal.String(), 0, val.Type().Bits())
if err == nil {
val.SetUint(i)
} else {
err = fmt.Errorf("cannot parse '%s' as int: %s", name, err)
}
default:
err = fmt.Errorf(
"'%s' expected type '%s', got unconvertible type '%s'",
name, val.Type(), dataVal.Type(),
)
}
return err
}
func (d *Decoder) decodeString(name string, data any, val reflect.Value) (err error) { func (d *Decoder) decodeString(name string, data any, val reflect.Value) (err error) {
dataVal := reflect.ValueOf(data) dataVal := reflect.ValueOf(data)
kind := dataVal.Kind() kind := dataVal.Kind()
@ -169,6 +201,17 @@ func (d *Decoder) decodeSlice(name string, data any, val reflect.Value) error {
valType := val.Type() valType := val.Type()
valElemType := valType.Elem() valElemType := valType.Elem()
if dataVal.Kind() == reflect.String && valElemType.Kind() == reflect.Uint8 {
s := []byte(dataVal.String())
b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
n, err := base64.StdEncoding.Decode(b, s)
if err != nil {
return fmt.Errorf("try decode '%s' by base64 error: %w", name, err)
}
val.SetBytes(b[:n])
return nil
}
if dataVal.Kind() != reflect.Slice { if dataVal.Kind() != reflect.Slice {
return fmt.Errorf("'%s' is not a slice", name) return fmt.Errorf("'%s' is not a slice", name)
} }