diff --git a/CMakeLists.txt b/CMakeLists.txt index 989b9ce..dd31fad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,7 @@ add_library(genmetaballs_core genmetaballs/src/cuda/core/geometry.cuh genmetaballs/src/cuda/core/geometry.cu genmetaballs/src/cuda/core/confidence.cuh + genmetaballs/src/cuda/core/image.cuh ) # Set include directories for the core library diff --git a/genmetaballs/src/cuda/bindings.cu b/genmetaballs/src/cuda/bindings.cu index d51c30c..3ae2652 100644 --- a/genmetaballs/src/cuda/bindings.cu +++ b/genmetaballs/src/cuda/bindings.cu @@ -8,12 +8,17 @@ #include "core/camera.cuh" #include "core/confidence.cuh" #include "core/geometry.cuh" +#include "core/image.cuh" #include "core/utils.cuh" namespace nb = nanobind; template void bind_array2d(nb::module_& m, const char* name); +template +void bind_image(nb::module_& m, const char* name); +template +void bind_image_view(nb::module_& m, const char* name); NB_MODULE(_genmetaballs_bindings, m) { @@ -82,6 +87,15 @@ NB_MODULE(_genmetaballs_bindings, m) { "Get the direction of the ray going through pixel (px, py) in camera frame", nb::arg("px"), nb::arg("py")); + /* + * Image module bindings + */ + nb::module_ image = m.def_submodule("image", "Image data structure for GenMetaballs"); + bind_image_view(image, "CPUImageView"); + bind_image(image, "CPUImage"); + bind_image_view(image, "GPUImageView"); + bind_image(image, "GPUImage"); + /* * Confidence module bindings */ @@ -172,3 +186,30 @@ void bind_array2d(nb::module_& m, const char* name) { .def_prop_ro("ndim", &Array2D::ndim) .def_prop_ro("size", &Array2D::size); } + +template +void bind_image_view(nb::module_& m, const char* name) { + nb::class_>(m, name) + .def(nb::init&, const Array2D&>(), + nb::arg("confidence"), nb::arg("depth")) + .def_prop_ro("confidence", [](const ImageView& view) { return view.confidence; }) + .def_prop_ro("depth", [](const ImageView& view) { return view.depth; }) + .def_prop_ro("num_rows", &ImageView::num_rows) + .def_prop_ro("num_cols", &ImageView::num_cols) + .def("__repr__", [=](const ImageView& view) { + return nb::str("{}(height={}, width={})") + .format(name, view.num_rows(), view.num_cols()); + }); +} + +template +void bind_image(nb::module_& m, const char* name) { + nb::class_>(m, name) + .def(nb::init(), nb::arg("height"), nb::arg("width")) + .def_prop_ro("num_rows", &Image::num_rows) + .def_prop_ro("num_cols", &Image::num_cols) + .def("as_view", &Image::as_view, "Get a view of the image data as ImageView") + .def("__repr__", [=](const Image& img) { + return nb::str("{}(height={}, width={})").format(name, img.num_rows(), img.num_cols()); + }); +} diff --git a/genmetaballs/src/cuda/core/camera.cuh b/genmetaballs/src/cuda/core/camera.cuh index d43c490..53d5a31 100644 --- a/genmetaballs/src/cuda/core/camera.cuh +++ b/genmetaballs/src/cuda/core/camera.cuh @@ -21,7 +21,7 @@ struct Intrinsics { // Returns a 2D array of ray directions in camera frame in the specified pixel range // and store them in the provided buffer. By default, the full image is used template - CUDA_CALLABLE Array2D& get_ray_directions(Array2D buffer, + CUDA_CALLABLE Array2D& get_ray_directions(Array2D& buffer, uint32_t px_start = 0, uint32_t px_end = UINT32_MAX, uint32_t py_start = 0, diff --git a/genmetaballs/src/cuda/core/image.cuh b/genmetaballs/src/cuda/core/image.cuh index d8ac7f5..2e7ed52 100644 --- a/genmetaballs/src/cuda/core/image.cuh +++ b/genmetaballs/src/cuda/core/image.cuh @@ -1,12 +1,76 @@ #pragma once #include +#include +#include +#include -#include "geometry.cuh" #include "utils.cuh" -template -struct Image { - Array2D confidence; - Array2D depth; +/* Non-owning view into an image */ +template +class ImageView { +public: + Array2D confidence; + Array2D depth; + + CUDA_CALLABLE constexpr auto num_rows() const noexcept { + return confidence.num_rows(); + } + CUDA_CALLABLE constexpr auto num_cols() const noexcept { + return confidence.num_cols(); + } +}; + +/* The image buffer which handles the allocation & deallocation of memory with RAII. + * While the underlying data may be stored in either host or device memory, this + * class is always managed from the host side. + */ +template +class Image { +private: + // Host memory -> thrust::host_vector + // Device memory -> thrust::device_vector + template + using vector_t = std::conditional_t, + thrust::device_vector>; + + // RAII storage for the image data + vector_t confidence_data_; + vector_t depth_data_; + + const uint32_t height_; + const uint32_t width_; + + // Make all Image instantiations friends so they can access each other's private members + template + friend class Image; + +public: + /* Allocate the memory for a new image & default initialize with zeros + * The "height" of the image correponds to "rows" in the 2D array, whereas + * the "width" of the image corresponds to "columns". + */ + __host__ Image(uint32_t height, uint32_t width) + : height_(height), width_(width), confidence_data_(height * width), + depth_data_(height * width) {} + + /* Copy constructor from a Image which may reside in a different memory location */ + template + __host__ Image(const Image& other) + : height_(other.num_rows()), width_(other.num_cols()), + confidence_data_(other.confidence_data_), depth_data_(other.depth_data_) {} + + /* Create a view of this image which points to the internal data */ + CUDA_CALLABLE auto as_view() { + return ImageView{{confidence_data_.data(), height_, width_}, + {depth_data_.data(), height_, width_}}; + } + + CUDA_CALLABLE constexpr auto num_rows() const noexcept { + return height_; + } + CUDA_CALLABLE constexpr auto num_cols() const noexcept { + return width_; + } }; diff --git a/genmetaballs/src/genmetaballs/core/__init__.py b/genmetaballs/src/genmetaballs/core/__init__.py index 17c4ba5..ac2a6fd 100644 --- a/genmetaballs/src/genmetaballs/core/__init__.py +++ b/genmetaballs/src/genmetaballs/core/__init__.py @@ -1,3 +1,5 @@ +from typing import Literal + from genmetaballs._genmetaballs_bindings import geometry from genmetaballs._genmetaballs_bindings.blender import ( FourParameterBlender, @@ -8,10 +10,13 @@ TwoParameterConfidence, ZeroParameterConfidence, ) +from genmetaballs._genmetaballs_bindings.image import CPUImage, GPUImage from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid +type DeviceType = Literal["cpu", "gpu"] + -def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D: +def array2d_float(data, device: DeviceType) -> CPUFloatArray2D | GPUFloatArray2D: """Create a FloatArray2D on the specified device from an array. Args: @@ -26,6 +31,22 @@ def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D: raise ValueError(f"Unsupported device type: {device}") +def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUImage: + """Create an Image on the specified device. + + Args: + height: The height of the image. + width: The width of the image. + device: 'cpu' or 'gpu' to specify the target device. + """ + if device == "cpu": + return CPUImage(height, width) + elif device == "gpu": + return GPUImage(height, width) + else: + raise ValueError(f"Unsupported device type: {device}") + + __all__ = [ "array2d_float", "ZeroParameterConfidence", @@ -36,4 +57,5 @@ def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D: "sigmoid", "FourParameterBlender", "ThreeParameterBlender", + "make_image", ] diff --git a/tests/cpp_tests/test_image.cu b/tests/cpp_tests/test_image.cu new file mode 100644 index 0000000..4252f7f --- /dev/null +++ b/tests/cpp_tests/test_image.cu @@ -0,0 +1,71 @@ +#include +#include + +#include "core/image.cuh" + +namespace test_image_gpu { + +/* A simple kernel that changes the image data based on pixel coordinate + * Note that while we allocate the image with "Image" type, we pass it to the kernel as "ImageView" + * via as_view() method. + */ +__global__ void manipulate_image_kernel(ImageView img) { + uint32_t row = threadIdx.y; + uint32_t col = threadIdx.x; + + if (row < img.num_rows() && col < img.num_cols()) { + img.confidence[row][col] = static_cast(row); + img.depth[row][col] = static_cast(col); + } +} + +} // namespace test_image_gpu + +TEST(TestImage, ImageCreationHost) { + constexpr uint32_t height = 128; + constexpr uint32_t width = 256; + + // Create an image in host memory + Image img_buffer(height, width); + auto img = img_buffer.as_view(); + + // Check dimensions + EXPECT_EQ(img.num_rows(), height); + EXPECT_EQ(img.num_cols(), width); + + // Check that confidence and depth are initialized to zero + for (uint32_t r = 0; r < height; ++r) { + for (uint32_t c = 0; c < width; ++c) { + EXPECT_FLOAT_EQ(img.confidence[r][c], 0.0f); + EXPECT_FLOAT_EQ(img.depth[r][c], 0.0f); + } + } +} + +TEST(TestImage, ImageManipulationOnDevice) { + constexpr uint32_t height = 16; + constexpr uint32_t width = 16; + + // Create an image in device memory + Image img_device(height, width); + + // Launch kernel to manipulate image data + dim3 threadsPerBlock(width, height); + test_image_gpu::manipulate_image_kernel<<<1, threadsPerBlock>>>(img_device.as_view()); + + // Copy image back to host for verification + Image img_host = img_device; + cudaDeviceSynchronize(); + auto img = img_host.as_view(); + + EXPECT_EQ(img.num_rows(), height); + EXPECT_EQ(img.num_cols(), width); + + // Verify the manipulated data + for (uint32_t r = 0; r < height; ++r) { + for (uint32_t c = 0; c < width; ++c) { + EXPECT_FLOAT_EQ(img.confidence[r][c], static_cast(r)); + EXPECT_FLOAT_EQ(img.depth[r][c], static_cast(c)); + } + } +} diff --git a/tests/python_tests/test_image.py b/tests/python_tests/test_image.py new file mode 100644 index 0000000..f87fb5a --- /dev/null +++ b/tests/python_tests/test_image.py @@ -0,0 +1,27 @@ +import jax.numpy as jnp +import pytest + +from genmetaballs.core import make_image + + +@pytest.mark.parametrize("device", ["cpu", "gpu"]) +def test_image_creation(device: str) -> None: + height, width = 480, 640 + image = make_image(height, width, device=device) + assert image.num_rows == height + assert image.num_cols == width + + # create views and check their types + image_view = image.as_view() + assert image_view.num_rows == height + assert image_view.num_cols == width + + # check that the confidence and depth arrays can be converted to JAX arrays + confidence = image_view.confidence.as_jax() + depth = image_view.depth.as_jax() + assert isinstance(confidence, jnp.ndarray) + assert isinstance(depth, jnp.ndarray) + + # make sure the arrays are initialized to zero + assert jnp.allclose(confidence, 0.0) + assert jnp.allclose(depth, 0.0)