Skip to content

Commit

Permalink
detect/threshold: expand cache support for rule tracking
Browse files Browse the repository at this point in the history
Use the same hash key as for the regular threshold storage,
so include gid, rev, tentant id.
  • Loading branch information
victorjulien committed Jun 28, 2024
1 parent 1e9fdc4 commit 7bcf364
Showing 1 changed file with 36 additions and 21 deletions.
57 changes: 36 additions & 21 deletions src/detect-engine-threshold.c
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,17 @@ uint32_t ThresholdsExpire(const SCTime_t ts)

#include "util-hash.h"

#define TC_ADDRESS 0
#define TC_SID 1
#define TC_GID 2
#define TC_REV 3
#define TC_TENANT 4

typedef struct ThresholdCacheItem {
int8_t track; // by_src/by_dst
int8_t ipv;
int8_t retval;
uint32_t addr;
uint32_t sid;
uint32_t key[5];
SCTime_t expires_at;
RB_ENTRY(ThresholdCacheItem) rb;
} ThresholdCacheItem;
Expand Down Expand Up @@ -297,8 +302,8 @@ static void ThresholdCacheExpire(SCTime_t now)

static uint32_t ThresholdCacheHashFunc(HashTable *ht, void *data, uint16_t datalen)
{
ThresholdCacheItem *tci = data;
int hash = tci->ipv * tci->track + tci->addr + tci->sid;
ThresholdCacheItem *e = data;
uint32_t hash = hashword(e->key, sizeof(e->key) / sizeof(uint32_t), 0) * (e->ipv + e->track);
hash = hash % ht->array_size;
return hash;
}
Expand All @@ -308,8 +313,8 @@ static char ThresholdCacheHashCompareFunc(
{
ThresholdCacheItem *tci1 = data1;
ThresholdCacheItem *tci2 = data2;
return tci1->ipv == tci2->ipv && tci1->track == tci2->track && tci1->addr == tci2->addr &&
tci1->sid == tci2->sid;
return tci1->ipv == tci2->ipv && tci1->track == tci2->track &&
memcmp(tci1->key, tci2->key, sizeof(tci1->key)) == 0;
}

static void ThresholdCacheHashFreeFunc(void *data)
Expand All @@ -319,7 +324,7 @@ static void ThresholdCacheHashFreeFunc(void *data)

/// \brief Thread local cache
static int SetupCache(const Packet *p, const int8_t track, const int8_t retval, const uint32_t sid,
SCTime_t expires)
const uint32_t gid, const uint32_t rev, SCTime_t expires)
{
if (!threshold_cache_ht) {
threshold_cache_ht = HashTableInit(256, ThresholdCacheHashFunc,
Expand All @@ -339,8 +344,11 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval,
.track = track,
.ipv = 4,
.retval = retval,
.addr = addr,
.sid = sid,
.key[TC_ADDRESS] = addr,
.key[TC_SID] = sid,
.key[TC_GID] = gid,
.key[TC_REV] = rev,
.key[TC_TENANT] = p->tenant_id,
.expires_at = expires,
};
ThresholdCacheItem *found = HashTableLookup(threshold_cache_ht, &lookup, 0);
Expand All @@ -350,8 +358,11 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval,
n->track = track;
n->ipv = 4;
n->retval = retval;
n->addr = addr;
n->sid = sid;
n->key[TC_ADDRESS] = addr;
n->key[TC_SID] = sid;
n->key[TC_GID] = gid;
n->key[TC_REV] = rev;
n->key[TC_TENANT] = p->tenant_id;
n->expires_at = expires;

if (HashTableAdd(threshold_cache_ht, n, 0) == 0) {
Expand Down Expand Up @@ -381,7 +392,8 @@ static int SetupCache(const Packet *p, const int8_t track, const int8_t retval,
* \retval -4 error - unsupported tracker
* \retval ret cached return code
*/
static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid)
static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid, const uint32_t gid,
const uint32_t rev)
{
cache_lookup_cnt++;

Expand All @@ -407,8 +419,11 @@ static int CheckCache(const Packet *p, const int8_t track, const uint32_t sid)
ThresholdCacheItem lookup = {
.track = track,
.ipv = 4,
.addr = addr,
.sid = sid,
.key[TC_ADDRESS] = addr,
.key[TC_SID] = sid,
.key[TC_GID] = gid,
.key[TC_REV] = rev,
.key[TC_TENANT] = p->tenant_id,
};
ThresholdCacheItem *found = HashTableLookup(threshold_cache_ht, &lookup, 0);
if (found) {
Expand Down Expand Up @@ -652,7 +667,7 @@ static int ThresholdSetup(const DetectThresholdData *td, ThresholdEntry *te,

static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *te,
const Packet *p, // ts only? - cache too
const uint32_t sid, PacketAlert *pa)
const uint32_t sid, const uint32_t gid, const uint32_t rev, PacketAlert *pa)
{
int ret = 0;
const SCTime_t packet_time = p->ts;
Expand All @@ -670,7 +685,7 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t
ret = 2;

if (PacketIsIPv4(p)) {
SetupCache(p, td->track, (int8_t)ret, sid, entry);
SetupCache(p, td->track, (int8_t)ret, sid, gid, rev, entry);
}
}
} else {
Expand Down Expand Up @@ -705,7 +720,7 @@ static int ThresholdCheckUpdate(const DetectThresholdData *td, ThresholdEntry *t
ret = 2;

if (PacketIsIPv4(p)) {
SetupCache(p, td->track, (int8_t)ret, sid, entry);
SetupCache(p, td->track, (int8_t)ret, sid, gid, rev, entry);
}
}
} else {
Expand Down Expand Up @@ -819,7 +834,7 @@ static int ThresholdGetFromHash(struct Thresholds *tctx, const Packet *p, const
r = ThresholdSetup(td, te, p->ts, s->id, s->gid, s->rev, p->tenant_id);
} else {
// existing, check/update
r = ThresholdCheckUpdate(td, te, p, s->id, pa);
r = ThresholdCheckUpdate(td, te, p, s->id, s->gid, s->rev, pa);
}

(void)THashDecrUsecnt(res.data);
Expand Down Expand Up @@ -855,7 +870,7 @@ static int ThresholdHandlePacketFlow(Flow *f, Packet *p, const DetectThresholdDa
}
} else {
// existing, check/update
ret = ThresholdCheckUpdate(td, found, p, sid, pa);
ret = ThresholdCheckUpdate(td, found, p, sid, gid, rev, pa);
}
return ret;
}
Expand Down Expand Up @@ -886,7 +901,7 @@ int PacketAlertThreshold(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx
ret = ThresholdHandlePacketSuppress(p,td,s->id,s->gid);
} else if (td->track == TRACK_SRC) {
if (PacketIsIPv4(p) && (td->type == TYPE_LIMIT || td->type == TYPE_BOTH)) {
int cache_ret = CheckCache(p, td->track, s->id);
int cache_ret = CheckCache(p, td->track, s->id, s->gid, s->rev);
if (cache_ret >= 0) {
SCReturnInt(cache_ret);
}
Expand All @@ -895,7 +910,7 @@ int PacketAlertThreshold(DetectEngineCtx *de_ctx, DetectEngineThreadCtx *det_ctx
ret = ThresholdGetFromHash(&ctx, p, s, td, pa);
} else if (td->track == TRACK_DST) {
if (PacketIsIPv4(p) && (td->type == TYPE_LIMIT || td->type == TYPE_BOTH)) {
int cache_ret = CheckCache(p, td->track, s->id);
int cache_ret = CheckCache(p, td->track, s->id, s->gid, s->rev);
if (cache_ret >= 0) {
SCReturnInt(cache_ret);
}
Expand Down

0 comments on commit 7bcf364

Please sign in to comment.