diff --git a/transport/tuic/common/congestion.go b/transport/tuic/common/congestion.go index 158f22d6..1176ebd6 100644 --- a/transport/tuic/common/congestion.go +++ b/transport/tuic/common/congestion.go @@ -2,6 +2,7 @@ package common import ( "github.com/Dreamacro/clash/transport/tuic/congestion" + congestionv2 "github.com/Dreamacro/clash/transport/tuic/congestion_v2" "github.com/metacubex/quic-go" c "github.com/metacubex/quic-go/congestion" @@ -17,23 +18,7 @@ func SetCongestionController(quicConn quic.Connection, cc string, cwnd int) { cwnd = 32 } switch cc { - case "cubic": - quicConn.SetCongestionControl( - congestion.NewCubicSender( - congestion.DefaultClock{}, - congestion.GetInitialPacketSize(quicConn.RemoteAddr()), - false, - ), - ) - case "new_reno": - quicConn.SetCongestionControl( - congestion.NewCubicSender( - congestion.DefaultClock{}, - congestion.GetInitialPacketSize(quicConn.RemoteAddr()), - true, - ), - ) - case "bbr": + case "bbr_meta_v1": quicConn.SetCongestionControl( congestion.NewBBRSender( congestion.DefaultClock{}, @@ -42,5 +27,15 @@ func SetCongestionController(quicConn quic.Connection, cc string, cwnd int) { congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize, ), ) + case "bbr_meta_v2": + fallthrough + case "bbr": + quicConn.SetCongestionControl( + congestionv2.NewBbrSender( + congestionv2.DefaultClock{}, + congestionv2.GetInitialPacketSize(quicConn.RemoteAddr()), + c.ByteCount(cwnd), + ), + ) } } diff --git a/transport/tuic/congestion/cubic.go b/transport/tuic/congestion/cubic.go deleted file mode 100644 index dd491a32..00000000 --- a/transport/tuic/congestion/cubic.go +++ /dev/null @@ -1,213 +0,0 @@ -package congestion - -import ( - "math" - "time" - - "github.com/metacubex/quic-go/congestion" -) - -// This cubic implementation is based on the one found in Chromiums's QUIC -// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. - -// Constants based on TCP defaults. -// The following constants are in 2^10 fractions of a second instead of ms to -// allow a 10 shift right to divide. - -// 1024*1024^3 (first 1024 is from 0.100^3) -// where 0.100 is 100 ms which is the scaling round trip time. -const ( - cubeScale = 40 - cubeCongestionWindowScale = 410 - cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize - // TODO: when re-enabling cubic, make sure to use the actual packet size here - maxDatagramSize = congestion.ByteCount(InitialPacketSizeIPv4) -) - -const defaultNumConnections = 1 - -// Default Cubic backoff factor -const beta float32 = 0.7 - -// Additional backoff factor when loss occurs in the concave part of the Cubic -// curve. This additional backoff factor is expected to give up bandwidth to -// new concurrent flows and speed up convergence. -const betaLastMax float32 = 0.85 - -// Cubic implements the cubic algorithm from TCP -type Cubic struct { - clock Clock - - // Number of connections to simulate. - numConnections int - - // Time when this cycle started, after last loss event. - epoch time.Time - - // Max congestion window used just before last loss event. - // Note: to improve fairness to other streams an additional back off is - // applied to this value if the new value is below our latest value. - lastMaxCongestionWindow congestion.ByteCount - - // Number of acked bytes since the cycle started (epoch). - ackedBytesCount congestion.ByteCount - - // TCP Reno equivalent congestion window in packets. - estimatedTCPcongestionWindow congestion.ByteCount - - // Origin point of cubic function. - originPointCongestionWindow congestion.ByteCount - - // Time to origin point of cubic function in 2^10 fractions of a second. - timeToOriginPoint uint32 - - // Last congestion window in packets computed by cubic function. - lastTargetCongestionWindow congestion.ByteCount -} - -// NewCubic returns a new Cubic instance -func NewCubic(clock Clock) *Cubic { - c := &Cubic{ - clock: clock, - numConnections: defaultNumConnections, - } - c.Reset() - return c -} - -// Reset is called after a timeout to reset the cubic state -func (c *Cubic) Reset() { - c.epoch = time.Time{} - c.lastMaxCongestionWindow = 0 - c.ackedBytesCount = 0 - c.estimatedTCPcongestionWindow = 0 - c.originPointCongestionWindow = 0 - c.timeToOriginPoint = 0 - c.lastTargetCongestionWindow = 0 -} - -func (c *Cubic) alpha() float32 { - // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that - // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. - // We derive the equivalent alpha for an N-connection emulation as: - b := c.beta() - return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) -} - -func (c *Cubic) beta() float32 { - // kNConnectionBeta is the backoff factor after loss for our N-connection - // emulation, which emulates the effective backoff of an ensemble of N - // TCP-Reno connections on a single loss event. The effective multiplier is - // computed as: - return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) -} - -func (c *Cubic) betaLastMax() float32 { - // betaLastMax is the additional backoff factor after loss for our - // N-connection emulation, which emulates the additional backoff of - // an ensemble of N TCP-Reno connections on a single loss event. The - // effective multiplier is computed as: - return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) -} - -// OnApplicationLimited is called on ack arrival when sender is unable to use -// the available congestion window. Resets Cubic state during quiescence. -func (c *Cubic) OnApplicationLimited() { - // When sender is not using the available congestion window, the window does - // not grow. But to be RTT-independent, Cubic assumes that the sender has been - // using the entire window during the time since the beginning of the current - // "epoch" (the end of the last loss recovery period). Since - // application-limited periods break this assumption, we reset the epoch when - // in such a period. This reset effectively freezes congestion window growth - // through application-limited periods and allows Cubic growth to continue - // when the entire window is being used. - c.epoch = time.Time{} -} - -// CongestionWindowAfterPacketLoss computes a new congestion window to use after -// a loss event. Returns the new congestion window in packets. The new -// congestion window is a multiplicative decrease of our current window. -func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount { - if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { - // We never reached the old max, so assume we are competing with another - // flow. Use our extra back off factor to allow the other flow to go up. - c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) - } else { - c.lastMaxCongestionWindow = currentCongestionWindow - } - c.epoch = time.Time{} // Reset time. - return congestion.ByteCount(float32(currentCongestionWindow) * c.beta()) -} - -// CongestionWindowAfterAck computes a new congestion window to use after a received ACK. -// Returns the new congestion window in packets. The new congestion window -// follows a cubic function that depends on the time passed since last -// packet loss. -func (c *Cubic) CongestionWindowAfterAck( - ackedBytes congestion.ByteCount, - currentCongestionWindow congestion.ByteCount, - delayMin time.Duration, - eventTime time.Time, -) congestion.ByteCount { - c.ackedBytesCount += ackedBytes - - if c.epoch.IsZero() { - // First ACK after a loss event. - c.epoch = eventTime // Start of epoch. - c.ackedBytesCount = ackedBytes // Reset count. - // Reset estimated_tcp_congestion_window_ to be in sync with cubic. - c.estimatedTCPcongestionWindow = currentCongestionWindow - if c.lastMaxCongestionWindow <= currentCongestionWindow { - c.timeToOriginPoint = 0 - c.originPointCongestionWindow = currentCongestionWindow - } else { - c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) - c.originPointCongestionWindow = c.lastMaxCongestionWindow - } - } - - // Change the time unit from microseconds to 2^10 fractions per second. Take - // the round trip time in account. This is done to allow us to use shift as a - // divide operator. - elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) - - // Right-shifts of negative, signed numbers have implementation-dependent - // behavior, so force the offset to be positive, as is done in the kernel. - offset := int64(c.timeToOriginPoint) - elapsedTime - if offset < 0 { - offset = -offset - } - - deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale - var targetCongestionWindow congestion.ByteCount - if elapsedTime > int64(c.timeToOriginPoint) { - targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow - } else { - targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow - } - // Limit the CWND increase to half the acked bytes. - targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) - - // Increase the window by approximately Alpha * 1 MSS of bytes every - // time we ack an estimated tcp window of bytes. For small - // congestion windows (less than 25), the formula below will - // increase slightly slower than linearly per estimated tcp window - // of bytes. - c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) - c.ackedBytesCount = 0 - - // We have a new cubic congestion window. - c.lastTargetCongestionWindow = targetCongestionWindow - - // Compute target congestion_window based on cubic target and estimated TCP - // congestion_window, use highest (fastest). - if targetCongestionWindow < c.estimatedTCPcongestionWindow { - targetCongestionWindow = c.estimatedTCPcongestionWindow - } - return targetCongestionWindow -} - -// SetNumConnections sets the number of emulated connections -func (c *Cubic) SetNumConnections(n int) { - c.numConnections = n -} diff --git a/transport/tuic/congestion/cubic_sender.go b/transport/tuic/congestion/cubic_sender.go deleted file mode 100644 index f544cd74..00000000 --- a/transport/tuic/congestion/cubic_sender.go +++ /dev/null @@ -1,297 +0,0 @@ -package congestion - -import ( - "fmt" - "time" - - "github.com/metacubex/quic-go/congestion" -) - -const ( - maxBurstPackets = 3 - renoBeta = 0.7 // Reno backoff factor. - minCongestionWindowPackets = 2 - initialCongestionWindow = 32 -) - -const InvalidPacketNumber congestion.PacketNumber = -1 -const MaxCongestionWindowPackets = 20000 -const MaxByteCount = congestion.ByteCount(1<<62 - 1) - -type cubicSender struct { - hybridSlowStart HybridSlowStart - rttStats congestion.RTTStatsProvider - cubic *Cubic - pacer *pacer - clock Clock - - reno bool - - // Track the largest packet that has been sent. - largestSentPacketNumber congestion.PacketNumber - - // Track the largest packet that has been acked. - largestAckedPacketNumber congestion.PacketNumber - - // Track the largest packet number outstanding when a CWND cutback occurs. - largestSentAtLastCutback congestion.PacketNumber - - // Whether the last loss event caused us to exit slowstart. - // Used for stats collection of slowstartPacketsLost - lastCutbackExitedSlowstart bool - - // Congestion window in bytes. - congestionWindow congestion.ByteCount - - // Slow start congestion window in bytes, aka ssthresh. - slowStartThreshold congestion.ByteCount - - // ACK counter for the Reno implementation. - numAckedPackets uint64 - - initialCongestionWindow congestion.ByteCount - initialMaxCongestionWindow congestion.ByteCount - - maxDatagramSize congestion.ByteCount -} - -var ( - _ congestion.CongestionControl = &cubicSender{} -) - -// NewCubicSender makes a new cubic sender -func NewCubicSender( - clock Clock, - initialMaxDatagramSize congestion.ByteCount, - reno bool, -) *cubicSender { - return newCubicSender( - clock, - reno, - initialMaxDatagramSize, - initialCongestionWindow*initialMaxDatagramSize, - MaxCongestionWindowPackets*initialMaxDatagramSize, - ) -} - -func newCubicSender( - clock Clock, - reno bool, - initialMaxDatagramSize, - initialCongestionWindow, - initialMaxCongestionWindow congestion.ByteCount, -) *cubicSender { - c := &cubicSender{ - largestSentPacketNumber: InvalidPacketNumber, - largestAckedPacketNumber: InvalidPacketNumber, - largestSentAtLastCutback: InvalidPacketNumber, - initialCongestionWindow: initialCongestionWindow, - initialMaxCongestionWindow: initialMaxCongestionWindow, - congestionWindow: initialCongestionWindow, - slowStartThreshold: MaxByteCount, - cubic: NewCubic(clock), - clock: clock, - reno: reno, - maxDatagramSize: initialMaxDatagramSize, - } - c.pacer = newPacer(c.BandwidthEstimate) - return c -} - -func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { - c.rttStats = provider -} - -// TimeUntilSend returns when the next packet should be sent. -func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time { - return c.pacer.TimeUntilSend() -} - -func (c *cubicSender) HasPacingBudget(now time.Time) bool { - return c.pacer.Budget(now) >= c.maxDatagramSize -} - -func (c *cubicSender) maxCongestionWindow() congestion.ByteCount { - return c.maxDatagramSize * MaxCongestionWindowPackets -} - -func (c *cubicSender) minCongestionWindow() congestion.ByteCount { - return c.maxDatagramSize * minCongestionWindowPackets -} - -func (c *cubicSender) OnPacketSent( - sentTime time.Time, - _ congestion.ByteCount, - packetNumber congestion.PacketNumber, - bytes congestion.ByteCount, - isRetransmittable bool, -) { - c.pacer.SentPacket(sentTime, bytes) - if !isRetransmittable { - return - } - c.largestSentPacketNumber = packetNumber - c.hybridSlowStart.OnPacketSent(packetNumber) -} - -func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool { - return bytesInFlight < c.GetCongestionWindow() -} - -func (c *cubicSender) InRecovery() bool { - return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback -} - -func (c *cubicSender) InSlowStart() bool { - return c.GetCongestionWindow() < c.slowStartThreshold -} - -func (c *cubicSender) GetCongestionWindow() congestion.ByteCount { - return c.congestionWindow -} - -func (c *cubicSender) MaybeExitSlowStart() { - if c.InSlowStart() && - c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { - // exit slow start - c.slowStartThreshold = c.congestionWindow - } -} - -func (c *cubicSender) OnPacketAcked( - ackedPacketNumber congestion.PacketNumber, - ackedBytes congestion.ByteCount, - priorInFlight congestion.ByteCount, - eventTime time.Time, -) { - c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber) - if c.InRecovery() { - return - } - c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) - if c.InSlowStart() { - c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) - } -} - -func (c *cubicSender) OnCongestionEvent(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { - // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets - // already sent should be treated as a single loss event, since it's expected. - if packetNumber <= c.largestSentAtLastCutback { - return - } - c.lastCutbackExitedSlowstart = c.InSlowStart() - - if c.reno { - c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta) - } else { - c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) - } - if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { - c.congestionWindow = minCwnd - } - c.slowStartThreshold = c.congestionWindow - c.largestSentAtLastCutback = c.largestSentPacketNumber - // reset packet count from congestion avoidance mode. We start - // counting again when we're out of recovery. - c.numAckedPackets = 0 -} - -func (b *cubicSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { - // Stub -} - -// Called when we receive an ack. Normal TCP tracks how many packets one ack -// represents, but quic has a separate ack for each packet. -func (c *cubicSender) maybeIncreaseCwnd( - _ congestion.PacketNumber, - ackedBytes congestion.ByteCount, - priorInFlight congestion.ByteCount, - eventTime time.Time, -) { - // Do not increase the congestion window unless the sender is close to using - // the current window. - if !c.isCwndLimited(priorInFlight) { - c.cubic.OnApplicationLimited() - return - } - if c.congestionWindow >= c.maxCongestionWindow() { - return - } - if c.InSlowStart() { - // TCP slow start, exponential growth, increase by one for each ACK. - c.congestionWindow += c.maxDatagramSize - return - } - // Congestion avoidance - if c.reno { - // Classic Reno congestion avoidance. - c.numAckedPackets++ - if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { - c.congestionWindow += c.maxDatagramSize - c.numAckedPackets = 0 - } - } else { - c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) - } -} - -func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool { - congestionWindow := c.GetCongestionWindow() - if bytesInFlight >= congestionWindow { - return true - } - availableBytes := congestionWindow - bytesInFlight - slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 - return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize -} - -// BandwidthEstimate returns the current bandwidth estimate -func (c *cubicSender) BandwidthEstimate() Bandwidth { - if c.rttStats == nil { - return infBandwidth - } - srtt := c.rttStats.SmoothedRTT() - if srtt == 0 { - // If we haven't measured an rtt, the bandwidth estimate is unknown. - return infBandwidth - } - return BandwidthFromDelta(c.GetCongestionWindow(), srtt) -} - -// OnRetransmissionTimeout is called on an retransmission timeout -func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { - c.largestSentAtLastCutback = InvalidPacketNumber - if !packetsRetransmitted { - return - } - c.hybridSlowStart.Restart() - c.cubic.Reset() - c.slowStartThreshold = c.congestionWindow / 2 - c.congestionWindow = c.minCongestionWindow() -} - -// OnConnectionMigration is called when the connection is migrated (?) -func (c *cubicSender) OnConnectionMigration() { - c.hybridSlowStart.Restart() - c.largestSentPacketNumber = InvalidPacketNumber - c.largestAckedPacketNumber = InvalidPacketNumber - c.largestSentAtLastCutback = InvalidPacketNumber - c.lastCutbackExitedSlowstart = false - c.cubic.Reset() - c.numAckedPackets = 0 - c.congestionWindow = c.initialCongestionWindow - c.slowStartThreshold = c.initialMaxCongestionWindow -} - -func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) { - if s < c.maxDatagramSize { - panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) - } - cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() - c.maxDatagramSize = s - if cwndIsMinCwnd { - c.congestionWindow = c.minCongestionWindow() - } - c.pacer.SetMaxDatagramSize(s) -} diff --git a/transport/tuic/congestion/hybrid_slow_start.go b/transport/tuic/congestion/hybrid_slow_start.go deleted file mode 100644 index 7586774e..00000000 --- a/transport/tuic/congestion/hybrid_slow_start.go +++ /dev/null @@ -1,112 +0,0 @@ -package congestion - -import ( - "time" - - "github.com/metacubex/quic-go/congestion" -) - -// Note(pwestin): the magic clamping numbers come from the original code in -// tcp_cubic.c. -const hybridStartLowWindow = congestion.ByteCount(16) - -// Number of delay samples for detecting the increase of delay. -const hybridStartMinSamples = uint32(8) - -// Exit slow start if the min rtt has increased by more than 1/8th. -const hybridStartDelayFactorExp = 3 // 2^3 = 8 -// The original paper specifies 2 and 8ms, but those have changed over time. -const ( - hybridStartDelayMinThresholdUs = int64(4000) - hybridStartDelayMaxThresholdUs = int64(16000) -) - -// HybridSlowStart implements the TCP hybrid slow start algorithm -type HybridSlowStart struct { - endPacketNumber congestion.PacketNumber - lastSentPacketNumber congestion.PacketNumber - started bool - currentMinRTT time.Duration - rttSampleCount uint32 - hystartFound bool -} - -// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. -func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) { - s.endPacketNumber = lastSent - s.currentMinRTT = 0 - s.rttSampleCount = 0 - s.started = true -} - -// IsEndOfRound returns true if this ack is the last packet number of our current slow start round. -func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool { - return s.endPacketNumber < ack -} - -// ShouldExitSlowStart should be called on every new ack frame, since a new -// RTT measurement can be made then. -// rtt: the RTT for this ack packet. -// minRTT: is the lowest delay (RTT) we have seen during the session. -// congestionWindow: the congestion window in packets. -func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool { - if !s.started { - // Time to start the hybrid slow start. - s.StartReceiveRound(s.lastSentPacketNumber) - } - if s.hystartFound { - return true - } - // Second detection parameter - delay increase detection. - // Compare the minimum delay (s.currentMinRTT) of the current - // burst of packets relative to the minimum delay during the session. - // Note: we only look at the first few(8) packets in each burst, since we - // only want to compare the lowest RTT of the burst relative to previous - // bursts. - s.rttSampleCount++ - if s.rttSampleCount <= hybridStartMinSamples { - if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { - s.currentMinRTT = latestRTT - } - } - // We only need to check this once per round. - if s.rttSampleCount == hybridStartMinSamples { - // Divide minRTT by 8 to get a rtt increase threshold for exiting. - minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) - // Ensure the rtt threshold is never less than 2ms or more than 16ms. - minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) - minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond - - if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { - s.hystartFound = true - } - } - // Exit from slow start if the cwnd is greater than 16 and - // increasing delay is found. - return congestionWindow >= hybridStartLowWindow && s.hystartFound -} - -// OnPacketSent is called when a packet was sent -func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) { - s.lastSentPacketNumber = packetNumber -} - -// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end -// the round when the final packet of the burst is received and start it on -// the next incoming ack. -func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) { - if s.IsEndOfRound(ackedPacketNumber) { - s.started = false - } -} - -// Started returns true if started -func (s *HybridSlowStart) Started() bool { - return s.started -} - -// Restart the slow start phase -func (s *HybridSlowStart) Restart() { - s.started = false - s.hystartFound = false -} diff --git a/transport/tuic/congestion/minmax.go b/transport/tuic/congestion/minmax.go deleted file mode 100644 index 0a8f4ad4..00000000 --- a/transport/tuic/congestion/minmax.go +++ /dev/null @@ -1,56 +0,0 @@ -package congestion - -import ( - "math" - "time" -) - -// InfDuration is a duration of infinite length -const InfDuration = time.Duration(math.MaxInt64) - -// MinNonZeroDuration return the minimum duration that's not zero. -func MinNonZeroDuration(a, b time.Duration) time.Duration { - if a == 0 { - return b - } - if b == 0 { - return a - } - return Min(a, b) -} - -// AbsDuration returns the absolute value of a time duration -func AbsDuration(d time.Duration) time.Duration { - if d >= 0 { - return d - } - return -d -} - -// MinTime returns the earlier time -func MinTime(a, b time.Time) time.Time { - if a.After(b) { - return b - } - return a -} - -// MinNonZeroTime returns the earlist time that is not time.Time{} -// If both a and b are time.Time{}, it returns time.Time{} -func MinNonZeroTime(a, b time.Time) time.Time { - if a.IsZero() { - return b - } - if b.IsZero() { - return a - } - return MinTime(a, b) -} - -// MaxTime returns the later time -func MaxTime(a, b time.Time) time.Time { - if a.After(b) { - return a - } - return b -} diff --git a/transport/tuic/congestion_v2/bandwidth.go b/transport/tuic/congestion_v2/bandwidth.go new file mode 100644 index 00000000..df39a077 --- /dev/null +++ b/transport/tuic/congestion_v2/bandwidth.go @@ -0,0 +1,27 @@ +package congestion + +import ( + "math" + "time" + + "github.com/metacubex/quic-go/congestion" +) + +const ( + infBandwidth = Bandwidth(math.MaxUint64) +) + +// Bandwidth of a connection +type Bandwidth uint64 + +const ( + // BitsPerSecond is 1 bit per second + BitsPerSecond Bandwidth = 1 + // BytesPerSecond is 1 byte per second + BytesPerSecond = 8 * BitsPerSecond +) + +// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta +func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { + return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond +} diff --git a/transport/tuic/congestion_v2/bandwidth_sampler.go b/transport/tuic/congestion_v2/bandwidth_sampler.go new file mode 100644 index 00000000..9028df64 --- /dev/null +++ b/transport/tuic/congestion_v2/bandwidth_sampler.go @@ -0,0 +1,874 @@ +package congestion + +import ( + "math" + "time" + + "github.com/metacubex/quic-go/congestion" +) + +const ( + infRTT = time.Duration(math.MaxInt64) + defaultConnectionStateMapQueueSize = 256 + defaultCandidatesBufferSize = 256 +) + +type roundTripCount uint64 + +// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned +// to the caller when the packet is acked or lost. +type sendTimeState struct { + // Whether other states in this object is valid. + isValid bool + // Whether the sender is app limited at the time the packet was sent. + // App limited bandwidth sample might be artificially low because the sender + // did not have enough data to send in order to saturate the link. + isAppLimited bool + // Total number of sent bytes at the time the packet was sent. + // Includes the packet itself. + totalBytesSent congestion.ByteCount + // Total number of acked bytes at the time the packet was sent. + totalBytesAcked congestion.ByteCount + // Total number of lost bytes at the time the packet was sent. + totalBytesLost congestion.ByteCount + // Total number of inflight bytes at the time the packet was sent. + // Includes the packet itself. + // It should be equal to |total_bytes_sent| minus the sum of + // |total_bytes_acked|, |total_bytes_lost| and total neutered bytes. + bytesInFlight congestion.ByteCount +} + +func newSendTimeState( + isAppLimited bool, + totalBytesSent congestion.ByteCount, + totalBytesAcked congestion.ByteCount, + totalBytesLost congestion.ByteCount, + bytesInFlight congestion.ByteCount, +) *sendTimeState { + return &sendTimeState{ + isValid: true, + isAppLimited: isAppLimited, + totalBytesSent: totalBytesSent, + totalBytesAcked: totalBytesAcked, + totalBytesLost: totalBytesLost, + bytesInFlight: bytesInFlight, + } +} + +type extraAckedEvent struct { + // The excess bytes acknowlwedged in the time delta for this event. + extraAcked congestion.ByteCount + + // The bytes acknowledged and time delta from the event. + bytesAcked congestion.ByteCount + timeDelta time.Duration + // The round trip of the event. + round roundTripCount +} + +func maxExtraAckedEventFunc(a, b extraAckedEvent) int { + if a.extraAcked > b.extraAcked { + return 1 + } else if a.extraAcked < b.extraAcked { + return -1 + } + return 0 +} + +// BandwidthSample +type bandwidthSample struct { + // The bandwidth at that particular sample. Zero if no valid bandwidth sample + // is available. + bandwidth Bandwidth + // The RTT measurement at this particular sample. Zero if no RTT sample is + // available. Does not correct for delayed ack time. + rtt time.Duration + // |send_rate| is computed from the current packet being acked('P') and an + // earlier packet that is acked before P was sent. + sendRate Bandwidth + // States captured when the packet was sent. + stateAtSend sendTimeState +} + +func newBandwidthSample() *bandwidthSample { + return &bandwidthSample{ + sendRate: infBandwidth, + } +} + +// MaxAckHeightTracker is part of the BandwidthSampler. It is called after every +// ack event to keep track the degree of ack aggregation(a.k.a "ack height"). +type maxAckHeightTracker struct { + // Tracks the maximum number of bytes acked faster than the estimated + // bandwidth. + maxAckHeightFilter *WindowedFilter[extraAckedEvent, roundTripCount] + // The time this aggregation started and the number of bytes acked during it. + aggregationEpochStartTime time.Time + aggregationEpochBytes congestion.ByteCount + // The last sent packet number before the current aggregation epoch started. + lastSentPacketNumberBeforeEpoch congestion.PacketNumber + // The number of ack aggregation epochs ever started, including the ongoing + // one. Stats only. + numAckAggregationEpochs uint64 + ackAggregationBandwidthThreshold float64 + startNewAggregationEpochAfterFullRound bool + reduceExtraAckedOnBandwidthIncrease bool +} + +func newMaxAckHeightTracker(windowLength roundTripCount) *maxAckHeightTracker { + return &maxAckHeightTracker{ + maxAckHeightFilter: NewWindowedFilter(windowLength, maxExtraAckedEventFunc), + lastSentPacketNumberBeforeEpoch: invalidPacketNumber, + ackAggregationBandwidthThreshold: 1.0, + } +} + +func (m *maxAckHeightTracker) Get() congestion.ByteCount { + return m.maxAckHeightFilter.GetBest().extraAcked +} + +func (m *maxAckHeightTracker) Update( + bandwidthEstimate Bandwidth, + isNewMaxBandwidth bool, + roundTripCount roundTripCount, + lastSentPacketNumber congestion.PacketNumber, + lastAckedPacketNumber congestion.PacketNumber, + ackTime time.Time, + bytesAcked congestion.ByteCount, +) congestion.ByteCount { + forceNewEpoch := false + + if m.reduceExtraAckedOnBandwidthIncrease && isNewMaxBandwidth { + // Save and clear existing entries. + best := m.maxAckHeightFilter.GetBest() + secondBest := m.maxAckHeightFilter.GetSecondBest() + thirdBest := m.maxAckHeightFilter.GetThirdBest() + m.maxAckHeightFilter.Clear() + + // Reinsert the heights into the filter after recalculating. + expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, best.timeDelta) + if expectedBytesAcked < best.bytesAcked { + best.extraAcked = best.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(best, best.round) + } + expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, secondBest.timeDelta) + if expectedBytesAcked < secondBest.bytesAcked { + secondBest.extraAcked = secondBest.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(secondBest, secondBest.round) + } + expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, thirdBest.timeDelta) + if expectedBytesAcked < thirdBest.bytesAcked { + thirdBest.extraAcked = thirdBest.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(thirdBest, thirdBest.round) + } + } + + // If any packet sent after the start of the epoch has been acked, start a new + // epoch. + if m.startNewAggregationEpochAfterFullRound && + m.lastSentPacketNumberBeforeEpoch != invalidPacketNumber && + lastAckedPacketNumber != invalidPacketNumber && + lastAckedPacketNumber > m.lastSentPacketNumberBeforeEpoch { + forceNewEpoch = true + } + if m.aggregationEpochStartTime.IsZero() || forceNewEpoch { + m.aggregationEpochBytes = bytesAcked + m.aggregationEpochStartTime = ackTime + m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber + m.numAckAggregationEpochs++ + return 0 + } + + // Compute how many bytes are expected to be delivered, assuming max bandwidth + // is correct. + aggregationDelta := ackTime.Sub(m.aggregationEpochStartTime) + expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, aggregationDelta) + // Reset the current aggregation epoch as soon as the ack arrival rate is less + // than or equal to the max bandwidth. + if m.aggregationEpochBytes <= congestion.ByteCount(m.ackAggregationBandwidthThreshold*float64(expectedBytesAcked)) { + // Reset to start measuring a new aggregation epoch. + m.aggregationEpochBytes = bytesAcked + m.aggregationEpochStartTime = ackTime + m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber + m.numAckAggregationEpochs++ + return 0 + } + + m.aggregationEpochBytes += bytesAcked + + // Compute how many extra bytes were delivered vs max bandwidth. + extraBytesAcked := m.aggregationEpochBytes - expectedBytesAcked + newEvent := extraAckedEvent{ + extraAcked: expectedBytesAcked, + bytesAcked: m.aggregationEpochBytes, + timeDelta: aggregationDelta, + } + m.maxAckHeightFilter.Update(newEvent, roundTripCount) + return extraBytesAcked +} + +func (m *maxAckHeightTracker) SetFilterWindowLength(length roundTripCount) { + m.maxAckHeightFilter.SetWindowLength(length) +} + +func (m *maxAckHeightTracker) Reset(newHeight congestion.ByteCount, newTime roundTripCount) { + newEvent := extraAckedEvent{ + extraAcked: newHeight, + round: newTime, + } + m.maxAckHeightFilter.Reset(newEvent, newTime) +} + +func (m *maxAckHeightTracker) SetAckAggregationBandwidthThreshold(threshold float64) { + m.ackAggregationBandwidthThreshold = threshold +} + +func (m *maxAckHeightTracker) SetStartNewAggregationEpochAfterFullRound(value bool) { + m.startNewAggregationEpochAfterFullRound = value +} + +func (m *maxAckHeightTracker) SetReduceExtraAckedOnBandwidthIncrease(value bool) { + m.reduceExtraAckedOnBandwidthIncrease = value +} + +func (m *maxAckHeightTracker) AckAggregationBandwidthThreshold() float64 { + return m.ackAggregationBandwidthThreshold +} + +func (m *maxAckHeightTracker) NumAckAggregationEpochs() uint64 { + return m.numAckAggregationEpochs +} + +// AckPoint represents a point on the ack line. +type ackPoint struct { + ackTime time.Time + totalBytesAcked congestion.ByteCount +} + +// RecentAckPoints maintains the most recent 2 ack points at distinct times. +type recentAckPoints struct { + ackPoints [2]ackPoint +} + +func (r *recentAckPoints) Update(ackTime time.Time, totalBytesAcked congestion.ByteCount) { + if ackTime.Before(r.ackPoints[1].ackTime) { + r.ackPoints[1].ackTime = ackTime + } else if ackTime.After(r.ackPoints[1].ackTime) { + r.ackPoints[0] = r.ackPoints[1] + r.ackPoints[1].ackTime = ackTime + } + + r.ackPoints[1].totalBytesAcked = totalBytesAcked +} + +func (r *recentAckPoints) Clear() { + r.ackPoints[0] = ackPoint{} + r.ackPoints[1] = ackPoint{} +} + +func (r *recentAckPoints) MostRecentPoint() *ackPoint { + return &r.ackPoints[1] +} + +func (r *recentAckPoints) LessRecentPoint() *ackPoint { + if r.ackPoints[0].totalBytesAcked != 0 { + return &r.ackPoints[0] + } + + return &r.ackPoints[1] +} + +// ConnectionStateOnSentPacket represents the information about a sent packet +// and the state of the connection at the moment the packet was sent, +// specifically the information about the most recently acknowledged packet at +// that moment. +type connectionStateOnSentPacket struct { + // Time at which the packet is sent. + sentTime time.Time + // Size of the packet. + size congestion.ByteCount + // The value of |totalBytesSentAtLastAckedPacket| at the time the + // packet was sent. + totalBytesSentAtLastAckedPacket congestion.ByteCount + // The value of |lastAckedPacketSentTime| at the time the packet was + // sent. + lastAckedPacketSentTime time.Time + // The value of |lastAckedPacketAckTime| at the time the packet was + // sent. + lastAckedPacketAckTime time.Time + // Send time states that are returned to the congestion controller when the + // packet is acked or lost. + sendTimeState sendTimeState +} + +// Snapshot constructor. Records the current state of the bandwidth +// sampler. +// |bytes_in_flight| is the bytes in flight right after the packet is sent. +func newConnectionStateOnSentPacket( + sentTime time.Time, + size congestion.ByteCount, + bytesInFlight congestion.ByteCount, + sampler *bandwidthSampler, +) *connectionStateOnSentPacket { + return &connectionStateOnSentPacket{ + sentTime: sentTime, + size: size, + totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket, + lastAckedPacketSentTime: sampler.lastAckedPacketSentTime, + lastAckedPacketAckTime: sampler.lastAckedPacketAckTime, + sendTimeState: *newSendTimeState( + sampler.isAppLimited, + sampler.totalBytesSent, + sampler.totalBytesAcked, + sampler.totalBytesLost, + bytesInFlight, + ), + } +} + +// BandwidthSampler keeps track of sent and acknowledged packets and outputs a +// bandwidth sample for every packet acknowledged. The samples are taken for +// individual packets, and are not filtered; the consumer has to filter the +// bandwidth samples itself. In certain cases, the sampler will locally severely +// underestimate the bandwidth, hence a maximum filter with a size of at least +// one RTT is recommended. +// +// This class bases its samples on the slope of two curves: the number of bytes +// sent over time, and the number of bytes acknowledged as received over time. +// It produces a sample of both slopes for every packet that gets acknowledged, +// based on a slope between two points on each of the corresponding curves. Note +// that due to the packet loss, the number of bytes on each curve might get +// further and further away from each other, meaning that it is not feasible to +// compare byte values coming from different curves with each other. +// +// The obvious points for measuring slope sample are the ones corresponding to +// the packet that was just acknowledged. Let us denote them as S_1 (point at +// which the current packet was sent) and A_1 (point at which the current packet +// was acknowledged). However, taking a slope requires two points on each line, +// so estimating bandwidth requires picking a packet in the past with respect to +// which the slope is measured. +// +// For that purpose, BandwidthSampler always keeps track of the most recently +// acknowledged packet, and records it together with every outgoing packet. +// When a packet gets acknowledged (A_1), it has not only information about when +// it itself was sent (S_1), but also the information about the latest +// acknowledged packet right before it was sent (S_0 and A_0). +// +// Based on that data, send and ack rate are estimated as: +// +// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0)) +// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0)) +// +// Here, the ack rate is intuitively the rate we want to treat as bandwidth. +// However, in certain cases (e.g. ack compression) the ack rate at a point may +// end up higher than the rate at which the data was originally sent, which is +// not indicative of the real bandwidth. Hence, we use the send rate as an upper +// bound, and the sample value is +// +// rate_sample = Min(send_rate, ack_rate) +// +// An important edge case handled by the sampler is tracking the app-limited +// samples. There are multiple meaning of "app-limited" used interchangeably, +// hence it is important to understand and to be able to distinguish between +// them. +// +// Meaning 1: connection state. The connection is said to be app-limited when +// there is no outstanding data to send. This means that certain bandwidth +// samples in the future would not be an accurate indication of the link +// capacity, and it is important to inform consumer about that. Whenever +// connection becomes app-limited, the sampler is notified via OnAppLimited() +// method. +// +// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth +// sampler becomes notified about the connection being app-limited, it enters +// app-limited phase. In that phase, all *sent* packets are marked as +// app-limited. Note that the connection itself does not have to be +// app-limited during the app-limited phase, and in fact it will not be +// (otherwise how would it send packets?). The boolean flag below indicates +// whether the sampler is in that phase. +// +// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is +// sent during the app-limited phase, the resulting sample related to the +// packet will be marked as app-limited. +// +// With the terminology issue out of the way, let us consider the question of +// what kind of situation it addresses. +// +// Consider a scenario where we first send packets 1 to 20 at a regular +// bandwidth, and then immediately run out of data. After a few seconds, we send +// packets 21 to 60, and only receive ack for 21 between sending packets 40 and +// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0 +// we use to compute the slope is going to be packet 20, a few seconds apart +// from the current packet, hence the resulting estimate would be extremely low +// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21, +// meaning that the bandwidth sample would exclude the quiescence. +// +// Based on the analysis of that scenario, we implement the following rule: once +// OnAppLimited() is called, all sent packets will produce app-limited samples +// up until an ack for a packet that was sent after OnAppLimited() was called. +// Note that while the scenario above is not the only scenario when the +// connection is app-limited, the approach works in other cases too. + +type congestionEventSample struct { + // The maximum bandwidth sample from all acked packets. + // QuicBandwidth::Zero() if no samples are available. + sampleMaxBandwidth Bandwidth + // Whether |sample_max_bandwidth| is from a app-limited sample. + sampleIsAppLimited bool + // The minimum rtt sample from all acked packets. + // QuicTime::Delta::Infinite() if no samples are available. + sampleRtt time.Duration + // For each packet p in acked packets, this is the max value of INFLIGHT(p), + // where INFLIGHT(p) is the number of bytes acked while p is inflight. + sampleMaxInflight congestion.ByteCount + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + lastPacketSendState sendTimeState + // The number of extra bytes acked from this ack event, compared to what is + // expected from the flow's bandwidth. Larger value means more ack + // aggregation. + extraAcked congestion.ByteCount +} + +func newCongestionEventSample() *congestionEventSample { + return &congestionEventSample{ + sampleRtt: infRTT, + } +} + +type bandwidthSampler struct { + // The total number of congestion controlled bytes sent during the connection. + totalBytesSent congestion.ByteCount + + // The total number of congestion controlled bytes which were acknowledged. + totalBytesAcked congestion.ByteCount + + // The total number of congestion controlled bytes which were lost. + totalBytesLost congestion.ByteCount + + // The total number of congestion controlled bytes which have been neutered. + totalBytesNeutered congestion.ByteCount + + // The value of |total_bytes_sent_| at the time the last acknowledged packet + // was sent. Valid only when |last_acked_packet_sent_time_| is valid. + totalBytesSentAtLastAckedPacket congestion.ByteCount + + // The time at which the last acknowledged packet was sent. Set to + // QuicTime::Zero() if no valid timestamp is available. + lastAckedPacketSentTime time.Time + + // The time at which the most recent packet was acknowledged. + lastAckedPacketAckTime time.Time + + // The most recently sent packet. + lastSentPacket congestion.PacketNumber + + // The most recently acked packet. + lastAckedPacket congestion.PacketNumber + + // Indicates whether the bandwidth sampler is currently in an app-limited + // phase. + isAppLimited bool + + // The packet that will be acknowledged after this one will cause the sampler + // to exit the app-limited phase. + endOfAppLimitedPhase congestion.PacketNumber + + // Record of the connection state at the point where each packet in flight was + // sent, indexed by the packet number. + connectionStateMap *packetNumberIndexedQueue[connectionStateOnSentPacket] + + recentAckPoints recentAckPoints + a0Candidates RingBuffer[ackPoint] + + // Maximum number of tracked packets. + maxTrackedPackets congestion.ByteCount + + maxAckHeightTracker *maxAckHeightTracker + totalBytesAckedAfterLastAckEvent congestion.ByteCount + + // True if connection option 'BSAO' is set. + overestimateAvoidance bool + + // True if connection option 'BBRB' is set. + limitMaxAckHeightTrackerBySendRate bool +} + +func newBandwidthSampler(maxAckHeightTrackerWindowLength roundTripCount) *bandwidthSampler { + b := &bandwidthSampler{ + maxAckHeightTracker: newMaxAckHeightTracker(maxAckHeightTrackerWindowLength), + connectionStateMap: newPacketNumberIndexedQueue[connectionStateOnSentPacket](defaultConnectionStateMapQueueSize), + lastSentPacket: invalidPacketNumber, + lastAckedPacket: invalidPacketNumber, + endOfAppLimitedPhase: invalidPacketNumber, + } + + b.a0Candidates.Init(defaultCandidatesBufferSize) + + return b +} + +func (b *bandwidthSampler) MaxAckHeight() congestion.ByteCount { + return b.maxAckHeightTracker.Get() +} + +func (b *bandwidthSampler) NumAckAggregationEpochs() uint64 { + return b.maxAckHeightTracker.NumAckAggregationEpochs() +} + +func (b *bandwidthSampler) SetMaxAckHeightTrackerWindowLength(length roundTripCount) { + b.maxAckHeightTracker.SetFilterWindowLength(length) +} + +func (b *bandwidthSampler) ResetMaxAckHeightTracker(newHeight congestion.ByteCount, newTime roundTripCount) { + b.maxAckHeightTracker.Reset(newHeight, newTime) +} + +func (b *bandwidthSampler) SetStartNewAggregationEpochAfterFullRound(value bool) { + b.maxAckHeightTracker.SetStartNewAggregationEpochAfterFullRound(value) +} + +func (b *bandwidthSampler) SetLimitMaxAckHeightTrackerBySendRate(value bool) { + b.limitMaxAckHeightTrackerBySendRate = value +} + +func (b *bandwidthSampler) SetReduceExtraAckedOnBandwidthIncrease(value bool) { + b.maxAckHeightTracker.SetReduceExtraAckedOnBandwidthIncrease(value) +} + +func (b *bandwidthSampler) EnableOverestimateAvoidance() { + if b.overestimateAvoidance { + return + } + + b.overestimateAvoidance = true + b.maxAckHeightTracker.SetAckAggregationBandwidthThreshold(2.0) +} + +func (b *bandwidthSampler) IsOverestimateAvoidanceEnabled() bool { + return b.overestimateAvoidance +} + +func (b *bandwidthSampler) OnPacketSent( + sentTime time.Time, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + bytesInFlight congestion.ByteCount, + isRetransmittable bool, +) { + b.lastSentPacket = packetNumber + + if !isRetransmittable { + return + } + + b.totalBytesSent += bytes + + // If there are no packets in flight, the time at which the new transmission + // opens can be treated as the A_0 point for the purpose of bandwidth + // sampling. This underestimates bandwidth to some extent, and produces some + // artificially low samples for most packets in flight, but it provides with + // samples at important points where we would not have them otherwise, most + // importantly at the beginning of the connection. + if bytesInFlight == 0 { + b.lastAckedPacketAckTime = sentTime + if b.overestimateAvoidance { + b.recentAckPoints.Clear() + b.recentAckPoints.Update(sentTime, b.totalBytesAcked) + b.a0Candidates.Clear() + b.a0Candidates.PushBack(*b.recentAckPoints.MostRecentPoint()) + } + b.totalBytesSentAtLastAckedPacket = b.totalBytesSent + + // In this situation ack compression is not a concern, set send rate to + // effectively infinite. + b.lastAckedPacketSentTime = sentTime + } + + b.connectionStateMap.Emplace(packetNumber, newConnectionStateOnSentPacket( + sentTime, + bytes, + bytesInFlight+bytes, + b, + )) +} + +func (b *bandwidthSampler) OnCongestionEvent( + ackTime time.Time, + ackedPackets []congestion.AckedPacketInfo, + lostPackets []congestion.LostPacketInfo, + maxBandwidth Bandwidth, + estBandwidthUpperBound Bandwidth, + roundTripCount roundTripCount, +) congestionEventSample { + eventSample := newCongestionEventSample() + + var lastLostPacketSendState sendTimeState + + for _, p := range lostPackets { + sendState := b.OnPacketLost(p.PacketNumber, p.BytesLost) + if sendState.isValid { + lastLostPacketSendState = sendState + } + } + + if len(ackedPackets) == 0 { + // Only populate send state for a loss-only event. + eventSample.lastPacketSendState = lastLostPacketSendState + return *eventSample + } + + var lastAckedPacketSendState sendTimeState + var maxSendRate Bandwidth + + for _, p := range ackedPackets { + sample := b.onPacketAcknowledged(ackTime, p.PacketNumber) + if !sample.stateAtSend.isValid { + continue + } + + lastAckedPacketSendState = sample.stateAtSend + + if sample.rtt != 0 { + eventSample.sampleRtt = Min(eventSample.sampleRtt, sample.rtt) + } + if sample.bandwidth > eventSample.sampleMaxBandwidth { + eventSample.sampleMaxBandwidth = sample.bandwidth + eventSample.sampleIsAppLimited = sample.stateAtSend.isAppLimited + } + if sample.sendRate != infBandwidth { + maxSendRate = Max(maxSendRate, sample.sendRate) + } + inflightSample := b.totalBytesAcked - lastAckedPacketSendState.totalBytesAcked + if inflightSample > eventSample.sampleMaxInflight { + eventSample.sampleMaxInflight = inflightSample + } + } + + if !lastLostPacketSendState.isValid { + eventSample.lastPacketSendState = lastAckedPacketSendState + } else if !lastAckedPacketSendState.isValid { + eventSample.lastPacketSendState = lastLostPacketSendState + } else { + // If two packets are inflight and an alarm is armed to lose a packet and it + // wakes up late, then the first of two in flight packets could have been + // acknowledged before the wakeup, which re-evaluates loss detection, and + // could declare the later of the two lost. + if lostPackets[len(lostPackets)-1].PacketNumber > ackedPackets[len(ackedPackets)-1].PacketNumber { + eventSample.lastPacketSendState = lastLostPacketSendState + } else { + eventSample.lastPacketSendState = lastAckedPacketSendState + } + } + + isNewMaxBandwidth := eventSample.sampleMaxBandwidth > maxBandwidth + maxBandwidth = Max(maxBandwidth, eventSample.sampleMaxBandwidth) + if b.limitMaxAckHeightTrackerBySendRate { + maxBandwidth = Max(maxBandwidth, maxSendRate) + } + + eventSample.extraAcked = b.onAckEventEnd(Min(estBandwidthUpperBound, maxBandwidth), isNewMaxBandwidth, roundTripCount) + + return *eventSample +} + +func (b *bandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber, bytesLost congestion.ByteCount) (s sendTimeState) { + b.totalBytesLost += bytesLost + if sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber); sentPacketPointer != nil { + sentPacketToSendTimeState(sentPacketPointer, &s) + } + return s +} + +func (b *bandwidthSampler) OnPacketNeutered(packetNumber congestion.PacketNumber) { + b.connectionStateMap.Remove(packetNumber, func(sentPacket connectionStateOnSentPacket) { + b.totalBytesNeutered += sentPacket.size + }) +} + +func (b *bandwidthSampler) OnAppLimited() { + b.isAppLimited = true + b.endOfAppLimitedPhase = b.lastSentPacket +} + +func (b *bandwidthSampler) RemoveObsoletePackets(leastUnacked congestion.PacketNumber) { + // A packet can become obsolete when it is removed from QuicUnackedPacketMap's + // view of inflight before it is acked or marked as lost. For example, when + // QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto packet, + // the packet is removed from QuicUnackedPacketMap's inflight, but is not + // marked as acked or lost in the BandwidthSampler. + b.connectionStateMap.RemoveUpTo(leastUnacked) +} + +func (b *bandwidthSampler) TotalBytesSent() congestion.ByteCount { + return b.totalBytesSent +} + +func (b *bandwidthSampler) TotalBytesLost() congestion.ByteCount { + return b.totalBytesLost +} + +func (b *bandwidthSampler) TotalBytesAcked() congestion.ByteCount { + return b.totalBytesAcked +} + +func (b *bandwidthSampler) TotalBytesNeutered() congestion.ByteCount { + return b.totalBytesNeutered +} + +func (b *bandwidthSampler) IsAppLimited() bool { + return b.isAppLimited +} + +func (b *bandwidthSampler) EndOfAppLimitedPhase() congestion.PacketNumber { + return b.endOfAppLimitedPhase +} + +func (b *bandwidthSampler) max_ack_height() congestion.ByteCount { + return b.maxAckHeightTracker.Get() +} + +func (b *bandwidthSampler) chooseA0Point(totalBytesAcked congestion.ByteCount, a0 *ackPoint) bool { + if b.a0Candidates.Empty() { + return false + } + + if b.a0Candidates.Len() == 1 { + *a0 = *b.a0Candidates.Front() + return true + } + + for i := 1; i < b.a0Candidates.Len(); i++ { + if b.a0Candidates.Offset(i).totalBytesAcked > totalBytesAcked { + *a0 = *b.a0Candidates.Offset(i - 1) + if i > 1 { + for j := 0; j < i-1; j++ { + b.a0Candidates.PopFront() + } + } + return true + } + } + + *a0 = *b.a0Candidates.Back() + for k := 0; k < b.a0Candidates.Len()-1; k++ { + b.a0Candidates.PopFront() + } + return true +} + +func (b *bandwidthSampler) onPacketAcknowledged(ackTime time.Time, packetNumber congestion.PacketNumber) bandwidthSample { + sample := newBandwidthSample() + b.lastAckedPacket = packetNumber + sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber) + if sentPacketPointer == nil { + return *sample + } + + // OnPacketAcknowledgedInner + b.totalBytesAcked += sentPacketPointer.size + b.totalBytesSentAtLastAckedPacket = sentPacketPointer.sendTimeState.totalBytesSent + b.lastAckedPacketSentTime = sentPacketPointer.sentTime + b.lastAckedPacketAckTime = ackTime + if b.overestimateAvoidance { + b.recentAckPoints.Update(ackTime, b.totalBytesAcked) + } + + if b.isAppLimited { + // Exit app-limited phase in two cases: + // (1) end_of_app_limited_phase_ is not initialized, i.e., so far all + // packets are sent while there are buffered packets or pending data. + // (2) The current acked packet is after the sent packet marked as the end + // of the app limit phase. + if b.endOfAppLimitedPhase == invalidPacketNumber || + packetNumber > b.endOfAppLimitedPhase { + b.isAppLimited = false + } + } + + // There might have been no packets acknowledged at the moment when the + // current packet was sent. In that case, there is no bandwidth sample to + // make. + if sentPacketPointer.lastAckedPacketSentTime.IsZero() { + return *sample + } + + // Infinite rate indicates that the sampler is supposed to discard the + // current send rate sample and use only the ack rate. + sendRate := infBandwidth + if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) { + sendRate = BandwidthFromDelta( + sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket, + sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime)) + } + + var a0 ackPoint + if b.overestimateAvoidance && b.chooseA0Point(sentPacketPointer.sendTimeState.totalBytesAcked, &a0) { + } else { + a0.ackTime = sentPacketPointer.lastAckedPacketAckTime + a0.totalBytesAcked = sentPacketPointer.sendTimeState.totalBytesAcked + } + + // During the slope calculation, ensure that ack time of the current packet is + // always larger than the time of the previous packet, otherwise division by + // zero or integer underflow can occur. + if ackTime.Sub(a0.ackTime) <= 0 { + return *sample + } + + ackRate := BandwidthFromDelta(b.totalBytesAcked-a0.totalBytesAcked, ackTime.Sub(a0.ackTime)) + + sample.bandwidth = Min(sendRate, ackRate) + // Note: this sample does not account for delayed acknowledgement time. This + // means that the RTT measurements here can be artificially high, especially + // on low bandwidth connections. + sample.rtt = ackTime.Sub(sentPacketPointer.sentTime) + sample.sendRate = sendRate + sentPacketToSendTimeState(sentPacketPointer, &sample.stateAtSend) + + return *sample +} + +func (b *bandwidthSampler) onAckEventEnd( + bandwidthEstimate Bandwidth, + isNewMaxBandwidth bool, + roundTripCount roundTripCount, +) congestion.ByteCount { + newlyAckedBytes := b.totalBytesAcked - b.totalBytesAckedAfterLastAckEvent + if newlyAckedBytes == 0 { + return 0 + } + b.totalBytesAckedAfterLastAckEvent = b.totalBytesAcked + extraAcked := b.maxAckHeightTracker.Update( + bandwidthEstimate, + isNewMaxBandwidth, + roundTripCount, + b.lastSentPacket, + b.lastAckedPacket, + b.lastAckedPacketAckTime, + newlyAckedBytes) + // If |extra_acked| is zero, i.e. this ack event marks the start of a new ack + // aggregation epoch, save LessRecentPoint, which is the last ack point of the + // previous epoch, as a A0 candidate. + if b.overestimateAvoidance && extraAcked == 0 { + b.a0Candidates.PushBack(*b.recentAckPoints.LessRecentPoint()) + } + return extraAcked +} + +func sentPacketToSendTimeState(sentPacket *connectionStateOnSentPacket, sendTimeState *sendTimeState) { + *sendTimeState = sentPacket.sendTimeState + sendTimeState.isValid = true +} + +// BytesFromBandwidthAndTimeDelta calculates the bytes +// from a bandwidth(bits per second) and a time delta +func bytesFromBandwidthAndTimeDelta(bandwidth Bandwidth, delta time.Duration) congestion.ByteCount { + return (congestion.ByteCount(bandwidth) * congestion.ByteCount(delta)) / + (congestion.ByteCount(time.Second) * 8) +} + +func timeDeltaFromBytesAndBandwidth(bytes congestion.ByteCount, bandwidth Bandwidth) time.Duration { + return time.Duration(bytes*8) * time.Second / time.Duration(bandwidth) +} diff --git a/transport/tuic/congestion_v2/bbr_sender.go b/transport/tuic/congestion_v2/bbr_sender.go new file mode 100644 index 00000000..a7700fa1 --- /dev/null +++ b/transport/tuic/congestion_v2/bbr_sender.go @@ -0,0 +1,927 @@ +package congestion + +// src from https://github.com/google/quiche/blob/e7872fc9e12bb1d46a118949c3d4da36de58aa44/quiche/quic/core/congestion_control/bbr_sender.cc + +import ( + "fmt" + "net" + "time" + + "github.com/metacubex/quic-go/congestion" + + "github.com/zhangyunhao116/fastrand" +) + +// BbrSender implements BBR congestion control algorithm. BBR aims to estimate +// the current available Bottleneck Bandwidth and RTT (hence the name), and +// regulates the pacing rate and the size of the congestion window based on +// those signals. +// +// BBR relies on pacing in order to function properly. Do not use BBR when +// pacing is disabled. +// + +const ( + invalidPacketNumber = -1 + initialCongestionWindowPackets = 32 + + // Constants based on TCP defaults. + // The minimum CWND to ensure delayed acks don't reduce bandwidth measurements. + // Does not inflate the pacing rate. + defaultMinimumCongestionWindow = 4 * congestion.ByteCount(congestion.InitialPacketSizeIPv4) + + // The gain used for the STARTUP, equal to 2/ln(2). + defaultHighGain = 2.885 + // The newly derived gain for STARTUP, equal to 4 * ln(2) + derivedHighGain = 2.773 + // The newly derived CWND gain for STARTUP, 2. + derivedHighCWNDGain = 2.0 +) + +// The cycle of gains used during the PROBE_BW stage. +var pacingGain = [...]float64{1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0} + +const ( + // The length of the gain cycle. + gainCycleLength = len(pacingGain) + // The size of the bandwidth filter window, in round-trips. + bandwidthWindowSize = gainCycleLength + 2 + + // The time after which the current min_rtt value expires. + minRttExpiry = 10 * time.Second + // The minimum time the connection can spend in PROBE_RTT mode. + probeRttTime = 200 * time.Millisecond + // If the bandwidth does not increase by the factor of |kStartupGrowthTarget| + // within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection + // will exit the STARTUP mode. + startupGrowthTarget = 1.25 + roundTripsWithoutGrowthBeforeExitingStartup = int64(3) + + // Flag. + defaultStartupFullLossCount = 8 + quicBbr2DefaultLossThreshold = 0.02 + maxBbrBurstPackets = 3 +) + +type bbrMode int + +const ( + // Startup phase of the connection. + bbrModeStartup = iota + // After achieving the highest possible bandwidth during the startup, lower + // the pacing rate in order to drain the queue. + bbrModeDrain + // Cruising mode. + bbrModeProbeBw + // Temporarily slow down sending in order to empty the buffer and measure + // the real minimum RTT. + bbrModeProbeRtt +) + +// Indicates how the congestion control limits the amount of bytes in flight. +type bbrRecoveryState int + +const ( + // Do not limit. + bbrRecoveryStateNotInRecovery = iota + // Allow an extra outstanding byte for each byte acknowledged. + bbrRecoveryStateConservation + // Allow two extra outstanding bytes for each byte acknowledged (slow + // start). + bbrRecoveryStateGrowth +) + +type bbrSender struct { + rttStats congestion.RTTStatsProvider + clock Clock + pacer *Pacer + + mode bbrMode + + // Bandwidth sampler provides BBR with the bandwidth measurements at + // individual points. + sampler *bandwidthSampler + + // The number of the round trips that have occurred during the connection. + roundTripCount roundTripCount + + // The packet number of the most recently sent packet. + lastSentPacket congestion.PacketNumber + // Acknowledgement of any packet after |current_round_trip_end_| will cause + // the round trip counter to advance. + currentRoundTripEnd congestion.PacketNumber + + // Number of congestion events with some losses, in the current round. + numLossEventsInRound uint64 + + // Number of total bytes lost in the current round. + bytesLostInRound congestion.ByteCount + + // The filter that tracks the maximum bandwidth over the multiple recent + // round-trips. + maxBandwidth *WindowedFilter[Bandwidth, roundTripCount] + + // Minimum RTT estimate. Automatically expires within 10 seconds (and + // triggers PROBE_RTT mode) if no new value is sampled during that period. + minRtt time.Duration + // The time at which the current value of |min_rtt_| was assigned. + minRttTimestamp time.Time + + // The maximum allowed number of bytes in flight. + congestionWindow congestion.ByteCount + + // The initial value of the |congestion_window_|. + initialCongestionWindow congestion.ByteCount + + // The largest value the |congestion_window_| can achieve. + maxCongestionWindow congestion.ByteCount + + // The smallest value the |congestion_window_| can achieve. + minCongestionWindow congestion.ByteCount + + // The pacing gain applied during the STARTUP phase. + highGain float64 + + // The CWND gain applied during the STARTUP phase. + highCwndGain float64 + + // The pacing gain applied during the DRAIN phase. + drainGain float64 + + // The current pacing rate of the connection. + pacingRate Bandwidth + + // The gain currently applied to the pacing rate. + pacingGain float64 + // The gain currently applied to the congestion window. + congestionWindowGain float64 + + // The gain used for the congestion window during PROBE_BW. Latched from + // quic_bbr_cwnd_gain flag. + congestionWindowGainConstant float64 + // The number of RTTs to stay in STARTUP mode. Defaults to 3. + numStartupRtts int64 + + // Number of round-trips in PROBE_BW mode, used for determining the current + // pacing gain cycle. + cycleCurrentOffset int + // The time at which the last pacing gain cycle was started. + lastCycleStart time.Time + + // Indicates whether the connection has reached the full bandwidth mode. + isAtFullBandwidth bool + // Number of rounds during which there was no significant bandwidth increase. + roundsWithoutBandwidthGain int64 + // The bandwidth compared to which the increase is measured. + bandwidthAtLastRound Bandwidth + + // Set to true upon exiting quiescence. + exitingQuiescence bool + + // Time at which PROBE_RTT has to be exited. Setting it to zero indicates + // that the time is yet unknown as the number of packets in flight has not + // reached the required value. + exitProbeRttAt time.Time + // Indicates whether a round-trip has passed since PROBE_RTT became active. + probeRttRoundPassed bool + + // Indicates whether the most recent bandwidth sample was marked as + // app-limited. + lastSampleIsAppLimited bool + // Indicates whether any non app-limited samples have been recorded. + hasNoAppLimitedSample bool + + // Current state of recovery. + recoveryState bbrRecoveryState + // Receiving acknowledgement of a packet after |end_recovery_at_| will cause + // BBR to exit the recovery mode. A value above zero indicates at least one + // loss has been detected, so it must not be set back to zero. + endRecoveryAt congestion.PacketNumber + // A window used to limit the number of bytes in flight during loss recovery. + recoveryWindow congestion.ByteCount + // If true, consider all samples in recovery app-limited. + isAppLimitedRecovery bool // not used + + // When true, pace at 1.5x and disable packet conservation in STARTUP. + slowerStartup bool // not used + // When true, disables packet conservation in STARTUP. + rateBasedStartup bool // not used + + // When true, add the most recent ack aggregation measurement during STARTUP. + enableAckAggregationDuringStartup bool + // When true, expire the windowed ack aggregation values in STARTUP when + // bandwidth increases more than 25%. + expireAckAggregationInStartup bool + + // If true, will not exit low gain mode until bytes_in_flight drops below BDP + // or it's time for high gain mode. + drainToTarget bool + + // If true, slow down pacing rate in STARTUP when overshooting is detected. + detectOvershooting bool + // Bytes lost while detect_overshooting_ is true. + bytesLostWhileDetectingOvershooting congestion.ByteCount + // Slow down pacing rate if + // bytes_lost_while_detecting_overshooting_ * + // bytes_lost_multiplier_while_detecting_overshooting_ > IW. + bytesLostMultiplierWhileDetectingOvershooting uint8 + // When overshooting is detected, do not drop pacing_rate_ below this value / + // min_rtt. + cwndToCalculateMinPacingRate congestion.ByteCount + + // Max congestion window when adjusting network parameters. + maxCongestionWindowWithNetworkParametersAdjusted congestion.ByteCount // not used + + // Params. + maxDatagramSize congestion.ByteCount + // Recorded on packet sent. equivalent |unacked_packets_->bytes_in_flight()| + bytesInFlight congestion.ByteCount +} + +var _ congestion.CongestionControl = &bbrSender{} + +func NewBbrSender( + clock Clock, + initialMaxDatagramSize congestion.ByteCount, + initialCongestionWindowPackets congestion.ByteCount, +) *bbrSender { + return newBbrSender( + clock, + initialMaxDatagramSize, + initialCongestionWindowPackets*initialMaxDatagramSize, + congestion.MaxCongestionWindowPackets*initialMaxDatagramSize, + ) +} + +func newBbrSender( + clock Clock, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow congestion.ByteCount, +) *bbrSender { + b := &bbrSender{ + clock: clock, + mode: bbrModeStartup, + sampler: newBandwidthSampler(roundTripCount(bandwidthWindowSize)), + lastSentPacket: invalidPacketNumber, + currentRoundTripEnd: invalidPacketNumber, + maxBandwidth: NewWindowedFilter(roundTripCount(bandwidthWindowSize), MaxFilter[Bandwidth]), + congestionWindow: initialCongestionWindow, + initialCongestionWindow: initialCongestionWindow, + maxCongestionWindow: initialMaxCongestionWindow, + minCongestionWindow: defaultMinimumCongestionWindow, + highGain: defaultHighGain, + highCwndGain: defaultHighGain, + drainGain: 1.0 / defaultHighGain, + pacingGain: 1.0, + congestionWindowGain: 1.0, + congestionWindowGainConstant: 2.0, + numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup, + recoveryState: bbrRecoveryStateNotInRecovery, + endRecoveryAt: invalidPacketNumber, + recoveryWindow: initialMaxCongestionWindow, + bytesLostMultiplierWhileDetectingOvershooting: 2, + cwndToCalculateMinPacingRate: initialCongestionWindow, + maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow, + maxDatagramSize: initialMaxDatagramSize, + } + b.pacer = NewPacer(func() congestion.ByteCount { + // Pacer wants bytes per second, but Bandwidth is in bits per second. + return congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond)) + }) + + /* + if b.tracer != nil { + b.lastState = logging.CongestionStateStartup + b.tracer.UpdatedCongestionState(logging.CongestionStateStartup) + } + */ + + b.enterStartupMode(b.clock.Now()) + b.setHighCwndGain(derivedHighCWNDGain) + + return b +} + +func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { + b.rttStats = provider +} + +// TimeUntilSend implements the SendAlgorithm interface. +func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { + return b.pacer.TimeUntilSend() +} + +// HasPacingBudget implements the SendAlgorithm interface. +func (b *bbrSender) HasPacingBudget(now time.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +// OnPacketSent implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketSent( + sentTime time.Time, + bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) + + b.lastSentPacket = packetNumber + b.bytesInFlight = bytesInFlight + + if bytesInFlight == 0 { + b.exitingQuiescence = true + } + + b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable) +} + +// CanSend implements the SendAlgorithm interface. +func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < b.GetCongestionWindow() +} + +// MaybeExitSlowStart implements the SendAlgorithm interface. +func (b *bbrSender) MaybeExitSlowStart() { + // Do nothing +} + +// OnPacketAcked implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes, priorInFlight congestion.ByteCount, eventTime time.Time) { + // Do nothing. +} + +// OnPacketLost implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // Do nothing. +} + +// OnRetransmissionTimeout implements the SendAlgorithm interface. +func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) { + // Do nothing. +} + +// SetMaxDatagramSize implements the SendAlgorithm interface. +func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) { + if s < b.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s)) + } + cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow + b.maxDatagramSize = s + if cwndIsMinCwnd { + b.congestionWindow = b.minCongestionWindow + } + b.pacer.SetMaxDatagramSize(s) +} + +// InSlowStart implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) InSlowStart() bool { + return b.mode == bbrModeStartup +} + +// InRecovery implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) InRecovery() bool { + return b.recoveryState != bbrRecoveryStateNotInRecovery +} + +// GetCongestionWindow implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) GetCongestionWindow() congestion.ByteCount { + if b.mode == bbrModeProbeRtt { + return b.probeRttCongestionWindow() + } + + if b.InRecovery() { + return Min(b.congestionWindow, b.recoveryWindow) + } + + return b.congestionWindow +} + +func (b *bbrSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // Do nothing. +} + +func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + totalBytesAckedBefore := b.sampler.TotalBytesAcked() + totalBytesLostBefore := b.sampler.TotalBytesLost() + + var isRoundStart, minRttExpired bool + var excessAcked, bytesLost congestion.ByteCount + + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + var lastPacketSendState sendTimeState + + b.maybeApplimited(priorInFlight) + + // Update bytesInFlight + b.bytesInFlight = priorInFlight + for _, p := range ackedPackets { + b.bytesInFlight -= p.BytesAcked + } + for _, p := range lostPackets { + b.bytesInFlight -= p.BytesLost + } + + if len(ackedPackets) != 0 { + lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber + isRoundStart = b.updateRoundTripCounter(lastAckedPacket) + b.updateRecoveryState(lastAckedPacket, len(lostPackets) != 0, isRoundStart) + } + + sample := b.sampler.OnCongestionEvent(eventTime, + ackedPackets, lostPackets, b.maxBandwidth.GetBest(), infBandwidth, b.roundTripCount) + if sample.lastPacketSendState.isValid { + b.lastSampleIsAppLimited = sample.lastPacketSendState.isAppLimited + b.hasNoAppLimitedSample = b.hasNoAppLimitedSample || !b.lastSampleIsAppLimited + } + // Avoid updating |max_bandwidth_| if a) this is a loss-only event, or b) all + // packets in |acked_packets| did not generate valid samples. (e.g. ack of + // ack-only packets). In both cases, sampler_.total_bytes_acked() will not + // change. + if totalBytesAckedBefore != b.sampler.TotalBytesAcked() { + if !sample.sampleIsAppLimited || sample.sampleMaxBandwidth > b.maxBandwidth.GetBest() { + b.maxBandwidth.Update(sample.sampleMaxBandwidth, b.roundTripCount) + } + } + + if sample.sampleRtt != infRTT { + minRttExpired = b.maybeUpdateMinRtt(eventTime, sample.sampleRtt) + } + bytesLost = b.sampler.TotalBytesLost() - totalBytesLostBefore + + excessAcked = sample.extraAcked + lastPacketSendState = sample.lastPacketSendState + + if len(lostPackets) != 0 { + b.numLossEventsInRound++ + b.bytesLostInRound += bytesLost + } + + // Handle logic specific to PROBE_BW mode. + if b.mode == bbrModeProbeBw { + b.updateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) != 0) + } + + // Handle logic specific to STARTUP and DRAIN modes. + if isRoundStart && !b.isAtFullBandwidth { + b.checkIfFullBandwidthReached(&lastPacketSendState) + } + + b.maybeExitStartupOrDrain(eventTime) + + // Handle logic specific to PROBE_RTT. + b.maybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) + + // Calculate number of packets acked and lost. + bytesAcked := b.sampler.TotalBytesAcked() - totalBytesAckedBefore + + // After the model is updated, recalculate the pacing rate and congestion + // window. + b.calculatePacingRate(bytesLost) + b.calculateCongestionWindow(bytesAcked, excessAcked) + b.calculateRecoveryWindow(bytesAcked, bytesLost) + + // Cleanup internal state. + if len(lostPackets) != 0 { + lastLostPacket := lostPackets[len(lostPackets)-1].PacketNumber + b.sampler.RemoveObsoletePackets(lastLostPacket) + } + if isRoundStart { + b.numLossEventsInRound = 0 + b.bytesLostInRound = 0 + } +} + +func (b *bbrSender) PacingRate() Bandwidth { + if b.pacingRate == 0 { + return Bandwidth(b.highGain * float64( + BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt()))) + } + + return b.pacingRate +} + +func (b *bbrSender) hasGoodBandwidthEstimateForResumption() bool { + return b.hasNonAppLimitedSample() +} + +func (b *bbrSender) hasNonAppLimitedSample() bool { + return b.hasNoAppLimitedSample +} + +// Sets the pacing gain used in STARTUP. Must be greater than 1. +func (b *bbrSender) setHighGain(highGain float64) { + b.highGain = highGain + if b.mode == bbrModeStartup { + b.pacingGain = highGain + } +} + +// Sets the CWND gain used in STARTUP. Must be greater than 1. +func (b *bbrSender) setHighCwndGain(highCwndGain float64) { + b.highCwndGain = highCwndGain + if b.mode == bbrModeStartup { + b.congestionWindowGain = highCwndGain + } +} + +// Sets the gain used in DRAIN. Must be less than 1. +func (b *bbrSender) setDrainGain(drainGain float64) { + b.drainGain = drainGain +} + +// What's the current estimated bandwidth in bytes per second. +func (b *bbrSender) bandwidthEstimate() Bandwidth { + return b.maxBandwidth.GetBest() +} + +// Returns the current estimate of the RTT of the connection. Outside of the +// edge cases, this is minimum RTT. +func (b *bbrSender) getMinRtt() time.Duration { + if b.minRtt != 0 { + return b.minRtt + } + // min_rtt could be available if the handshake packet gets neutered then + // gets acknowledged. This could only happen for QUIC crypto where we do not + // drop keys. + minRtt := b.rttStats.MinRTT() + if minRtt == 0 { + return 100 * time.Millisecond + } else { + return minRtt + } +} + +// Computes the target congestion window using the specified gain. +func (b *bbrSender) getTargetCongestionWindow(gain float64) congestion.ByteCount { + bdp := bdpFromRttAndBandwidth(b.getMinRtt(), b.bandwidthEstimate()) + congestionWindow := congestion.ByteCount(gain * float64(bdp)) + + // BDP estimate will be zero if no bandwidth samples are available yet. + if congestionWindow == 0 { + congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow)) + } + + return Max(congestionWindow, b.minCongestionWindow) +} + +// The target congestion window during PROBE_RTT. +func (b *bbrSender) probeRttCongestionWindow() congestion.ByteCount { + return b.minCongestionWindow +} + +func (b *bbrSender) maybeUpdateMinRtt(now time.Time, sampleMinRtt time.Duration) bool { + // Do not expire min_rtt if none was ever available. + minRttExpired := b.minRtt != 0 && now.After(b.minRttTimestamp.Add(minRttExpiry)) + if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 { + b.minRtt = sampleMinRtt + b.minRttTimestamp = now + } + + return minRttExpired +} + +// Enters the STARTUP mode. +func (b *bbrSender) enterStartupMode(now time.Time) { + b.mode = bbrModeStartup + // b.maybeTraceStateChange(logging.CongestionStateStartup) + b.pacingGain = b.highGain + b.congestionWindowGain = b.highCwndGain +} + +// Enters the PROBE_BW mode. +func (b *bbrSender) enterProbeBandwidthMode(now time.Time) { + b.mode = bbrModeProbeBw + // b.maybeTraceStateChange(logging.CongestionStateProbeBw) + b.congestionWindowGain = b.congestionWindowGainConstant + + // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is + // excluded because in that case increased gain and decreased gain would not + // follow each other. + b.cycleCurrentOffset = int(fastrand.Int31n(congestion.PacketsPerConnectionID)) % (gainCycleLength - 1) + if b.cycleCurrentOffset >= 1 { + b.cycleCurrentOffset += 1 + } + + b.lastCycleStart = now + b.pacingGain = pacingGain[b.cycleCurrentOffset] +} + +// Updates the round-trip counter if a round-trip has passed. Returns true if +// the counter has been advanced. +func (b *bbrSender) updateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool { + if b.currentRoundTripEnd == invalidPacketNumber || lastAckedPacket > b.currentRoundTripEnd { + b.roundTripCount++ + b.currentRoundTripEnd = b.lastSentPacket + return true + } + return false +} + +// Updates the current gain used in PROBE_BW mode. +func (b *bbrSender) updateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) { + // In most cases, the cycle is advanced after an RTT passes. + shouldAdvanceGainCycling := now.After(b.lastCycleStart.Add(b.getMinRtt())) + // If the pacing gain is above 1.0, the connection is trying to probe the + // bandwidth by increasing the number of bytes in flight to at least + // pacing_gain * BDP. Make sure that it actually reaches the target, as long + // as there are no losses suggesting that the buffers are not able to hold + // that much. + if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.getTargetCongestionWindow(b.pacingGain) { + shouldAdvanceGainCycling = false + } + + // If pacing gain is below 1.0, the connection is trying to drain the extra + // queue which could have been incurred by probing prior to it. If the number + // of bytes in flight falls down to the estimated BDP value earlier, conclude + // that the queue has been successfully drained and exit this cycle early. + if b.pacingGain < 1.0 && b.bytesInFlight <= b.getTargetCongestionWindow(1) { + shouldAdvanceGainCycling = true + } + + if shouldAdvanceGainCycling { + b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % gainCycleLength + b.lastCycleStart = now + // Stay in low gain mode until the target BDP is hit. + // Low gain mode will be exited immediately when the target BDP is achieved. + if b.drainToTarget && b.pacingGain < 1 && + pacingGain[b.cycleCurrentOffset] == 1 && + b.bytesInFlight > b.getTargetCongestionWindow(1) { + return + } + b.pacingGain = pacingGain[b.cycleCurrentOffset] + } +} + +// Tracks for how many round-trips the bandwidth has not increased +// significantly. +func (b *bbrSender) checkIfFullBandwidthReached(lastPacketSendState *sendTimeState) { + if b.lastSampleIsAppLimited { + return + } + + target := Bandwidth(float64(b.bandwidthAtLastRound) * startupGrowthTarget) + if b.bandwidthEstimate() >= target { + b.bandwidthAtLastRound = b.bandwidthEstimate() + b.roundsWithoutBandwidthGain = 0 + if b.expireAckAggregationInStartup { + // Expire old excess delivery measurements now that bandwidth increased. + b.sampler.ResetMaxAckHeightTracker(0, b.roundTripCount) + } + return + } + + b.roundsWithoutBandwidthGain++ + if b.roundsWithoutBandwidthGain >= b.numStartupRtts || + b.shouldExitStartupDueToLoss(lastPacketSendState) { + b.isAtFullBandwidth = true + } +} + +func (b *bbrSender) maybeApplimited(bytesInFlight congestion.ByteCount) { + congestionWindow := b.GetCongestionWindow() + if bytesInFlight >= congestionWindow { + return + } + availableBytes := congestionWindow - bytesInFlight + drainLimited := b.mode == bbrModeDrain && bytesInFlight > congestionWindow/2 + if !drainLimited || availableBytes > maxBbrBurstPackets*b.maxDatagramSize { + b.sampler.OnAppLimited() + } +} + +// Transitions from STARTUP to DRAIN and from DRAIN to PROBE_BW if +// appropriate. +func (b *bbrSender) maybeExitStartupOrDrain(now time.Time) { + if b.mode == bbrModeStartup && b.isAtFullBandwidth { + b.mode = bbrModeDrain + // b.maybeTraceStateChange(logging.CongestionStateDrain) + b.pacingGain = b.drainGain + b.congestionWindowGain = b.highCwndGain + } + if b.mode == bbrModeDrain && b.bytesInFlight <= b.getTargetCongestionWindow(1) { + b.enterProbeBandwidthMode(now) + } +} + +// Decides whether to enter or exit PROBE_RTT. +func (b *bbrSender) maybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) { + if minRttExpired && !b.exitingQuiescence && b.mode != bbrModeProbeRtt { + b.mode = bbrModeProbeRtt + // b.maybeTraceStateChange(logging.CongestionStateProbRtt) + b.pacingGain = 1.0 + // Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight| + // is at the target small value. + b.exitProbeRttAt = time.Time{} + } + + if b.mode == bbrModeProbeRtt { + b.sampler.OnAppLimited() + // b.maybeTraceStateChange(logging.CongestionStateApplicationLimited) + + if b.exitProbeRttAt.IsZero() { + // If the window has reached the appropriate size, schedule exiting + // PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but + // we allow an extra packet since QUIC checks CWND before sending a + // packet. + if b.bytesInFlight < b.probeRttCongestionWindow()+congestion.MaxPacketBufferSize { + b.exitProbeRttAt = now.Add(probeRttTime) + b.probeRttRoundPassed = false + } + } else { + if isRoundStart { + b.probeRttRoundPassed = true + } + if now.Sub(b.exitProbeRttAt) >= 0 && b.probeRttRoundPassed { + b.minRttTimestamp = now + if !b.isAtFullBandwidth { + b.enterStartupMode(now) + } else { + b.enterProbeBandwidthMode(now) + } + } + } + } + + b.exitingQuiescence = false +} + +// Determines whether BBR needs to enter, exit or advance state of the +// recovery. +func (b *bbrSender) updateRecoveryState(lastAckedPacket congestion.PacketNumber, hasLosses, isRoundStart bool) { + // Disable recovery in startup, if loss-based exit is enabled. + if !b.isAtFullBandwidth { + return + } + + // Exit recovery when there are no losses for a round. + if hasLosses { + b.endRecoveryAt = b.lastSentPacket + } + + switch b.recoveryState { + case bbrRecoveryStateNotInRecovery: + if hasLosses { + b.recoveryState = bbrRecoveryStateConservation + // This will cause the |recovery_window_| to be set to the correct + // value in CalculateRecoveryWindow(). + b.recoveryWindow = 0 + // Since the conservation phase is meant to be lasting for a whole + // round, extend the current round as if it were started right now. + b.currentRoundTripEnd = b.lastSentPacket + } + case bbrRecoveryStateConservation: + if isRoundStart { + b.recoveryState = bbrRecoveryStateGrowth + } + fallthrough + case bbrRecoveryStateGrowth: + // Exit recovery if appropriate. + if !hasLosses && lastAckedPacket > b.endRecoveryAt { + b.recoveryState = bbrRecoveryStateNotInRecovery + } + } +} + +// Determines the appropriate pacing rate for the connection. +func (b *bbrSender) calculatePacingRate(bytesLost congestion.ByteCount) { + if b.bandwidthEstimate() == 0 { + return + } + + targetRate := Bandwidth(b.pacingGain * float64(b.bandwidthEstimate())) + if b.isAtFullBandwidth { + b.pacingRate = targetRate + return + } + + // Pace at the rate of initial_window / RTT as soon as RTT measurements are + // available. + if b.pacingRate == 0 && b.rttStats.MinRTT() != 0 { + b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT()) + return + } + + if b.detectOvershooting { + b.bytesLostWhileDetectingOvershooting += bytesLost + // Check for overshooting with network parameters adjusted when pacing rate + // > target_rate and loss has been detected. + if b.pacingRate > targetRate && b.bytesLostWhileDetectingOvershooting > 0 { + if b.hasNoAppLimitedSample || + b.bytesLostWhileDetectingOvershooting*congestion.ByteCount(b.bytesLostMultiplierWhileDetectingOvershooting) > b.initialCongestionWindow { + // We are fairly sure overshoot happens if 1) there is at least one + // non app-limited bw sample or 2) half of IW gets lost. Slow pacing + // rate. + b.pacingRate = Max(targetRate, BandwidthFromDelta(b.cwndToCalculateMinPacingRate, b.rttStats.MinRTT())) + b.bytesLostWhileDetectingOvershooting = 0 + b.detectOvershooting = false + } + } + } + + // Do not decrease the pacing rate during startup. + b.pacingRate = Max(b.pacingRate, targetRate) +} + +// Determines the appropriate congestion window for the connection. +func (b *bbrSender) calculateCongestionWindow(bytesAcked, excessAcked congestion.ByteCount) { + if b.mode == bbrModeProbeRtt { + return + } + + targetWindow := b.getTargetCongestionWindow(b.congestionWindowGain) + if b.isAtFullBandwidth { + // Add the max recently measured ack aggregation to CWND. + targetWindow += b.sampler.MaxAckHeight() + } else if b.enableAckAggregationDuringStartup { + // Add the most recent excess acked. Because CWND never decreases in + // STARTUP, this will automatically create a very localized max filter. + targetWindow += excessAcked + } + + // Instead of immediately setting the target CWND as the new one, BBR grows + // the CWND towards |target_window| by only increasing it |bytes_acked| at a + // time. + if b.isAtFullBandwidth { + b.congestionWindow = Min(targetWindow, b.congestionWindow+bytesAcked) + } else if b.congestionWindow < targetWindow || + b.sampler.TotalBytesAcked() < b.initialCongestionWindow { + // If the connection is not yet out of startup phase, do not decrease the + // window. + b.congestionWindow += bytesAcked + } + + // Enforce the limits on the congestion window. + b.congestionWindow = Max(b.congestionWindow, b.minCongestionWindow) + b.congestionWindow = Min(b.congestionWindow, b.maxCongestionWindow) +} + +// Determines the appropriate window that constrains the in-flight during recovery. +func (b *bbrSender) calculateRecoveryWindow(bytesAcked, bytesLost congestion.ByteCount) { + if b.recoveryState == bbrRecoveryStateNotInRecovery { + return + } + + // Set up the initial recovery window. + if b.recoveryWindow == 0 { + b.recoveryWindow = b.bytesInFlight + bytesAcked + b.recoveryWindow = Max(b.minCongestionWindow, b.recoveryWindow) + return + } + + // Remove losses from the recovery window, while accounting for a potential + // integer underflow. + if b.recoveryWindow >= bytesLost { + b.recoveryWindow = b.recoveryWindow - bytesLost + } else { + b.recoveryWindow = b.maxDatagramSize + } + + // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, + // release additional |bytes_acked| to achieve a slow-start-like behavior. + if b.recoveryState == bbrRecoveryStateGrowth { + b.recoveryWindow += bytesAcked + } + + // Always allow sending at least |bytes_acked| in response. + b.recoveryWindow = Max(b.recoveryWindow, b.bytesInFlight+bytesAcked) + b.recoveryWindow = Max(b.minCongestionWindow, b.recoveryWindow) +} + +// Return whether we should exit STARTUP due to excessive loss. +func (b *bbrSender) shouldExitStartupDueToLoss(lastPacketSendState *sendTimeState) bool { + if b.numLossEventsInRound < defaultStartupFullLossCount || !lastPacketSendState.isValid { + return false + } + + inflightAtSend := lastPacketSendState.bytesInFlight + + if inflightAtSend > 0 && b.bytesLostInRound > 0 { + if b.bytesLostInRound > congestion.ByteCount(float64(inflightAtSend)*quicBbr2DefaultLossThreshold) { + return true + } + return false + } + return false +} + +func bdpFromRttAndBandwidth(rtt time.Duration, bandwidth Bandwidth) congestion.ByteCount { + return congestion.ByteCount(rtt) * congestion.ByteCount(bandwidth) / congestion.ByteCount(BytesPerSecond) / congestion.ByteCount(time.Second) +} + +func GetInitialPacketSize(addr net.Addr) congestion.ByteCount { + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if udpAddr.IP.To4() != nil { + return congestion.InitialPacketSizeIPv4 + } else { + return congestion.InitialPacketSizeIPv6 + } + } else { + return congestion.MinInitialPacketSize + } +} diff --git a/transport/tuic/congestion_v2/clock.go b/transport/tuic/congestion_v2/clock.go new file mode 100644 index 00000000..405fae70 --- /dev/null +++ b/transport/tuic/congestion_v2/clock.go @@ -0,0 +1,18 @@ +package congestion + +import "time" + +// A Clock returns the current time +type Clock interface { + Now() time.Time +} + +// DefaultClock implements the Clock interface using the Go stdlib clock. +type DefaultClock struct{} + +var _ Clock = DefaultClock{} + +// Now gets the current time +func (DefaultClock) Now() time.Time { + return time.Now() +} diff --git a/transport/tuic/congestion_v2/minmax_go120.go b/transport/tuic/congestion_v2/minmax_go120.go new file mode 100644 index 00000000..1266edbc --- /dev/null +++ b/transport/tuic/congestion_v2/minmax_go120.go @@ -0,0 +1,19 @@ +//go:build !go1.21 + +package congestion + +import "golang.org/x/exp/constraints" + +func Max[T constraints.Ordered](a, b T) T { + if a < b { + return b + } + return a +} + +func Min[T constraints.Ordered](a, b T) T { + if a < b { + return a + } + return b +} diff --git a/transport/tuic/congestion_v2/minmax_go121.go b/transport/tuic/congestion_v2/minmax_go121.go new file mode 100644 index 00000000..65b06726 --- /dev/null +++ b/transport/tuic/congestion_v2/minmax_go121.go @@ -0,0 +1,13 @@ +//go:build go1.21 + +package congestion + +import "cmp" + +func Max[T cmp.Ordered](a, b T) T { + return max(a, b) +} + +func Min[T cmp.Ordered](a, b T) T { + return min(a, b) +} diff --git a/transport/tuic/congestion_v2/pacer.go b/transport/tuic/congestion_v2/pacer.go new file mode 100644 index 00000000..ba4ca138 --- /dev/null +++ b/transport/tuic/congestion_v2/pacer.go @@ -0,0 +1,71 @@ +package congestion + +import ( + "math" + "time" + + "github.com/metacubex/quic-go/congestion" +) + +const ( + maxBurstPackets = 10 +) + +// Pacer implements a token bucket pacing algorithm. +type Pacer struct { + budgetAtLastSent congestion.ByteCount + maxDatagramSize congestion.ByteCount + lastSentTime time.Time + getBandwidth func() congestion.ByteCount // in bytes/s +} + +func NewPacer(getBandwidth func() congestion.ByteCount) *Pacer { + p := &Pacer{ + budgetAtLastSent: maxBurstPackets * congestion.InitialPacketSizeIPv4, + maxDatagramSize: congestion.InitialPacketSizeIPv4, + getBandwidth: getBandwidth, + } + return p +} + +func (p *Pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *Pacer) Budget(now time.Time) congestion.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + return Min(p.maxBurstSize(), budget) +} + +func (p *Pacer) maxBurstSize() congestion.ByteCount { + return Max( + congestion.ByteCount((congestion.MinPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9, + maxBurstPackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value of time.Time if a packet can be sent immediately. +func (p *Pacer) TimeUntilSend() time.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return time.Time{} + } + return p.lastSentTime.Add(Max( + congestion.MinPacingDelay, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/ + float64(p.getBandwidth())))*time.Nanosecond, + )) +} + +func (p *Pacer) SetMaxDatagramSize(s congestion.ByteCount) { + p.maxDatagramSize = s +} diff --git a/transport/tuic/congestion_v2/packet_number_indexed_queue.go b/transport/tuic/congestion_v2/packet_number_indexed_queue.go new file mode 100644 index 00000000..119d36f6 --- /dev/null +++ b/transport/tuic/congestion_v2/packet_number_indexed_queue.go @@ -0,0 +1,199 @@ +package congestion + +import ( + "github.com/metacubex/quic-go/congestion" +) + +// packetNumberIndexedQueue is a queue of mostly continuous numbered entries +// which supports the following operations: +// - adding elements to the end of the queue, or at some point past the end +// - removing elements in any order +// - retrieving elements +// If all elements are inserted in order, all of the operations above are +// amortized O(1) time. +// +// Internally, the data structure is a deque where each element is marked as +// present or not. The deque starts at the lowest present index. Whenever an +// element is removed, it's marked as not present, and the front of the deque is +// cleared of elements that are not present. +// +// The tail of the queue is not cleared due to the assumption of entries being +// inserted in order, though removing all elements of the queue will return it +// to its initial state. +// +// Note that this data structure is inherently hazardous, since an addition of +// just two entries will cause it to consume all of the memory available. +// Because of that, it is not a general-purpose container and should not be used +// as one. + +type entryWrapper[T any] struct { + present bool + entry T +} + +type packetNumberIndexedQueue[T any] struct { + entries RingBuffer[entryWrapper[T]] + numberOfPresentEntries int + firstPacket congestion.PacketNumber +} + +func newPacketNumberIndexedQueue[T any](size int) *packetNumberIndexedQueue[T] { + q := &packetNumberIndexedQueue[T]{ + firstPacket: invalidPacketNumber, + } + + q.entries.Init(size) + + return q +} + +// Emplace inserts data associated |packet_number| into (or past) the end of the +// queue, filling up the missing intermediate entries as necessary. Returns +// true if the element has been inserted successfully, false if it was already +// in the queue or inserted out of order. +func (p *packetNumberIndexedQueue[T]) Emplace(packetNumber congestion.PacketNumber, entry *T) bool { + if packetNumber == invalidPacketNumber || entry == nil { + return false + } + + if p.IsEmpty() { + p.entries.PushBack(entryWrapper[T]{ + present: true, + entry: *entry, + }) + p.numberOfPresentEntries = 1 + p.firstPacket = packetNumber + return true + } + + // Do not allow insertion out-of-order. + if packetNumber <= p.LastPacket() { + return false + } + + // Handle potentially missing elements. + offset := int(packetNumber - p.FirstPacket()) + if gap := offset - p.entries.Len(); gap > 0 { + for i := 0; i < gap; i++ { + p.entries.PushBack(entryWrapper[T]{}) + } + } + + p.entries.PushBack(entryWrapper[T]{ + present: true, + entry: *entry, + }) + p.numberOfPresentEntries++ + return true +} + +// GetEntry Retrieve the entry associated with the packet number. Returns the pointer +// to the entry in case of success, or nullptr if the entry does not exist. +func (p *packetNumberIndexedQueue[T]) GetEntry(packetNumber congestion.PacketNumber) *T { + ew := p.getEntryWraper(packetNumber) + if ew == nil { + return nil + } + + return &ew.entry +} + +// Remove, Same as above, but if an entry is present in the queue, also call f(entry) +// before removing it. +func (p *packetNumberIndexedQueue[T]) Remove(packetNumber congestion.PacketNumber, f func(T)) bool { + ew := p.getEntryWraper(packetNumber) + if ew == nil { + return false + } + if f != nil { + f(ew.entry) + } + ew.present = false + p.numberOfPresentEntries-- + + if packetNumber == p.FirstPacket() { + p.clearup() + } + + return true +} + +// RemoveUpTo, but not including |packet_number|. +// Unused slots in the front are also removed, which means when the function +// returns, |first_packet()| can be larger than |packet_number|. +func (p *packetNumberIndexedQueue[T]) RemoveUpTo(packetNumber congestion.PacketNumber) { + for !p.entries.Empty() && + p.firstPacket != invalidPacketNumber && + p.firstPacket < packetNumber { + if p.entries.Front().present { + p.numberOfPresentEntries-- + } + p.entries.PopFront() + p.firstPacket++ + } + p.clearup() + + return +} + +// IsEmpty return if queue is empty. +func (p *packetNumberIndexedQueue[T]) IsEmpty() bool { + return p.numberOfPresentEntries == 0 +} + +// NumberOfPresentEntries returns the number of entries in the queue. +func (p *packetNumberIndexedQueue[T]) NumberOfPresentEntries() int { + return p.numberOfPresentEntries +} + +// EntrySlotsUsed returns the number of entries allocated in the underlying deque. This is +// proportional to the memory usage of the queue. +func (p *packetNumberIndexedQueue[T]) EntrySlotsUsed() int { + return p.entries.Len() +} + +// LastPacket returns packet number of the first entry in the queue. +func (p *packetNumberIndexedQueue[T]) FirstPacket() (packetNumber congestion.PacketNumber) { + return p.firstPacket +} + +// LastPacket returns packet number of the last entry ever inserted in the queue. Note that the +// entry in question may have already been removed. Zero if the queue is +// empty. +func (p *packetNumberIndexedQueue[T]) LastPacket() (packetNumber congestion.PacketNumber) { + if p.IsEmpty() { + return invalidPacketNumber + } + + return p.firstPacket + congestion.PacketNumber(p.entries.Len()-1) +} + +func (p *packetNumberIndexedQueue[T]) clearup() { + for !p.entries.Empty() && !p.entries.Front().present { + p.entries.PopFront() + p.firstPacket++ + } + if p.entries.Empty() { + p.firstPacket = invalidPacketNumber + } +} + +func (p *packetNumberIndexedQueue[T]) getEntryWraper(packetNumber congestion.PacketNumber) *entryWrapper[T] { + if packetNumber == invalidPacketNumber || + p.IsEmpty() || + packetNumber < p.firstPacket { + return nil + } + + offset := int(packetNumber - p.firstPacket) + if offset >= p.entries.Len() { + return nil + } + + ew := p.entries.Offset(offset) + if ew == nil || !ew.present { + return nil + } + + return ew +} diff --git a/transport/tuic/congestion_v2/ringbuffer.go b/transport/tuic/congestion_v2/ringbuffer.go new file mode 100644 index 00000000..e110c00f --- /dev/null +++ b/transport/tuic/congestion_v2/ringbuffer.go @@ -0,0 +1,118 @@ +package congestion + +// A RingBuffer is a ring buffer. +// It acts as a heap that doesn't cause any allocations. +type RingBuffer[T any] struct { + ring []T + headPos, tailPos int + full bool +} + +// Init preallocs a buffer with a certain size. +func (r *RingBuffer[T]) Init(size int) { + r.ring = make([]T, size) +} + +// Len returns the number of elements in the ring buffer. +func (r *RingBuffer[T]) Len() int { + if r.full { + return len(r.ring) + } + if r.tailPos >= r.headPos { + return r.tailPos - r.headPos + } + return r.tailPos - r.headPos + len(r.ring) +} + +// Empty says if the ring buffer is empty. +func (r *RingBuffer[T]) Empty() bool { + return !r.full && r.headPos == r.tailPos +} + +// PushBack adds a new element. +// If the ring buffer is full, its capacity is increased first. +func (r *RingBuffer[T]) PushBack(t T) { + if r.full || len(r.ring) == 0 { + r.grow() + } + r.ring[r.tailPos] = t + r.tailPos++ + if r.tailPos == len(r.ring) { + r.tailPos = 0 + } + if r.tailPos == r.headPos { + r.full = true + } +} + +// PopFront returns the next element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) PopFront() T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue") + } + r.full = false + t := r.ring[r.headPos] + r.ring[r.headPos] = *new(T) + r.headPos++ + if r.headPos == len(r.ring) { + r.headPos = 0 + } + return t +} + +// Offset returns the offset element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first +// and check if the index larger than buffer length. +func (r *RingBuffer[T]) Offset(index int) *T { + if r.Empty() || index >= r.Len() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: offset from invalid index") + } + offset := (r.headPos + index) % len(r.ring) + return &r.ring[offset] +} + +// Front returns the front element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) Front() *T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: front from an empty queue") + } + return &r.ring[r.headPos] +} + +// Back returns the back element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) Back() *T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: back from an empty queue") + } + return r.Offset(r.Len() - 1) +} + +// Grow the maximum size of the queue. +// This method assume the queue is full. +func (r *RingBuffer[T]) grow() { + oldRing := r.ring + newSize := len(oldRing) * 2 + if newSize == 0 { + newSize = 1 + } + r.ring = make([]T, newSize) + headLen := copy(r.ring, oldRing[r.headPos:]) + copy(r.ring[headLen:], oldRing[:r.headPos]) + r.headPos, r.tailPos, r.full = 0, len(oldRing), false +} + +// Clear removes all elements. +func (r *RingBuffer[T]) Clear() { + var zeroValue T + for i := range r.ring { + r.ring[i] = zeroValue + } + r.headPos, r.tailPos, r.full = 0, 0, false +} diff --git a/transport/tuic/congestion_v2/windowed_filter.go b/transport/tuic/congestion_v2/windowed_filter.go new file mode 100644 index 00000000..2421b48b --- /dev/null +++ b/transport/tuic/congestion_v2/windowed_filter.go @@ -0,0 +1,162 @@ +package congestion + +import ( + "golang.org/x/exp/constraints" +) + +// Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum) +// estimate of a stream of samples over some fixed time interval. (E.g., +// the minimum RTT over the past five minutes.) The algorithm keeps track of +// the best, second best, and third best min (or max) estimates, maintaining an +// invariant that the measurement time of the n'th best >= n-1'th best. + +// The algorithm works as follows. On a reset, all three estimates are set to +// the same sample. The second best estimate is then recorded in the second +// quarter of the window, and a third best estimate is recorded in the second +// half of the window, bounding the worst case error when the true min is +// monotonically increasing (or true max is monotonically decreasing) over the +// window. +// +// A new best sample replaces all three estimates, since the new best is lower +// (or higher) than everything else in the window and it is the most recent. +// The window thus effectively gets reset on every new min. The same property +// holds true for second best and third best estimates. Specifically, when a +// sample arrives that is better than the second best but not better than the +// best, it replaces the second and third best estimates but not the best +// estimate. Similarly, a sample that is better than the third best estimate +// but not the other estimates replaces only the third best estimate. +// +// Finally, when the best expires, it is replaced by the second best, which in +// turn is replaced by the third best. The newest sample replaces the third +// best. + +type WindowedFilterValue interface { + any +} + +type WindowedFilterTime interface { + constraints.Integer | constraints.Float +} + +type WindowedFilter[V WindowedFilterValue, T WindowedFilterTime] struct { + // Time length of window. + windowLength T + estimates []entry[V, T] + comparator func(V, V) int +} + +type entry[V WindowedFilterValue, T WindowedFilterTime] struct { + sample V + time T +} + +// Compares two values and returns true if the first is greater than or equal +// to the second. +func MaxFilter[O constraints.Ordered](a, b O) int { + if a > b { + return 1 + } else if a < b { + return -1 + } + return 0 +} + +// Compares two values and returns true if the first is less than or equal +// to the second. +func MinFilter[O constraints.Ordered](a, b O) int { + if a < b { + return 1 + } else if a > b { + return -1 + } + return 0 +} + +func NewWindowedFilter[V WindowedFilterValue, T WindowedFilterTime](windowLength T, comparator func(V, V) int) *WindowedFilter[V, T] { + return &WindowedFilter[V, T]{ + windowLength: windowLength, + estimates: make([]entry[V, T], 3, 3), + comparator: comparator, + } +} + +// Changes the window length. Does not update any current samples. +func (f *WindowedFilter[V, T]) SetWindowLength(windowLength T) { + f.windowLength = windowLength +} + +func (f *WindowedFilter[V, T]) GetBest() V { + return f.estimates[0].sample +} + +func (f *WindowedFilter[V, T]) GetSecondBest() V { + return f.estimates[1].sample +} + +func (f *WindowedFilter[V, T]) GetThirdBest() V { + return f.estimates[2].sample +} + +// Updates best estimates with |sample|, and expires and updates best +// estimates as necessary. +func (f *WindowedFilter[V, T]) Update(newSample V, newTime T) { + // Reset all estimates if they have not yet been initialized, if new sample + // is a new best, or if the newest recorded estimate is too old. + if f.comparator(f.estimates[0].sample, *new(V)) == 0 || + f.comparator(newSample, f.estimates[0].sample) >= 0 || + newTime-f.estimates[2].time > f.windowLength { + f.Reset(newSample, newTime) + return + } + + if f.comparator(newSample, f.estimates[1].sample) >= 0 { + f.estimates[1] = entry[V, T]{newSample, newTime} + f.estimates[2] = f.estimates[1] + } else if f.comparator(newSample, f.estimates[2].sample) >= 0 { + f.estimates[2] = entry[V, T]{newSample, newTime} + } + + // Expire and update estimates as necessary. + if newTime-f.estimates[0].time > f.windowLength { + // The best estimate hasn't been updated for an entire window, so promote + // second and third best estimates. + f.estimates[0] = f.estimates[1] + f.estimates[1] = f.estimates[2] + f.estimates[2] = entry[V, T]{newSample, newTime} + // Need to iterate one more time. Check if the new best estimate is + // outside the window as well, since it may also have been recorded a + // long time ago. Don't need to iterate once more since we cover that + // case at the beginning of the method. + if newTime-f.estimates[0].time > f.windowLength { + f.estimates[0] = f.estimates[1] + f.estimates[1] = f.estimates[2] + } + return + } + if f.comparator(f.estimates[1].sample, f.estimates[0].sample) == 0 && + newTime-f.estimates[1].time > f.windowLength/4 { + // A quarter of the window has passed without a better sample, so the + // second-best estimate is taken from the second quarter of the window. + f.estimates[1] = entry[V, T]{newSample, newTime} + f.estimates[2] = f.estimates[1] + return + } + + if f.comparator(f.estimates[2].sample, f.estimates[1].sample) == 0 && + newTime-f.estimates[2].time > f.windowLength/2 { + // We've passed a half of the window without a better estimate, so take + // a third-best estimate from the second half of the window. + f.estimates[2] = entry[V, T]{newSample, newTime} + } +} + +// Resets all estimates to new sample. +func (f *WindowedFilter[V, T]) Reset(newSample V, newTime T) { + f.estimates[2] = entry[V, T]{newSample, newTime} + f.estimates[1] = f.estimates[2] + f.estimates[0] = f.estimates[1] +} + +func (f *WindowedFilter[V, T]) Clear() { + f.estimates = make([]entry[V, T], 3, 3) +}