From a0d15df88041921a0fcedb2f59215a400cdb1c01 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 23 Dec 2022 11:00:55 +0800 Subject: [PATCH] fix: trying to let hysteria's port hopping work --- adapter/outbound/hysteria.go | 12 +++----- transport/hysteria/conns/udp/hop.go | 39 ++++++++++++++------------ transport/hysteria/core/client.go | 6 ++-- transport/hysteria/transport/client.go | 14 ++++----- 4 files changed, 36 insertions(+), 35 deletions(-) diff --git a/adapter/outbound/hysteria.go b/adapter/outbound/hysteria.go index e019f7fa..a1276415 100644 --- a/adapter/outbound/hysteria.go +++ b/adapter/outbound/hysteria.go @@ -89,7 +89,7 @@ type HysteriaOption struct { BasicOption Name string `proxy:"name"` Server string `proxy:"server"` - Port int `proxy:"port"` + Port int `proxy:"port,omitempty"` Ports string `proxy:"ports,omitempty"` Protocol string `proxy:"protocol,omitempty"` ObfsProtocol string `proxy:"obfs-protocol,omitempty"` // compatible with Stash @@ -134,12 +134,8 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) { Timeout: 8 * time.Second, }, } - var addr string - if len(option.Ports) == 0 { - addr = net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) - } else { - addr = net.JoinHostPort(option.Server, option.Ports) - } + addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) + ports := option.Ports serverName := option.Server if option.SNI != "" { @@ -244,7 +240,7 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) { down = uint64(option.DownSpeed * mbpsToBps) } client, err := core.NewClient( - addr, option.Protocol, auth, tlsConfig, quicConfig, clientTransport, up, down, func(refBPS uint64) congestion.CongestionControl { + addr, ports, option.Protocol, auth, tlsConfig, quicConfig, clientTransport, up, down, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) }, obfuscator, hopInterval, option.FastOpen, ) diff --git a/transport/hysteria/conns/udp/hop.go b/transport/hysteria/conns/udp/hop.go index 4097b692..51d85ef2 100644 --- a/transport/hysteria/conns/udp/hop.go +++ b/transport/hysteria/conns/udp/hop.go @@ -58,20 +58,24 @@ type udpPacket struct { addr net.Addr } -func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obfs obfs.Obfuscator, dialer utils.PacketDialer) (*ObfsUDPHopClientPacketConn, error) { - host, ports, err := parseAddr(server) +func NewObfsUDPHopClientPacketConn(server string, serverPorts string, hopInterval time.Duration, obfs obfs.Obfuscator, dialer utils.PacketDialer) (*ObfsUDPHopClientPacketConn, error) { + ports, err := parsePorts(serverPorts) if err != nil { return nil, err } // Resolve the server IP address, then attach the ports to UDP addresses - ip, err := dialer.RemoteAddr(host) + rAddr, err := dialer.RemoteAddr(server) + if err != nil { + return nil, err + } + ip, _, err := net.SplitHostPort(rAddr.String()) if err != nil { return nil, err } serverAddrs := make([]net.Addr, len(ports)) for i, port := range ports { serverAddrs[i] = &net.UDPAddr{ - IP: net.ParseIP(ip.String()), + IP: net.ParseIP(ip), Port: int(port), } } @@ -90,7 +94,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf }, }, } - curConn, err := dialer.ListenPacket(ip) + curConn, err := dialer.ListenPacket(rAddr) if err != nil { return nil, err } @@ -100,7 +104,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf conn.currentConn = curConn } go conn.recvRoutine(conn.currentConn) - go conn.hopRoutine(dialer, ip) + go conn.hopRoutine(dialer, rAddr) return conn, nil } @@ -307,29 +311,25 @@ func trySetPacketConnWriteBuffer(pc net.PacketConn, bytes int) error { return nil } -// parseAddr parses the multi-port server address and returns the host and ports. +// parsePorts parses the multi-port server address and returns the host and ports. // Supports both comma-separated single ports and dash-separated port ranges. // Format: "host:port1,port2-port3,port4" -func parseAddr(addr string) (host string, ports []uint16, err error) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - return "", nil, err - } - portStrs := strings.Split(portStr, ",") +func parsePorts(serverPorts string) (ports []uint16, err error) { + portStrs := strings.Split(serverPorts, ",") for _, portStr := range portStrs { if strings.Contains(portStr, "-") { // Port range portRange := strings.Split(portStr, "-") if len(portRange) != 2 { - return "", nil, net.InvalidAddrError("invalid port range") + return nil, net.InvalidAddrError("invalid port range") } start, err := strconv.ParseUint(portRange[0], 10, 16) if err != nil { - return "", nil, net.InvalidAddrError("invalid port range") + return nil, net.InvalidAddrError("invalid port range") } end, err := strconv.ParseUint(portRange[1], 10, 16) if err != nil { - return "", nil, net.InvalidAddrError("invalid port range") + return nil, net.InvalidAddrError("invalid port range") } if start > end { start, end = end, start @@ -341,10 +341,13 @@ func parseAddr(addr string) (host string, ports []uint16, err error) { // Single port port, err := strconv.ParseUint(portStr, 10, 16) if err != nil { - return "", nil, net.InvalidAddrError("invalid port") + return nil, net.InvalidAddrError("invalid port") } ports = append(ports, uint16(port)) } } - return host, ports, nil + if len(ports) == 0 { + return nil, net.InvalidAddrError("invalid port") + } + return ports, nil } diff --git a/transport/hysteria/core/client.go b/transport/hysteria/core/client.go index e98a0c6b..5465e3d2 100644 --- a/transport/hysteria/core/client.go +++ b/transport/hysteria/core/client.go @@ -31,6 +31,7 @@ type CongestionFactory func(refBPS uint64) congestion.CongestionControl type Client struct { transport *transport.ClientTransport serverAddr string + serverPorts string protocol string sendBPS, recvBPS uint64 auth []byte @@ -51,13 +52,14 @@ type Client struct { fastOpen bool } -func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, +func NewClient(serverAddr string, serverPorts string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator obfs.Obfuscator, hopInterval time.Duration, fastOpen bool) (*Client, error) { quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery c := &Client{ transport: transport, serverAddr: serverAddr, + serverPorts: serverPorts, protocol: protocol, sendBPS: sendBPS, recvBPS: recvBPS, @@ -73,7 +75,7 @@ func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.C } func (c *Client) connectToServer(dialer utils.PacketDialer) error { - qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator, c.hopInterval, dialer) + qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.serverPorts, c.tlsConfig, c.quicConfig, c.obfuscator, c.hopInterval, dialer) if err != nil { return err } diff --git a/transport/hysteria/transport/client.go b/transport/hysteria/transport/client.go index af0c3ea2..e65e5016 100644 --- a/transport/hysteria/transport/client.go +++ b/transport/hysteria/transport/client.go @@ -20,7 +20,7 @@ type ClientTransport struct { Dialer *net.Dialer } -func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (net.PacketConn, error) { +func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, serverPorts string, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (net.PacketConn, error) { server := rAddr.String() if len(proto) == 0 || proto == "udp" { conn, err := dialer.ListenPacket(rAddr) @@ -28,14 +28,14 @@ func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, obfs obf return nil, err } if obfs != nil { - if isMultiPortAddr(server) { - return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, obfs, dialer) + if serverPorts != "" { + return udp.NewObfsUDPHopClientPacketConn(server, serverPorts, hopInterval, obfs, dialer) } oc := udp.NewObfsUDPConn(conn, obfs) return oc, nil } else { - if isMultiPortAddr(server) { - return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil, dialer) + if serverPorts != "" { + return udp.NewObfsUDPHopClientPacketConn(server, serverPorts, hopInterval, nil, dialer) } return conn, nil } @@ -65,13 +65,13 @@ func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, obfs obf } } -func (ct *ClientTransport) QUICDial(proto string, server string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (quic.Connection, error) { +func (ct *ClientTransport) QUICDial(proto string, server string, serverPorts string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (quic.Connection, error) { serverUDPAddr, err := dialer.RemoteAddr(server) if err != nil { return nil, err } - pktConn, err := ct.quicPacketConn(proto, serverUDPAddr, obfs, hopInterval, dialer) + pktConn, err := ct.quicPacketConn(proto, serverUDPAddr, serverPorts, obfs, hopInterval, dialer) if err != nil { return nil, err }