From b699fb046b7ed08d028d452ca9c218eb69f01e8a Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Wed, 9 Nov 2022 19:35:03 +0800 Subject: [PATCH] fix: wireguard's dns resolve --- adapter/outbound/wireguard.go | 27 +++++++++++++++------------ component/resolver/resolver.go | 34 +++++++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index fe26285e..0dd13515 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -43,7 +43,7 @@ type WireGuardOption struct { Name string `proxy:"name"` Server string `proxy:"server"` Port int `proxy:"port"` - Ip string `proxy:"ip"` + Ip string `proxy:"ip,omitempty"` Ipv6 string `proxy:"ipv6,omitempty"` PrivateKey string `proxy:"private-key"` PublicKey string `proxy:"public-key"` @@ -94,16 +94,15 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { peerAddr.Port = uint16(option.Port) outbound.bind = wireguard.NewClientBind(context.Background(), outbound.dialer, peerAddr, reserved) localPrefixes := make([]netip.Prefix, 0, 2) - if len(option.Ip) == 0 { - return nil, E.New("missing local address") - } - if !strings.Contains(option.Ip, "/") { - option.Ip = option.Ip + "/32" - } - if prefix, err := netip.ParsePrefix(option.Ip); err == nil { - localPrefixes = append(localPrefixes, prefix) - } else { - return nil, E.Cause(err, "ip address parse error") + if len(option.Ip) > 0 { + if !strings.Contains(option.Ip, "/") { + option.Ip = option.Ip + "/32" + } + if prefix, err := netip.ParsePrefix(option.Ip); err == nil { + localPrefixes = append(localPrefixes, prefix) + } else { + return nil, E.Cause(err, "ip address parse error") + } } if len(option.Ipv6) > 0 { if !strings.Contains(option.Ipv6, "/") { @@ -115,6 +114,9 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { return nil, E.Cause(err, "ipv6 address parse error") } } + if len(localPrefixes) == 0 { + return nil, E.New("missing local address") + } var privateKey, peerPublicKey, preSharedKey string { bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey) @@ -202,7 +204,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts return nil, err } if !metadata.Resolved() { - addrs, err := resolver.ResolveAllIP(metadata.Host) + var addrs []netip.Addr + addrs, err = resolver.ResolveAllIP(metadata.Host) if err != nil { return nil, err } diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index af32cc94..51df094f 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -154,7 +154,15 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) { return []netip.Addr{}, ErrIPNotFound } - return []netip.Addr{netip.AddrFrom16(*(*[16]byte)(ipAddrs[rand.Intn(len(ipAddrs))]))}, nil + addrs := make([]netip.Addr, 0, len(ipAddrs)) + for _, ipAddr := range ipAddrs { + addrs = append(addrs, nnip.IpToAddr(ipAddr)) + } + + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + return addrs, nil } return []netip.Addr{}, ErrIPNotFound } @@ -188,12 +196,15 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) { return []netip.Addr{}, ErrIPNotFound } - ip := ipAddrs[rand.Intn(len(ipAddrs))].To4() - if ip == nil { - return []netip.Addr{}, ErrIPVersion + addrs := make([]netip.Addr, 0, len(ipAddrs)) + for _, ipAddr := range ipAddrs { + addrs = append(addrs, nnip.IpToAddr(ipAddr)) } - return []netip.Addr{netip.AddrFrom4(*(*[4]byte)(ip))}, nil + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + return addrs, nil } return []netip.Addr{}, ErrIPNotFound } @@ -219,12 +230,21 @@ func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) { } if DefaultResolver == nil { - ipAddr, err := net.ResolveIPAddr("ip", host) + ipAddrs, err := net.DefaultResolver.LookupIP(context.Background(), "ip", host) if err != nil { return []netip.Addr{}, err + } else if len(ipAddrs) == 0 { + return []netip.Addr{}, ErrIPNotFound + } + addrs := make([]netip.Addr, 0, len(ipAddrs)) + for _, ipAddr := range ipAddrs { + addrs = append(addrs, nnip.IpToAddr(ipAddr)) } - return []netip.Addr{nnip.IpToAddr(ipAddr.IP)}, nil + rand.Shuffle(len(addrs), func(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] + }) + return addrs, nil } return []netip.Addr{}, ErrIPNotFound }