From 40de87e93f406ab2b159b9d49f761eb700bd14ab Mon Sep 17 00:00:00 2001
From: Stephen Brennan <stephen.s.brennan@oracle.com>
Date: Thu, 26 Oct 2023 10:13:49 -0700
Subject: [PATCH] helpers: Infer bitmap size when possible for bitops

The Linux API for for_each_set_bit() involves a size, but when dealing
with complete array types (a relatively common case), we can default to
the size of the array in bits. Add this ability to the helper.

Signed-off-by: Stephen Brennan <stephen.s.brennan@oracle.com>
---
 drgn/helpers/linux/bitops.py              | 34 +++++++++++++++++------
 tests/linux_kernel/helpers/test_bitops.py |  2 ++
 2 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/drgn/helpers/linux/bitops.py b/drgn/helpers/linux/bitops.py
index 4d0ae2f97..42843864d 100644
--- a/drgn/helpers/linux/bitops.py
+++ b/drgn/helpers/linux/bitops.py
@@ -14,9 +14,9 @@
 ``unsigned long``.
 """
 
-from typing import Iterator
+from typing import Iterator, Optional
 
-from drgn import IntegerLike, Object, sizeof
+from drgn import IntegerLike, Object, TypeKind, sizeof
 
 __all__ = (
     "for_each_clear_bit",
@@ -25,14 +25,23 @@
 )
 
 
-def for_each_set_bit(bitmap: Object, size: IntegerLike) -> Iterator[int]:
+def for_each_set_bit(
+    bitmap: Object, size: Optional[IntegerLike] = None
+) -> Iterator[int]:
     """
     Iterate over all set (one) bits in a bitmap.
 
     :param bitmap: pointer to, or array of, ``unsigned long``
-    :param size: Size of *bitmap* in bits.
+    :param size: Size of *bitmap* in bits. When *bitmap* is a sized array type
+        (EG: ``unsigned long[2]``), this value will default to the size of the
+        array in bits.
     """
-    size = int(size)
+    if size is not None:
+        size = int(size)
+    elif bitmap.type_.kind == TypeKind.ARRAY and bitmap.type_.length is not None:
+        size = 8 * sizeof(bitmap)
+    else:
+        raise ValueError("bitmap is not a complete array type, and size is not given")
     word_bits = 8 * sizeof(bitmap.type_.type)
     for i in range((size + word_bits - 1) // word_bits):
         word = bitmap[i].value_()
@@ -41,14 +50,23 @@ def for_each_set_bit(bitmap: Object, size: IntegerLike) -> Iterator[int]:
                 yield (word_bits * i) + j
 
 
-def for_each_clear_bit(bitmap: Object, size: IntegerLike) -> Iterator[int]:
+def for_each_clear_bit(
+    bitmap: Object, size: Optional[IntegerLike] = None
+) -> Iterator[int]:
     """
     Iterate over all clear (zero) bits in a bitmap.
 
     :param bitmap: pointer to, or array of, ``unsigned long``
-    :param size: Size of *bitmap* in bits.
+    :param size: Size of *bitmap* in bits. When *bitmap* is a sized array type
+        (EG: ``unsigned long[2]``), this value will default to the size of the
+        array in bits.
     """
-    size = int(size)
+    if size is not None:
+        size = int(size)
+    elif bitmap.type_.kind == TypeKind.ARRAY and bitmap.type_.length is not None:
+        size = 8 * sizeof(bitmap)
+    else:
+        raise ValueError("bitmap is not a complete array type, and size is not given")
     word_bits = 8 * sizeof(bitmap.type_.type)
     for i in range((size + word_bits - 1) // word_bits):
         word = bitmap[i].value_()
diff --git a/tests/linux_kernel/helpers/test_bitops.py b/tests/linux_kernel/helpers/test_bitops.py
index 24d859599..f566aabb4 100644
--- a/tests/linux_kernel/helpers/test_bitops.py
+++ b/tests/linux_kernel/helpers/test_bitops.py
@@ -40,6 +40,7 @@ def test_for_each_set_bit(self):
         for type_ in self.valid_integer_types():
             bitmap = Object.from_bytes_(self.prog, type_, self.BITMAP)
             self.assertEqual(list(for_each_set_bit(bitmap, 128)), self.SET_BITS)
+            self.assertEqual(list(for_each_set_bit(bitmap)), self.SET_BITS)
             self.assertEqual(
                 list(for_each_set_bit(bitmap, 101)),
                 [bit for bit in self.SET_BITS if bit < 101],
@@ -49,6 +50,7 @@ def test_for_each_clear_bit(self):
         for type_ in self.valid_integer_types():
             bitmap = Object.from_bytes_(self.prog, type_, self.BITMAP)
             self.assertEqual(list(for_each_clear_bit(bitmap, 128)), self.CLEAR_BITS)
+            self.assertEqual(list(for_each_clear_bit(bitmap)), self.CLEAR_BITS)
             self.assertEqual(
                 list(for_each_clear_bit(bitmap, 100)),
                 [bit for bit in self.CLEAR_BITS if bit < 100],