From 32fd64e9c3b2101b3d4a6a40fa423e9e8e0629ba Mon Sep 17 00:00:00 2001 From: Xiaoyan Wang Date: Sat, 29 Nov 2025 22:34:36 -0500 Subject: [PATCH 1/4] Introduce `Image` and `ImageView` class for allocate & manipulate image data on device/host --- CMakeLists.txt | 1 + genmetaballs/src/cuda/core/image.cuh | 65 +++++++++++++++++++++++++--- tests/cpp_tests/test_image.cu | 44 +++++++++++++++++++ 3 files changed, 105 insertions(+), 5 deletions(-) create mode 100644 tests/cpp_tests/test_image.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 989b9ce..dd31fad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,7 @@ add_library(genmetaballs_core genmetaballs/src/cuda/core/geometry.cuh genmetaballs/src/cuda/core/geometry.cu genmetaballs/src/cuda/core/confidence.cuh + genmetaballs/src/cuda/core/image.cuh ) # Set include directories for the core library diff --git a/genmetaballs/src/cuda/core/image.cuh b/genmetaballs/src/cuda/core/image.cuh index d8ac7f5..c1beb40 100644 --- a/genmetaballs/src/cuda/core/image.cuh +++ b/genmetaballs/src/cuda/core/image.cuh @@ -1,12 +1,67 @@ #pragma once #include +#include +#include +#include -#include "geometry.cuh" #include "utils.cuh" -template -struct Image { - Array2D confidence; - Array2D depth; +/* Non-owning view into an image */ +template +class ImageView { +public: + Array2D confidence; + Array2D depth; + + CUDA_CALLABLE auto num_rows() const noexcept { + return confidence.num_rows(); + } + CUDA_CALLABLE auto num_cols() const noexcept { + return confidence.num_cols(); + } +}; + +/* The image buffer which handles the allocation & deallocation of memory with RAII. + * While the underlying data may be stored in either host or device memory, this + * class is always managed from the host side. + */ +template +class Image { +private: + // Host memory -> thrust::host_vector + // Device memory -> thrust::device_vector + template + using vector_t = std::conditional_t, + thrust::device_vector>; + + // RAII storage for the image data + vector_t confidence_data_; + vector_t depth_data_; + + uint32_t height_; + uint32_t width_; + +public: + /* Allocate the memory for a new image & default initialize with zeros + * The "height" of the image correponds to "rows" in the 2D array, whereas + * the "width" of the image corresponds to "columns". + */ + __host__ Image(uint32_t height, uint32_t width) + : height_(height), width_(width), confidence_data_(height * width), + depth_data_(height * width) {} + + /* Create a view of this image which points to the internal data */ + CUDA_CALLABLE ImageView as_view() { + return ImageView{ + Array2D(confidence_data_.data(), height_, width_), + Array2D(depth_data_.data(), height_, width_)}; + } + + CUDA_CALLABLE auto num_rows() const noexcept { + return height_; + } + CUDA_CALLABLE auto num_cols() const noexcept { + return width_; + } }; diff --git a/tests/cpp_tests/test_image.cu b/tests/cpp_tests/test_image.cu new file mode 100644 index 0000000..0ccba16 --- /dev/null +++ b/tests/cpp_tests/test_image.cu @@ -0,0 +1,44 @@ +#include +#include +#include + +#include "core/image.cuh" + +namespace test_image_gpu { + +/* A simple kernel that changes the image data based on pixel coordinate + * Note that while we allocate the image with "Image" type, we pass it to the kernel as "ImageView" + * via as_view() method. + */ +__global__ void manipulate_image_kernel(ImageView img) { + uint32_t row = threadIdx.y; + uint32_t col = threadIdx.x; + + if (row < img.num_rows() && col < img.num_cols()) { + img.confidence[row][col] = static_cast(row); + img.depth[row][col] = static_cast(col); + } +} + +} // namespace test_image_gpu + +TEST(TestImage, ImageCreationHost) { + constexpr uint32_t height = 128; + constexpr uint32_t width = 256; + + // Create an image in host memory + Image img_buffer(height, width); + auto img = img_buffer.as_view(); + + // Check dimensions + EXPECT_EQ(img.num_rows(), height); + EXPECT_EQ(img.num_cols(), width); + + // Check that confidence and depth are initialized to zero + for (uint32_t r = 0; r < height; ++r) { + for (uint32_t c = 0; c < width; ++c) { + EXPECT_FLOAT_EQ(img.confidence[r][c], 0.0f); + EXPECT_FLOAT_EQ(img.depth[r][c], 0.0f); + } + } +} From 5d4ed486a3cdd71394984ba8a6260b635bf39b8e Mon Sep 17 00:00:00 2001 From: Xiaoyan Wang Date: Sun, 30 Nov 2025 01:18:52 -0500 Subject: [PATCH 2/4] Test image modification on GPU & add ability to move `Image` between devices --- genmetaballs/src/cuda/core/image.cuh | 29 ++++++++++++++++++---------- tests/cpp_tests/test_image.cu | 29 +++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/genmetaballs/src/cuda/core/image.cuh b/genmetaballs/src/cuda/core/image.cuh index c1beb40..2e7ed52 100644 --- a/genmetaballs/src/cuda/core/image.cuh +++ b/genmetaballs/src/cuda/core/image.cuh @@ -14,10 +14,10 @@ public: Array2D confidence; Array2D depth; - CUDA_CALLABLE auto num_rows() const noexcept { + CUDA_CALLABLE constexpr auto num_rows() const noexcept { return confidence.num_rows(); } - CUDA_CALLABLE auto num_cols() const noexcept { + CUDA_CALLABLE constexpr auto num_cols() const noexcept { return confidence.num_cols(); } }; @@ -39,8 +39,12 @@ private: vector_t confidence_data_; vector_t depth_data_; - uint32_t height_; - uint32_t width_; + const uint32_t height_; + const uint32_t width_; + + // Make all Image instantiations friends so they can access each other's private members + template + friend class Image; public: /* Allocate the memory for a new image & default initialize with zeros @@ -51,17 +55,22 @@ public: : height_(height), width_(width), confidence_data_(height * width), depth_data_(height * width) {} + /* Copy constructor from a Image which may reside in a different memory location */ + template + __host__ Image(const Image& other) + : height_(other.num_rows()), width_(other.num_cols()), + confidence_data_(other.confidence_data_), depth_data_(other.depth_data_) {} + /* Create a view of this image which points to the internal data */ - CUDA_CALLABLE ImageView as_view() { - return ImageView{ - Array2D(confidence_data_.data(), height_, width_), - Array2D(depth_data_.data(), height_, width_)}; + CUDA_CALLABLE auto as_view() { + return ImageView{{confidence_data_.data(), height_, width_}, + {depth_data_.data(), height_, width_}}; } - CUDA_CALLABLE auto num_rows() const noexcept { + CUDA_CALLABLE constexpr auto num_rows() const noexcept { return height_; } - CUDA_CALLABLE auto num_cols() const noexcept { + CUDA_CALLABLE constexpr auto num_cols() const noexcept { return width_; } }; diff --git a/tests/cpp_tests/test_image.cu b/tests/cpp_tests/test_image.cu index 0ccba16..4252f7f 100644 --- a/tests/cpp_tests/test_image.cu +++ b/tests/cpp_tests/test_image.cu @@ -1,6 +1,5 @@ #include #include -#include #include "core/image.cuh" @@ -42,3 +41,31 @@ TEST(TestImage, ImageCreationHost) { } } } + +TEST(TestImage, ImageManipulationOnDevice) { + constexpr uint32_t height = 16; + constexpr uint32_t width = 16; + + // Create an image in device memory + Image img_device(height, width); + + // Launch kernel to manipulate image data + dim3 threadsPerBlock(width, height); + test_image_gpu::manipulate_image_kernel<<<1, threadsPerBlock>>>(img_device.as_view()); + + // Copy image back to host for verification + Image img_host = img_device; + cudaDeviceSynchronize(); + auto img = img_host.as_view(); + + EXPECT_EQ(img.num_rows(), height); + EXPECT_EQ(img.num_cols(), width); + + // Verify the manipulated data + for (uint32_t r = 0; r < height; ++r) { + for (uint32_t c = 0; c < width; ++c) { + EXPECT_FLOAT_EQ(img.confidence[r][c], static_cast(r)); + EXPECT_FLOAT_EQ(img.depth[r][c], static_cast(c)); + } + } +} From 1e94af58bf8e0fbc121ee503f37c2323df9cc4a8 Mon Sep 17 00:00:00 2001 From: Xiaoyan Wang Date: Sun, 30 Nov 2025 02:07:26 -0500 Subject: [PATCH 3/4] Python binding & smoke test for image classes --- genmetaballs/src/cuda/bindings.cu | 31 ++++++++++++++++ genmetaballs/src/cuda/core/camera.cuh | 2 +- .../src/genmetaballs/core/__init__.py | 36 +++++++++++++++++++ tests/python_tests/test_image.py | 19 ++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 tests/python_tests/test_image.py diff --git a/genmetaballs/src/cuda/bindings.cu b/genmetaballs/src/cuda/bindings.cu index d51c30c..8776e13 100644 --- a/genmetaballs/src/cuda/bindings.cu +++ b/genmetaballs/src/cuda/bindings.cu @@ -8,6 +8,7 @@ #include "core/camera.cuh" #include "core/confidence.cuh" #include "core/geometry.cuh" +#include "core/image.cuh" #include "core/utils.cuh" namespace nb = nanobind; @@ -82,6 +83,36 @@ NB_MODULE(_genmetaballs_bindings, m) { "Get the direction of the ray going through pixel (px, py) in camera frame", nb::arg("px"), nb::arg("py")); + /* + * Image module bindings + * Note that only Host (CPU) version is exposed for simplicity, as the image data is usually + * needed for visualization only. + */ + nb::module_ image = m.def_submodule("image", "Image data structure for GenMetaballs"); + nb::class_>(image, "CPUImageView") + .def(nb::init&, + const Array2D&>(), + nb::arg("confidence"), nb::arg("depth")) + .def_prop_ro("confidence", + [](const ImageView& view) { return view.confidence; }) + .def_prop_ro("depth", + [](const ImageView& view) { return view.depth; }) + .def_prop_ro("num_rows", &ImageView::num_rows) + .def_prop_ro("num_cols", &ImageView::num_cols) + .def("__repr__", [](const ImageView& view) { + return nb::str("CPUImageView(height={}, width={})") + .format(view.num_rows(), view.num_cols()); + }); + nb::class_>(image, "CPUImage") + .def(nb::init(), nb::arg("height"), nb::arg("width")) + .def_prop_ro("num_rows", &Image::num_rows) + .def_prop_ro("num_cols", &Image::num_cols) + .def("as_view", &Image::as_view, + "Get a view of the image data as ImageView") + .def("__repr__", [](const Image& img) { + return nb::str("CPUImage(height={}, width={})").format(img.num_rows(), img.num_cols()); + }); + /* * Confidence module bindings */ diff --git a/genmetaballs/src/cuda/core/camera.cuh b/genmetaballs/src/cuda/core/camera.cuh index d43c490..53d5a31 100644 --- a/genmetaballs/src/cuda/core/camera.cuh +++ b/genmetaballs/src/cuda/core/camera.cuh @@ -21,7 +21,7 @@ struct Intrinsics { // Returns a 2D array of ray directions in camera frame in the specified pixel range // and store them in the provided buffer. By default, the full image is used template - CUDA_CALLABLE Array2D& get_ray_directions(Array2D buffer, + CUDA_CALLABLE Array2D& get_ray_directions(Array2D& buffer, uint32_t px_start = 0, uint32_t px_end = UINT32_MAX, uint32_t py_start = 0, diff --git a/genmetaballs/src/genmetaballs/core/__init__.py b/genmetaballs/src/genmetaballs/core/__init__.py index 17c4ba5..91ef46a 100644 --- a/genmetaballs/src/genmetaballs/core/__init__.py +++ b/genmetaballs/src/genmetaballs/core/__init__.py @@ -1,3 +1,5 @@ +import jax.numpy as jnp + from genmetaballs._genmetaballs_bindings import geometry from genmetaballs._genmetaballs_bindings.blender import ( FourParameterBlender, @@ -8,6 +10,7 @@ TwoParameterConfidence, ZeroParameterConfidence, ) +from genmetaballs._genmetaballs_bindings.image import CPUImage, CPUImageView from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid @@ -26,6 +29,38 @@ def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D: raise ValueError(f"Unsupported device type: {device}") +class Image: + def __init__(self, height: int, width: int) -> None: + """Create an Image on CPU + + Unlike the C++ version, this Image class keep the buffer internally. + This is because Python does reference counting and manage the memory + automatically for us. + + Args: + height: Number of rows in the image. + width: Number of columns in the image. + """ + self._image = CPUImage(height, width) + # keep a view for easy access + self._view: CPUImageView = self._image.as_view() + + @property + def confidence(self) -> jnp.ndarray: + """Get the confidence array.""" + return self._view.confidence.as_jax() + + @property + def depth(self) -> jnp.ndarray: + """Get the depth array.""" + return self._view.depth.as_jax() + + @property + def shape(self) -> tuple[int, int]: + """Get the shape of the image as (height, width).""" + return (self._image.num_rows, self._image.num_cols) + + __all__ = [ "array2d_float", "ZeroParameterConfidence", @@ -36,4 +71,5 @@ def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D: "sigmoid", "FourParameterBlender", "ThreeParameterBlender", + "Image", ] diff --git a/tests/python_tests/test_image.py b/tests/python_tests/test_image.py new file mode 100644 index 0000000..4f90204 --- /dev/null +++ b/tests/python_tests/test_image.py @@ -0,0 +1,19 @@ +import jax.numpy as jnp + +from genmetaballs.core import Image + + +def test_image_creation(): + height, width = 480, 640 + image = Image(height, width) + assert image.shape == (height, width) + assert image.confidence.shape == (height, width) + assert image.depth.shape == (height, width) + + # check types + assert isinstance(image.confidence, jnp.ndarray) + assert isinstance(image.depth, jnp.ndarray) + + # make sure the arrays are initialized to zero + assert jnp.allclose(image.confidence, 0.0) + assert jnp.allclose(image.depth, 0.0) From 41eb163b5e4ed8b720c0663b21d19b7667a8605f Mon Sep 17 00:00:00 2001 From: Xiaoyan Wang Date: Sun, 30 Nov 2025 02:52:34 -0500 Subject: [PATCH 4/4] Bind GPU Image & update test --- genmetaballs/src/cuda/bindings.cu | 60 +++++++++++-------- .../src/genmetaballs/core/__init__.py | 52 ++++++---------- tests/python_tests/test_image.py | 30 ++++++---- 3 files changed, 73 insertions(+), 69 deletions(-) diff --git a/genmetaballs/src/cuda/bindings.cu b/genmetaballs/src/cuda/bindings.cu index 8776e13..3ae2652 100644 --- a/genmetaballs/src/cuda/bindings.cu +++ b/genmetaballs/src/cuda/bindings.cu @@ -15,6 +15,10 @@ namespace nb = nanobind; template void bind_array2d(nb::module_& m, const char* name); +template +void bind_image(nb::module_& m, const char* name); +template +void bind_image_view(nb::module_& m, const char* name); NB_MODULE(_genmetaballs_bindings, m) { @@ -85,33 +89,12 @@ NB_MODULE(_genmetaballs_bindings, m) { /* * Image module bindings - * Note that only Host (CPU) version is exposed for simplicity, as the image data is usually - * needed for visualization only. */ nb::module_ image = m.def_submodule("image", "Image data structure for GenMetaballs"); - nb::class_>(image, "CPUImageView") - .def(nb::init&, - const Array2D&>(), - nb::arg("confidence"), nb::arg("depth")) - .def_prop_ro("confidence", - [](const ImageView& view) { return view.confidence; }) - .def_prop_ro("depth", - [](const ImageView& view) { return view.depth; }) - .def_prop_ro("num_rows", &ImageView::num_rows) - .def_prop_ro("num_cols", &ImageView::num_cols) - .def("__repr__", [](const ImageView& view) { - return nb::str("CPUImageView(height={}, width={})") - .format(view.num_rows(), view.num_cols()); - }); - nb::class_>(image, "CPUImage") - .def(nb::init(), nb::arg("height"), nb::arg("width")) - .def_prop_ro("num_rows", &Image::num_rows) - .def_prop_ro("num_cols", &Image::num_cols) - .def("as_view", &Image::as_view, - "Get a view of the image data as ImageView") - .def("__repr__", [](const Image& img) { - return nb::str("CPUImage(height={}, width={})").format(img.num_rows(), img.num_cols()); - }); + bind_image_view(image, "CPUImageView"); + bind_image(image, "CPUImage"); + bind_image_view(image, "GPUImageView"); + bind_image(image, "GPUImage"); /* * Confidence module bindings @@ -203,3 +186,30 @@ void bind_array2d(nb::module_& m, const char* name) { .def_prop_ro("ndim", &Array2D::ndim) .def_prop_ro("size", &Array2D::size); } + +template +void bind_image_view(nb::module_& m, const char* name) { + nb::class_>(m, name) + .def(nb::init&, const Array2D&>(), + nb::arg("confidence"), nb::arg("depth")) + .def_prop_ro("confidence", [](const ImageView& view) { return view.confidence; }) + .def_prop_ro("depth", [](const ImageView& view) { return view.depth; }) + .def_prop_ro("num_rows", &ImageView::num_rows) + .def_prop_ro("num_cols", &ImageView::num_cols) + .def("__repr__", [=](const ImageView& view) { + return nb::str("{}(height={}, width={})") + .format(name, view.num_rows(), view.num_cols()); + }); +} + +template +void bind_image(nb::module_& m, const char* name) { + nb::class_>(m, name) + .def(nb::init(), nb::arg("height"), nb::arg("width")) + .def_prop_ro("num_rows", &Image::num_rows) + .def_prop_ro("num_cols", &Image::num_cols) + .def("as_view", &Image::as_view, "Get a view of the image data as ImageView") + .def("__repr__", [=](const Image& img) { + return nb::str("{}(height={}, width={})").format(name, img.num_rows(), img.num_cols()); + }); +} diff --git a/genmetaballs/src/genmetaballs/core/__init__.py b/genmetaballs/src/genmetaballs/core/__init__.py index 91ef46a..ac2a6fd 100644 --- a/genmetaballs/src/genmetaballs/core/__init__.py +++ b/genmetaballs/src/genmetaballs/core/__init__.py @@ -1,4 +1,4 @@ -import jax.numpy as jnp +from typing import Literal from genmetaballs._genmetaballs_bindings import geometry from genmetaballs._genmetaballs_bindings.blender import ( @@ -10,11 +10,13 @@ TwoParameterConfidence, ZeroParameterConfidence, ) -from genmetaballs._genmetaballs_bindings.image import CPUImage, CPUImageView +from genmetaballs._genmetaballs_bindings.image import CPUImage, GPUImage from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid +type DeviceType = Literal["cpu", "gpu"] -def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D: + +def array2d_float(data, device: DeviceType) -> CPUFloatArray2D | GPUFloatArray2D: """Create a FloatArray2D on the specified device from an array. Args: @@ -29,36 +31,20 @@ def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D: raise ValueError(f"Unsupported device type: {device}") -class Image: - def __init__(self, height: int, width: int) -> None: - """Create an Image on CPU - - Unlike the C++ version, this Image class keep the buffer internally. - This is because Python does reference counting and manage the memory - automatically for us. - - Args: - height: Number of rows in the image. - width: Number of columns in the image. - """ - self._image = CPUImage(height, width) - # keep a view for easy access - self._view: CPUImageView = self._image.as_view() +def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUImage: + """Create an Image on the specified device. - @property - def confidence(self) -> jnp.ndarray: - """Get the confidence array.""" - return self._view.confidence.as_jax() - - @property - def depth(self) -> jnp.ndarray: - """Get the depth array.""" - return self._view.depth.as_jax() - - @property - def shape(self) -> tuple[int, int]: - """Get the shape of the image as (height, width).""" - return (self._image.num_rows, self._image.num_cols) + Args: + height: The height of the image. + width: The width of the image. + device: 'cpu' or 'gpu' to specify the target device. + """ + if device == "cpu": + return CPUImage(height, width) + elif device == "gpu": + return GPUImage(height, width) + else: + raise ValueError(f"Unsupported device type: {device}") __all__ = [ @@ -71,5 +57,5 @@ def shape(self) -> tuple[int, int]: "sigmoid", "FourParameterBlender", "ThreeParameterBlender", - "Image", + "make_image", ] diff --git a/tests/python_tests/test_image.py b/tests/python_tests/test_image.py index 4f90204..f87fb5a 100644 --- a/tests/python_tests/test_image.py +++ b/tests/python_tests/test_image.py @@ -1,19 +1,27 @@ import jax.numpy as jnp +import pytest -from genmetaballs.core import Image +from genmetaballs.core import make_image -def test_image_creation(): +@pytest.mark.parametrize("device", ["cpu", "gpu"]) +def test_image_creation(device: str) -> None: height, width = 480, 640 - image = Image(height, width) - assert image.shape == (height, width) - assert image.confidence.shape == (height, width) - assert image.depth.shape == (height, width) + image = make_image(height, width, device=device) + assert image.num_rows == height + assert image.num_cols == width - # check types - assert isinstance(image.confidence, jnp.ndarray) - assert isinstance(image.depth, jnp.ndarray) + # create views and check their types + image_view = image.as_view() + assert image_view.num_rows == height + assert image_view.num_cols == width + + # check that the confidence and depth arrays can be converted to JAX arrays + confidence = image_view.confidence.as_jax() + depth = image_view.depth.as_jax() + assert isinstance(confidence, jnp.ndarray) + assert isinstance(depth, jnp.ndarray) # make sure the arrays are initialized to zero - assert jnp.allclose(image.confidence, 0.0) - assert jnp.allclose(image.depth, 0.0) + assert jnp.allclose(confidence, 0.0) + assert jnp.allclose(depth, 0.0)