diff --git a/iterator.go b/iterator.go index 2494256..db4617e 100644 --- a/iterator.go +++ b/iterator.go @@ -76,7 +76,7 @@ func (it *mapIterator[In, Out]) Next() (out Out, revision Revision, ok bool) { return } -// Filter skips objects for which the supplied predicate returns true +// Filter includes objects for which the supplied predicate returns true func Filter[Obj any, It Iterator[Obj]](iter It, pred func(Obj) bool) Iterator[Obj] { return &filterIterator[Obj]{ iter: iter, diff --git a/iterator_test.go b/iterator_test.go new file mode 100644 index 0000000..d21fc02 --- /dev/null +++ b/iterator_test.go @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package statedb + +import ( + "testing" + + "github.com/cilium/statedb/index" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFilter(t *testing.T) { + type testObject struct { + ID int + } + + db, _ := NewDB(nil, NewExpVarMetrics(false)) + idIndex := Index[*testObject, int]{ + Name: "id", + FromObject: func(t *testObject) index.KeySet { + return index.NewKeySet(index.Int(t.ID)) + }, + FromKey: index.Int, + Unique: true, + } + table, _ := NewTable("test", idIndex) + require.NoError(t, db.RegisterTable(table)) + + txn := db.WriteTxn(table) + table.Insert(txn, &testObject{ID: 1}) + table.Insert(txn, &testObject{ID: 2}) + table.Insert(txn, &testObject{ID: 3}) + table.Insert(txn, &testObject{ID: 4}) + table.Insert(txn, &testObject{ID: 5}) + txn.Commit() + + iter, _ := table.All(db.ReadTxn()) + filtered := CollectSet( + Map( + Filter( + iter, + func(obj *testObject) bool { + return obj.ID%2 == 0 + }, + ), + func(obj *testObject) int { + return obj.ID + }, + ), + ) + assert.Len(t, filtered, 2) + assert.True(t, filtered.Has(2)) + assert.True(t, filtered.Has(4)) +}