diff --git a/go.mod b/go.mod index 826eb0ba..d1503b48 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 github.com/mdlayher/netlink v1.7.2 github.com/metacubex/gopacket v1.1.20-0.20230608035415-7e2f98a3e759 - github.com/metacubex/quic-go v0.39.1-0.20230930051114-b486c7799a55 - github.com/metacubex/sing-quic v0.0.0-20230930052455-ae588c275b9c + github.com/metacubex/quic-go v0.39.1-0.20231001052253-5776efe31623 + github.com/metacubex/sing-quic v0.0.0-20230926004739-7c7c534c2255 github.com/metacubex/sing-shadowsocks v0.2.5 github.com/metacubex/sing-shadowsocks2 v0.1.4 github.com/metacubex/sing-tun v0.1.13-0.20230926010214-4e9d1add2aee @@ -32,7 +32,7 @@ require ( github.com/oschwald/maxminddb-golang v1.12.0 github.com/puzpuzpuz/xsync/v2 v2.5.0 github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 - github.com/sagernet/sing v0.2.11 + github.com/sagernet/sing v0.2.12 github.com/sagernet/sing-mux v0.1.3 github.com/sagernet/sing-shadowtls v0.1.4 github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 @@ -105,4 +105,4 @@ require ( golang.org/x/tools v0.13.0 // indirect ) -replace github.com/sagernet/sing => github.com/metacubex/sing v0.0.0-20230926010351-b23b466642d1 +replace github.com/sagernet/sing => github.com/metacubex/sing v0.0.0-20231001053806-1230641572b9 diff --git a/go.sum b/go.sum index d6c96aaa..3aa28a7f 100644 --- a/go.sum +++ b/go.sum @@ -93,12 +93,12 @@ github.com/metacubex/gopacket v1.1.20-0.20230608035415-7e2f98a3e759 h1:cjd4biTvO github.com/metacubex/gopacket v1.1.20-0.20230608035415-7e2f98a3e759/go.mod h1:UHOv2xu+RIgLwpXca7TLrXleEd4oR3sPatW6IF8wU88= github.com/metacubex/gvisor v0.0.0-20230611153922-78842f086475 h1:qSEOvPPaMrWggFyFhFYGyMR8i1HKyhXjdi1QYUAa2ww= github.com/metacubex/gvisor v0.0.0-20230611153922-78842f086475/go.mod h1:wehEpqiogdeyncfhckJP5gD2LtBgJW0wnDC24mJ+8Jg= -github.com/metacubex/quic-go v0.39.1-0.20230930051114-b486c7799a55 h1:cAqp0BFOTr/1TpFicH1dA1q/6fp7E/JkqHBORfohqr4= -github.com/metacubex/quic-go v0.39.1-0.20230930051114-b486c7799a55/go.mod h1:4pe6cY+nAMFU/Uxn1rfnxNIowsaJGDQ3uyy4VuiPkP4= -github.com/metacubex/sing v0.0.0-20230926010351-b23b466642d1 h1:MkYAvDyhb7cwuqL4ZLKU3Oi6tYjFnz1sz5LS82JmtDo= -github.com/metacubex/sing v0.0.0-20230926010351-b23b466642d1/go.mod h1:GQ673iPfUnkbK/dIPkfd1Xh1MjOGo36gkl/mkiHY7Jg= -github.com/metacubex/sing-quic v0.0.0-20230930052455-ae588c275b9c h1:j7PKIUUhOAxJaLf/NmUKuIs9R06xNoYizwYgqf5HSrA= -github.com/metacubex/sing-quic v0.0.0-20230930052455-ae588c275b9c/go.mod h1:TPAXFCHCtzW9Dm+wq1l1R/p0v/S/xmuRU0qfPR7WlOA= +github.com/metacubex/quic-go v0.39.1-0.20231001052253-5776efe31623 h1:lxXUXdS2GB4Ktn3ocnzQ53v1lqd6LYYfYIKICugTaJM= +github.com/metacubex/quic-go v0.39.1-0.20231001052253-5776efe31623/go.mod h1:4pe6cY+nAMFU/Uxn1rfnxNIowsaJGDQ3uyy4VuiPkP4= +github.com/metacubex/sing v0.0.0-20231001053806-1230641572b9 h1:F0+IuW0tZ96QHEmrebXAdYnz7ab7Gz4l5yYC4g6Cg8k= +github.com/metacubex/sing v0.0.0-20231001053806-1230641572b9/go.mod h1:GQ673iPfUnkbK/dIPkfd1Xh1MjOGo36gkl/mkiHY7Jg= +github.com/metacubex/sing-quic v0.0.0-20230926004739-7c7c534c2255 h1:NfdM4hDFIhq9QxDStJ9Rz1h73sRUO/2L4pRZ6lGWRz8= +github.com/metacubex/sing-quic v0.0.0-20230926004739-7c7c534c2255/go.mod h1:asoMecRyaA6pLSLVR+qFdp/vD24m8KZ1O/QDxWa7RsM= github.com/metacubex/sing-shadowsocks v0.2.5 h1:O2RRSHlKGEpAVG/OHJQxyHqDy8uvvdCW/oW2TDBOIhc= github.com/metacubex/sing-shadowsocks v0.2.5/go.mod h1:Xz2uW9BEYGEoA8B4XEpoxt7ERHClFCwsMAvWaruoyMo= github.com/metacubex/sing-shadowsocks2 v0.1.4 h1:OOCf8lgsVcpTOJUeaFAMzyKVebaQOBnKirDdUdBoKIE= diff --git a/transport/hysteria/congestion/brutal.go b/transport/hysteria/congestion/brutal.go index 67067917..88bf6f34 100644 --- a/transport/hysteria/congestion/brutal.go +++ b/transport/hysteria/congestion/brutal.go @@ -100,10 +100,6 @@ func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostByt b.updateAckRate(currentTimestamp) } -func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { - // Stub -} - func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { b.maxDatagramSize = size b.pacer.SetMaxDatagramSize(size) diff --git a/transport/tuic/common/congestion.go b/transport/tuic/common/congestion.go index 1176ebd6..13c1300f 100644 --- a/transport/tuic/common/congestion.go +++ b/transport/tuic/common/congestion.go @@ -18,6 +18,22 @@ func SetCongestionController(quicConn quic.Connection, cc string, cwnd int) { cwnd = 32 } switch cc { + case "cubic": + quicConn.SetCongestionControl( + congestion.NewCubicSender( + congestion.DefaultClock{}, + congestion.GetInitialPacketSize(quicConn.RemoteAddr()), + false, + ), + ) + case "new_reno": + quicConn.SetCongestionControl( + congestion.NewCubicSender( + congestion.DefaultClock{}, + congestion.GetInitialPacketSize(quicConn.RemoteAddr()), + true, + ), + ) case "bbr_meta_v1": quicConn.SetCongestionControl( congestion.NewBBRSender( diff --git a/transport/tuic/congestion/cubic.go b/transport/tuic/congestion/cubic.go new file mode 100644 index 00000000..dd491a32 --- /dev/null +++ b/transport/tuic/congestion/cubic.go @@ -0,0 +1,213 @@ +package congestion + +import ( + "math" + "time" + + "github.com/metacubex/quic-go/congestion" +) + +// This cubic implementation is based on the one found in Chromiums's QUIC +// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. + +// Constants based on TCP defaults. +// The following constants are in 2^10 fractions of a second instead of ms to +// allow a 10 shift right to divide. + +// 1024*1024^3 (first 1024 is from 0.100^3) +// where 0.100 is 100 ms which is the scaling round trip time. +const ( + cubeScale = 40 + cubeCongestionWindowScale = 410 + cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize + // TODO: when re-enabling cubic, make sure to use the actual packet size here + maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4) +) + +const defaultNumConnections = 1 + +// Default Cubic backoff factor +const beta float32 = 0.7 + +// Additional backoff factor when loss occurs in the concave part of the Cubic +// curve. This additional backoff factor is expected to give up bandwidth to +// new concurrent flows and speed up convergence. +const betaLastMax float32 = 0.85 + +// Cubic implements the cubic algorithm from TCP +type Cubic struct { + clock Clock + + // Number of connections to simulate. + numConnections int + + // Time when this cycle started, after last loss event. + epoch time.Time + + // Max congestion window used just before last loss event. + // Note: to improve fairness to other streams an additional back off is + // applied to this value if the new value is below our latest value. + lastMaxCongestionWindow congestion.ByteCount + + // Number of acked bytes since the cycle started (epoch). + ackedBytesCount congestion.ByteCount + + // TCP Reno equivalent congestion window in packets. + estimatedTCPcongestionWindow congestion.ByteCount + + // Origin point of cubic function. + originPointCongestionWindow congestion.ByteCount + + // Time to origin point of cubic function in 2^10 fractions of a second. + timeToOriginPoint uint32 + + // Last congestion window in packets computed by cubic function. + lastTargetCongestionWindow congestion.ByteCount +} + +// NewCubic returns a new Cubic instance +func NewCubic(clock Clock) *Cubic { + c := &Cubic{ + clock: clock, + numConnections: defaultNumConnections, + } + c.Reset() + return c +} + +// Reset is called after a timeout to reset the cubic state +func (c *Cubic) Reset() { + c.epoch = time.Time{} + c.lastMaxCongestionWindow = 0 + c.ackedBytesCount = 0 + c.estimatedTCPcongestionWindow = 0 + c.originPointCongestionWindow = 0 + c.timeToOriginPoint = 0 + c.lastTargetCongestionWindow = 0 +} + +func (c *Cubic) alpha() float32 { + // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that + // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. + // We derive the equivalent alpha for an N-connection emulation as: + b := c.beta() + return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) +} + +func (c *Cubic) beta() float32 { + // kNConnectionBeta is the backoff factor after loss for our N-connection + // emulation, which emulates the effective backoff of an ensemble of N + // TCP-Reno connections on a single loss event. The effective multiplier is + // computed as: + return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) +} + +func (c *Cubic) betaLastMax() float32 { + // betaLastMax is the additional backoff factor after loss for our + // N-connection emulation, which emulates the additional backoff of + // an ensemble of N TCP-Reno connections on a single loss event. The + // effective multiplier is computed as: + return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) +} + +// OnApplicationLimited is called on ack arrival when sender is unable to use +// the available congestion window. Resets Cubic state during quiescence. +func (c *Cubic) OnApplicationLimited() { + // When sender is not using the available congestion window, the window does + // not grow. But to be RTT-independent, Cubic assumes that the sender has been + // using the entire window during the time since the beginning of the current + // "epoch" (the end of the last loss recovery period). Since + // application-limited periods break this assumption, we reset the epoch when + // in such a period. This reset effectively freezes congestion window growth + // through application-limited periods and allows Cubic growth to continue + // when the entire window is being used. + c.epoch = time.Time{} +} + +// CongestionWindowAfterPacketLoss computes a new congestion window to use after +// a loss event. Returns the new congestion window in packets. The new +// congestion window is a multiplicative decrease of our current window. +func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount { + if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { + // We never reached the old max, so assume we are competing with another + // flow. Use our extra back off factor to allow the other flow to go up. + c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) + } else { + c.lastMaxCongestionWindow = currentCongestionWindow + } + c.epoch = time.Time{} // Reset time. + return congestion.ByteCount(float32(currentCongestionWindow) * c.beta()) +} + +// CongestionWindowAfterAck computes a new congestion window to use after a received ACK. +// Returns the new congestion window in packets. The new congestion window +// follows a cubic function that depends on the time passed since last +// packet loss. +func (c *Cubic) CongestionWindowAfterAck( + ackedBytes congestion.ByteCount, + currentCongestionWindow congestion.ByteCount, + delayMin time.Duration, + eventTime time.Time, +) congestion.ByteCount { + c.ackedBytesCount += ackedBytes + + if c.epoch.IsZero() { + // First ACK after a loss event. + c.epoch = eventTime // Start of epoch. + c.ackedBytesCount = ackedBytes // Reset count. + // Reset estimated_tcp_congestion_window_ to be in sync with cubic. + c.estimatedTCPcongestionWindow = currentCongestionWindow + if c.lastMaxCongestionWindow <= currentCongestionWindow { + c.timeToOriginPoint = 0 + c.originPointCongestionWindow = currentCongestionWindow + } else { + c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) + c.originPointCongestionWindow = c.lastMaxCongestionWindow + } + } + + // Change the time unit from microseconds to 2^10 fractions per second. Take + // the round trip time in account. This is done to allow us to use shift as a + // divide operator. + elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) + + // Right-shifts of negative, signed numbers have implementation-dependent + // behavior, so force the offset to be positive, as is done in the kernel. + offset := int64(c.timeToOriginPoint) - elapsedTime + if offset < 0 { + offset = -offset + } + + deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale + var targetCongestionWindow congestion.ByteCount + if elapsedTime > int64(c.timeToOriginPoint) { + targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow + } else { + targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow + } + // Limit the CWND increase to half the acked bytes. + targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + + // Increase the window by approximately Alpha * 1 MSS of bytes every + // time we ack an estimated tcp window of bytes. For small + // congestion windows (less than 25), the formula below will + // increase slightly slower than linearly per estimated tcp window + // of bytes. + c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) + c.ackedBytesCount = 0 + + // We have a new cubic congestion window. + c.lastTargetCongestionWindow = targetCongestionWindow + + // Compute target congestion_window based on cubic target and estimated TCP + // congestion_window, use highest (fastest). + if targetCongestionWindow < c.estimatedTCPcongestionWindow { + targetCongestionWindow = c.estimatedTCPcongestionWindow + } + return targetCongestionWindow +} + +// SetNumConnections sets the number of emulated connections +func (c *Cubic) SetNumConnections(n int) { + c.numConnections = n +} diff --git a/transport/tuic/congestion/cubic_sender.go b/transport/tuic/congestion/cubic_sender.go new file mode 100644 index 00000000..f544cd74 --- /dev/null +++ b/transport/tuic/congestion/cubic_sender.go @@ -0,0 +1,297 @@ +package congestion + +import ( + "fmt" + "time" + + "github.com/metacubex/quic-go/congestion" +) + +const ( + maxBurstPackets = 3 + renoBeta = 0.7 // Reno backoff factor. + minCongestionWindowPackets = 2 + initialCongestionWindow = 32 +) + +const InvalidPacketNumber congestion.PacketNumber = -1 +const MaxCongestionWindowPackets = 20000 +const MaxByteCount = congestion.ByteCount(1<<62 - 1) + +type cubicSender struct { + hybridSlowStart HybridSlowStart + rttStats congestion.RTTStatsProvider + cubic *Cubic + pacer *pacer + clock Clock + + reno bool + + // Track the largest packet that has been sent. + largestSentPacketNumber congestion.PacketNumber + + // Track the largest packet that has been acked. + largestAckedPacketNumber congestion.PacketNumber + + // Track the largest packet number outstanding when a CWND cutback occurs. + largestSentAtLastCutback congestion.PacketNumber + + // Whether the last loss event caused us to exit slowstart. + // Used for stats collection of slowstartPacketsLost + lastCutbackExitedSlowstart bool + + // Congestion window in bytes. + congestionWindow congestion.ByteCount + + // Slow start congestion window in bytes, aka ssthresh. + slowStartThreshold congestion.ByteCount + + // ACK counter for the Reno implementation. + numAckedPackets uint64 + + initialCongestionWindow congestion.ByteCount + initialMaxCongestionWindow congestion.ByteCount + + maxDatagramSize congestion.ByteCount +} + +var ( + _ congestion.CongestionControl = &cubicSender{} +) + +// NewCubicSender makes a new cubic sender +func NewCubicSender( + clock Clock, + initialMaxDatagramSize congestion.ByteCount, + reno bool, +) *cubicSender { + return newCubicSender( + clock, + reno, + initialMaxDatagramSize, + initialCongestionWindow*initialMaxDatagramSize, + MaxCongestionWindowPackets*initialMaxDatagramSize, + ) +} + +func newCubicSender( + clock Clock, + reno bool, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow congestion.ByteCount, +) *cubicSender { + c := &cubicSender{ + largestSentPacketNumber: InvalidPacketNumber, + largestAckedPacketNumber: InvalidPacketNumber, + largestSentAtLastCutback: InvalidPacketNumber, + initialCongestionWindow: initialCongestionWindow, + initialMaxCongestionWindow: initialMaxCongestionWindow, + congestionWindow: initialCongestionWindow, + slowStartThreshold: MaxByteCount, + cubic: NewCubic(clock), + clock: clock, + reno: reno, + maxDatagramSize: initialMaxDatagramSize, + } + c.pacer = newPacer(c.BandwidthEstimate) + return c +} + +func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { + c.rttStats = provider +} + +// TimeUntilSend returns when the next packet should be sent. +func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time { + return c.pacer.TimeUntilSend() +} + +func (c *cubicSender) HasPacingBudget(now time.Time) bool { + return c.pacer.Budget(now) >= c.maxDatagramSize +} + +func (c *cubicSender) maxCongestionWindow() congestion.ByteCount { + return c.maxDatagramSize * MaxCongestionWindowPackets +} + +func (c *cubicSender) minCongestionWindow() congestion.ByteCount { + return c.maxDatagramSize * minCongestionWindowPackets +} + +func (c *cubicSender) OnPacketSent( + sentTime time.Time, + _ congestion.ByteCount, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + isRetransmittable bool, +) { + c.pacer.SentPacket(sentTime, bytes) + if !isRetransmittable { + return + } + c.largestSentPacketNumber = packetNumber + c.hybridSlowStart.OnPacketSent(packetNumber) +} + +func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < c.GetCongestionWindow() +} + +func (c *cubicSender) InRecovery() bool { + return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback +} + +func (c *cubicSender) InSlowStart() bool { + return c.GetCongestionWindow() < c.slowStartThreshold +} + +func (c *cubicSender) GetCongestionWindow() congestion.ByteCount { + return c.congestionWindow +} + +func (c *cubicSender) MaybeExitSlowStart() { + if c.InSlowStart() && + c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { + // exit slow start + c.slowStartThreshold = c.congestionWindow + } +} + +func (c *cubicSender) OnPacketAcked( + ackedPacketNumber congestion.PacketNumber, + ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, + eventTime time.Time, +) { + c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber) + if c.InRecovery() { + return + } + c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) + if c.InSlowStart() { + c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) + } +} + +func (c *cubicSender) OnCongestionEvent(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets + // already sent should be treated as a single loss event, since it's expected. + if packetNumber <= c.largestSentAtLastCutback { + return + } + c.lastCutbackExitedSlowstart = c.InSlowStart() + + if c.reno { + c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta) + } else { + c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) + } + if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { + c.congestionWindow = minCwnd + } + c.slowStartThreshold = c.congestionWindow + c.largestSentAtLastCutback = c.largestSentPacketNumber + // reset packet count from congestion avoidance mode. We start + // counting again when we're out of recovery. + c.numAckedPackets = 0 +} + +func (b *cubicSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + // Stub +} + +// Called when we receive an ack. Normal TCP tracks how many packets one ack +// represents, but quic has a separate ack for each packet. +func (c *cubicSender) maybeIncreaseCwnd( + _ congestion.PacketNumber, + ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, + eventTime time.Time, +) { + // Do not increase the congestion window unless the sender is close to using + // the current window. + if !c.isCwndLimited(priorInFlight) { + c.cubic.OnApplicationLimited() + return + } + if c.congestionWindow >= c.maxCongestionWindow() { + return + } + if c.InSlowStart() { + // TCP slow start, exponential growth, increase by one for each ACK. + c.congestionWindow += c.maxDatagramSize + return + } + // Congestion avoidance + if c.reno { + // Classic Reno congestion avoidance. + c.numAckedPackets++ + if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { + c.congestionWindow += c.maxDatagramSize + c.numAckedPackets = 0 + } + } else { + c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) + } +} + +func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool { + congestionWindow := c.GetCongestionWindow() + if bytesInFlight >= congestionWindow { + return true + } + availableBytes := congestionWindow - bytesInFlight + slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 + return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize +} + +// BandwidthEstimate returns the current bandwidth estimate +func (c *cubicSender) BandwidthEstimate() Bandwidth { + if c.rttStats == nil { + return infBandwidth + } + srtt := c.rttStats.SmoothedRTT() + if srtt == 0 { + // If we haven't measured an rtt, the bandwidth estimate is unknown. + return infBandwidth + } + return BandwidthFromDelta(c.GetCongestionWindow(), srtt) +} + +// OnRetransmissionTimeout is called on an retransmission timeout +func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { + c.largestSentAtLastCutback = InvalidPacketNumber + if !packetsRetransmitted { + return + } + c.hybridSlowStart.Restart() + c.cubic.Reset() + c.slowStartThreshold = c.congestionWindow / 2 + c.congestionWindow = c.minCongestionWindow() +} + +// OnConnectionMigration is called when the connection is migrated (?) +func (c *cubicSender) OnConnectionMigration() { + c.hybridSlowStart.Restart() + c.largestSentPacketNumber = InvalidPacketNumber + c.largestAckedPacketNumber = InvalidPacketNumber + c.largestSentAtLastCutback = InvalidPacketNumber + c.lastCutbackExitedSlowstart = false + c.cubic.Reset() + c.numAckedPackets = 0 + c.congestionWindow = c.initialCongestionWindow + c.slowStartThreshold = c.initialMaxCongestionWindow +} + +func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) { + if s < c.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) + } + cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() + c.maxDatagramSize = s + if cwndIsMinCwnd { + c.congestionWindow = c.minCongestionWindow() + } + c.pacer.SetMaxDatagramSize(s) +} diff --git a/transport/tuic/congestion/hybrid_slow_start.go b/transport/tuic/congestion/hybrid_slow_start.go new file mode 100644 index 00000000..7586774e --- /dev/null +++ b/transport/tuic/congestion/hybrid_slow_start.go @@ -0,0 +1,112 @@ +package congestion + +import ( + "time" + + "github.com/metacubex/quic-go/congestion" +) + +// Note(pwestin): the magic clamping numbers come from the original code in +// tcp_cubic.c. +const hybridStartLowWindow = congestion.ByteCount(16) + +// Number of delay samples for detecting the increase of delay. +const hybridStartMinSamples = uint32(8) + +// Exit slow start if the min rtt has increased by more than 1/8th. +const hybridStartDelayFactorExp = 3 // 2^3 = 8 +// The original paper specifies 2 and 8ms, but those have changed over time. +const ( + hybridStartDelayMinThresholdUs = int64(4000) + hybridStartDelayMaxThresholdUs = int64(16000) +) + +// HybridSlowStart implements the TCP hybrid slow start algorithm +type HybridSlowStart struct { + endPacketNumber congestion.PacketNumber + lastSentPacketNumber congestion.PacketNumber + started bool + currentMinRTT time.Duration + rttSampleCount uint32 + hystartFound bool +} + +// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. +func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) { + s.endPacketNumber = lastSent + s.currentMinRTT = 0 + s.rttSampleCount = 0 + s.started = true +} + +// IsEndOfRound returns true if this ack is the last packet number of our current slow start round. +func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool { + return s.endPacketNumber < ack +} + +// ShouldExitSlowStart should be called on every new ack frame, since a new +// RTT measurement can be made then. +// rtt: the RTT for this ack packet. +// minRTT: is the lowest delay (RTT) we have seen during the session. +// congestionWindow: the congestion window in packets. +func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool { + if !s.started { + // Time to start the hybrid slow start. + s.StartReceiveRound(s.lastSentPacketNumber) + } + if s.hystartFound { + return true + } + // Second detection parameter - delay increase detection. + // Compare the minimum delay (s.currentMinRTT) of the current + // burst of packets relative to the minimum delay during the session. + // Note: we only look at the first few(8) packets in each burst, since we + // only want to compare the lowest RTT of the burst relative to previous + // bursts. + s.rttSampleCount++ + if s.rttSampleCount <= hybridStartMinSamples { + if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { + s.currentMinRTT = latestRTT + } + } + // We only need to check this once per round. + if s.rttSampleCount == hybridStartMinSamples { + // Divide minRTT by 8 to get a rtt increase threshold for exiting. + minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) + // Ensure the rtt threshold is never less than 2ms or more than 16ms. + minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) + minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond + + if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { + s.hystartFound = true + } + } + // Exit from slow start if the cwnd is greater than 16 and + // increasing delay is found. + return congestionWindow >= hybridStartLowWindow && s.hystartFound +} + +// OnPacketSent is called when a packet was sent +func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) { + s.lastSentPacketNumber = packetNumber +} + +// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end +// the round when the final packet of the burst is received and start it on +// the next incoming ack. +func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) { + if s.IsEndOfRound(ackedPacketNumber) { + s.started = false + } +} + +// Started returns true if started +func (s *HybridSlowStart) Started() bool { + return s.started +} + +// Restart the slow start phase +func (s *HybridSlowStart) Restart() { + s.started = false + s.hystartFound = false +} diff --git a/transport/tuic/congestion/minmax.go b/transport/tuic/congestion/minmax.go new file mode 100644 index 00000000..0a8f4ad4 --- /dev/null +++ b/transport/tuic/congestion/minmax.go @@ -0,0 +1,56 @@ +package congestion + +import ( + "math" + "time" +) + +// InfDuration is a duration of infinite length +const InfDuration = time.Duration(math.MaxInt64) + +// MinNonZeroDuration return the minimum duration that's not zero. +func MinNonZeroDuration(a, b time.Duration) time.Duration { + if a == 0 { + return b + } + if b == 0 { + return a + } + return Min(a, b) +} + +// AbsDuration returns the absolute value of a time duration +func AbsDuration(d time.Duration) time.Duration { + if d >= 0 { + return d + } + return -d +} + +// MinTime returns the earlier time +func MinTime(a, b time.Time) time.Time { + if a.After(b) { + return b + } + return a +} + +// MinNonZeroTime returns the earlist time that is not time.Time{} +// If both a and b are time.Time{}, it returns time.Time{} +func MinNonZeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b + } + if b.IsZero() { + return a + } + return MinTime(a, b) +} + +// MaxTime returns the later time +func MaxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +}