From 5420012d43bcd6484e3e9b06e09793fe83dfa80c Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Fri, 3 Oct 2025 13:50:01 +0200 Subject: [PATCH 1/6] Add refactored GCC implementation --- bwe.go | 6 + ecn.go | 25 +++ gcc/arrival_group_accumulator.go | 57 +++++++ gcc/arrival_group_accumulator_test.go | 210 +++++++++++++++++++++++++ gcc/delay_rate_controller.go | 90 +++++++++++ gcc/delivery_rate_estimator.go | 93 +++++++++++ gcc/delivery_rate_estimator_test.go | 77 +++++++++ gcc/exponential_moving_average.go | 22 +++ gcc/exponential_moving_average_test.go | 125 +++++++++++++++ gcc/gcc.go | 3 + gcc/kalman.go | 87 ++++++++++ gcc/loss_rate_controller.go | 56 +++++++ gcc/loss_rate_controller_test.go | 89 +++++++++++ gcc/overuse_detector.go | 87 ++++++++++ gcc/overuse_detector_test.go | 189 ++++++++++++++++++++++ gcc/rate_controller.go | 82 ++++++++++ gcc/rate_controller_test.go | 146 +++++++++++++++++ gcc/send_side_bwe.go | 89 +++++++++++ gcc/state.go | 79 ++++++++++ gcc/state_test.go | 30 ++++ gcc/usage.go | 27 ++++ packet.go | 50 ++++++ 22 files changed, 1719 insertions(+) create mode 100644 bwe.go create mode 100644 ecn.go create mode 100644 gcc/arrival_group_accumulator.go create mode 100644 gcc/arrival_group_accumulator_test.go create mode 100644 gcc/delay_rate_controller.go create mode 100644 gcc/delivery_rate_estimator.go create mode 100644 gcc/delivery_rate_estimator_test.go create mode 100644 gcc/exponential_moving_average.go create mode 100644 gcc/exponential_moving_average_test.go create mode 100644 gcc/gcc.go create mode 100644 gcc/kalman.go create mode 100644 gcc/loss_rate_controller.go create mode 100644 gcc/loss_rate_controller_test.go create mode 100644 gcc/overuse_detector.go create mode 100644 gcc/overuse_detector_test.go create mode 100644 gcc/rate_controller.go create mode 100644 gcc/rate_controller_test.go create mode 100644 gcc/send_side_bwe.go create mode 100644 gcc/state.go create mode 100644 gcc/state_test.go create mode 100644 gcc/usage.go create mode 100644 packet.go diff --git a/bwe.go b/bwe.go new file mode 100644 index 0000000..461ecd4 --- /dev/null +++ b/bwe.go @@ -0,0 +1,6 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +// Package bwe implements data structures that are common to all bandwidth +// estimators. +package bwe diff --git a/ecn.go b/ecn.go new file mode 100644 index 0000000..8aaf33f --- /dev/null +++ b/ecn.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package bwe + +// ECN represents the ECN bits of an IP packet header. +type ECN uint8 + +const ( + // ECNNonECT signals Non ECN-Capable Transport, Non-ECT. + // nolint:misspell + ECNNonECT ECN = iota // 00 + + // ECNECT1 signals ECN Capable Transport, ECT(0). + // nolint:misspell + ECNECT1 // 01 + + // ECNECT0 signals ECN Capable Transport, ECT(1). + // nolint:misspell + ECNECT0 // 10 + + // ECNCE signals ECN Congestion Encountered, CE. + // nolint:misspell + ECNCE // 11 +) diff --git a/gcc/arrival_group_accumulator.go b/gcc/arrival_group_accumulator.go new file mode 100644 index 0000000..70a548f --- /dev/null +++ b/gcc/arrival_group_accumulator.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "time" + + "github.com/pion/bwe" +) + +type arrivalGroup []bwe.Packet + +type arrivalGroupAccumulator struct { + next arrivalGroup + burstInterval time.Duration + maxBurstDuration time.Duration +} + +func newArrivalGroupAccumulator() *arrivalGroupAccumulator { + return &arrivalGroupAccumulator{ + next: make([]bwe.Packet, 0), + burstInterval: 5 * time.Millisecond, + maxBurstDuration: 100 * time.Millisecond, + } +} + +func (a *arrivalGroupAccumulator) onPacketAcked(ack bwe.Packet) arrivalGroup { + if len(a.next) == 0 { + a.next = append(a.next, ack) + + return nil + } + + sendTimeDelta := ack.Departure.Sub(a.next[0].Departure) + if sendTimeDelta < a.burstInterval { + a.next = append(a.next, ack) + + return nil + } + + arrivalTimeDeltaLast := ack.Arrival.Sub(a.next[len(a.next)-1].Arrival) + arrivalTimeDeltaFirst := ack.Arrival.Sub(a.next[0].Arrival) + propagationDelta := arrivalTimeDeltaFirst - sendTimeDelta + + if propagationDelta < 0 && arrivalTimeDeltaLast <= a.burstInterval && arrivalTimeDeltaFirst < a.maxBurstDuration { + a.next = append(a.next, ack) + + return nil + } + + group := make(arrivalGroup, len(a.next)) + copy(group, a.next) + a.next = arrivalGroup{ack} + + return group +} diff --git a/gcc/arrival_group_accumulator_test.go b/gcc/arrival_group_accumulator_test.go new file mode 100644 index 0000000..656827a --- /dev/null +++ b/gcc/arrival_group_accumulator_test.go @@ -0,0 +1,210 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestArrivalGroupAccumulator(t *testing.T) { + triggerNewGroupElement := Acknowledgment{ + Departure: time.Time{}.Add(time.Second), + Arrival: time.Time{}.Add(time.Second), + } + cases := []struct { + name string + log []Acknowledgment + exp []arrivalGroup + }{ + { + name: "emptyCreatesNoGroups", + log: []Acknowledgment{}, + exp: []arrivalGroup{}, + }, + { + name: "createsSingleElementGroup", + log: []Acknowledgment{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(time.Millisecond), + }, + }, + }, + }, + { + name: "createsTwoElementGroup", + log: []Acknowledgment{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }}, + }, + { + name: "createsTwoArrivalGroups1", + log: []Acknowledgment{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(24 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(24 * time.Millisecond), + }, + }, + }, + }, + { + name: "ignoresOutOfOrderPackets", + log: []Acknowledgment{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(8 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(8 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + }, + }, + }, + { + name: "newGroupBecauseOfInterDepartureTime", + log: []Acknowledgment{ + { + SeqNr: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SeqNr: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SeqNr: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SeqNr: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + SeqNr: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SeqNr: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + }, + { + { + SeqNr: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SeqNr: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + }, + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + aga := newArrivalGroupAccumulator() + received := []arrivalGroup{} + for _, ack := range tc.log { + next := aga.onPacketAcked(ack) + if next != nil { + received = append(received, next) + } + } + assert.Equal(t, tc.exp, received) + }) + } +} diff --git a/gcc/delay_rate_controller.go b/gcc/delay_rate_controller.go new file mode 100644 index 0000000..c5beda3 --- /dev/null +++ b/gcc/delay_rate_controller.go @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "time" + + "github.com/pion/bwe" + "github.com/pion/logging" +) + +type delayRateController struct { + log logging.LeveledLogger + aga *arrivalGroupAccumulator + last arrivalGroup + kf *kalmanFilter + od *overuseDetector + rc *rateController + latestUsage usage + samples int +} + +func newDelayRateController(initialRate int, logger logging.LeveledLogger) *delayRateController { + return &delayRateController{ + log: logger, + aga: newArrivalGroupAccumulator(), + last: []bwe.Packet{}, + kf: newKalmanFilter(), + od: newOveruseDetector(true), + rc: newRateController(initialRate), + latestUsage: 0, + samples: 0, + } +} + +func (c *delayRateController) onPacketAcked(ack bwe.Packet) { + next := c.aga.onPacketAcked(ack) + if next == nil { + return + } + if len(next) == 0 { + // ignore empty groups, should never occur + return + } + if len(c.last) == 0 { + c.last = next + + return + } + + prevSize := groupSize(c.last) + nextSize := groupSize(next) + sizeDelta := nextSize - prevSize + + interArrivalTime := next[len(next)-1].Arrival.Sub(c.last[len(c.last)-1].Arrival) + interDepartureTime := next[len(next)-1].Departure.Sub(c.last[len(c.last)-1].Departure) + interGroupDelay := interArrivalTime - interDepartureTime + estimate := c.kf.update(float64(interGroupDelay.Milliseconds()), float64(sizeDelta)) + c.samples++ + c.latestUsage = c.od.update(ack.Arrival, estimate, c.samples) + c.last = next + c.log.Tracef( + "ts=%v.%06d, seq=%v, size=%v, interArrivalTime=%v, interDepartureTime=%v, interGroupDelay=%v, estimate=%v, threshold=%v, usage=%v, state=%v", // nolint + c.last[0].Departure.UTC().Format("2006/01/02 15:04:05"), + c.last[0].Departure.UTC().Nanosecond()/1e3, + next[0].SequenceNumber, + nextSize, + interArrivalTime.Microseconds(), + interDepartureTime.Microseconds(), + interGroupDelay.Microseconds(), + estimate, + c.od.delayThreshold, + int(c.latestUsage), + int(c.rc.s), + ) +} + +func (c *delayRateController) update(ts time.Time, lastDeliveryRate int, rtt time.Duration) int { + return c.rc.update(ts, c.latestUsage, lastDeliveryRate, rtt) +} + +func groupSize(group arrivalGroup) int { + sum := 0 + for _, ack := range group { + sum += int(ack.Size) + } + + return sum +} diff --git a/gcc/delivery_rate_estimator.go b/gcc/delivery_rate_estimator.go new file mode 100644 index 0000000..2e176c1 --- /dev/null +++ b/gcc/delivery_rate_estimator.go @@ -0,0 +1,93 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "container/heap" + "time" +) + +type deliveryRateHeapItem struct { + arrival time.Time + size int +} + +type deliveryRateHeap []deliveryRateHeapItem + +// Len implements heap.Interface. +func (d deliveryRateHeap) Len() int { + return len(d) +} + +// Less implements heap.Interface. +func (d deliveryRateHeap) Less(i int, j int) bool { + return d[i].arrival.Before(d[j].arrival) +} + +// Pop implements heap.Interface. +func (d *deliveryRateHeap) Pop() any { + old := *d + n := len(old) + x := old[n-1] + *d = old[0 : n-1] + + return x +} + +// Push implements heap.Interface. +func (d *deliveryRateHeap) Push(x any) { + // nolint + *d = append(*d, x.(deliveryRateHeapItem)) +} + +// Swap implements heap.Interface. +func (d deliveryRateHeap) Swap(i int, j int) { + d[i], d[j] = d[j], d[i] +} + +type deliveryRateEstimator struct { + window time.Duration + latestArrival time.Time + history *deliveryRateHeap +} + +func newDeliveryRateEstimator(window time.Duration) *deliveryRateEstimator { + return &deliveryRateEstimator{ + window: window, + latestArrival: time.Time{}, + history: &deliveryRateHeap{}, + } +} + +func (e *deliveryRateEstimator) onPacketAcked(arrival time.Time, size int) { + if arrival.After(e.latestArrival) { + e.latestArrival = arrival + } + heap.Push(e.history, deliveryRateHeapItem{ + arrival: arrival, + size: size, + }) +} + +func (e *deliveryRateEstimator) getRate() int { + deadline := e.latestArrival.Add(-e.window) + for len(*e.history) > 0 && (*e.history)[0].arrival.Before(deadline) { + heap.Pop(e.history) + } + earliest := e.latestArrival + sum := 0 + for _, i := range *e.history { + if i.arrival.Before(earliest) { + earliest = i.arrival + } + sum += i.size + } + d := e.latestArrival.Sub(earliest) + if d == 0 { + return 0 + } + rate := 8 * float64(sum) / d.Seconds() + + return int(rate) +} diff --git a/gcc/delivery_rate_estimator_test.go b/gcc/delivery_rate_estimator_test.go new file mode 100644 index 0000000..0324ffb --- /dev/null +++ b/gcc/delivery_rate_estimator_test.go @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDeliveryRateEstimator(t *testing.T) { + type ack struct { + arrival time.Time + size int + } + cases := []struct { + window time.Duration + acks []ack + expectedRate int + }{ + { + window: 0, + acks: []ack{}, + expectedRate: 0, + }, + { + window: time.Second, + acks: []ack{}, + expectedRate: 0, + }, + { + window: time.Second, + acks: []ack{ + {time.Time{}, 1200}, + }, + expectedRate: 0, + }, + { + window: time.Second, + acks: []ack{ + {time.Time{}.Add(time.Millisecond), 1200}, + }, + expectedRate: 0, + }, + { + window: time.Second, + acks: []ack{ + {time.Time{}.Add(time.Second), 1200}, + {time.Time{}.Add(1500 * time.Millisecond), 1200}, + {time.Time{}.Add(2 * time.Second), 1200}, + }, + expectedRate: 28800, + }, + { + window: time.Second, + acks: []ack{ + {time.Time{}.Add(500 * time.Millisecond), 1200}, + {time.Time{}.Add(time.Second), 1200}, + {time.Time{}.Add(1500 * time.Millisecond), 1200}, + {time.Time{}.Add(2 * time.Second), 1200}, + }, + expectedRate: 28800, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + e := newDeliveryRateEstimator(tc.window) + for _, ack := range tc.acks { + e.onPacketAcked(ack.arrival, ack.size) + } + assert.Equal(t, tc.expectedRate, e.getRate()) + }) + } +} diff --git a/gcc/exponential_moving_average.go b/gcc/exponential_moving_average.go new file mode 100644 index 0000000..c20450b --- /dev/null +++ b/gcc/exponential_moving_average.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +type exponentialMovingAverage struct { + initialized bool + alpha float64 + average float64 + variance float64 +} + +func (a *exponentialMovingAverage) update(sample float64) { + if !a.initialized { + a.average = sample + a.initialized = true + } else { + delta := sample - a.average + a.average = a.alpha*sample + (1-a.alpha)*a.average + a.variance = (1-a.alpha)*a.variance + a.alpha*(1-a.alpha)*(delta*delta) + } +} diff --git a/gcc/exponential_moving_average_test.go b/gcc/exponential_moving_average_test.go new file mode 100644 index 0000000..293b90c --- /dev/null +++ b/gcc/exponential_moving_average_test.go @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +// python to generate test cases: +// import numpy as np +// import pandas as pd +// data = np.random.randint(1, 10, size=10) +// df = pd.DataFrame(data) +// expectedAvg = df.ewm(alpha=0.9, adjust=False).mean() +// expectedVar = df.ewm(alpha=0.9, adjust=False).var(bias=True) + +func TestExponentialMovingAverage(t *testing.T) { + cases := []struct { + alpha float64 + updates []float64 + expectedAvg []float64 + expectedVar []float64 + }{ + { + alpha: 0.9, + updates: []float64{}, + expectedAvg: []float64{}, + expectedVar: []float64{}, + }, + { + alpha: 0.9, + updates: []float64{1, 2, 3, 4}, + expectedAvg: []float64{ + 1.000, + 1.900, + 2.890, + 3.889, + }, + expectedVar: []float64{ + 0.000000, + 0.090000, + 0.117900, + 0.122679, + }, + }, + { + alpha: 0.9, + updates: []float64{8, 8, 5, 1, 3, 1, 8, 2, 8, 9}, + expectedAvg: []float64{ + 8.000000, + 8.000000, + 5.300000, + 1.430000, + 2.843000, + 1.184300, + 7.318430, + 2.531843, + 7.453184, + 8.845318, + }, + expectedVar: []float64{ + 0.000000, + 0.000000, + 0.810000, + 1.745100, + 0.396351, + 0.345334, + 4.215372, + 2.967250, + 2.987792, + 0.514117, + }, + }, + { + alpha: 0.9, + updates: []float64{7, 5, 6, 7, 3, 6, 8, 9, 5, 5}, + expectedAvg: []float64{ + 7.000000, + 5.200000, + 5.920000, + 6.892000, + 3.389200, + 5.738920, + 7.773892, + 8.877389, + 5.387739, + 5.038774, + }, + expectedVar: []float64{ + 0.000000, + 0.360000, + 0.093600, + 0.114336, + 1.374723, + 0.750937, + 0.535217, + 0.188822, + 1.371955, + 0.150726, + }, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + a := exponentialMovingAverage{ + alpha: tc.alpha, + average: 0, + variance: 0, + } + avgs := []float64{} + vars := []float64{} + for _, u := range tc.updates { + a.update(u) + avgs = append(avgs, a.average) + vars = append(vars, a.variance) + } + assert.InDeltaSlice(t, tc.expectedAvg, avgs, 0.001) + assert.InDeltaSlice(t, tc.expectedVar, vars, 0.001) + }) + } +} diff --git a/gcc/gcc.go b/gcc/gcc.go new file mode 100644 index 0000000..871b875 --- /dev/null +++ b/gcc/gcc.go @@ -0,0 +1,3 @@ +// Package gcc implements a congestion controller based on +// https://datatracker.ietf.org/doc/html/draft-ietf-rmcat-gcc-02. +package gcc diff --git a/gcc/kalman.go b/gcc/kalman.go new file mode 100644 index 0000000..3079f31 --- /dev/null +++ b/gcc/kalman.go @@ -0,0 +1,87 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "math" +) + +type kalmanFilter struct { + state [2]float64 // [slope, offset] + + processNoise [2]float64 + e [2][2]float64 + avgNoise float64 + varNoise float64 +} + +func newKalmanFilter() *kalmanFilter { + kf := &kalmanFilter{ + state: [2]float64{8.0 / 512.0, 0}, + processNoise: [2]float64{1e-13, 1e-3}, + e: [2][2]float64{{100.0, 0}, {0, 1e-1}}, + varNoise: 50.0, + } + + return kf +} + +func (k *kalmanFilter) update(timeDelta float64, sizeDelta float64) float64 { + k.e[0][0] += k.processNoise[0] + k.e[1][1] += k.processNoise[1] + + // nolint + h := [2]float64{sizeDelta, 1.0} + Eh := [2]float64{ + k.e[0][0]*h[0] + k.e[0][1]*h[1], + k.e[1][0]*h[0] + k.e[1][1]*h[1], + } + residual := timeDelta - (k.state[0]*h[0] + k.state[1]*h[1]) + + maxResidual := 3.0 * math.Sqrt(k.varNoise) + if math.Abs(residual) < maxResidual { + k.updateNoiseEstimate(residual, timeDelta) + } else { + if residual < 0 { + k.updateNoiseEstimate(-maxResidual, timeDelta) + } else { + k.updateNoiseEstimate(maxResidual, timeDelta) + } + } + + denom := k.varNoise + h[0]*Eh[0] + h[1]*Eh[1] + + // nolint + K := [2]float64{ + Eh[0] / denom, Eh[1] / denom, + } + + IKh := [2][2]float64{ + {1.0 - K[0]*h[0], -K[0] * h[1]}, + {-K[1] * h[0], 1.0 - K[1]*h[1]}, + } + + e00 := k.e[0][0] + e01 := k.e[0][1] + + k.e[0][0] = e00*IKh[0][0] + k.e[1][0]*IKh[0][1] + k.e[0][1] = e01*IKh[0][0] + k.e[1][1]*IKh[0][1] + k.e[1][0] = e00*IKh[1][0] + k.e[1][0]*IKh[1][1] + k.e[1][1] = e01*IKh[1][0] + k.e[1][1]*IKh[1][1] + + k.state[0] += K[0] * residual + k.state[1] += K[1] * residual + + return k.state[1] +} + +func (k *kalmanFilter) updateNoiseEstimate(residual float64, timeDelta float64) { + alpha := 0.002 + beta := math.Pow(1-alpha, timeDelta*30.0/1000.0) + k.avgNoise = beta*k.avgNoise + (1-beta)*residual + k.varNoise = beta*k.varNoise + (1-beta)*(k.avgNoise-residual)*(k.avgNoise-residual) + if k.varNoise < 1 { + k.varNoise = 1 + } +} diff --git a/gcc/loss_rate_controller.go b/gcc/loss_rate_controller.go new file mode 100644 index 0000000..63fa41c --- /dev/null +++ b/gcc/loss_rate_controller.go @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +type lossRateController struct { + bitrate int + min, max float64 + + packetsSinceLastUpdate int + arrivedSinceLastUpdate int + lostSinceLastUpdate int +} + +func newLossRateController(initialRate, minRate, maxRate int) *lossRateController { + return &lossRateController{ + bitrate: initialRate, + min: float64(minRate), + max: float64(maxRate), + packetsSinceLastUpdate: 0, + arrivedSinceLastUpdate: 0, + lostSinceLastUpdate: 0, + } +} + +func (l *lossRateController) onPacketAcked() { + l.packetsSinceLastUpdate++ + l.arrivedSinceLastUpdate++ +} + +func (l *lossRateController) onPacketLost() { + l.packetsSinceLastUpdate++ + l.lostSinceLastUpdate++ +} + +func (l *lossRateController) update(lastDeliveryRate int) int { + lossRate := float64(l.lostSinceLastUpdate) / float64(l.packetsSinceLastUpdate) + var target float64 + if lossRate > 0.1 { + target = float64(l.bitrate) * (1 - 0.5*lossRate) + target = max(target, l.min) + } else if lossRate < 0.02 { + target = float64(l.bitrate) * 1.05 + target = max(min(target, 1.5*float64(lastDeliveryRate)), float64(l.bitrate)) + target = min(target, l.max) + } + if target != 0 { + l.bitrate = int(target) + } + + l.packetsSinceLastUpdate = 0 + l.arrivedSinceLastUpdate = 0 + l.lostSinceLastUpdate = 0 + + return l.bitrate +} diff --git a/gcc/loss_rate_controller_test.go b/gcc/loss_rate_controller_test.go new file mode 100644 index 0000000..ad0a425 --- /dev/null +++ b/gcc/loss_rate_controller_test.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLossRateController(t *testing.T) { + cases := []struct { + init, min, max int + acked int + lost int + deliveredRate int + expectedRate int + }{ + {}, // all zeros + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 0, + lost: 0, + deliveredRate: 0, + expectedRate: 100_000, + }, + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 99, + lost: 1, + deliveredRate: 100_000, + expectedRate: 105_000, + }, + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 99, + lost: 1, + deliveredRate: 90_000, + expectedRate: 105_000, + }, + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 95, + lost: 5, + deliveredRate: 99_000, + expectedRate: 100_000, + }, + { + init: 100_000, + min: 50_000, + max: 1_000_000, + acked: 89, + lost: 11, + deliveredRate: 90_000, + expectedRate: 94_500, + }, + { + init: 100_000, + min: 100_000, + max: 1_000_000, + acked: 89, + lost: 11, + deliveredRate: 90_000, + expectedRate: 100_000, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + lrc := newLossRateController(tc.init, tc.min, tc.max) + for i := 0; i < tc.acked; i++ { + lrc.onPacketAcked() + } + for i := 0; i < tc.lost; i++ { + lrc.onPacketLost() + } + assert.Equal(t, tc.expectedRate, lrc.update(tc.deliveredRate)) + }) + } +} diff --git a/gcc/overuse_detector.go b/gcc/overuse_detector.go new file mode 100644 index 0000000..2235e4f --- /dev/null +++ b/gcc/overuse_detector.go @@ -0,0 +1,87 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "math" + "time" +) + +const ( + kU = 0.01 + kD = 0.00018 + + maxNumDeltas = 60 +) + +type overuseDetector struct { + adaptiveThreshold bool + overUseTimeThreshold time.Duration + delayThreshold float64 + lastEstimate time.Duration + lastUpdate time.Time + firstOverUse time.Time + inOveruse bool + lastUsage usage +} + +func newOveruseDetector(adaptive bool) *overuseDetector { + return &overuseDetector{ + adaptiveThreshold: adaptive, + overUseTimeThreshold: 10 * time.Millisecond, + delayThreshold: 12.5, + lastEstimate: 0, + lastUpdate: time.Time{}, + firstOverUse: time.Time{}, + inOveruse: false, + } +} + +func (d *overuseDetector) update(ts time.Time, trend float64, numDeltas int) usage { + if numDeltas < 2 { + return usageNormal + } + modifiedTrend := float64(min(numDeltas, maxNumDeltas)) * trend + + switch { + case modifiedTrend > d.delayThreshold: + if d.firstOverUse.IsZero() { + d.firstOverUse = ts + } + if ts.Sub(d.firstOverUse) > d.overUseTimeThreshold { + d.firstOverUse = time.Time{} + d.lastUsage = usageOver + } + case modifiedTrend < -d.delayThreshold: + d.firstOverUse = time.Time{} + d.lastUsage = usageUnder + default: + d.firstOverUse = time.Time{} + d.lastUsage = usageNormal + } + if d.adaptiveThreshold { + d.adaptThreshold(ts, modifiedTrend) + } + + return d.lastUsage +} + +func (d *overuseDetector) adaptThreshold(ts time.Time, modifiedTrend float64) { + if d.lastUpdate.IsZero() { + d.lastUpdate = ts + } + if math.Abs(modifiedTrend) > d.delayThreshold+15 { + d.lastUpdate = ts + + return + } + k := kU + if math.Abs(modifiedTrend) < d.delayThreshold { + k = kD + } + delta := min(ts.Sub(d.lastUpdate), 100*time.Millisecond) + d.delayThreshold += k * (math.Abs(modifiedTrend) - d.delayThreshold) * float64(delta.Milliseconds()) + d.delayThreshold = max(min(d.delayThreshold, 600.0), 6.0) + d.lastUpdate = ts +} diff --git a/gcc/overuse_detector_test.go b/gcc/overuse_detector_test.go new file mode 100644 index 0000000..11299b2 --- /dev/null +++ b/gcc/overuse_detector_test.go @@ -0,0 +1,189 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestOveruseDetectorUpdate(t *testing.T) { + type estimate struct { + ts time.Time + estimate float64 + numDeltas int + } + cases := []struct { + name string + adaptive bool + values []estimate + expected []usage + }{ + { + name: "noEstimateNoUsageStatic", + adaptive: false, + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuseStatic", + adaptive: false, + values: []estimate{ + {time.Time{}, 1.0, 1}, + {time.Time{}.Add(5 * time.Millisecond), 20, 2}, + {time.Time{}.Add(20 * time.Millisecond), 30, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluseStatic", + adaptive: false, + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuseStatic", + adaptive: false, + values: []estimate{{time.Time{}, -20, 2}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelayStatic", + adaptive: false, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(2 * time.Millisecond), 30, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreasedStatic", + adaptive: false, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(10 * time.Millisecond), 40, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + {time.Time{}.Add(35 * time.Millisecond), 3, 4}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + { + name: "noEstimateNoUsageAdaptive", + adaptive: true, + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuseAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}, 1, 1}, + {time.Time{}.Add(5 * time.Millisecond), 20, 2}, + {time.Time{}.Add(20 * time.Millisecond), 30, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluseAdaptive", + adaptive: true, + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuseAdaptive", + adaptive: true, + values: []estimate{{time.Time{}, -20, 2}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelayAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(2 * time.Millisecond), 30, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreasedAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(10 * time.Millisecond), 40, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + {time.Time{}.Add(35 * time.Millisecond), 3, 4}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + od := newOveruseDetector(tc.adaptive) + received := []usage{} + for _, e := range tc.values { + u := od.update(e.ts, e.estimate, e.numDeltas) + received = append(received, u) + } + assert.Equal(t, tc.expected, received) + }) + } +} + +func TestOveruseDetectorAdaptThreshold(t *testing.T) { + cases := []struct { + name string + od *overuseDetector + ts time.Time + estimate float64 + expectedThreshold float64 + }{ + { + name: "minThreshold", + od: &overuseDetector{}, + ts: time.Time{}, + estimate: 0, + expectedThreshold: 6, + }, + { + name: "increase", + od: &overuseDetector{ + delayThreshold: 12.5, + lastUpdate: time.Time{}.Add(time.Second), + }, + ts: time.Time{}.Add(2 * time.Second), + estimate: 25, + expectedThreshold: 25, + }, + { + name: "maxThreshold", + od: &overuseDetector{ + delayThreshold: 6, + lastUpdate: time.Time{}, + }, + ts: time.Time{}.Add(time.Second), + estimate: 6.1, + expectedThreshold: 6, + }, + { + name: "decrease", + od: &overuseDetector{ + delayThreshold: 12.5, + lastUpdate: time.Time{}, + }, + ts: time.Time{}.Add(time.Second), + estimate: 1, + expectedThreshold: 12.5, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.od.adaptThreshold(tc.ts, tc.estimate) + assert.Equal(t, tc.expectedThreshold, tc.od.delayThreshold) + }) + } +} diff --git a/gcc/rate_controller.go b/gcc/rate_controller.go new file mode 100644 index 0000000..2690984 --- /dev/null +++ b/gcc/rate_controller.go @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "math" + "time" +) + +type rateController struct { + s state + rate int + + decreaseFactor float64 // (beta) + lastUpdate time.Time + lastDecrease *exponentialMovingAverage +} + +func newRateController(initialRate int) *rateController { + return &rateController{ + s: stateIncrease, + rate: initialRate, + decreaseFactor: 0.85, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{}, + } +} + +func (c *rateController) update(ts time.Time, u usage, deliveredRate int, rtt time.Duration) int { + nextState := c.s.transition(u) + c.s = nextState + + if c.s == stateIncrease { + var target float64 + if c.canIncreaseMultiplicatively(float64(deliveredRate)) { + window := ts.Sub(c.lastUpdate) + target = c.multiplicativeIncrease(float64(c.rate), window) + } else { + bitsPerFrame := float64(c.rate) / 30.0 + packetsPerFrame := math.Ceil(bitsPerFrame / (1200 * 8)) + expectedPacketSizeBits := bitsPerFrame / packetsPerFrame + target = c.additiveIncrease(float64(c.rate), int(expectedPacketSizeBits), rtt) + } + c.rate = int(max(min(target, 1.5*float64(deliveredRate)), float64(c.rate))) + } + + if c.s == stateDecrease { + c.rate = int(c.decreaseFactor * float64(deliveredRate)) + c.lastDecrease.update(float64(c.rate)) + } + + c.lastUpdate = ts + + return c.rate +} + +func (c *rateController) canIncreaseMultiplicatively(deliveredRate float64) bool { + if c.lastDecrease.average == 0 { + return true + } + stdDev := math.Sqrt(c.lastDecrease.variance) + lower := c.lastDecrease.average - 3*stdDev + upper := c.lastDecrease.average + 3*stdDev + + return deliveredRate < lower || deliveredRate > upper +} + +func (c *rateController) multiplicativeIncrease(rate float64, window time.Duration) float64 { + exponent := min(window.Seconds(), 1.0) + eta := math.Pow(1.08, exponent) + target := eta * rate + + return target +} + +func (c *rateController) additiveIncrease(rate float64, expectedPacketSizeBits int, window time.Duration) float64 { + alpha := 0.5 * min(window.Seconds(), 1.0) + target := rate + max(1000, alpha*float64(expectedPacketSizeBits)) + + return target +} diff --git a/gcc/rate_controller_test.go b/gcc/rate_controller_test.go new file mode 100644 index 0000000..ebef313 --- /dev/null +++ b/gcc/rate_controller_test.go @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRateController(t *testing.T) { + cases := []struct { + name string + rc rateController + ts time.Time + u usage + delivered int + rtt time.Duration + expectedRate int + }{ + { + name: "zero", + rc: rateController{ + s: 0, + rate: 0, + decreaseFactor: 0, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{}, + }, + ts: time.Time{}, + u: 0, + delivered: 0, + rtt: 0, + expectedRate: 0, + }, + { + name: "multiplicativeIncrease", + rc: rateController{ + s: stateIncrease, + rate: 100, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{}, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 100, + rtt: 0, + expectedRate: 108, + }, + { + name: "minimumAdditiveIncrease", + rc: rateController{ + s: stateIncrease, + rate: 100_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 100_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 100_000, + rtt: 20 * time.Millisecond, + expectedRate: 101_000, + }, + { + name: "additiveIncrease", + rc: rateController{ + s: stateIncrease, + rate: 1_000_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 1_000_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 1_000_000, + rtt: 2000 * time.Millisecond, + expectedRate: 1_004166, + }, + { + name: "minimumAdditiveIncreaseAppLimited", + rc: rateController{ + s: stateIncrease, + rate: 100_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 100_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 50_000, + rtt: 20 * time.Millisecond, + expectedRate: 100_000, + }, + { + name: "additiveIncreaseAppLimited", + rc: rateController{ + s: stateIncrease, + rate: 1_000_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 1_000_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageNormal, + delivered: 100_000, + rtt: 2000 * time.Millisecond, + expectedRate: 1_000_000, + }, + { + name: "decrease", + rc: rateController{ + s: stateDecrease, + rate: 1_000_000, + decreaseFactor: 0.9, + lastUpdate: time.Time{}, + lastDecrease: &exponentialMovingAverage{ + average: 1_000_000, + }, + }, + ts: time.Time{}.Add(time.Second), + u: usageOver, + delivered: 1_000_000, + rtt: 2000 * time.Millisecond, + expectedRate: 900_000, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + res := tc.rc.update(tc.ts, tc.u, tc.delivered, tc.rtt) + assert.Equal(t, tc.expectedRate, res) + }) + } +} diff --git a/gcc/send_side_bwe.go b/gcc/send_side_bwe.go new file mode 100644 index 0000000..2cd28ff --- /dev/null +++ b/gcc/send_side_bwe.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "time" + + "github.com/pion/bwe" + "github.com/pion/logging" +) + +// Option is a functional option for a SendSideController. +type Option func(*SendSideController) error + +// Logger configures a custom logger for a SendSideController. +func Logger(l logging.LeveledLogger) Option { + return func(ssc *SendSideController) error { + ssc.log = l + ssc.drc.log = l + + return nil + } +} + +// SendSideController is a sender side congestion controller. +type SendSideController struct { + log logging.LeveledLogger + dre *deliveryRateEstimator + lrc *lossRateController + drc *delayRateController + targetRate int +} + +// NewSendSideController creates a new SendSideController with initial, min and +// max rates. +func NewSendSideController(initialRate, minRate, maxRate int, opts ...Option) (*SendSideController, error) { + ssc := &SendSideController{ + log: logging.NewDefaultLoggerFactory().NewLogger("bwe_send_side_controller"), + dre: newDeliveryRateEstimator(time.Second), + lrc: newLossRateController(initialRate, minRate, maxRate), + drc: newDelayRateController(initialRate, logging.NewDefaultLoggerFactory().NewLogger("bwe_delay_rate_controller")), + targetRate: initialRate, + } + for _, opt := range opts { + if err := opt(ssc); err != nil { + return nil, err + } + } + + return ssc, nil +} + +// OnAcks must be called when new acknowledgments arrive. arrival is the arrival +// time of the feedback, RTT is the last measured RTT and acks is a list of +// Acknowledgments contained in the latest feedback. Packets MUST not be +// acknowledged more than once. +func (c *SendSideController) OnAcks(arrival time.Time, rtt time.Duration, acks []bwe.Packet) int { + if len(acks) == 0 { + return c.targetRate + } + + for _, ack := range acks { + if ack.Arrived { + c.lrc.onPacketAcked() + if !ack.Arrival.IsZero() { + c.dre.onPacketAcked(ack.Arrival, int(ack.Size)) + c.drc.onPacketAcked(ack) + } + } else { + c.lrc.onPacketLost() + } + } + + delivered := c.dre.getRate() + lossTarget := c.lrc.update(delivered) + delayTarget := c.drc.update(arrival, delivered, rtt) + c.targetRate = min(lossTarget, delayTarget) + c.log.Tracef( + "rtt=%v, delivered=%v, lossTarget=%v, delayTarget=%v, target=%v", + rtt.Nanoseconds(), + delivered, + lossTarget, + delayTarget, + c.targetRate, + ) + + return c.targetRate +} diff --git a/gcc/state.go b/gcc/state.go new file mode 100644 index 0000000..6d0a274 --- /dev/null +++ b/gcc/state.go @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import "fmt" + +type state int + +const ( + stateDecrease state = -1 + stateHold state = 0 + stateIncrease state = 1 +) + +func (s state) transition(u usage) state { + switch s { + case stateHold: + return transitionFromHold(u) + case stateIncrease: + return transitionFromIncrease(u) + case stateDecrease: + return transitionFromDecrease(u) + } + + return stateIncrease +} + +func transitionFromHold(u usage) state { + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateIncrease + case usageUnder: + return stateHold + } + + return stateIncrease +} + +func transitionFromIncrease(u usage) state { + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateIncrease + case usageUnder: + return stateHold + } + + return stateIncrease +} + +func transitionFromDecrease(u usage) state { + switch u { + case usageOver: + return stateDecrease + case usageNormal: + return stateHold + case usageUnder: + return stateHold + } + + return stateIncrease +} + +func (s state) String() string { + switch s { + case stateIncrease: + return "increase" + case stateDecrease: + return "decrease" + case stateHold: + return "hold" + default: + return fmt.Sprintf("invalid state: %d", s) + } +} diff --git a/gcc/state_test.go b/gcc/state_test.go new file mode 100644 index 0000000..68acf56 --- /dev/null +++ b/gcc/state_test.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestState(t *testing.T) { + t.Run("hold", func(t *testing.T) { + assert.Equal(t, stateDecrease, stateHold.transition(usageOver)) + assert.Equal(t, stateIncrease, stateHold.transition(usageNormal)) + assert.Equal(t, stateHold, stateHold.transition(usageUnder)) + }) + + t.Run("increase", func(t *testing.T) { + assert.Equal(t, stateDecrease, stateIncrease.transition(usageOver)) + assert.Equal(t, stateIncrease, stateIncrease.transition(usageNormal)) + assert.Equal(t, stateHold, stateIncrease.transition(usageUnder)) + }) + + t.Run("decrease", func(t *testing.T) { + assert.Equal(t, stateDecrease, stateDecrease.transition(usageOver)) + assert.Equal(t, stateHold, stateDecrease.transition(usageNormal)) + assert.Equal(t, stateHold, stateDecrease.transition(usageUnder)) + }) +} diff --git a/gcc/usage.go b/gcc/usage.go new file mode 100644 index 0000000..d3ccfac --- /dev/null +++ b/gcc/usage.go @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import "fmt" + +type usage int + +const ( + usageUnder usage = -1 + usageNormal usage = 0 + usageOver usage = 1 +) + +func (u usage) String() string { + switch u { + case usageOver: + return "overuse" + case usageUnder: + return "underuse" + case usageNormal: + return "normal" + default: + return fmt.Sprintf("invalid usage: %d", u) + } +} diff --git a/packet.go b/packet.go new file mode 100644 index 0000000..4f14f7a --- /dev/null +++ b/packet.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package bwe + +import ( + "fmt" + "time" +) + +// An Packet stores send and receive information about a packet. +type Packet struct { + // StreamID is the ID of the stream to which the packet belongs. The + // StreamID MUST be unique among all streams controlled by the congestion + // controller. + StreamID uint64 + + // SequenceNumber is the sequence number of the packet within its stream. + // SequenceNumbers of consecutive packets might have gaps. + SequenceNumber uint64 + + // TransportWideSequenceNumber is a transport wide sequence number of the + // packet. It MUST be unique over all streams and it MUST increase by 1 for + // every outgoing packet. + TransportWideSequenceNumber uint64 + + // Size is the size of the packet in bytes. + Size int + + // Arrived indicates if the packet arrived at the receiver. False does not + // necessarily mean the packet was lost, it might still be in transit. + Arrived bool + + // Departure is the departure time of the packet taken at the sender. It + // should be the time measured at the latest possible moment before sending + // the packet. + Departure time.Time + + // Arrival is the arrival time of the packet at the receiver. Arrival and + // Departure do not require synchronized clocks and can therefore not + // directly be compared. + Arrival time.Time + + // ECN marking of the packet when it arrived. + ECN ECN +} + +func (a Packet) String() string { + return fmt.Sprintf("seq=%v, departure=%v, arrival=%v", a.SequenceNumber, a.Departure, a.Arrival) +} From 02484fcd1dd7939a12cd0f3a18edc8b0946d0bfd Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Fri, 10 Oct 2025 18:46:12 +0200 Subject: [PATCH 2/6] Use GCC in tests --- ecn.go | 25 ---- gcc/arrival_group_accumulator.go | 53 +++++++-- gcc/arrival_group_accumulator_test.go | 72 ++++++------ gcc/delay_rate_controller.go | 22 ++-- gcc/gcc.go | 3 + gcc/send_side_bwe.go | 43 +++---- go.mod | 2 +- go.sum | 4 +- .../log_format_test.go => log_format_test.go | 2 +- packet.go | 50 -------- simulation/peer_test.go => peer_test.go | 70 ++++++++--- ...ect_codec_test.go => perfect_codec_test.go | 20 ++-- simulation/simulation.go | 6 - virtual_network_test.go | 110 ++++++++++++++++++ simulation/vnet_test.go => vnet_test.go | 86 ++------------ 15 files changed, 307 insertions(+), 261 deletions(-) delete mode 100644 ecn.go rename simulation/log_format_test.go => log_format_test.go (98%) delete mode 100644 packet.go rename simulation/peer_test.go => peer_test.go (83%) rename simulation/perfect_codec_test.go => perfect_codec_test.go (92%) delete mode 100644 simulation/simulation.go create mode 100644 virtual_network_test.go rename simulation/vnet_test.go => vnet_test.go (55%) diff --git a/ecn.go b/ecn.go deleted file mode 100644 index 8aaf33f..0000000 --- a/ecn.go +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-FileCopyrightText: 2025 The Pion community -// SPDX-License-Identifier: MIT - -package bwe - -// ECN represents the ECN bits of an IP packet header. -type ECN uint8 - -const ( - // ECNNonECT signals Non ECN-Capable Transport, Non-ECT. - // nolint:misspell - ECNNonECT ECN = iota // 00 - - // ECNECT1 signals ECN Capable Transport, ECT(0). - // nolint:misspell - ECNECT1 // 01 - - // ECNECT0 signals ECN Capable Transport, ECT(1). - // nolint:misspell - ECNECT0 // 10 - - // ECNCE signals ECN Congestion Encountered, CE. - // nolint:misspell - ECNCE // 11 -) diff --git a/gcc/arrival_group_accumulator.go b/gcc/arrival_group_accumulator.go index 70a548f..647e277 100644 --- a/gcc/arrival_group_accumulator.go +++ b/gcc/arrival_group_accumulator.go @@ -5,11 +5,16 @@ package gcc import ( "time" - - "github.com/pion/bwe" ) -type arrivalGroup []bwe.Packet +type arrivalGroupItem struct { + SequenceNumber uint64 + Departure time.Time + Arrival time.Time + Size int +} + +type arrivalGroup []arrivalGroupItem type arrivalGroupAccumulator struct { next arrivalGroup @@ -19,39 +24,63 @@ type arrivalGroupAccumulator struct { func newArrivalGroupAccumulator() *arrivalGroupAccumulator { return &arrivalGroupAccumulator{ - next: make([]bwe.Packet, 0), + next: make([]arrivalGroupItem, 0), burstInterval: 5 * time.Millisecond, maxBurstDuration: 100 * time.Millisecond, } } -func (a *arrivalGroupAccumulator) onPacketAcked(ack bwe.Packet) arrivalGroup { +func (a *arrivalGroupAccumulator) onPacketAcked( + sequenceNumber uint64, + size int, + departure, arrival time.Time, +) arrivalGroup { if len(a.next) == 0 { - a.next = append(a.next, ack) + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) return nil } - sendTimeDelta := ack.Departure.Sub(a.next[0].Departure) + sendTimeDelta := departure.Sub(a.next[0].Departure) if sendTimeDelta < a.burstInterval { - a.next = append(a.next, ack) + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) return nil } - arrivalTimeDeltaLast := ack.Arrival.Sub(a.next[len(a.next)-1].Arrival) - arrivalTimeDeltaFirst := ack.Arrival.Sub(a.next[0].Arrival) + arrivalTimeDeltaLast := arrival.Sub(a.next[len(a.next)-1].Arrival) + arrivalTimeDeltaFirst := arrival.Sub(a.next[0].Arrival) propagationDelta := arrivalTimeDeltaFirst - sendTimeDelta if propagationDelta < 0 && arrivalTimeDeltaLast <= a.burstInterval && arrivalTimeDeltaFirst < a.maxBurstDuration { - a.next = append(a.next, ack) + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) return nil } group := make(arrivalGroup, len(a.next)) copy(group, a.next) - a.next = arrivalGroup{ack} + a.next = arrivalGroup{arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }} return group } diff --git a/gcc/arrival_group_accumulator_test.go b/gcc/arrival_group_accumulator_test.go index 656827a..ff75177 100644 --- a/gcc/arrival_group_accumulator_test.go +++ b/gcc/arrival_group_accumulator_test.go @@ -11,23 +11,28 @@ import ( ) func TestArrivalGroupAccumulator(t *testing.T) { - triggerNewGroupElement := Acknowledgment{ + type logItem struct { + SequenceNumber uint64 + Departure time.Time + Arrival time.Time + } + triggerNewGroupElement := logItem{ Departure: time.Time{}.Add(time.Second), Arrival: time.Time{}.Add(time.Second), } cases := []struct { name string - log []Acknowledgment + log []logItem exp []arrivalGroup }{ { name: "emptyCreatesNoGroups", - log: []Acknowledgment{}, + log: []logItem{}, exp: []arrivalGroup{}, }, { name: "createsSingleElementGroup", - log: []Acknowledgment{ + log: []logItem{ { Departure: time.Time{}, Arrival: time.Time{}.Add(time.Millisecond), @@ -45,7 +50,7 @@ func TestArrivalGroupAccumulator(t *testing.T) { }, { name: "createsTwoElementGroup", - log: []Acknowledgment{ + log: []logItem{ { Departure: time.Time{}, Arrival: time.Time{}.Add(15 * time.Millisecond), @@ -69,7 +74,7 @@ func TestArrivalGroupAccumulator(t *testing.T) { }, { name: "createsTwoArrivalGroups1", - log: []Acknowledgment{ + log: []logItem{ { Departure: time.Time{}, Arrival: time.Time{}.Add(15 * time.Millisecond), @@ -105,7 +110,7 @@ func TestArrivalGroupAccumulator(t *testing.T) { }, { name: "ignoresOutOfOrderPackets", - log: []Acknowledgment{ + log: []logItem{ { Departure: time.Time{}, Arrival: time.Time{}.Add(15 * time.Millisecond), @@ -141,52 +146,52 @@ func TestArrivalGroupAccumulator(t *testing.T) { }, { name: "newGroupBecauseOfInterDepartureTime", - log: []Acknowledgment{ + log: []logItem{ { - SeqNr: 0, - Departure: time.Time{}, - Arrival: time.Time{}.Add(4 * time.Millisecond), + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), }, { - SeqNr: 1, - Departure: time.Time{}.Add(3 * time.Millisecond), - Arrival: time.Time{}.Add(4 * time.Millisecond), + SequenceNumber: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), }, { - SeqNr: 2, - Departure: time.Time{}.Add(6 * time.Millisecond), - Arrival: time.Time{}.Add(10 * time.Millisecond), + SequenceNumber: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), }, { - SeqNr: 3, - Departure: time.Time{}.Add(9 * time.Millisecond), - Arrival: time.Time{}.Add(10 * time.Millisecond), + SequenceNumber: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), }, triggerNewGroupElement, }, exp: []arrivalGroup{ { { - SeqNr: 0, - Departure: time.Time{}, - Arrival: time.Time{}.Add(4 * time.Millisecond), + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), }, { - SeqNr: 1, - Departure: time.Time{}.Add(3 * time.Millisecond), - Arrival: time.Time{}.Add(4 * time.Millisecond), + SequenceNumber: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), }, }, { { - SeqNr: 2, - Departure: time.Time{}.Add(6 * time.Millisecond), - Arrival: time.Time{}.Add(10 * time.Millisecond), + SequenceNumber: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), }, { - SeqNr: 3, - Departure: time.Time{}.Add(9 * time.Millisecond), - Arrival: time.Time{}.Add(10 * time.Millisecond), + SequenceNumber: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), }, }, }, @@ -194,12 +199,11 @@ func TestArrivalGroupAccumulator(t *testing.T) { } for _, tc := range cases { - tc := tc t.Run(tc.name, func(t *testing.T) { aga := newArrivalGroupAccumulator() received := []arrivalGroup{} for _, ack := range tc.log { - next := aga.onPacketAcked(ack) + next := aga.onPacketAcked(ack.SequenceNumber, 0, ack.Departure, ack.Arrival) if next != nil { received = append(received, next) } diff --git a/gcc/delay_rate_controller.go b/gcc/delay_rate_controller.go index c5beda3..73dd2b3 100644 --- a/gcc/delay_rate_controller.go +++ b/gcc/delay_rate_controller.go @@ -6,7 +6,6 @@ package gcc import ( "time" - "github.com/pion/bwe" "github.com/pion/logging" ) @@ -21,11 +20,11 @@ type delayRateController struct { samples int } -func newDelayRateController(initialRate int, logger logging.LeveledLogger) *delayRateController { +func newDelayRateController(initialRate int) *delayRateController { return &delayRateController{ - log: logger, + log: logging.NewDefaultLoggerFactory().NewLogger("bwe_delay_rate_controller"), aga: newArrivalGroupAccumulator(), - last: []bwe.Packet{}, + last: []arrivalGroupItem{}, kf: newKalmanFilter(), od: newOveruseDetector(true), rc: newRateController(initialRate), @@ -34,8 +33,13 @@ func newDelayRateController(initialRate int, logger logging.LeveledLogger) *dela } } -func (c *delayRateController) onPacketAcked(ack bwe.Packet) { - next := c.aga.onPacketAcked(ack) +func (c *delayRateController) onPacketAcked(sequenceNumber uint64, size int, departure, arrival time.Time) { + next := c.aga.onPacketAcked( + sequenceNumber, + size, + departure, + arrival, + ) if next == nil { return } @@ -58,7 +62,7 @@ func (c *delayRateController) onPacketAcked(ack bwe.Packet) { interGroupDelay := interArrivalTime - interDepartureTime estimate := c.kf.update(float64(interGroupDelay.Milliseconds()), float64(sizeDelta)) c.samples++ - c.latestUsage = c.od.update(ack.Arrival, estimate, c.samples) + c.latestUsage = c.od.update(arrival, estimate, c.samples) c.last = next c.log.Tracef( "ts=%v.%06d, seq=%v, size=%v, interArrivalTime=%v, interDepartureTime=%v, interGroupDelay=%v, estimate=%v, threshold=%v, usage=%v, state=%v", // nolint @@ -82,8 +86,8 @@ func (c *delayRateController) update(ts time.Time, lastDeliveryRate int, rtt tim func groupSize(group arrivalGroup) int { sum := 0 - for _, ack := range group { - sum += int(ack.Size) + for _, item := range group { + sum += item.Size } return sum diff --git a/gcc/gcc.go b/gcc/gcc.go index 871b875..c10a0fd 100644 --- a/gcc/gcc.go +++ b/gcc/gcc.go @@ -1,3 +1,6 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + // Package gcc implements a congestion controller based on // https://datatracker.ietf.org/doc/html/draft-ietf-rmcat-gcc-02. package gcc diff --git a/gcc/send_side_bwe.go b/gcc/send_side_bwe.go index 2cd28ff..ecef101 100644 --- a/gcc/send_side_bwe.go +++ b/gcc/send_side_bwe.go @@ -6,7 +6,6 @@ package gcc import ( "time" - "github.com/pion/bwe" "github.com/pion/logging" ) @@ -39,7 +38,7 @@ func NewSendSideController(initialRate, minRate, maxRate int, opts ...Option) (* log: logging.NewDefaultLoggerFactory().NewLogger("bwe_send_side_controller"), dre: newDeliveryRateEstimator(time.Second), lrc: newLossRateController(initialRate, minRate, maxRate), - drc: newDelayRateController(initialRate, logging.NewDefaultLoggerFactory().NewLogger("bwe_delay_rate_controller")), + drc: newDelayRateController(initialRate), targetRate: initialRate, } for _, opt := range opts { @@ -51,30 +50,32 @@ func NewSendSideController(initialRate, minRate, maxRate int, opts ...Option) (* return ssc, nil } -// OnAcks must be called when new acknowledgments arrive. arrival is the arrival -// time of the feedback, RTT is the last measured RTT and acks is a list of -// Acknowledgments contained in the latest feedback. Packets MUST not be -// acknowledged more than once. -func (c *SendSideController) OnAcks(arrival time.Time, rtt time.Duration, acks []bwe.Packet) int { - if len(acks) == 0 { - return c.targetRate - } +func (c *SendSideController) OnLoss() { + c.lrc.onPacketLost() +} - for _, ack := range acks { - if ack.Arrived { - c.lrc.onPacketAcked() - if !ack.Arrival.IsZero() { - c.dre.onPacketAcked(ack.Arrival, int(ack.Size)) - c.drc.onPacketAcked(ack) - } - } else { - c.lrc.onPacketLost() - } +// OnAck must be called when new acknowledgments arrive. Packets MUST not be +// acknowledged more than once. +func (c *SendSideController) OnAck(sequenceNumber uint64, size int, departure, arrival time.Time) { + c.lrc.onPacketAcked() + if !arrival.IsZero() { + c.dre.onPacketAcked(arrival, size) + c.drc.onPacketAcked( + sequenceNumber, + size, + departure, + arrival, + ) } +} +// OnFeedback must be called when a new feedback report arrives. ts is the +// arrival timestamp of the feedback report. rtt is the latest RTT sample. It +// returns the new target rate. +func (c *SendSideController) OnFeedback(ts time.Time, rtt time.Duration) int { delivered := c.dre.getRate() lossTarget := c.lrc.update(delivered) - delayTarget := c.drc.update(arrival, delivered, rtt) + delayTarget := c.drc.update(ts, delivered, rtt) c.targetRate = min(lossTarget, delayTarget) c.log.Tracef( "rtt=%v, delivered=%v, lossTarget=%v, delayTarget=%v, target=%v", diff --git a/go.mod b/go.mod index 127c707..eee60b6 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/pion/bwe go 1.24 require ( - github.com/pion/interceptor v0.1.41-0.20250918133005-ab70b00249ad + github.com/pion/interceptor v0.1.42-0.20251016092317-ce5124bd6cdf github.com/pion/logging v0.2.4 github.com/pion/rtcp v1.2.16 github.com/pion/rtp v1.8.23 diff --git a/go.sum b/go.sum index 873f4c0..dacb3f1 100644 --- a/go.sum +++ b/go.sum @@ -12,8 +12,8 @@ github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= github.com/pion/ice/v4 v4.0.10 h1:P59w1iauC/wPk9PdY8Vjl4fOFL5B+USq1+xbDcN6gT4= github.com/pion/ice/v4 v4.0.10/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= -github.com/pion/interceptor v0.1.41-0.20250918133005-ab70b00249ad h1:9Md9jf21oboaul3cm0ss/hn6KG0xsJ7CzPJjdDnpJqk= -github.com/pion/interceptor v0.1.41-0.20250918133005-ab70b00249ad/go.mod h1:nEt4187unvRXJFyjiw00GKo+kIuXMWQI9K89fsosDLY= +github.com/pion/interceptor v0.1.42-0.20251016092317-ce5124bd6cdf h1:/I4PY/suu+TndWEIdHBke29M58QIxUR3neZndVxVUzA= +github.com/pion/interceptor v0.1.42-0.20251016092317-ce5124bd6cdf/go.mod h1:sn0zW8AhYpX24aDub5BumnPp0Zr2nHk6gS2fxXSfFD0= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= diff --git a/simulation/log_format_test.go b/log_format_test.go similarity index 98% rename from simulation/log_format_test.go rename to log_format_test.go index 8ddb7af..7b8f899 100644 --- a/simulation/log_format_test.go +++ b/log_format_test.go @@ -3,7 +3,7 @@ //go:build !js -package simulation +package bwe_test import ( "fmt" diff --git a/packet.go b/packet.go deleted file mode 100644 index 4f14f7a..0000000 --- a/packet.go +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-FileCopyrightText: 2025 The Pion community -// SPDX-License-Identifier: MIT - -package bwe - -import ( - "fmt" - "time" -) - -// An Packet stores send and receive information about a packet. -type Packet struct { - // StreamID is the ID of the stream to which the packet belongs. The - // StreamID MUST be unique among all streams controlled by the congestion - // controller. - StreamID uint64 - - // SequenceNumber is the sequence number of the packet within its stream. - // SequenceNumbers of consecutive packets might have gaps. - SequenceNumber uint64 - - // TransportWideSequenceNumber is a transport wide sequence number of the - // packet. It MUST be unique over all streams and it MUST increase by 1 for - // every outgoing packet. - TransportWideSequenceNumber uint64 - - // Size is the size of the packet in bytes. - Size int - - // Arrived indicates if the packet arrived at the receiver. False does not - // necessarily mean the packet was lost, it might still be in transit. - Arrived bool - - // Departure is the departure time of the packet taken at the sender. It - // should be the time measured at the latest possible moment before sending - // the packet. - Departure time.Time - - // Arrival is the arrival time of the packet at the receiver. Arrival and - // Departure do not require synchronized clocks and can therefore not - // directly be compared. - Arrival time.Time - - // ECN marking of the packet when it arrived. - ECN ECN -} - -func (a Packet) String() string { - return fmt.Sprintf("seq=%v, departure=%v, arrival=%v", a.SequenceNumber, a.Departure, a.Arrival) -} diff --git a/simulation/peer_test.go b/peer_test.go similarity index 83% rename from simulation/peer_test.go rename to peer_test.go index b601eee..6e74935 100644 --- a/simulation/peer_test.go +++ b/peer_test.go @@ -3,13 +3,15 @@ //go:build !js -package simulation +package bwe_test import ( + "github.com/pion/bwe/gcc" "github.com/pion/interceptor" "github.com/pion/interceptor/pkg/packetdump" "github.com/pion/interceptor/pkg/rfc8888" "github.com/pion/interceptor/pkg/rtpfb" + "github.com/pion/interceptor/pkg/twcc" "github.com/pion/logging" "github.com/pion/transport/v3/vnet" "github.com/pion/webrtc/v4" @@ -79,17 +81,18 @@ func registerRTPFB() option { } } -// func registerTWCC() option { -// return func(p *peer) error { -// twcc, err := twcc.NewSenderInterceptor() -// if err != nil { -// return err -// } -// p.interceptorRegistry.Add(twcc) -// -// return nil -// } -// } +func registerTWCC() option { + return func(p *peer) error { + twcc, err := twcc.NewSenderInterceptor() + if err != nil { + return err + } + p.interceptorRegistry.Add(twcc) + + return nil + } +} + // // func registerTWCCHeaderExtension() option { // return func(p *peer) error { @@ -115,6 +118,18 @@ func registerCCFB() option { } } +func initGCC(onRateUpdate func(int)) option { + return func(p *peer) (err error) { + p.estimator, err = gcc.NewSendSideController(1_000_000, 128_000, 5_000_000) + if err != nil { + return err + } + p.onRateUpdate = onRateUpdate + + return nil + } +} + type peer struct { logger logging.LeveledLogger pc *webrtc.PeerConnection @@ -125,6 +140,9 @@ type peer struct { onRemoteTrack func(*webrtc.TrackRemote) onConnected func() + + estimator *gcc.SendSideController + onRateUpdate func(int) } func newPeer(opts ...option) (*peer, error) { @@ -274,9 +292,35 @@ func (p *peer) addRemoteTrack() error { func (p *peer) readRTCP(r *webrtc.RTPSender) { for { - _, _, err := r.ReadRTCP() + _, attr, err := r.ReadRTCP() if err != nil { return } + report, ok := attr.Get(rtpfb.CCFBAttributesKey).(rtpfb.Report) + if ok { + p.updateTargetRate(report) + } + } +} + +func (p *peer) updateTargetRate(report rtpfb.Report) { + if p.estimator != nil { + for _, pr := range report.PacketReports { + if pr.Arrived { + p.estimator.OnAck( + pr.SequenceNumber, + pr.Size, + pr.Departure, + pr.Arrival, + ) + } else { + p.estimator.OnLoss() + } + } + rate := p.estimator.OnFeedback(report.Arrival, report.RTT) + p.logger.Infof("new target rate: %v", rate) + if p.onRateUpdate != nil { + p.onRateUpdate(rate) + } } } diff --git a/simulation/perfect_codec_test.go b/perfect_codec_test.go similarity index 92% rename from simulation/perfect_codec_test.go rename to perfect_codec_test.go index 7c93fa5..a44b420 100644 --- a/simulation/perfect_codec_test.go +++ b/perfect_codec_test.go @@ -3,7 +3,7 @@ //go:build !js -package simulation +package bwe_test import ( "crypto/rand" @@ -46,14 +46,16 @@ func newPerfectCodec(writer sampleWriter, targetBitrateBps int) *perfectCodec { } // setTargetBitrate sets the target bitrate to r bits per second. -// func (c *perfectCodec) setTargetBitrate(r int) { -// c.wg.Go(func() { -// select { -// case c.bitrateUpdateCh <- r: -// case <-c.done: -// } -// }) -// } +func (c *perfectCodec) setTargetBitrate(r int) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + select { + case c.bitrateUpdateCh <- r: + case <-c.done: + } + }() +} // start begins the codec operation, generating frames at the configured frame rate. func (c *perfectCodec) start() { diff --git a/simulation/simulation.go b/simulation/simulation.go deleted file mode 100644 index 7f17136..0000000 --- a/simulation/simulation.go +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-FileCopyrightText: 2025 The Pion community -// SPDX-License-Identifier: MIT - -// Package simulation implements bandwidth estimation tests using the synctest -// package. -package simulation diff --git a/virtual_network_test.go b/virtual_network_test.go new file mode 100644 index 0000000..672cce4 --- /dev/null +++ b/virtual_network_test.go @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js && go1.25 + +package bwe_test + +import ( + "errors" + "testing" + + "github.com/pion/logging" + "github.com/pion/transport/v3/vnet" + "github.com/stretchr/testify/assert" +) + +type virtualNetwork struct { + wan *vnet.Router + left *vnet.Net + leftTBF *vnet.TokenBucketFilter + right *vnet.Net + rightTBF *vnet.TokenBucketFilter +} + +func (n *virtualNetwork) Close() error { + return errors.Join( + n.leftTBF.Close(), + n.rightTBF.Close(), + n.wan.Stop(), + ) +} + +func createVirtualNetwork(t *testing.T) *virtualNetwork { + t.Helper() + + wan, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + + leftRouter, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.1.0/24", + StaticIPs: []string{ + "10.0.1.1/10.0.1.101", + }, + LoggerFactory: logging.NewDefaultLoggerFactory(), + NATType: &vnet.NATType{ + Mode: vnet.NATModeNAT1To1, + }, + }) + assert.NoError(t, err) + + leftTBF, err := vnet.NewTokenBucketFilter(leftRouter, vnet.TBFRate(1_000_000), vnet.TBFMaxBurst(80_000)) + assert.NoError(t, err) + + err = wan.AddNet(leftTBF) + assert.NoError(t, err) + + err = wan.AddChildRouter(leftRouter) + assert.NoError(t, err) + + rightRouter, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.2.0/24", + StaticIPs: []string{ + "10.0.2.1/10.0.2.101", + }, + LoggerFactory: logging.NewDefaultLoggerFactory(), + NATType: &vnet.NATType{ + Mode: vnet.NATModeNAT1To1, + }, + }) + assert.NoError(t, err) + + rightTBF, err := vnet.NewTokenBucketFilter(rightRouter, vnet.TBFRate(1_000_000), vnet.TBFMaxBurst(80_000)) + assert.NoError(t, err) + + err = wan.AddNet(rightTBF) + assert.NoError(t, err) + + err = wan.AddChildRouter(rightRouter) + assert.NoError(t, err) + + err = wan.Start() + assert.NoError(t, err) + + leftNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"10.0.1.101"}, + StaticIP: "", + }) + assert.NoError(t, err) + err = leftRouter.AddNet(leftNet) + assert.NoError(t, err) + + rightNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"10.0.2.101"}, + StaticIP: "", + }) + assert.NoError(t, err) + err = rightRouter.AddNet(rightNet) + assert.NoError(t, err) + + return &virtualNetwork{ + wan: wan, + left: leftNet, + leftTBF: leftTBF, + right: rightNet, + rightTBF: rightTBF, + } +} diff --git a/simulation/vnet_test.go b/vnet_test.go similarity index 55% rename from simulation/vnet_test.go rename to vnet_test.go index ec8c342..cfb334c 100644 --- a/simulation/vnet_test.go +++ b/vnet_test.go @@ -3,7 +3,7 @@ //go:build !js && go1.25 -package simulation +package bwe_test import ( "errors" @@ -12,85 +12,10 @@ import ( "testing/synctest" "time" - "github.com/pion/logging" - "github.com/pion/transport/v3/vnet" "github.com/pion/webrtc/v4" "github.com/stretchr/testify/assert" ) -type network struct { - wan *vnet.Router - left *vnet.Net - right *vnet.Net -} - -func (n *network) Close() error { - return n.wan.Stop() -} - -func createVirtualNetwork(t *testing.T) *network { - t.Helper() - - wan, err := vnet.NewRouter(&vnet.RouterConfig{ - CIDR: "0.0.0.0/0", - LoggerFactory: logging.NewDefaultLoggerFactory(), - }) - assert.NoError(t, err) - - leftRouter, err := vnet.NewRouter(&vnet.RouterConfig{ - CIDR: "10.0.1.0/24", - StaticIPs: []string{ - "10.0.1.1/10.0.1.101", - }, - LoggerFactory: logging.NewDefaultLoggerFactory(), - NATType: &vnet.NATType{ - Mode: vnet.NATModeNAT1To1, - }, - }) - assert.NoError(t, err) - err = wan.AddRouter(leftRouter) - assert.NoError(t, err) - - rightRouter, err := vnet.NewRouter(&vnet.RouterConfig{ - CIDR: "10.0.2.0/24", - StaticIPs: []string{ - "10.0.2.1/10.0.2.101", - }, - LoggerFactory: logging.NewDefaultLoggerFactory(), - NATType: &vnet.NATType{ - Mode: vnet.NATModeNAT1To1, - }, - }) - assert.NoError(t, err) - err = wan.AddRouter(rightRouter) - assert.NoError(t, err) - - err = wan.Start() - assert.NoError(t, err) - - leftNet, err := vnet.NewNet(&vnet.NetConfig{ - StaticIPs: []string{"10.0.1.101"}, - StaticIP: "", - }) - assert.NoError(t, err) - err = leftRouter.AddNet(leftNet) - assert.NoError(t, err) - - rightNet, err := vnet.NewNet(&vnet.NetConfig{ - StaticIPs: []string{"10.0.2.101"}, - StaticIP: "", - }) - assert.NoError(t, err) - err = rightRouter.AddNet(rightNet) - assert.NoError(t, err) - - return &network{ - wan: wan, - left: leftNet, - right: rightNet, - } -} - func TestVnet(t *testing.T) { synctest.Test(t, func(t *testing.T) { t.Helper() @@ -103,6 +28,7 @@ func TestVnet(t *testing.T) { receiver, err := newPeer( registerDefaultCodecs(), setVNet(network.left, []string{"10.0.1.1"}), + registerTWCC(), onRemoteTrack(func(track *webrtc.TrackRemote) { close(onTrack) go func() { @@ -129,19 +55,23 @@ func TestVnet(t *testing.T) { err = receiver.addRemoteTrack() assert.NoError(t, err) + var codec *perfectCodec sender, err := newPeer( registerDefaultCodecs(), onConnected(func() { close(connected) }), setVNet(network.right, []string{"10.0.2.1"}), registerPacketLogger("sender"), registerRTPFB(), + initGCC(func(rate int) { + codec.setTargetBitrate(rate) + }), ) assert.NoError(t, err) track, err := sender.addLocalTrack() assert.NoError(t, err) - codec := newPerfectCodec(track, 1_000_000) + codec = newPerfectCodec(track, 1_000_000) go func() { <-connected codec.start() @@ -165,7 +95,7 @@ func TestVnet(t *testing.T) { case <-time.After(time.Second): assert.Fail(t, "on track not called") } - time.Sleep(10 * time.Second) + time.Sleep(100 * time.Second) close(done) err = codec.Close() From d795f8e9b702505e1be7e2df633ffb64fb426a38 Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Fri, 17 Oct 2025 21:19:25 +0200 Subject: [PATCH 3/6] Add better logging for synctest tests --- log_format_test.go | 13 ++++++++----- peer_test.go | 8 +++++--- vnet_test.go | 39 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/log_format_test.go b/log_format_test.go index 7b8f899..c85fa41 100644 --- a/log_format_test.go +++ b/log_format_test.go @@ -16,15 +16,14 @@ import ( ) type packetLogger struct { - vantagePoint string - direction string + logger *slog.Logger + direction string } func (l *packetLogger) LogRTPPacket(header *rtp.Header, payload []byte, attributes interceptor.Attributes) { ts := time.Now() - slog.Info( + l.logger.Info( "rtp", - "vantage-point", l.vantagePoint, "direction", l.direction, "ts", ts, "pt", header.PayloadType, @@ -38,6 +37,10 @@ func (l *packetLogger) LogRTPPacket(header *rtp.Header, payload []byte, attribut func (l *packetLogger) LogRTCPPackets(pkts []rtcp.Packet, attributes interceptor.Attributes) { for _, pkt := range pkts { - slog.Info("rtcp", "vantage-point", l.vantagePoint, "direction", l.direction, "type", fmt.Sprintf("%T", pkt)) + l.logger.Info( + "rtcp", + "direction", l.direction, + "type", fmt.Sprintf("%T", pkt), + ) } } diff --git a/peer_test.go b/peer_test.go index 6e74935..d22dd43 100644 --- a/peer_test.go +++ b/peer_test.go @@ -6,6 +6,8 @@ package bwe_test import ( + "log/slog" + "github.com/pion/bwe/gcc" "github.com/pion/interceptor" "github.com/pion/interceptor/pkg/packetdump" @@ -50,14 +52,14 @@ func registerDefaultCodecs() option { } } -func registerPacketLogger(vantagePoint string) option { +func registerPacketLogger(logger *slog.Logger) option { return func(p *peer) error { - ipl := &packetLogger{vantagePoint: vantagePoint, direction: "in"} + ipl := &packetLogger{logger: logger, direction: "in"} rd, err := packetdump.NewReceiverInterceptor(packetdump.PacketLog(ipl)) if err != nil { return err } - opl := &packetLogger{vantagePoint: vantagePoint, direction: "out"} + opl := &packetLogger{logger: logger, direction: "out"} sd, err := packetdump.NewSenderInterceptor(packetdump.PacketLog(opl)) if err != nil { return err diff --git a/vnet_test.go b/vnet_test.go index cfb334c..785436b 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -7,7 +7,11 @@ package bwe_test import ( "errors" + "fmt" "io" + "log/slog" + "os" + "path/filepath" "testing" "testing/synctest" "time" @@ -16,10 +20,41 @@ import ( "github.com/stretchr/testify/assert" ) +func testLogger(t *testing.T) (*slog.Logger, func()) { + t.Helper() + + logDir := os.Getenv("BWE_LOG_DIR") + if logDir == "" { + logDir = "logs" + } + if err := os.MkdirAll(logDir, 0o755); err != nil { + t.Fatalf("failed to create log dir %q: %v", logDir, err) + } + + filename := filepath.Join(logDir, fmt.Sprintf("%s.jsonl", t.Name())) + file, err := os.Create(filename) + if err != nil { + t.Fatalf("failed to create log file %q: %v", filename, err) + } + + handler := slog.NewJSONHandler(file, &slog.HandlerOptions{Level: slog.LevelInfo}) + logger := slog.New(handler) + + cleanup := func() { + file.Sync() + file.Close() + } + + return logger, cleanup +} + func TestVnet(t *testing.T) { synctest.Test(t, func(t *testing.T) { t.Helper() + logger, cleanup := testLogger(t) + defer cleanup() + onTrack := make(chan struct{}) connected := make(chan struct{}) done := make(chan struct{}) @@ -47,7 +82,7 @@ func TestVnet(t *testing.T) { } }() }), - registerPacketLogger("receiver"), + registerPacketLogger(logger.With("vantage-point", "receiver")), registerCCFB(), ) assert.NoError(t, err) @@ -60,7 +95,7 @@ func TestVnet(t *testing.T) { registerDefaultCodecs(), onConnected(func() { close(connected) }), setVNet(network.right, []string{"10.0.2.1"}), - registerPacketLogger("sender"), + registerPacketLogger(logger.With("vantage-point", "sender")), registerRTPFB(), initGCC(func(rate int) { codec.setTargetBitrate(rate) From 427e698487f6399b735925cff461171ab20f712e Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Fri, 17 Oct 2025 21:20:01 +0200 Subject: [PATCH 4/6] Ignore test logs --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6e2f206..ad25391 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ bin/ vendor/ node_modules/ +logs/ ### Files ### ############# From bcc77ba2837aed72373df669437c6a108d5dfc6b Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Fri, 17 Oct 2025 23:05:45 +0200 Subject: [PATCH 5/6] Log unwrapped sequence numbers --- log_format_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++--- peer_test.go | 4 ++-- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/log_format_test.go b/log_format_test.go index c85fa41..7e51dcf 100644 --- a/log_format_test.go +++ b/log_format_test.go @@ -8,27 +8,40 @@ package bwe_test import ( "fmt" "log/slog" - "time" "github.com/pion/interceptor" "github.com/pion/rtcp" "github.com/pion/rtp" ) +const ( + maxSequenceNumberPlusOne = int64(65536) + breakpoint = 32768 // half of max uint16 +) + type packetLogger struct { logger *slog.Logger direction string + seq *unwrapper +} + +func newPacketLogger(logger *slog.Logger, direction string) *packetLogger { + return &packetLogger{ + logger: logger, + direction: direction, + seq: &unwrapper{}, + } } func (l *packetLogger) LogRTPPacket(header *rtp.Header, payload []byte, attributes interceptor.Attributes) { - ts := time.Now() + u := l.seq.Unwrap(header.SequenceNumber) l.logger.Info( "rtp", "direction", l.direction, - "ts", ts, "pt", header.PayloadType, "ssrc", header.SSRC, "sequence-number", header.SequenceNumber, + "unwrapped-sequence-number", u, "rtp-timestamp", header.Timestamp, "marker", header.Marker, "payload-size", len(payload), @@ -44,3 +57,41 @@ func (l *packetLogger) LogRTCPPackets(pkts []rtcp.Packet, attributes interceptor ) } } + +// Unwrapper stores an unwrapped sequence number. +type unwrapper struct { + init bool + lastUnwrapped int64 +} + +func isNewer(value, previous uint16) bool { + if value-previous == breakpoint { + return value > previous + } + + return value != previous && (value-previous) < breakpoint +} + +// Unwrap unwraps the next sequencenumber. +func (u *unwrapper) Unwrap(i uint16) int64 { + if !u.init { + u.init = true + u.lastUnwrapped = int64(i) + + return u.lastUnwrapped + } + + lastWrapped := uint16(u.lastUnwrapped) //nolint:gosec // G115 + delta := int64(i - lastWrapped) + if isNewer(i, lastWrapped) { + if delta < 0 { + delta += maxSequenceNumberPlusOne + } + } else if delta > 0 && u.lastUnwrapped+delta-maxSequenceNumberPlusOne >= 0 { + delta -= maxSequenceNumberPlusOne + } + + u.lastUnwrapped += delta + + return u.lastUnwrapped +} diff --git a/peer_test.go b/peer_test.go index d22dd43..359d98c 100644 --- a/peer_test.go +++ b/peer_test.go @@ -54,12 +54,12 @@ func registerDefaultCodecs() option { func registerPacketLogger(logger *slog.Logger) option { return func(p *peer) error { - ipl := &packetLogger{logger: logger, direction: "in"} + ipl := newPacketLogger(logger, "in") rd, err := packetdump.NewReceiverInterceptor(packetdump.PacketLog(ipl)) if err != nil { return err } - opl := &packetLogger{logger: logger, direction: "out"} + opl := newPacketLogger(logger, "out") sd, err := packetdump.NewSenderInterceptor(packetdump.PacketLog(opl)) if err != nil { return err From 7307c326a3a518bee9e8658364f359a33d7dd46a Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Fri, 17 Oct 2025 23:05:56 +0200 Subject: [PATCH 6/6] Rename test file --- vnet_test.go => bwe_test.go | 65 ++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 15 deletions(-) rename vnet_test.go => bwe_test.go (76%) diff --git a/vnet_test.go b/bwe_test.go similarity index 76% rename from vnet_test.go rename to bwe_test.go index 785436b..0973b67 100644 --- a/vnet_test.go +++ b/bwe_test.go @@ -6,9 +6,11 @@ package bwe_test import ( + "encoding/json" "errors" "fmt" "io" + "log" "log/slog" "os" "path/filepath" @@ -20,35 +22,47 @@ import ( "github.com/stretchr/testify/assert" ) -func testLogger(t *testing.T) (*slog.Logger, func()) { - t.Helper() +var logDir string - logDir := os.Getenv("BWE_LOG_DIR") +func TestMain(m *testing.M) { + logDir = os.Getenv("BWE_LOG_DIR") if logDir == "" { - logDir = "logs" + logDir = "test-web/logs" } if err := os.MkdirAll(logDir, 0o755); err != nil { - t.Fatalf("failed to create log dir %q: %v", logDir, err) + log.Printf("failed to create log dir %q: %v", logDir, err) + os.Exit(1) } - filename := filepath.Join(logDir, fmt.Sprintf("%s.jsonl", t.Name())) - file, err := os.Create(filename) + ec := m.Run() + + files, err := filepath.Glob(filepath.Join(logDir, "*.jsonl")) if err != nil { - t.Fatalf("failed to create log file %q: %v", filename, err) + log.Printf("Failed to list JSONL files: %v", err) } - handler := slog.NewJSONHandler(file, &slog.HandlerOptions{Level: slog.LevelInfo}) - logger := slog.New(handler) + var names []string + for _, f := range files { + names = append(names, filepath.Base(f)) + } - cleanup := func() { - file.Sync() - file.Close() + b, err := json.Marshal(names) + if err != nil { + log.Printf("Failed to marshal index.json: %v", err) + os.Exit(ec) } - return logger, cleanup + indexPath := filepath.Join(logDir, "index.json") + if err := os.WriteFile(indexPath, b, 0644); err != nil { + log.Printf("Failed to write index.json: %v", err) + } else { + log.Printf("Generated index.json with %d files", len(names)) + } + + os.Exit(ec) } -func TestVnet(t *testing.T) { +func TestBWE(t *testing.T) { synctest.Test(t, func(t *testing.T) { t.Helper() @@ -98,6 +112,7 @@ func TestVnet(t *testing.T) { registerPacketLogger(logger.With("vantage-point", "sender")), registerRTPFB(), initGCC(func(rate int) { + logger.Info("setting codec target bitrate", "rate", rate) codec.setTargetBitrate(rate) }), ) @@ -148,3 +163,23 @@ func TestVnet(t *testing.T) { synctest.Wait() }) } + +func testLogger(t *testing.T) (*slog.Logger, func()) { + t.Helper() + + filename := filepath.Join(logDir, fmt.Sprintf("%s.jsonl", t.Name())) + file, err := os.Create(filename) + if err != nil { + t.Fatalf("failed to create log file %q: %v", filename, err) + } + + handler := slog.NewJSONHandler(file, &slog.HandlerOptions{Level: slog.LevelInfo}) + logger := slog.New(handler) + + cleanup := func() { + file.Sync() + file.Close() + } + + return logger, cleanup +}