diff --git a/bitmapcontainer.go b/bitmapcontainer.go index 822add1f..1b705501 100644 --- a/bitmapcontainer.go +++ b/bitmapcontainer.go @@ -226,6 +226,7 @@ func (bcmi *bitmapContainerManyIterator) nextMany(hs uint32, buf []uint32) int { return n } +// nextMany64 returns the number of values added to the buffer func (bcmi *bitmapContainerManyIterator) nextMany64(hs uint64, buf []uint64) int { n := 0 base := bcmi.base diff --git a/bitmapcontainer_test.go b/bitmapcontainer_test.go index bbb6dd1f..d1feb0fc 100644 --- a/bitmapcontainer_test.go +++ b/bitmapcontainer_test.go @@ -454,3 +454,34 @@ func TestBitMapContainerValidate(t *testing.T) { assert.Error(t, bc.validate()) } + +func TestBitmapcontainerNextHasMany(t *testing.T) { + t.Run("Empty Bitmap", func(t *testing.T) { + bc := newBitmapContainer() + iterator := newBitmapContainerManyIterator(bc) + high := uint64(1024) + buf := []uint64{} + result := iterator.nextMany64(high, buf) + assert.Equal(t, 0, result) + }) + + t.Run("512 in iterator and buf size 512", func(t *testing.T) { + bc := newBitmapContainer() + bc.iaddRange(0, 512) + iterator := newBitmapContainerManyIterator(bc) + high := uint64(1024) + buf := make([]uint64, 512) + result := iterator.nextMany64(high, buf) + assert.Equal(t, 512, result) + }) + + t.Run("512 in iterator and buf size 256", func(t *testing.T) { + bc := newBitmapContainer() + bc.iaddRange(0, 512) + iterator := newBitmapContainerManyIterator(bc) + high := uint64(1024) + buf := make([]uint64, 256) + result := iterator.nextMany64(high, buf) + assert.Equal(t, 256, result) + }) +} diff --git a/manyiterator_test.go b/manyiterator_test.go new file mode 100644 index 00000000..d4fee753 --- /dev/null +++ b/manyiterator_test.go @@ -0,0 +1,42 @@ +package roaring + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestManyIterator(t *testing.T) { + type searchTest struct { + name string + iterator shortIterator + high uint64 + buf []uint64 + expectedValue int + } + + tests := []searchTest{ + { + "no values", + shortIterator{}, + uint64(1024), + []uint64{}, + 0, + }, + { + "1 value ", + shortIterator{[]uint16{uint16(1)}, 0}, + uint64(1024), + make([]uint64, 1), + 1, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + iterator := testCase.iterator + result := iterator.nextMany64(testCase.high, testCase.buf) + assert.Equal(t, testCase.expectedValue, result) + }) + } +} diff --git a/setutil.go b/setutil.go index 29ba4c3a..8def774f 100644 --- a/setutil.go +++ b/setutil.go @@ -1,26 +1,12 @@ package roaring -func equal(a, b []uint16) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} - func difference(set1 []uint16, set2 []uint16, buffer []uint16) int { - if 0 == len(set2) { + if len(set2) == 0 { buffer = buffer[:len(set1)] - for k := 0; k < len(set1); k++ { - buffer[k] = set1[k] - } + copy(buffer, set1) return len(set1) } - if 0 == len(set1) { + if len(set1) == 0 { return 0 } pos := 0 @@ -134,6 +120,7 @@ func exclusiveUnion2by2(set1 []uint16, set2 []uint16, buffer []uint16) int { return pos } +// union2by2Cardinality computes the cardinality of the union func union2by2Cardinality(set1 []uint16, set2 []uint16) int { pos := 0 k1 := 0 @@ -196,6 +183,7 @@ func intersection2by2( } } +// intersection2by2Cardinality computes the cardinality of the intersection func intersection2by2Cardinality( set1 []uint16, set2 []uint16, @@ -209,41 +197,42 @@ func intersection2by2Cardinality( } } +// intersects2by2 computes whether the two sets intersect func intersects2by2( set1 []uint16, set2 []uint16, ) bool { // could be optimized if one set is much larger than the other one - if (0 == len(set1)) || (0 == len(set2)) { + if (len(set1) == 0) || (len(set2) == 0) { return false } - k1 := 0 - k2 := 0 - s1 := set1[k1] - s2 := set2[k2] + index1 := 0 + index2 := 0 + value1 := set1[index1] + value2 := set2[index2] mainwhile: for { - if s2 < s1 { + if value2 < value1 { for { - k2++ - if k2 == len(set2) { + index2++ + if index2 == len(set2) { break mainwhile } - s2 = set2[k2] - if s2 >= s1 { + value2 = set2[index2] + if value2 >= value1 { break } } } - if s1 < s2 { + if value1 < value2 { for { - k1++ - if k1 == len(set1) { + index1++ + if index1 == len(set1) { break mainwhile } - s1 = set1[k1] - if s1 >= s2 { + value1 = set1[index1] + if value1 >= value2 { break } } @@ -260,7 +249,7 @@ func localintersect2by2( set2 []uint16, buffer []uint16, ) int { - if (0 == len(set1)) || (0 == len(set2)) { + if (len(set1) == 0) || (len(set2) == 0) { return 0 } k1 := 0 @@ -313,56 +302,57 @@ mainwhile: return pos } +// / localintersect2by2Cardinality computes the cardinality of the intersection func localintersect2by2Cardinality( set1 []uint16, set2 []uint16, ) int { - if (0 == len(set1)) || (0 == len(set2)) { + if (len(set1) == 0) || (len(set2) == 0) { return 0 } - k1 := 0 - k2 := 0 + index1 := 0 + index2 := 0 pos := 0 - s1 := set1[k1] - s2 := set2[k2] + value1 := set1[index1] + value2 := set2[index2] mainwhile: for { - if s2 < s1 { + if value2 < value1 { for { - k2++ - if k2 == len(set2) { + index2++ + if index2 == len(set2) { break mainwhile } - s2 = set2[k2] - if s2 >= s1 { + value2 = set2[index2] + if value2 >= value1 { break } } } - if s1 < s2 { + if value1 < value2 { for { - k1++ - if k1 == len(set1) { + index1++ + if index1 == len(set1) { break mainwhile } - s1 = set1[k1] - if s1 >= s2 { + value1 = set1[index1] + if value1 >= value2 { break } } } else { // (set2[k2] == set1[k1]) pos++ - k1++ - if k1 == len(set1) { + index1++ + if index1 == len(set1) { break } - s1 = set1[k1] - k2++ - if k2 == len(set2) { + value1 = set1[index1] + index2++ + if index2 == len(set2) { break } - s2 = set2[k2] + value2 = set2[index2] } } return pos diff --git a/setutil_test.go b/setutil_test.go index 1d60e094..41275199 100644 --- a/setutil_test.go +++ b/setutil_test.go @@ -260,3 +260,133 @@ func TestBinarySearchPastWithBounds(t *testing.T) { }) } } + +func makeLargeSet(start int) []uint16 { + data := make([]uint16, 0, 256) + for i := 0; i < 256; i++ { + data = append(data, uint16(start+i)) + } + return data +} + +func TestSetUtilIntersection2By2Cardinality(t *testing.T) { + type searchTest struct { + name string + data1 []uint16 + data2 []uint16 + expectedValue int + } + + tests := []searchTest{ + { + "cardinality 1 intersection", + []uint16{0, 1, 2, 3, 4, 9}, + []uint16{8, 9, 10, 11, 12}, + 1, + }, + { + "empty set", + []uint16{}, + []uint16{8, 9, 10, 11, 12}, + 0, + }, + { + "cardinality 0", + []uint16{1}, + []uint16{8, 9, 10, 11, 12}, + 0, + }, + { + "large first set - cardinality 0", + makeLargeSet(1024), + []uint16{8, 9}, + 0, + }, + { + "large second set - cardinality 0", + []uint16{8, 9}, + makeLargeSet(1024), + 0, + }, + { + "large first set - cardinality 2", + makeLargeSet(0), + []uint16{8, 9}, + 2, + }, + { + "large second set - cardinality 2", + []uint16{8, 9}, + makeLargeSet(0), + 2, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + result := intersection2by2Cardinality(testCase.data1, testCase.data2) + assert.Equal(t, result, testCase.expectedValue) + }) + } +} + +func TestSetUtilUnionBy2Cardinality(t *testing.T) { + type searchTest struct { + name string + data1 []uint16 + data2 []uint16 + expectedValue int + } + + tests := []searchTest{ + { + "cardinality 1 intersection", + []uint16{0, 1, 2, 3, 4, 9}, + []uint16{8, 9, 10, 11, 12}, + 10, + }, + { + "empty set ", + []uint16{}, + []uint16{8, 9, 10, 11, 12}, + 5, + }, + { + "cardinality 6", + []uint16{1}, + []uint16{8, 9, 10, 11, 12}, + 6, + }, + { + "large first set - cardinality 258", + makeLargeSet(1024), + []uint16{8, 9}, + 258, + }, + { + "large second set - cardinality 0", + []uint16{8, 9}, + makeLargeSet(1024), + 258, + }, + { + "large first set - cardinality 2", + makeLargeSet(0), + []uint16{8, 9}, + 256, + }, + { + "large second set - cardinality 2", + []uint16{8, 9}, + makeLargeSet(0), + 256, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + result := union2by2Cardinality(testCase.data1, testCase.data2) + assert.Equal(t, result, testCase.expectedValue) + }) + } +}