diff --git a/.github/workflows/job_test_agent_local.yaml b/.github/workflows/job_test_agent_local.yaml new file mode 100644 index 0000000000..4c5276ff9e --- /dev/null +++ b/.github/workflows/job_test_agent_local.yaml @@ -0,0 +1,25 @@ +name: Test Agent Local +on: + workflow_call: + + + +jobs: + test_agent_local: + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - uses: actions/checkout@v4 + - name: Install + uses: ./.github/actions/install + with: + go: true + + + - name: Build + run: task build + working-directory: apps/agent + + - name: Test + run: go test -cover -json -timeout=60m -failfast ./pkg/... ./services/... | tparse -all -progress + working-directory: apps/agent diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index cdc796edc6..d3302a778c 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -20,7 +20,9 @@ jobs: name: Test API uses: ./.github/workflows/job_test_api_local.yaml - + test_agent_local: + name: Test Agent Local + uses: ./.github/workflows/job_test_agent_local.yaml # test_agent_integration: # name: Test Agent Integration # runs-on: ubuntu-latest diff --git a/apps/agent/pkg/circuitbreaker/lib.go b/apps/agent/pkg/circuitbreaker/lib.go index 04ff5fc0aa..327ebb2b09 100644 --- a/apps/agent/pkg/circuitbreaker/lib.go +++ b/apps/agent/pkg/circuitbreaker/lib.go @@ -169,7 +169,6 @@ func (cb *CB[Res]) preflight(ctx context.Context) error { now := cb.config.clock.Now() if now.After(cb.resetCountersAt) { - cb.logger.Info().Msg("resetting circuit breaker") cb.requests = 0 cb.successes = 0 cb.failures = 0 diff --git a/apps/agent/pkg/clock/real_clock.go b/apps/agent/pkg/clock/real_clock.go index 50e33a6a17..580be114e0 100644 --- a/apps/agent/pkg/clock/real_clock.go +++ b/apps/agent/pkg/clock/real_clock.go @@ -2,31 +2,15 @@ package clock import "time" -type TestClock struct { - now time.Time +type RealClock struct { } -func NewTestClock(now ...time.Time) *TestClock { - if len(now) == 0 { - now = append(now, time.Now()) - } - return &TestClock{now: now[0]} +func New() *RealClock { + return &RealClock{} } -var _ Clock = &TestClock{} +var _ Clock = &RealClock{} -func (c *TestClock) Now() time.Time { - return c.now -} - -// Tick advances the clock by the given duration and returns the new time. -func (c *TestClock) Tick(d time.Duration) time.Time { - c.now = c.now.Add(d) - return c.now -} - -// Set sets the clock to the given time and returns the new time. -func (c *TestClock) Set(t time.Time) time.Time { - c.now = t - return c.now +func (c *RealClock) Now() time.Time { + return time.Now() } diff --git a/apps/agent/pkg/clock/test_clock.go b/apps/agent/pkg/clock/test_clock.go index 580be114e0..50e33a6a17 100644 --- a/apps/agent/pkg/clock/test_clock.go +++ b/apps/agent/pkg/clock/test_clock.go @@ -2,15 +2,31 @@ package clock import "time" -type RealClock struct { +type TestClock struct { + now time.Time } -func New() *RealClock { - return &RealClock{} +func NewTestClock(now ...time.Time) *TestClock { + if len(now) == 0 { + now = append(now, time.Now()) + } + return &TestClock{now: now[0]} } -var _ Clock = &RealClock{} +var _ Clock = &TestClock{} -func (c *RealClock) Now() time.Time { - return time.Now() +func (c *TestClock) Now() time.Time { + return c.now +} + +// Tick advances the clock by the given duration and returns the new time. +func (c *TestClock) Tick(d time.Duration) time.Time { + c.now = c.now.Add(d) + return c.now +} + +// Set sets the clock to the given time and returns the new time. +func (c *TestClock) Set(t time.Time) time.Time { + c.now = t + return c.now } diff --git a/apps/agent/services/ratelimit/mitigate.go b/apps/agent/services/ratelimit/mitigate.go index 2d75b280f6..92c463e279 100644 --- a/apps/agent/services/ratelimit/mitigate.go +++ b/apps/agent/services/ratelimit/mitigate.go @@ -21,7 +21,6 @@ func (s *service) Mitigate(ctx context.Context, req *ratelimitv1.MitigateRequest bucket, _ := s.getBucket(bucketKey{req.Identifier, req.Limit, duration}) bucket.Lock() defer bucket.Unlock() - bucket.windows[req.Window.GetSequence()] = req.Window return &ratelimitv1.MitigateResponse{}, nil @@ -51,16 +50,20 @@ func (s *service) broadcastMitigation(req mitigateWindowRequest) { return } for _, peer := range peers { - _, err := peer.client.Mitigate(ctx, connect.NewRequest(&ratelimitv1.MitigateRequest{ - Identifier: req.identifier, - Limit: req.limit, - Duration: req.duration.Milliseconds(), - Window: req.window, - })) + _, err := s.mitigateCircuitBreaker.Do(ctx, func(innerCtx context.Context) (*connect.Response[ratelimitv1.MitigateResponse], error) { + innerCtx, cancel := context.WithTimeout(innerCtx, 10*time.Second) + defer cancel() + return peer.client.Mitigate(innerCtx, connect.NewRequest(&ratelimitv1.MitigateRequest{ + Identifier: req.identifier, + Limit: req.limit, + Duration: req.duration.Milliseconds(), + Window: req.window, + })) + }) if err != nil { s.logger.Err(err).Msg("failed to call mitigate") } else { - s.logger.Info().Str("peerId", peer.id).Msg("broadcasted mitigation") + s.logger.Debug().Str("peerId", peer.id).Msg("broadcasted mitigation") } } } diff --git a/apps/agent/services/ratelimit/ratelimit_mitigation_test.go b/apps/agent/services/ratelimit/ratelimit_mitigation_test.go index b517591d5f..a768995115 100644 --- a/apps/agent/services/ratelimit/ratelimit_mitigation_test.go +++ b/apps/agent/services/ratelimit/ratelimit_mitigation_test.go @@ -23,7 +23,7 @@ import ( ) func TestExceedingTheLimitShouldNotifyAllNodes(t *testing.T) { - t.Skip() + for _, clusterSize := range []int{1, 3, 5} { t.Run(fmt.Sprintf("Cluster Size %d", clusterSize), func(t *testing.T) { logger := logging.New(nil) @@ -94,12 +94,13 @@ func TestExceedingTheLimitShouldNotifyAllNodes(t *testing.T) { ctx := context.Background() // Saturate the window - for i := int64(0); i <= limit; i++ { + for i := int64(0); i < limit; i++ { rl := util.RandomElement(ratelimiters) res, err := rl.Ratelimit(ctx, req) require.NoError(t, err) t.Logf("saturate res: %+v", res) require.True(t, res.Success) + } time.Sleep(time.Second * 5) @@ -107,10 +108,11 @@ func TestExceedingTheLimitShouldNotifyAllNodes(t *testing.T) { // Let's hit everry node again // They should all be mitigated for i, rl := range ratelimiters { + res, err := rl.Ratelimit(ctx, req) require.NoError(t, err) t.Logf("res from %d: %+v", i, res) - // require.False(t, res.Success) + require.False(t, res.Success) } }) diff --git a/apps/agent/services/ratelimit/ratelimit_replication_test.go b/apps/agent/services/ratelimit/ratelimit_replication_test.go index cae53d6dde..8e93fc19e7 100644 --- a/apps/agent/services/ratelimit/ratelimit_replication_test.go +++ b/apps/agent/services/ratelimit/ratelimit_replication_test.go @@ -24,8 +24,7 @@ import ( "github.com/unkeyed/unkey/apps/agent/pkg/util" ) -func TestReplication(t *testing.T) { - t.Skip() +func TestSync(t *testing.T) { type Node struct { srv *service cluster cluster.Cluster @@ -106,7 +105,7 @@ func TestReplication(t *testing.T) { } // Figure out who is the origin - _, err := nodes[1].srv.Ratelimit(ctx, req) + _, err := nodes[0].srv.Ratelimit(ctx, req) require.NoError(t, err) time.Sleep(5 * time.Second) @@ -138,7 +137,6 @@ func TestReplication(t *testing.T) { require.True(t, ok) bucket.RLock() window := bucket.getCurrentWindow(now) - t.Logf("window on origin: %+v", window) counter := window.Counter bucket.RUnlock() diff --git a/apps/agent/services/ratelimit/ratelimit_test.go b/apps/agent/services/ratelimit/ratelimit_test.go index 8d8edc01f7..3338aba0fc 100644 --- a/apps/agent/services/ratelimit/ratelimit_test.go +++ b/apps/agent/services/ratelimit/ratelimit_test.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/url" - "sync" "testing" "time" @@ -28,134 +27,143 @@ import ( func TestAccuracy_fixed_time(t *testing.T) { - for _, clusterSize := range []int{5} { - t.Run(fmt.Sprintf("Cluster Size %d", clusterSize), func(t *testing.T) { - logger := logging.New(nil) - clusters := []cluster.Cluster{} - ratelimiters := []ratelimit.Service{} - serfAddrs := []string{} - - for i := range clusterSize { - c, serfAddr, rpcAddr := createCluster(t, fmt.Sprintf("node-%d", i), serfAddrs) - serfAddrs = append(serfAddrs, serfAddr) - clusters = append(clusters, c) - - rl, err := ratelimit.New(ratelimit.Config{ - Logger: logger, - Metrics: metrics.NewNoop(), - Cluster: c, - }) - require.NoError(t, err) - ratelimiters = append(ratelimiters, rl) - - srv, err := connectSrv.New(connectSrv.Config{ - Logger: logger, - Metrics: metrics.NewNoop(), - Image: "does not matter", - }) - require.NoError(t, err) - err = srv.AddService(connectSrv.NewRatelimitServer(rl, logger, "test-auth-token")) - require.NoError(t, err) - - require.NoError(t, err) - u, err := url.Parse(rpcAddr) - require.NoError(t, err) - go srv.Listen(u.Host) - - require.Eventually(t, func() bool { - client := ratelimitv1connect.NewRatelimitServiceClient(http.DefaultClient, rpcAddr) - res, livenessErr := client.Liveness(context.Background(), connect.NewRequest(&ratelimitv1.LivenessRequest{})) - require.NoError(t, livenessErr) - return res.Msg.Status == "ok" - - }, - time.Minute, 100*time.Millisecond) - - } - require.Len(t, ratelimiters, clusterSize) - require.Len(t, serfAddrs, clusterSize) - - for _, c := range clusters { - require.Eventually(t, func() bool { - return c.Size() == clusterSize - }, time.Minute, 100*time.Millisecond) - } - + for _, clusterSize := range []int{1, 3, 5} { + t.Run(fmt.Sprintf("ClusterSize:%d", clusterSize), func(t *testing.T) { for _, limit := range []int64{ 5, 10, 100, } { - for _, duration := range []time.Duration{ - 1 * time.Second, - 10 * time.Second, - 1 * time.Minute, - 5 * time.Minute, - 1 * time.Hour, - } { - for _, windows := range []int64{1, 2, 5, 10, 50} { - // Attack the ratelimit with 100x as much as it should let pass - requests := limit * windows * 100 - - for _, nIngressNodes := range []int{1, 3, clusterSize} { - if nIngressNodes > clusterSize { - nIngressNodes = clusterSize - } - t.Run(fmt.Sprintf("%d/%d ingress nodes: rate %d/%s %d requests across %d windows", - nIngressNodes, - clusterSize, - limit, - duration, - requests, - windows, - ), func(t *testing.T) { - - identifier := uid.New("test") - ingressNodes := ratelimiters[:nIngressNodes] - - now := time.Now() - end := now.Add(duration * time.Duration(windows)) - passed := int64(0) - - dt := duration * time.Duration(windows) / time.Duration(requests) - - for i := now; i.Before(end); i = i.Add(dt) { - rl := util.RandomElement(ingressNodes) - - res, err := rl.Ratelimit(context.Background(), &ratelimitv1.RatelimitRequest{ - // random time within one of the windows - Time: util.Pointer(i.UnixMilli()), - Identifier: identifier, - Limit: limit, - Duration: duration.Milliseconds(), - Cost: 1, - }) - require.NoError(t, err) - if res.Success { - passed++ - } - } + type Node struct { + srv ratelimit.Service + cluster cluster.Cluster + } - // At least 95% of the requests should pass - // lower := 0.95 - // At most 150% + 75% per additional ingress node should pass - upper := 1.50 + 1.0*float64(len(ingressNodes)-1) + nodes := []Node{} + logger := logging.New(nil) + serfAddrs := []string{} + + for i := 0; i < clusterSize; i++ { + node := Node{} + c, serfAddr, rpcAddr := createCluster(t, fmt.Sprintf("node-%d", i), serfAddrs) + serfAddrs = append(serfAddrs, serfAddr) + node.cluster = c + + srv, err := ratelimit.New(ratelimit.Config{ + Logger: logger, + Metrics: metrics.NewNoop(), + Cluster: c, + }) + require.NoError(t, err) + node.srv = srv + nodes = append(nodes, node) + + cSrv, err := connectSrv.New(connectSrv.Config{ + Logger: logger, + Metrics: metrics.NewNoop(), + Image: "does not matter", + }) + require.NoError(t, err) + err = cSrv.AddService(connectSrv.NewRatelimitServer(srv, logger, "test-auth-token")) + require.NoError(t, err) + + require.NoError(t, err) + u, err := url.Parse(rpcAddr) + require.NoError(t, err) + + go cSrv.Listen(u.Host) + + require.Eventually(t, func() bool { + client := ratelimitv1connect.NewRatelimitServiceClient(http.DefaultClient, rpcAddr) + res, livenessErr := client.Liveness(context.Background(), connect.NewRequest(&ratelimitv1.LivenessRequest{})) + require.NoError(t, livenessErr) + return res.Msg.Status == "ok" + + }, + time.Minute, 100*time.Millisecond) + } + require.Len(t, nodes, clusterSize) + require.Len(t, serfAddrs, clusterSize) - exactLimit := limit * (windows + 1) - // require.GreaterOrEqual(t, passed, int64(float64(exactLimit)*lower)) - require.LessOrEqual(t, passed, int64(float64(exactLimit)*upper)) + for _, n := range nodes { + require.Eventually(t, func() bool { + return n.cluster.Size() == clusterSize + }, time.Minute, 100*time.Millisecond) + } - }) - } + t.Run(fmt.Sprintf("limit:%d", limit), func(t *testing.T) { + + for _, duration := range []time.Duration{ + 10 * time.Second, + 1 * time.Minute, + 5 * time.Minute, + 1 * time.Hour, + } { + t.Run(fmt.Sprintf("duration:%s", duration), func(t *testing.T) { + + for _, windows := range []int64{1, 2, 5, 10, 50} { + t.Run(fmt.Sprintf("windows:%d", windows), func(t *testing.T) { + + // Attack the ratelimit with 100x as much as it should let pass + requests := limit * windows * 100 + + for _, nIngressNodes := range []int{1, 3, clusterSize} { + if nIngressNodes > clusterSize { + nIngressNodes = clusterSize + } + t.Run(fmt.Sprintf("%d/%d ingress nodes", + nIngressNodes, + clusterSize, + ), func(t *testing.T) { + + identifier := uid.New("test") + ingressNodes := nodes[:nIngressNodes] + + now := time.Now() + end := now.Add(duration * time.Duration(windows)) + passed := int64(0) + + dt := duration * time.Duration(windows) / time.Duration(requests) + + for i := now; i.Before(end); i = i.Add(dt) { + rl := util.RandomElement(ingressNodes) + + res, err := rl.srv.Ratelimit(context.Background(), &ratelimitv1.RatelimitRequest{ + // random time within one of the windows + Time: util.Pointer(i.UnixMilli()), + Identifier: identifier, + Limit: limit, + Duration: duration.Milliseconds(), + Cost: 1, + }) + require.NoError(t, err) + if res.Success { + passed++ + } + } + + lower := limit * windows + // At most 150% + 75% per additional ingress node should pass + upper := 1.50 + 1.0*float64(len(ingressNodes)-1) + + require.GreaterOrEqual(t, passed, lower) + require.LessOrEqual(t, passed, int64(float64(limit*(windows+1))*upper)) + + }) + } + }) + } + }) } + }) + for _, n := range nodes { + require.NoError(t, n.cluster.Shutdown()) } - } - for _, c := range clusters { - require.NoError(t, c.Shutdown()) } + }) } } @@ -205,35 +213,3 @@ func createCluster( return c, serfAddr, rpcAddr } - -func loadTest[T any](t *testing.T, rps int64, seconds int64, fn func() T) []T { - t.Helper() - - resultsC := make(chan T) - - var wg sync.WaitGroup - - for range seconds { - for range rps { - time.Sleep(time.Second / time.Duration(rps)) - - wg.Add(1) - go func() { - resultsC <- fn() - }() - } - } - - results := []T{} - go func() { - for res := range resultsC { - results = append(results, res) - wg.Done() - - } - }() - wg.Wait() - - return results - -} diff --git a/apps/agent/services/ratelimit/service.go b/apps/agent/services/ratelimit/service.go index d6e1bff85c..c40458a311 100644 --- a/apps/agent/services/ratelimit/service.go +++ b/apps/agent/services/ratelimit/service.go @@ -38,7 +38,8 @@ type service struct { // Store a reference leaseId -> window key leaseIdToKeyMap map[string]string - syncCircuitBreaker circuitbreaker.CircuitBreaker[*connect.Response[ratelimitv1.PushPullResponse]] + syncCircuitBreaker circuitbreaker.CircuitBreaker[*connect.Response[ratelimitv1.PushPullResponse]] + mitigateCircuitBreaker circuitbreaker.CircuitBreaker[*connect.Response[ratelimitv1.MitigateResponse]] } type Config struct { @@ -64,6 +65,15 @@ func New(cfg Config) (*service, error) { buckets: make(map[string]*bucket), leaseIdToKeyMapLock: sync.RWMutex{}, leaseIdToKeyMap: make(map[string]string), + + mitigateCircuitBreaker: circuitbreaker.New[*connect.Response[ratelimitv1.MitigateResponse]]( + "ratelimit.broadcastMitigation", + circuitbreaker.WithLogger(cfg.Logger), + circuitbreaker.WithCyclicPeriod(10*time.Second), + circuitbreaker.WithTimeout(time.Minute), + circuitbreaker.WithMaxRequests(100), + circuitbreaker.WithTripThreshold(50), + ), syncCircuitBreaker: circuitbreaker.New[*connect.Response[ratelimitv1.PushPullResponse]]( "ratelimit.syncWithOrigin", circuitbreaker.WithLogger(cfg.Logger), diff --git a/apps/agent/services/ratelimit/sliding_window.go b/apps/agent/services/ratelimit/sliding_window.go index 78bbb4337c..f8cd41b1e3 100644 --- a/apps/agent/services/ratelimit/sliding_window.go +++ b/apps/agent/services/ratelimit/sliding_window.go @@ -2,7 +2,6 @@ package ratelimit import ( "context" - "math" "time" ratelimitv1 "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1" @@ -109,6 +108,15 @@ func (r *service) CheckWindows(ctx context.Context, req ratelimitRequest) (prev return prev, curr } +// :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: +// :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: +// Experimentally, we are reverting this to fixed-window until we can get rid +// of the cloudflare cachelayer. +// +// Throughout this function there is commented out and annotated code that we +// need to reenable later. Such code is also marked with the comment "FIXED-WINDOW" +// :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: +// :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: func (r *service) Take(ctx context.Context, req ratelimitRequest) ratelimitResponse { ctx, span := tracing.Start(ctx, "slidingWindow.Take") defer span.End() @@ -127,13 +135,21 @@ func (r *service) Take(ctx context.Context, req ratelimitRequest) ratelimitRespo currentWindow := bucket.getCurrentWindow(req.Time) previousWindow := bucket.getPreviousWindow(req.Time) - currentWindowPercentage := float64(req.Time.UnixMilli()-currentWindow.Start) / float64(req.Duration.Milliseconds()) - previousWindowPercentage := 1.0 - currentWindowPercentage + // FIXED-WINDOW + // uncomment + // currentWindowPercentage := float64(req.Time.UnixMilli()-currentWindow.Start) / float64(req.Duration.Milliseconds()) + // previousWindowPercentage := 1.0 - currentWindowPercentage // Calculate the current count including all leases - fromPreviousWindow := float64(previousWindow.Counter) * previousWindowPercentage - fromCurrentWindow := float64(currentWindow.Counter) - current := int64(math.Ceil(fromCurrentWindow + fromPreviousWindow)) + // FIXED-WINDOW + // uncomment + // fromPreviousWindow := float64(previousWindow.Counter) * previousWindowPercentage + // fromCurrentWindow := float64(currentWindow.Counter) + + // FIXED-WINDOW + // replace this with the following line + // current := int64(math.Ceil(fromCurrentWindow + fromPreviousWindow)) + current := currentWindow.Counter // r.logger.Info().Int64("fromCurrentWindow", fromCurrentWindow).Int64("fromPreviousWindow", fromPreviousWindow).Time("now", req.Time).Time("currentWindow.start", currentWindow.start).Int64("msSinceStart", msSinceStart).Float64("currentWindowPercentage", currentWindowPercentage).Float64("previousWindowPercentage", previousWindowPercentage).Bool("currentWindowExists", currentWindowExists).Bool("previousWindowExists", previousWindowExists).Int64("current", current).Interface("buckets", r.buckets).Send() // currentWithLeases := id.current @@ -180,12 +196,12 @@ func (r *service) Take(ctx context.Context, req ratelimitRequest) ratelimitRespo currentWindow.Counter += req.Cost if currentWindow.Counter >= req.Limit && !currentWindow.MitigateBroadcasted && r.mitigateBuffer != nil { currentWindow.MitigateBroadcasted = true - // r.mitigateBuffer <- mitigateWindowRequest{ - // identifier: req.Identifier, - // limit: req.Limit, - // duration: req.Duration, - // window: currentWindow, - // } + r.mitigateBuffer <- mitigateWindowRequest{ + identifier: req.Identifier, + limit: req.Limit, + duration: req.Duration, + window: currentWindow, + } } current += req.Cost @@ -264,6 +280,7 @@ func (r *service) SetCounter(ctx context.Context, requests ...setCounterRequest) func newWindow(sequence int64, t time.Time, duration time.Duration) *ratelimitv1.Window { return &ratelimitv1.Window{ + Sequence: sequence, MitigateBroadcasted: false, Start: t.Truncate(duration).UnixMilli(), Duration: duration.Milliseconds(), diff --git a/apps/api/src/pkg/ratelimit/client.ts b/apps/api/src/pkg/ratelimit/client.ts index 39317fc2ef..fc231d9d04 100644 --- a/apps/api/src/pkg/ratelimit/client.ts +++ b/apps/api/src/pkg/ratelimit/client.ts @@ -3,6 +3,7 @@ import type { Logger } from "@unkey/worker-logging"; import type { Metrics } from "../metrics"; import type { Context } from "../hono/app"; +import { retry } from "../util/retry"; import { Agent } from "./agent"; import { type RateLimiter, @@ -14,13 +15,13 @@ import { export class AgentRatelimiter implements RateLimiter { private readonly logger: Logger; private readonly metrics: Metrics; - private readonly cache: Map; + private readonly cache: Map; private readonly agent: Agent; constructor(opts: { agent: { url: string; token: string }; logger: Logger; metrics: Metrics; - cache: Map; + cache: Map; }) { this.logger = opts.logger; this.metrics = opts.metrics; @@ -35,7 +36,7 @@ export class AgentRatelimiter implements RateLimiter { return [req.identifier, window, req.shard].join("::"); } - private setCache(id: string, current: number, reset: number, blocked: boolean) { + private setCacheMax(id: string, current: number, reset: number) { const maxEntries = 10_000; this.metrics.emit({ metric: "metric.cache.size", @@ -54,7 +55,11 @@ export class AgentRatelimiter implements RateLimiter { } } } - this.cache.set(id, { reset, current, blocked }); + const cached = this.cache.get(id) ?? { reset: 0, current: 0 }; + if (current > cached.current) { + this.cache.set(id, { reset, current }); + return current; + } } public async limit( @@ -122,8 +127,8 @@ export class AgentRatelimiter implements RateLimiter { * This might not happen too often, but in extreme cases the cache should hit and we can skip * the request to the durable object entirely, which speeds everything up and is cheaper for us */ - const cached = this.cache.get(id) ?? { current: 0, reset: 0, blocked: false }; - if (cached.blocked) { + const cached = this.cache.get(id) ?? { current: 0, reset: 0 }; + if (cached.current >= req.limit) { return Ok({ pass: false, current: cached.current, @@ -133,31 +138,22 @@ export class AgentRatelimiter implements RateLimiter { }); } - const p = (async () => { - const a = await this.callAgent(c, { + const p = retry(3, async () => + this.callAgent(c, { requestId: c.get("requestId"), identifier: req.identifier, cost, duration: req.interval, limit: req.limit, name: req.name, - }); - if (a.err) { + }).catch((err) => { this.logger.error("error calling agent", { - error: a.err.message, - json: JSON.stringify(a.err), - }); - return await this.callAgent(c, { - requestId: c.get("requestId"), - identifier: req.identifier, - cost, - duration: req.interval, - limit: req.limit, - name: req.name, + error: err.message, + json: JSON.stringify(err), }); - } - return a; - })(); + throw err; + }), + ); // A rollout of the sync rate limiting // Isolates younger than 60s must not sync. It would cause a stampede of requests as the cache is entirely empty @@ -169,7 +165,7 @@ export class AgentRatelimiter implements RateLimiter { if (sync) { const res = await p; if (res.val) { - this.setCache(id, res.val.current, res.val.reset, !res.val.pass); + this.setCacheMax(id, res.val.current, res.val.reset); } return res; } @@ -180,7 +176,7 @@ export class AgentRatelimiter implements RateLimiter { this.logger.error(res.err.message); return; } - this.setCache(id, res.val.current, res.val.reset, !res.val.pass); + this.setCacheMax(id, res.val.current, res.val.reset); this.metrics.emit({ workspaceId: req.workspaceId, @@ -203,7 +199,7 @@ export class AgentRatelimiter implements RateLimiter { }); } cached.current += cost; - this.setCache(id, cached.current, reset, false); + this.setCacheMax(id, cached.current, reset); return Ok({ pass: true,