fix: wireguard's dns resolve

This commit is contained in:
wwqgtxx 2022-11-09 19:35:03 +08:00
parent ae08d13de4
commit b699fb046b
2 changed files with 42 additions and 19 deletions

View file

@ -43,7 +43,7 @@ type WireGuardOption struct {
Name string `proxy:"name"` Name string `proxy:"name"`
Server string `proxy:"server"` Server string `proxy:"server"`
Port int `proxy:"port"` Port int `proxy:"port"`
Ip string `proxy:"ip"` Ip string `proxy:"ip,omitempty"`
Ipv6 string `proxy:"ipv6,omitempty"` Ipv6 string `proxy:"ipv6,omitempty"`
PrivateKey string `proxy:"private-key"` PrivateKey string `proxy:"private-key"`
PublicKey string `proxy:"public-key"` PublicKey string `proxy:"public-key"`
@ -94,9 +94,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
peerAddr.Port = uint16(option.Port) peerAddr.Port = uint16(option.Port)
outbound.bind = wireguard.NewClientBind(context.Background(), outbound.dialer, peerAddr, reserved) outbound.bind = wireguard.NewClientBind(context.Background(), outbound.dialer, peerAddr, reserved)
localPrefixes := make([]netip.Prefix, 0, 2) localPrefixes := make([]netip.Prefix, 0, 2)
if len(option.Ip) == 0 { if len(option.Ip) > 0 {
return nil, E.New("missing local address")
}
if !strings.Contains(option.Ip, "/") { if !strings.Contains(option.Ip, "/") {
option.Ip = option.Ip + "/32" option.Ip = option.Ip + "/32"
} }
@ -105,6 +103,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
} else { } else {
return nil, E.Cause(err, "ip address parse error") return nil, E.Cause(err, "ip address parse error")
} }
}
if len(option.Ipv6) > 0 { if len(option.Ipv6) > 0 {
if !strings.Contains(option.Ipv6, "/") { if !strings.Contains(option.Ipv6, "/") {
option.Ipv6 = option.Ipv6 + "/128" option.Ipv6 = option.Ipv6 + "/128"
@ -115,6 +114,9 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
return nil, E.Cause(err, "ipv6 address parse error") return nil, E.Cause(err, "ipv6 address parse error")
} }
} }
if len(localPrefixes) == 0 {
return nil, E.New("missing local address")
}
var privateKey, peerPublicKey, preSharedKey string var privateKey, peerPublicKey, preSharedKey string
{ {
bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey) bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey)
@ -202,7 +204,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
return nil, err return nil, err
} }
if !metadata.Resolved() { if !metadata.Resolved() {
addrs, err := resolver.ResolveAllIP(metadata.Host) var addrs []netip.Addr
addrs, err = resolver.ResolveAllIP(metadata.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -154,7 +154,15 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) {
return []netip.Addr{}, ErrIPNotFound return []netip.Addr{}, ErrIPNotFound
} }
return []netip.Addr{netip.AddrFrom16(*(*[16]byte)(ipAddrs[rand.Intn(len(ipAddrs))]))}, nil addrs := make([]netip.Addr, 0, len(ipAddrs))
for _, ipAddr := range ipAddrs {
addrs = append(addrs, nnip.IpToAddr(ipAddr))
}
rand.Shuffle(len(addrs), func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})
return addrs, nil
} }
return []netip.Addr{}, ErrIPNotFound return []netip.Addr{}, ErrIPNotFound
} }
@ -188,12 +196,15 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
return []netip.Addr{}, ErrIPNotFound return []netip.Addr{}, ErrIPNotFound
} }
ip := ipAddrs[rand.Intn(len(ipAddrs))].To4() addrs := make([]netip.Addr, 0, len(ipAddrs))
if ip == nil { for _, ipAddr := range ipAddrs {
return []netip.Addr{}, ErrIPVersion addrs = append(addrs, nnip.IpToAddr(ipAddr))
} }
return []netip.Addr{netip.AddrFrom4(*(*[4]byte)(ip))}, nil rand.Shuffle(len(addrs), func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})
return addrs, nil
} }
return []netip.Addr{}, ErrIPNotFound return []netip.Addr{}, ErrIPNotFound
} }
@ -219,12 +230,21 @@ func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) {
} }
if DefaultResolver == nil { if DefaultResolver == nil {
ipAddr, err := net.ResolveIPAddr("ip", host) ipAddrs, err := net.DefaultResolver.LookupIP(context.Background(), "ip", host)
if err != nil { if err != nil {
return []netip.Addr{}, err return []netip.Addr{}, err
} else if len(ipAddrs) == 0 {
return []netip.Addr{}, ErrIPNotFound
}
addrs := make([]netip.Addr, 0, len(ipAddrs))
for _, ipAddr := range ipAddrs {
addrs = append(addrs, nnip.IpToAddr(ipAddr))
} }
return []netip.Addr{nnip.IpToAddr(ipAddr.IP)}, nil rand.Shuffle(len(addrs), func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})
return addrs, nil
} }
return []netip.Addr{}, ErrIPNotFound return []netip.Addr{}, ErrIPNotFound
} }