diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/Main.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/Main.java index acac265e58..7614c4e10f 100644 --- a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/Main.java +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/Main.java @@ -14,6 +14,9 @@ import com.akto.threat.backend.service.ThreatActorService; import com.akto.threat.backend.service.ThreatApiService; import com.akto.threat.backend.tasks.FlushMessagesToDB; +import com.akto.threat.backend.cron.PercentilesCron; +import com.akto.util.AccountTask; +import com.akto.dto.Account; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; import com.mongodb.ReadPreference; @@ -22,6 +25,7 @@ import com.mongodb.client.MongoClients; import org.bson.codecs.configuration.CodecRegistry; import org.bson.codecs.pojo.PojoCodecProvider; +import java.util.function.Consumer; public class Main { @@ -72,6 +76,25 @@ public static void main(String[] args) throws Exception { ThreatApiService threatApiService = new ThreatApiService(threatProtectionMongo); ApiDistributionDataService apiDistributionDataService = new ApiDistributionDataService(threatProtectionMongo); + try { + PercentilesCron percentilesCron = new PercentilesCron(threatProtectionMongo); + logger.infoAndAddToDb("Starting PercentilesCron for all accounts", com.akto.log.LoggerMaker.LogDb.RUNTIME); + AccountTask.instance.executeTask(new Consumer() { + @Override + public void accept(Account account) { + try { + String accountDb = String.valueOf(account.getId()); + percentilesCron.cron(accountDb); + logger.infoAndAddToDb("Scheduled PercentilesCron for account " + accountDb, com.akto.log.LoggerMaker.LogDb.RUNTIME); + } catch (Exception e) { + logger.errorAndAddToDb("Failed scheduling PercentilesCron for account: " + account.getId() + " due to: " + e.getMessage(), com.akto.log.LoggerMaker.LogDb.RUNTIME); + } + } + }, "percentiles-cron"); + } catch (Exception e) { + logger.errorAndAddToDb("Error starting PercentilesCron: " + e.getMessage(), com.akto.log.LoggerMaker.LogDb.RUNTIME); + } + new BackendVerticle(maliciousEventService, threatActorService, threatApiService, apiDistributionDataService).start(); } diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/cron/PercentilesCron.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/cron/PercentilesCron.java new file mode 100644 index 0000000000..f69f9871aa --- /dev/null +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/cron/PercentilesCron.java @@ -0,0 +1,217 @@ +package com.akto.threat.backend.cron; + +import com.akto.log.LoggerMaker; +import com.akto.log.LoggerMaker.LogDb; +import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.BucketStats; +import com.akto.dao.ApiInfoDao; +import com.akto.threat.backend.db.ApiDistributionDataModel; +import com.akto.threat.backend.service.ApiDistributionDataService; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoCursor; +import com.mongodb.client.model.Filters; +import org.bson.conversions.Bson; + +import java.util.*; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import com.mongodb.client.model.Updates; +import com.mongodb.client.model.UpdateOptions; +import com.akto.utils.ThreatApiDistributionUtils; +import com.akto.dao.context.Context; + +public class PercentilesCron { + + private static final LoggerMaker logger = new LoggerMaker(PercentilesCron.class, LogDb.THREAT_DETECTION); + private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); + private final MongoClient mongoClient; + public static final int DEFAULT_BASELINE_DAYS = 2; + private static final int MIN_INITIAL_AGE_DAYS = 2; + + public PercentilesCron(MongoClient mongoClient) { + this.mongoClient = mongoClient; + } + + public void cron(String accountId) { + scheduler.scheduleAtFixedRate(new Runnable() { + public void run() { + try { + try { + int accId = Integer.parseInt(accountId); + Context.accountId.set(accId); + } catch (Exception ignore) { + // keep context unset if accountId isn't a number + } + runOnce(accountId); + } catch (Exception e) { + logger.errorAndAddToDb("error in PercentilesCron: accountId " + accountId + " " + e.getMessage()); + } finally { + Context.resetContextThreadLocals(); + } + } + }, 0, 2, TimeUnit.DAYS); + } + + + public void runOnce(String accountId) { + MongoCollection coll = this.mongoClient + .getDatabase(accountId) + .getCollection("api_distribution_data", ApiDistributionDataModel.class); + + Set keys = new HashSet<>(); + try (MongoCursor cursor = coll.find().iterator()) { + while (cursor.hasNext()) { + ApiDistributionDataModel doc = cursor.next(); + String key = doc.apiCollectionId + "|" + doc.url + "|" + doc.method; + keys.add(key); + } + } + + for (String key : keys) { + String[] parts = key.split("\\|", -1); + int apiCollectionId = Integer.parseInt(parts[0]); + String url = parts[1]; + String method = parts[2]; + + for (int windowSize : Arrays.asList(5, 15, 30)) { + // Ensure there exists at least one record that is MIN_INITIAL_AGE_DAYS old for this window size + if (!hasMinimumInitialAge(accountId, apiCollectionId, url, method, MIN_INITIAL_AGE_DAYS, windowSize)) { + logger.infoAndAddToDb("Skipping rateLimits update due to insufficient data age for apiCollectionId " + apiCollectionId + + " url " + url + " method " + method + " windowSize " + windowSize); + continue; + } + + // Fetch last baseline days of distribution data for this window size + List distributionData = fetchBucketStats(DEFAULT_BASELINE_DAYS, accountId, apiCollectionId, url, method, windowSize); + PercentilesResult r = calculatePercentiles(distributionData); + + updateApiInfo(r, apiCollectionId, url, method, windowSize); + } + } + } + + /** + * Updates ApiInfo collection with the given percentiles. + */ + public void updateApiInfo(PercentilesResult r, int apiCollectionId, String url, String method, int windowSize) { + try { + ApiInfoDao.instance.getMCollection().updateOne( + ApiInfoDao.getFilter(url, method, apiCollectionId), + Updates.combine( + Updates.set("rateLimits." + windowSize + ".p50", r.p50), + Updates.set("rateLimits." + windowSize + ".p75", r.p75), + Updates.set("rateLimits." + windowSize + ".p90", r.p90), + Updates.set("rateLimits." + windowSize + ".max_requests", r.maxRequests) + ), + new UpdateOptions().upsert(false) + ); + logger.infoAndAddToDb("Updated rateLimits for apiCollectionId " + apiCollectionId + " url " + url + " method " + method + " windowSize " + windowSize, + LoggerMaker.LogDb.RUNTIME); + } catch (Exception e) { + logger.errorAndAddToDb("Failed updating rateLimits for apiCollectionId " + apiCollectionId + " url " + url + " method " + method + " windowSize " + windowSize + ": " + e.getMessage(), + LoggerMaker.LogDb.RUNTIME); + } + } + + + private long getWindowStartForBaselinePeriod(int baselinePeriodDays) { + // We store windowStart as epoch/60 (minutes since epoch). + long currentMinutesSinceEpoch = Context.now() / 60; + long baselineMinutes = (long) baselinePeriodDays * 24L * 60L; + long lowerBoundWindowStart = currentMinutesSinceEpoch - baselineMinutes; + + return lowerBoundWindowStart; + } + + /** + * Fetches distribution documents for the given API over the past baseLinePeriod days. + */ + public List fetchBucketStats(int baseLinePeriod, String accountId, int apiCollectionId, String url, String method, int windowSize) { + + Bson filter = Filters.and( + Filters.eq("apiCollectionId", apiCollectionId), + Filters.eq("url", url), + Filters.eq("method", method), + Filters.eq("windowSize", windowSize), + Filters.gte("windowStart", (int) getWindowStartForBaselinePeriod(baseLinePeriod)) + ); + + return ApiDistributionDataService.fetchBucketStats(accountId, filter, mongoClient); + } + + + + /** + * Returns true if there exists at least one record with windowStart timestamp + * that is at least minAgeDays old from now for the given API key. + */ + public boolean hasMinimumInitialAge(String accountId, int apiCollectionId, String url, String method, int minAgeDays, int windowSize) { + MongoCollection coll = this.mongoClient + .getDatabase(accountId) + .getCollection("api_distribution_data", ApiDistributionDataModel.class); + + Bson filter = Filters.and( + Filters.eq("apiCollectionId", apiCollectionId), + Filters.eq("url", url), + Filters.eq("method", method), + Filters.eq("windowSize", windowSize), + Filters.lte("windowStart", (int) getWindowStartForBaselinePeriod(minAgeDays)) + ); + + try (MongoCursor cursor = coll.find(filter).limit(1).iterator()) { + return cursor.hasNext(); + } + } + + /** + * Calculate percentiles from a list of distribution docs. + */ + public PercentilesResult calculatePercentiles(List bucketStats) { + + long totalUsers = 0; + + /** + * (288 windows in a day for every 5 minutes) + * Time:5:00, 5:05, .. 7:00, 8:00, 9:00 + * Example: B1(500-1000 Api Calls)-> [ 39, 20, .. 40, 100K, 5k] + * + * TODO: What value should we pick for number of users from each bucket windows??? + * Choosing p75 for now + */ + for (BucketStats bstats : bucketStats) totalUsers += bstats.getP75(); + if (totalUsers <= 0) return new PercentilesResult(-1, -1, -1, -1); + + double p50Target = totalUsers * 0.5d; + double p75Target = totalUsers * 0.75d; + double p90Target = totalUsers * 0.9d; + + long cumulative = 0; + Integer p50Val = null, p75Val = null, p90Val = null; + + bucketStats.sort(Comparator.comparingInt(b -> Integer.parseInt(b.getBucketLabel().substring(1)))); + + for (BucketStats bstats: bucketStats) { + long countInBucket = bstats.getP75(); + cumulative += countInBucket; + if (p50Val == null && cumulative >= p50Target) p50Val = ThreatApiDistributionUtils.getBucketUpperBound(bstats.getBucketLabel()); + if (p75Val == null && cumulative >= p75Target) p75Val = ThreatApiDistributionUtils.getBucketUpperBound(bstats.getBucketLabel()); + if (p90Val == null && cumulative >= p90Target) { p90Val = ThreatApiDistributionUtils.getBucketUpperBound(bstats.getBucketLabel()); break; } + } + + // If percentiles not found, use the last bucket's upper bound (max value) + if (p50Val == null) p50Val = ThreatApiDistributionUtils.getBucketUpperBound("b14"); + if (p75Val == null) p75Val = ThreatApiDistributionUtils.getBucketUpperBound("b14"); + if (p90Val == null) p90Val = ThreatApiDistributionUtils.getBucketUpperBound("b14"); + + return new PercentilesResult(p50Val, p75Val, p90Val, -1); + } + + public static class PercentilesResult { + final int p50; + final int p75; + final int p90; + final int maxRequests; + public PercentilesResult(int p50, int p75, int p90, int maxRequests) { this.p50 = p50; this.p75 = p75; this.p90 = p90; this.maxRequests = maxRequests; } + } +} diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/db/ApiRateLimitBucketStatisticsModel.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/db/ApiRateLimitBucketStatisticsModel.java new file mode 100644 index 0000000000..51e596852c --- /dev/null +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/db/ApiRateLimitBucketStatisticsModel.java @@ -0,0 +1,184 @@ +package com.akto.threat.backend.db; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.ApiDistributionDataRequestPayload; +import com.akto.threat.backend.cron.PercentilesCron; +import com.akto.utils.ThreatApiDistributionUtils; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.ReplaceOptions; +import com.mongodb.client.model.Filters; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ApiRateLimitBucketStatisticsModel { + + public static final String ID = "_id"; + public static final String BUCKETS = "buckets"; + + private String id; // Format: apiCollectionId_method_url_windowSize + private List buckets; + private float rateLimitConfidence; + + @Data + @NoArgsConstructor + @AllArgsConstructor + public static class Bucket { + public static final String LABEL = "bucketLabel"; + public static final String USER_COUNTS = "userCounts"; + public static final String STATS = "stats"; + + // SeeThreatApiDistributionUtils.BUCKET_RANGES + private String label; + private List userCounts; + private Stats stats; + } + + @Data + @NoArgsConstructor + @AllArgsConstructor + public static class UserCountData { + public static final String USERS = "users"; + public static final String WINDOW_START = "windowStart"; + + private int users; + private int windowStart; + } + + @Data + @NoArgsConstructor + @AllArgsConstructor + public static class Stats { + public static final String MIN = "min"; + public static final String MAX = "max"; + public static final String P25 = "p25"; + public static final String P50 = "p50"; + public static final String P75 = "p75"; + + private int min; + private int max; + private int p25; + private int p50; + private int p75; + } + + public static String getBucketStatsDocIdForApi(int apiCollectionId, String method, String url, int windowSize){ + return String.valueOf(apiCollectionId) + "_" + method + "_" + url + "_" + windowSize; + } + + public static void calculateStatistics( + String accountId, + MongoClient mongoClient, + Map> frequencyBuckets) { + if (frequencyBuckets == null || frequencyBuckets.isEmpty()) return; + + MongoCollection coll = mongoClient + .getDatabase(accountId) + .getCollection("api_rate_limit_bucket_statistics", ApiRateLimitBucketStatisticsModel.class); + + for (Map.Entry> entry : frequencyBuckets.entrySet()) { + String docId = entry.getKey(); + List updates = entry.getValue(); + if (updates == null || updates.isEmpty()) continue; + + ApiRateLimitBucketStatisticsModel doc = Optional.ofNullable(coll.find(Filters.eq(ID, docId)).first()) + .orElseGet(() -> { + ApiRateLimitBucketStatisticsModel m = new ApiRateLimitBucketStatisticsModel(); + m.id = docId; + m.buckets = new ArrayList<>(); + m.rateLimitConfidence = 0.0f; + // Initialize all standard buckets + for (ThreatApiDistributionUtils.Range range : ThreatApiDistributionUtils.getBucketRanges()) { + m.buckets.add(new Bucket(range.label, new ArrayList<>(), new Stats(0,0,0,0,0))); + } + return m; + }); + + doc = applyUpdates(doc, updates); + + coll.replaceOne(Filters.eq(ID, docId), doc, new ReplaceOptions().upsert(true)); + } + } + + static ApiRateLimitBucketStatisticsModel applyUpdates(ApiRateLimitBucketStatisticsModel doc, List updates) { + if (doc == null) { + ApiRateLimitBucketStatisticsModel m = new ApiRateLimitBucketStatisticsModel(); + m.buckets = new ArrayList<>(); + for (ThreatApiDistributionUtils.Range range : ThreatApiDistributionUtils.getBucketRanges()) { + m.buckets.add(new Bucket(range.label, new ArrayList<>(), new Stats(0,0,0,0,0))); + } + doc = m; + } + + int windowSize = updates.get(0).getWindowSize(); + int capacity = capacityForWindowSize(windowSize); + + for (ApiDistributionDataRequestPayload.DistributionData u : updates) { + int windowStart = (int) u.getWindowStartEpochMin(); + Map dist = u.getDistributionMap(); + + for (Bucket bucket : doc.buckets) { + int users = dist.getOrDefault(bucket.label, 0); + upsertUserCount(bucket.userCounts, windowStart, users); + evictToCapacity(bucket.userCounts, capacity); + } + } + + for (Bucket b : doc.buckets) { + List values = b.userCounts.stream().map(uc -> uc.users).collect(Collectors.toList()); + if (values.isEmpty()) { + b.stats = new Stats(0,0,0,0,0); + } else { + Collections.sort(values); + int min = values.get(0); + int max = values.get(values.size() - 1); + int p25 = ThreatApiDistributionUtils.percentile(values, 25); + int p50 = ThreatApiDistributionUtils.percentile(values, 50); + int p75 = ThreatApiDistributionUtils.percentile(values, 75); + b.stats = new Stats(min, max, p25, p50, p75); + } + } + + return doc; + } + + + private static void upsertUserCount(List list, int windowStart, int users) { + if (list == null) return; + int idx = Collections.binarySearch(list, new UserCountData(0, windowStart), Comparator.comparingInt(a -> a.windowStart)); + if (idx >= 0) { + list.get(idx).users = users; + } else { + int insertAt = -idx - 1; + list.add(insertAt, new UserCountData(users, windowStart)); + } + } + + // Remove old windows from the start. + private static void evictToCapacity(List list, int capacity) { + if (list == null) return; + while (list.size() > capacity) { + list.remove(0); + } + } + + private static int capacityForWindowSize(int windowSize) { + // Calculate number of windows in 5, 15, 30 minutes. + // Ex: 5 minute windows for 2 days will have 576 capacity. + + // TODO: Pick this from threat configuration instead. + int approx = (PercentilesCron.DEFAULT_BASELINE_DAYS * 24 * 60) / Math.max(1, windowSize); + return Math.max(1, approx); + } + +} diff --git a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ApiDistributionDataService.java b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ApiDistributionDataService.java index c135774ed2..c2598941b1 100644 --- a/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ApiDistributionDataService.java +++ b/apps/threat-detection-backend/src/main/java/com/akto/threat/backend/service/ApiDistributionDataService.java @@ -15,6 +15,8 @@ import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.FetchApiDistributionDataRequest; import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.FetchApiDistributionDataResponse; import com.akto.threat.backend.db.ApiDistributionDataModel; +import com.akto.threat.backend.db.ApiRateLimitBucketStatisticsModel; +import com.akto.utils.ThreatApiDistributionUtils; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoCursor; @@ -36,6 +38,9 @@ public ApiDistributionDataService(MongoClient mongoClient) { public ApiDistributionDataResponsePayload saveApiDistributionData(String accountId, ApiDistributionDataRequestPayload payload) { List> bulkUpdates = new ArrayList<>(); + + + Map> frequencyBuckets = new HashMap<>(); for (ApiDistributionDataRequestPayload.DistributionData protoData : payload.getDistributionDataList()) { Bson filter = Filters.and( @@ -55,6 +60,12 @@ public ApiDistributionDataResponsePayload saveApiDistributionData(String account Updates.set("windowStart", protoData.getWindowStartEpochMin()) ); + + frequencyBuckets.computeIfAbsent( + ApiRateLimitBucketStatisticsModel.getBucketStatsDocIdForApi(protoData.getApiCollectionId(), + protoData.getMethod(), protoData.getUrl(), protoData.getWindowSize()), + k -> new ArrayList<>()).add(protoData); + UpdateOptions options = new UpdateOptions().upsert(true); bulkUpdates.add(new UpdateOneModel<>(filter, update, options)); @@ -65,52 +76,43 @@ public ApiDistributionDataResponsePayload saveApiDistributionData(String account .getCollection("api_distribution_data", ApiDistributionDataModel.class) .bulkWrite(bulkUpdates, new BulkWriteOptions().ordered(false)); + ApiRateLimitBucketStatisticsModel.calculateStatistics(accountId, this.mongoClient, frequencyBuckets); return ApiDistributionDataResponsePayload.newBuilder().build(); } - public FetchApiDistributionDataResponse getDistributionStats(String accountId, FetchApiDistributionDataRequest fetchApiDistributionDataRequest) { - - MongoCollection coll = - this.mongoClient - .getDatabase(accountId) - .getCollection("api_distribution_data", ApiDistributionDataModel.class); - - Bson filter = Filters.and( - Filters.eq("apiCollectionId", fetchApiDistributionDataRequest.getApiCollectionId()), - Filters.eq("url", fetchApiDistributionDataRequest.getUrl()), - Filters.eq("method", fetchApiDistributionDataRequest.getMethod()), - Filters.eq("windowSize", 5), - Filters.gte("windowStart", fetchApiDistributionDataRequest.getStartWindow()), - Filters.lte("windowStart", fetchApiDistributionDataRequest.getEndWindow()) - ); + public static List fetchBucketStats(String accountId, Bson filters, MongoClient mongoClient) { + MongoCollection coll = mongoClient + .getDatabase(accountId) + .getCollection("api_distribution_data", ApiDistributionDataModel.class); Map> bucketToValues = new HashMap<>(); - try (MongoCursor cursor = coll.find(filter).iterator()) { + try (MongoCursor cursor = coll.find(filters).iterator()) { while (cursor.hasNext()) { ApiDistributionDataModel doc = cursor.next(); - if (doc.distribution == null) continue; + if (doc.distribution == null) + continue; for (Map.Entry entry : doc.distribution.entrySet()) { bucketToValues.computeIfAbsent(entry.getKey(), k -> new ArrayList<>()) - .add(entry.getValue()); + .add(entry.getValue()); } } } - FetchApiDistributionDataResponse.Builder responseBuilder = FetchApiDistributionDataResponse.newBuilder(); + List bucketStats = new ArrayList<>(); for (Map.Entry> entry : bucketToValues.entrySet()) { String bucket = entry.getKey(); List values = entry.getValue(); if (values.isEmpty()) continue; - + Collections.sort(values); int min = values.get(0); int max = values.get(values.size() - 1); - int p25 = percentile(values, 25); - int p50 = percentile(values, 50); // median - int p75 = percentile(values, 75); - + int p25 = ThreatApiDistributionUtils.percentile(values, 25); + int p50 = ThreatApiDistributionUtils.percentile(values, 50); // median + int p75 = ThreatApiDistributionUtils.percentile(values, 75); + BucketStats stats = BucketStats.newBuilder() .setBucketLabel(bucket) .setMin(min) @@ -119,17 +121,28 @@ public FetchApiDistributionDataResponse getDistributionStats(String accountId, F .setP50(p50) .setP75(p75) .build(); - - responseBuilder.addBucketStats(stats); + + bucketStats.add(stats); } + + return bucketStats; - return responseBuilder.build(); } - private int percentile(List sorted, int p) { - if (sorted.isEmpty()) return 0; - int index = (int) Math.ceil(p / 100.0 * sorted.size()) - 1; - return sorted.get(Math.max(0, Math.min(index, sorted.size() - 1))); + public FetchApiDistributionDataResponse getDistributionStats(String accountId, FetchApiDistributionDataRequest fetchApiDistributionDataRequest) { + Bson filter = Filters.and( + Filters.eq("apiCollectionId", fetchApiDistributionDataRequest.getApiCollectionId()), + Filters.eq("url", fetchApiDistributionDataRequest.getUrl()), + Filters.eq("method", fetchApiDistributionDataRequest.getMethod()), + Filters.eq("windowSize", 5), + Filters.gte("windowStart", 1726461999 / 60), + Filters.lte("windowStart", 1757997999 / 60) + ); + + FetchApiDistributionDataResponse.Builder responseBuilder = FetchApiDistributionDataResponse.newBuilder(); + responseBuilder.addAllBucketStats(fetchBucketStats(accountId, filter, this.mongoClient)); + + return responseBuilder.build(); } } diff --git a/apps/threat-detection-backend/src/test/java/com/akto/threat/backend/cron/PercentilesCronTest.java b/apps/threat-detection-backend/src/test/java/com/akto/threat/backend/cron/PercentilesCronTest.java new file mode 100644 index 0000000000..42208bb74b --- /dev/null +++ b/apps/threat-detection-backend/src/test/java/com/akto/threat/backend/cron/PercentilesCronTest.java @@ -0,0 +1,190 @@ +package com.akto.threat.backend.cron; + +import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.BucketStats; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Field; +import java.util.*; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class PercentilesCronTest { + + private static BucketStats createBucketStats(String bucketLabel, int min, int max, int p25, int p50, int p75) { + return BucketStats.newBuilder() + .setBucketLabel(bucketLabel) + .setMin(min) + .setMax(max) + .setP25(p25) + .setP50(p50) + .setP75(p75) + .build(); + } + + private static int[] extractPercentiles(Object result) throws Exception { + Class c = result.getClass(); + Field p50 = c.getDeclaredField("p50"); + Field p75 = c.getDeclaredField("p75"); + Field p90 = c.getDeclaredField("p90"); + p50.setAccessible(true); + p75.setAccessible(true); + p90.setAccessible(true); + return new int[] { (int) p50.get(result), (int) p75.get(result), (int) p90.get(result) }; + } + + @Test + public void returnsNegativesForEmptyData() throws Exception { + PercentilesCron cron = new PercentilesCron(null); + List data = Collections.emptyList(); + + Object result = cron.calculatePercentiles(data); + int[] vals = extractPercentiles(result); + + assertEquals(-1, vals[0]); + assertEquals(-1, vals[1]); + assertEquals(-1, vals[2]); + } + + @Test + public void simpleThreeBucketDistribution() throws Exception { + // Using p75 values: b1=100, b2=200, b3=100 (total 400 users) + // p50 target = 200 -> falls in b2 (cumulative 300) => upper bound 50 + // p75 target = 300 -> falls in b2 (cumulative 300) => upper bound 50 + // p90 target = 360 -> falls in b3 (cumulative 400) => upper bound 100 + PercentilesCron cron = new PercentilesCron(null); + List data = Arrays.asList( + createBucketStats("b1", 80, 120, 90, 95, 100), // p75=100 users + createBucketStats("b2", 150, 250, 175, 190, 200), // p75=200 users + createBucketStats("b3", 70, 120, 85, 90, 100) // p75=100 users + ); + + Object result = cron.calculatePercentiles(data); + int[] vals = extractPercentiles(result); + + // b1 upper bound = 10, b2 upper bound = 50, b3 upper bound = 100 + assertEquals(50, vals[0]); // p50 + assertEquals(50, vals[1]); // p75 + assertEquals(100, vals[2]); // p90 + } + + @Test + public void handlesSparseAndMissingBuckets() throws Exception { + // Only b4 and b6 present using p75 values: b4=200, b6=100 (total 300) + // Upper bounds: b4=250, b6=1000 + // p50=150 -> falls in b4 (cumulative 200) => 250 + // p75=225 -> falls in b6 (cumulative 300) => 1000 + // p90=270 -> falls in b6 (cumulative 300) => 1000 + PercentilesCron cron = new PercentilesCron(null); + List data = Arrays.asList( + createBucketStats("b4", 180, 220, 190, 195, 200), // p75=200 + createBucketStats("b6", 80, 120, 90, 95, 100) // p75=100 + ); + + Object result = cron.calculatePercentiles(data); + int[] vals = extractPercentiles(result); + + assertEquals(250, vals[0]); // p50 + assertEquals(1000, vals[1]); // p75 + assertEquals(1000, vals[2]); // p90 + } + + @Test + public void reachesMaxUpperBoundWhenNeeded() throws Exception { + // Large counts only in the last bucket b14 + // b14 upper bound is Integer.MAX_VALUE + PercentilesCron cron = new PercentilesCron(null); + List data = Arrays.asList( + createBucketStats("b14", 9000, 11000, 9500, 9750, 10000) // p75=10000 + ); + + Object result = cron.calculatePercentiles(data); + int[] vals = extractPercentiles(result); + + assertEquals(Integer.MAX_VALUE, vals[0]); + assertEquals(Integer.MAX_VALUE, vals[1]); + assertEquals(Integer.MAX_VALUE, vals[2]); + } + + @Test + public void exactBoundaryTargetsChooseUpperBoundOfThatBucket() throws Exception { + // b1: p75=50, b2: p75=25, b3: p75=25 (total 100) + // Targets: p50=50, p75=75, p90=90 + // Cumulative: b1=50, b1+b2=75, b1+b2+b3=100 + // p50=50 -> exactly at b1 end => upper bound 10 + // p75=75 -> exactly at b2 end => upper bound 50 + // p90=90 -> falls in b3 => upper bound 100 + PercentilesCron cron = new PercentilesCron(null); + List data = Arrays.asList( + createBucketStats("b1", 40, 60, 45, 48, 50), // p75=50 + createBucketStats("b2", 20, 30, 22, 24, 25), // p75=25 + createBucketStats("b3", 20, 30, 22, 24, 25) // p75=25 + ); + + Object result = cron.calculatePercentiles(data); + int[] vals = extractPercentiles(result); + + assertEquals(10, vals[0]); // p50 + assertEquals(50, vals[1]); // p75 + assertEquals(100, vals[2]); // p90 + } + + @Test + public void testWithRealWorldExample() throws Exception { + // Simulating the example from the algorithm discussion + // Total users using p75: 19+33+54+78+97+105+97+78+54+33+20+12+9+7 = 696 + PercentilesCron cron = new PercentilesCron(null); + List data = Arrays.asList( + createBucketStats("b1", 4, 24, 9, 14, 19), // p75=19 + createBucketStats("b2", 18, 38, 23, 28, 33), // p75=33 + createBucketStats("b4", 63, 83, 68, 73, 78), // p75=78 + createBucketStats("b5", 82, 102, 87, 92, 97), // p75=97 + createBucketStats("b6", 90, 110, 95, 100, 105), // p75=105 + createBucketStats("b7", 82, 102, 87, 92, 97), // p75=97 + createBucketStats("b10", 18, 38, 23, 28, 33), // p75=33 + createBucketStats("b11", 4, 32, 10, 15, 20), // p75=20 + createBucketStats("b8", 63, 83, 68, 73, 78), // p75=78 + createBucketStats("b9", 39, 59, 44, 49, 54), // p75=54 + createBucketStats("b12", 0, 24, 2, 7, 12), // p75=12 + createBucketStats("b13", 0, 21, 0, 4, 9), // p75=9 + createBucketStats("b14", 0, 20, 0, 3, 7), // p75=7 + createBucketStats("b3", 39, 59, 44, 49, 54) // p75=54 + ); + + Object result = cron.calculatePercentiles(data); + int[] vals = extractPercentiles(result); + + // Total = 696 + // p50 = 348 -> should fall in b6 (cumulative reaches 386) + // p75 = 522 -> should fall in b8 (cumulative reaches 561) + // p90 = 626.4 -> should fall in b10 (cumulative reaches 648) + + // Verify the algorithm places percentiles in reasonable buckets + // b6 upper = 1000, b8 upper = 5000, b10 upper = 20000 + assertEquals(1000, vals[0]); // p50 in b6 + assertEquals(5000, vals[1]); // p75 in b8 + assertEquals(20000, vals[2]); // p90 in b10 + } + + @Test + public void handlesZeroP75Values() throws Exception { + // Test when some buckets have p75=0 + PercentilesCron cron = new PercentilesCron(null); + List data = Arrays.asList( + createBucketStats("b1", 0, 0, 0, 0, 0), // p75=0 + createBucketStats("b2", 10, 20, 12, 15, 18), // p75=18 + createBucketStats("b3", 0, 0, 0, 0, 0), // p75=0 + createBucketStats("b4", 5, 10, 6, 7, 8) // p75=8 + ); + + Object result = cron.calculatePercentiles(data); + int[] vals = extractPercentiles(result); + + // Total = 0 + 18 + 0 + 8 = 26 + // p50 = 13 -> falls in b2 + // p75 = 19.5 -> falls in b4 + // p90 = 23.4 -> falls in b4 + assertEquals(50, vals[0]); // p50 in b2 + assertEquals(250, vals[1]); // p75 in b4 + assertEquals(250, vals[2]); // p90 in b4 + } +} \ No newline at end of file diff --git a/apps/threat-detection-backend/src/test/java/com/akto/threat/backend/db/ApiRateLimitBucketStatisticsModelTest.java b/apps/threat-detection-backend/src/test/java/com/akto/threat/backend/db/ApiRateLimitBucketStatisticsModelTest.java new file mode 100644 index 0000000000..bb71e47d77 --- /dev/null +++ b/apps/threat-detection-backend/src/test/java/com/akto/threat/backend/db/ApiRateLimitBucketStatisticsModelTest.java @@ -0,0 +1,203 @@ +package com.akto.threat.backend.db; + +import com.akto.proto.generated.threat_detection.service.dashboard_service.v1.ApiDistributionDataRequestPayload; +import com.akto.threat.backend.cron.PercentilesCron; +import com.akto.utils.ThreatApiDistributionUtils; +import org.junit.jupiter.api.Test; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; + +public class ApiRateLimitBucketStatisticsModelTest { + + private ApiDistributionDataRequestPayload.DistributionData dd(int collectionId, String url, String method, int windowSize, int windowStart, Map dist){ + return ApiDistributionDataRequestPayload.DistributionData.newBuilder() + .setApiCollectionId(collectionId) + .setUrl(url) + .setMethod(method) + .setWindowSize(windowSize) + .setWindowStartEpochMin(windowStart) + .putAllDistribution(dist) + .build(); + } + + private List generateRandomTwoDays(int collectionId, String url, String method, long startWindowStartEpochMin, long seed) { + final int windowSize = 5; + // 2 days 5 minute windows = 576 + final int windows = (PercentilesCron.DEFAULT_BASELINE_DAYS * 24 * 60) / windowSize; + Random rnd = new Random(seed); + + List out = new ArrayList<>(windows); + + // Build fixed label list once (b1..b14) + List labels = new ArrayList<>(ThreatApiDistributionUtils.getLABEL_TO_RANGE_MAP().keySet()); + + long ws = startWindowStartEpochMin; + for (int i = 0; i < windows; i++, ws += windowSize) { + Map dist = new HashMap<>(); + + // Generate realistic, skewed counts: more weight to lower buckets, occasional spikes + int base = 5 + rnd.nextInt(20); // small base load + for (int li = 0; li < labels.size(); li++) { + String label = labels.get(li); + int weight = Math.max(1, (labels.size() - li)); // higher for lower buckets + int noise = rnd.nextInt(3 + li); // slightly increasing noise for higher buckets + int spike = (rnd.nextDouble() < 0.02 && li <= 5) ? rnd.nextInt(50) : 0; // rare spikes in lower buckets + int val = Math.max(0, base * weight / 10 + noise + spike); + dist.put(label, val); + } + + out.add(dd(collectionId, url, method, windowSize, (int) ws, dist)); + } + + return out; + } + + @Test + public void realisticRandomTwoDays_generatesAndValidatesStats() { + ApiRateLimitBucketStatisticsModel doc = null; + + long start = 1_000_000; // arbitrary epoch minutes + List batch = generateRandomTwoDays(42, "/users", "GET", start, 12345L); + + // Apply in chunks to simulate multiple calls + int chunk = 100; + for (int i = 0; i < batch.size(); i += chunk) { + int end = Math.min(i + chunk, batch.size()); + doc = ApiRateLimitBucketStatisticsModel.applyUpdates(doc, batch.subList(i, end)); + } + + // Validate capacity and stats equality vs recomputation per bucket + int expectedCapacity = (PercentilesCron.DEFAULT_BASELINE_DAYS * 24 * 60) / 5; + assertNotNull(doc); + assertNotNull(doc.getBuckets()); + assertFalse(doc.getBuckets().isEmpty()); + + for (ApiRateLimitBucketStatisticsModel.Bucket b : doc.getBuckets()) { + // Capacity should match 2 days of 5-min windows + assertEquals(expectedCapacity, b.getUserCounts().size(), "capacity mismatch for " + b.getLabel()); + + // Recompute stats + List vals = new ArrayList<>(); + for (ApiRateLimitBucketStatisticsModel.UserCountData ucd : b.getUserCounts()) vals.add(ucd.getUsers()); + Collections.sort(vals); + + int min = vals.isEmpty() ? 0 : vals.get(0); + int max = vals.isEmpty() ? 0 : vals.get(vals.size()-1); + int p25 = ThreatApiDistributionUtils.percentile(vals, 25); + int p50 = ThreatApiDistributionUtils.percentile(vals, 50); + int p75 = ThreatApiDistributionUtils.percentile(vals, 75); + + assertEquals(min, b.getStats().getMin(), "min mismatch for " + b.getLabel()); + assertEquals(max, b.getStats().getMax(), "max mismatch for " + b.getLabel()); + assertEquals(p25, b.getStats().getP25(), "p25 mismatch for " + b.getLabel()); + assertEquals(p50, b.getStats().getP50(), "p50 mismatch for " + b.getLabel()); + assertEquals(p75, b.getStats().getP75(), "p75 mismatch for " + b.getLabel()); + } + } + @Test + public void insertsAndComputesStats_singleWindow() { + ApiRateLimitBucketStatisticsModel doc = null; + List updates = new ArrayList<>(); + + Map dist = new HashMap<>(); + dist.put("b1", 5); + dist.put("b2", 10); + // other buckets implicitly zero + + updates.add(dd(1, "/a", "GET", 5, 1000, dist)); + + doc = ApiRateLimitBucketStatisticsModel.applyUpdates(doc, updates); + + // Expect one entry in each bucket's userCounts at windowStart=1000 (zeros for others) + for (ApiRateLimitBucketStatisticsModel.Bucket b : doc.getBuckets()) { + assertEquals(1, b.getUserCounts().size()); + assertEquals(1000, b.getUserCounts().get(0).getWindowStart()); + } + + // Stats for b1: only [5] + ApiRateLimitBucketStatisticsModel.Bucket b1 = doc.getBuckets().stream().filter(b -> b.getLabel().equals("b1")).findFirst().get(); + assertEquals(5, b1.getStats().getMin()); + assertEquals(5, b1.getStats().getMax()); + assertEquals(5, b1.getStats().getP25()); + assertEquals(5, b1.getStats().getP50()); + assertEquals(5, b1.getStats().getP75()); + + // Stats for some other bucket (e.g. b3) should be zeros + ApiRateLimitBucketStatisticsModel.Bucket b3 = doc.getBuckets().stream().filter(b -> b.getLabel().equals("b3")).findFirst().get(); + assertEquals(0, b3.getStats().getMin()); + assertEquals(0, b3.getStats().getMax()); + assertEquals(0, b3.getStats().getP25()); + assertEquals(0, b3.getStats().getP50()); + assertEquals(0, b3.getStats().getP75()); + } + + @Test + public void overwriteExistingWindow_updatesValue() { + ApiRateLimitBucketStatisticsModel doc = null; + List updates1 = new ArrayList<>(); + Map dist1 = new HashMap<>(); + dist1.put("b1", 10); + updates1.add(dd(1, "/a", "GET", 5, 1000, dist1)); + doc = ApiRateLimitBucketStatisticsModel.applyUpdates(doc, updates1); + + // overwrite same windowStart with new value + List updates2 = new ArrayList<>(); + Map dist2 = new HashMap<>(); + dist2.put("b1", 20); + updates2.add(dd(1, "/a", "GET", 5, 1000, dist2)); + doc = ApiRateLimitBucketStatisticsModel.applyUpdates(doc, updates2); + + ApiRateLimitBucketStatisticsModel.Bucket b1 = doc.getBuckets().stream().filter(b -> b.getLabel().equals("b1")).findFirst().get(); + assertEquals(1, b1.getUserCounts().size()); + assertEquals(20, b1.getUserCounts().get(0).getUsers()); + assertEquals(20, b1.getStats().getP50()); + } + + @Test + public void evictionAtCapacity_keepsMostRecentWindows() { + ApiRateLimitBucketStatisticsModel doc = null; + + // capacity for 5-min windows is DEFAULT_BASELINE_DAYS * 24 * 60 / 5 + int capacity = (PercentilesCron.DEFAULT_BASELINE_DAYS * 24 * 60) / 5; + + // Insert capacity+2 windows + List batch = new ArrayList<>(); + for (int i = 0; i < capacity + 2; i++) { + Map d = new HashMap<>(); + d.put("b2", i); // increasing values for easy verification + batch.add(dd(1, "/a", "GET", 5, 1000 + i, d)); + } + + doc = ApiRateLimitBucketStatisticsModel.applyUpdates(doc, batch); + + ApiRateLimitBucketStatisticsModel.Bucket b2 = doc.getBuckets().stream().filter(b -> b.getLabel().equals("b2")).findFirst().get(); + assertEquals(capacity, b2.getUserCounts().size()); + // Oldest should be windowStart = 1000 + 2 now (evicted two entries) + assertEquals(1002, b2.getUserCounts().get(0).getWindowStart()); + // Most recent should be 1000 + (capacity + 1) + assertEquals(1000 + capacity + 1, b2.getUserCounts().get(capacity - 1).getWindowStart()); + } + + @Test + public void percentilesComputedOnSortedValues() { + ApiRateLimitBucketStatisticsModel doc = null; + List updates = new ArrayList<>(); + + // Same bucket over multiple windows, unsorted arrival + updates.add(dd(1, "/a", "GET", 5, 1002, Collections.singletonMap("b1", 30))); + updates.add(dd(1, "/a", "GET", 5, 1000, Collections.singletonMap("b1", 10))); + updates.add(dd(1, "/a", "GET", 5, 1001, Collections.singletonMap("b1", 20))); + + doc = ApiRateLimitBucketStatisticsModel.applyUpdates(doc, updates); + + ApiRateLimitBucketStatisticsModel.Bucket b1 = doc.getBuckets().stream().filter(b -> b.getLabel().equals("b1")).findFirst().get(); + // values are [10, 20, 30] -> p25=10, p50=20, p75=30 + assertEquals(10, b1.getStats().getP25()); + assertEquals(20, b1.getStats().getP50()); + assertEquals(30, b1.getStats().getP75()); + } +} + + diff --git a/apps/threat-detection/src/main/java/com/akto/threat/detection/ip_api_counter/DistributionCalculator.java b/apps/threat-detection/src/main/java/com/akto/threat/detection/ip_api_counter/DistributionCalculator.java index 632a24b1b9..d889ac8394 100644 --- a/apps/threat-detection/src/main/java/com/akto/threat/detection/ip_api_counter/DistributionCalculator.java +++ b/apps/threat-detection/src/main/java/com/akto/threat/detection/ip_api_counter/DistributionCalculator.java @@ -2,10 +2,10 @@ import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import com.akto.utils.ThreatApiDistributionUtils; public class DistributionCalculator { /** @@ -28,14 +28,7 @@ public DistributionCalculator() { cmsCounterLayer = CmsCounterLayer.getInstance(); } - private static final List BUCKET_RANGES = Arrays.asList( - new Range(1, 10, "b1"), new Range(11, 50, "b2"), new Range(51, 100, "b3"), - new Range(101, 250, "b4"), new Range(251, 500, "b5"), new Range(501, 1000, "b6"), - new Range(1001, 2500, "b7"), new Range(2501, 5000, "b8"), new Range(5001, 10000, "b9"), - new Range(10001, 20000, "b10"), new Range(20001, 35000, "b11"), new Range(35001, 50000, "b12"), - new Range(50001, 100000, "b13"), new Range(100001, Integer.MAX_VALUE, "b14") - ); - + /** * Get the end of the window for a given current epoch minute and window size. * Example: @@ -80,7 +73,7 @@ public void updateFrequencyBuckets(String apiKey, long currentEpochMin, String i Map bucketMap = apiMap.get(apiKey); // Initialize all bucket labels to 0 if not present - for (Range r : BUCKET_RANGES) { + for (ThreatApiDistributionUtils.Range r : ThreatApiDistributionUtils.getBucketRanges()) { bucketMap.putIfAbsent(r.label, 0); } @@ -129,7 +122,7 @@ public long getSlidingWindowCount(String ipApiCmsKey, long currentEpochMin, int } private String getBucketLabel(long count) { - for (Range r : BUCKET_RANGES) { + for (ThreatApiDistributionUtils.Range r : ThreatApiDistributionUtils.getBucketRanges()) { if (count >= r.min && count < r.max) return r.label; } return "b1"; // fallback @@ -148,15 +141,4 @@ public Map> getBucketDistribution(int windowSize, l String windowKey = windowSize + "|" + windowStart; return frequencyBuckets.get(windowKey); } - - private static class Range { - int min, max; - String label; - - Range(int min, int max, String label) { - this.min = min; - this.max = max; - this.label = label; - } - } -} +} \ No newline at end of file diff --git a/libs/utils/src/main/java/com/akto/utils/ThreatApiDistributionUtils.java b/libs/utils/src/main/java/com/akto/utils/ThreatApiDistributionUtils.java new file mode 100644 index 0000000000..3fde839f99 --- /dev/null +++ b/libs/utils/src/main/java/com/akto/utils/ThreatApiDistributionUtils.java @@ -0,0 +1,57 @@ +package com.akto.utils; + +import java.util.Arrays; +import java.util.List; +import java.util.HashMap; +import java.util.Map; + +import lombok.Getter; + +public class ThreatApiDistributionUtils { + + public static class Range { + public final int min; + public final int max; + public final String label; + + public Range(int min, int max, String label) { + this.min = min; + this.max = max; + this.label = label; + } + } + + private static final List BUCKET_RANGES = Arrays.asList( + new Range(1, 10, "b1"), new Range(11, 50, "b2"), new Range(51, 100, "b3"), + new Range(101, 250, "b4"), new Range(251, 500, "b5"), new Range(501, 1000, "b6"), + new Range(1001, 2500, "b7"), new Range(2501, 5000, "b8"), new Range(5001, 10000, "b9"), + new Range(10001, 20000, "b10"), new Range(20001, 35000, "b11"), new Range(35001, 50000, "b12"), + new Range(50001, 100000, "b13"), new Range(100001, Integer.MAX_VALUE, "b14") + ); + + @Getter + private static final Map LABEL_TO_RANGE_MAP = new HashMap<>(); + static { + for (Range range : BUCKET_RANGES) { + LABEL_TO_RANGE_MAP.put(range.label, range); + } + } + + public static List getBucketRanges() { + return BUCKET_RANGES; + } + + public static int getBucketUpperBound(String bucketLabel){ + return LABEL_TO_RANGE_MAP.get(bucketLabel).max; + } + + public static int getBucketLowerBound(String bucketLabel){ + return LABEL_TO_RANGE_MAP.get(bucketLabel).min; + } + + public static int percentile(List sorted, int p) { + if (sorted.isEmpty()) return 0; + int index = (int) Math.ceil(p / 100.0 * sorted.size()) - 1; + return sorted.get(Math.max(0, Math.min(index, sorted.size() - 1))); + } +} \ No newline at end of file