Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_library(genmetaballs_core
genmetaballs/src/cuda/core/geometry.cuh
genmetaballs/src/cuda/core/geometry.cu
genmetaballs/src/cuda/core/confidence.cuh
genmetaballs/src/cuda/core/image.cuh
)

# Set include directories for the core library
Expand Down
41 changes: 41 additions & 0 deletions genmetaballs/src/cuda/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
#include "core/camera.cuh"
#include "core/confidence.cuh"
#include "core/geometry.cuh"
#include "core/image.cuh"
#include "core/utils.cuh"

namespace nb = nanobind;

template <typename T, MemoryLocation location>
void bind_array2d(nb::module_& m, const char* name);
template <MemoryLocation location>
void bind_image(nb::module_& m, const char* name);
template <MemoryLocation location>
void bind_image_view(nb::module_& m, const char* name);

NB_MODULE(_genmetaballs_bindings, m) {

Expand Down Expand Up @@ -82,6 +87,15 @@ NB_MODULE(_genmetaballs_bindings, m) {
"Get the direction of the ray going through pixel (px, py) in camera frame",
nb::arg("px"), nb::arg("py"));

/*
* Image module bindings
*/
nb::module_ image = m.def_submodule("image", "Image data structure for GenMetaballs");
bind_image_view<MemoryLocation::HOST>(image, "CPUImageView");
bind_image<MemoryLocation::HOST>(image, "CPUImage");
bind_image_view<MemoryLocation::DEVICE>(image, "GPUImageView");
bind_image<MemoryLocation::DEVICE>(image, "GPUImage");

/*
* Confidence module bindings
*/
Expand Down Expand Up @@ -172,3 +186,30 @@ void bind_array2d(nb::module_& m, const char* name) {
.def_prop_ro("ndim", &Array2D<T, location>::ndim)
.def_prop_ro("size", &Array2D<T, location>::size);
}

template <MemoryLocation location>
void bind_image_view(nb::module_& m, const char* name) {
nb::class_<ImageView<location>>(m, name)
.def(nb::init<const Array2D<float, location>&, const Array2D<float, location>&>(),
nb::arg("confidence"), nb::arg("depth"))
.def_prop_ro("confidence", [](const ImageView<location>& view) { return view.confidence; })
.def_prop_ro("depth", [](const ImageView<location>& view) { return view.depth; })
.def_prop_ro("num_rows", &ImageView<location>::num_rows)
.def_prop_ro("num_cols", &ImageView<location>::num_cols)
.def("__repr__", [=](const ImageView<location>& view) {
return nb::str("{}(height={}, width={})")
.format(name, view.num_rows(), view.num_cols());
});
}

template <MemoryLocation location>
void bind_image(nb::module_& m, const char* name) {
nb::class_<Image<location>>(m, name)
.def(nb::init<uint32_t, uint32_t>(), nb::arg("height"), nb::arg("width"))
.def_prop_ro("num_rows", &Image<location>::num_rows)
.def_prop_ro("num_cols", &Image<location>::num_cols)
.def("as_view", &Image<location>::as_view, "Get a view of the image data as ImageView")
.def("__repr__", [=](const Image<location>& img) {
return nb::str("{}(height={}, width={})").format(name, img.num_rows(), img.num_cols());
});
}
2 changes: 1 addition & 1 deletion genmetaballs/src/cuda/core/camera.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct Intrinsics {
// Returns a 2D array of ray directions in camera frame in the specified pixel range
// and store them in the provided buffer. By default, the full image is used
template <MemoryLocation location>
CUDA_CALLABLE Array2D<Vec3D, location>& get_ray_directions(Array2D<Vec3D, location> buffer,
CUDA_CALLABLE Array2D<Vec3D, location>& get_ray_directions(Array2D<Vec3D, location>& buffer,
uint32_t px_start = 0,
uint32_t px_end = UINT32_MAX,
uint32_t py_start = 0,
Expand Down
74 changes: 69 additions & 5 deletions genmetaballs/src/cuda/core/image.cuh
Original file line number Diff line number Diff line change
@@ -1,12 +1,76 @@
#pragma once

#include <cstdint>
#include <cuda_runtime.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include "geometry.cuh"
#include "utils.cuh"

template <uint32_t width, uint32_t height>
struct Image {
Array2D<float, width, height> confidence;
Array2D<float, width, height> depth;
/* Non-owning view into an image */
template <MemoryLocation location>
class ImageView {
public:
Array2D<float, location> confidence;
Array2D<float, location> depth;

CUDA_CALLABLE constexpr auto num_rows() const noexcept {
return confidence.num_rows();
}
CUDA_CALLABLE constexpr auto num_cols() const noexcept {
return confidence.num_cols();
}
};

/* The image buffer which handles the allocation & deallocation of memory with RAII.
* While the underlying data may be stored in either host or device memory, this
* class is always managed from the host side.
*/
template <MemoryLocation location>
class Image {
private:
// Host memory -> thrust::host_vector
// Device memory -> thrust::device_vector
template <typename T>
using vector_t = std::conditional_t<location == MemoryLocation::HOST, thrust::host_vector<T>,
thrust::device_vector<T>>;

// RAII storage for the image data
vector_t<float> confidence_data_;
vector_t<float> depth_data_;

const uint32_t height_;
const uint32_t width_;

// Make all Image instantiations friends so they can access each other's private members
template <MemoryLocation other_location>
friend class Image;

public:
/* Allocate the memory for a new image & default initialize with zeros
* The "height" of the image correponds to "rows" in the 2D array, whereas
* the "width" of the image corresponds to "columns".
*/
__host__ Image(uint32_t height, uint32_t width)
: height_(height), width_(width), confidence_data_(height * width),
depth_data_(height * width) {}

/* Copy constructor from a Image which may reside in a different memory location */
template <MemoryLocation other_location>
__host__ Image(const Image<other_location>& other)
: height_(other.num_rows()), width_(other.num_cols()),
confidence_data_(other.confidence_data_), depth_data_(other.depth_data_) {}

/* Create a view of this image which points to the internal data */
CUDA_CALLABLE auto as_view() {
return ImageView<location>{{confidence_data_.data(), height_, width_},
{depth_data_.data(), height_, width_}};
}

CUDA_CALLABLE constexpr auto num_rows() const noexcept {
return height_;
}
CUDA_CALLABLE constexpr auto num_cols() const noexcept {
return width_;
}
};
24 changes: 23 additions & 1 deletion genmetaballs/src/genmetaballs/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

from genmetaballs._genmetaballs_bindings import geometry
from genmetaballs._genmetaballs_bindings.blender import (
FourParameterBlender,
Expand All @@ -8,10 +10,13 @@
TwoParameterConfidence,
ZeroParameterConfidence,
)
from genmetaballs._genmetaballs_bindings.image import CPUImage, GPUImage
from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid

type DeviceType = Literal["cpu", "gpu"]


def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D:
def array2d_float(data, device: DeviceType) -> CPUFloatArray2D | GPUFloatArray2D:
"""Create a FloatArray2D on the specified device from an array.

Args:
Expand All @@ -26,6 +31,22 @@ def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D:
raise ValueError(f"Unsupported device type: {device}")


def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUImage:
"""Create an Image on the specified device.

Args:
height: The height of the image.
width: The width of the image.
device: 'cpu' or 'gpu' to specify the target device.
"""
if device == "cpu":
return CPUImage(height, width)
elif device == "gpu":
return GPUImage(height, width)
else:
raise ValueError(f"Unsupported device type: {device}")


__all__ = [
"array2d_float",
"ZeroParameterConfidence",
Expand All @@ -36,4 +57,5 @@ def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D:
"sigmoid",
"FourParameterBlender",
"ThreeParameterBlender",
"make_image",
]
71 changes: 71 additions & 0 deletions tests/cpp_tests/test_image.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include "core/image.cuh"

namespace test_image_gpu {

/* A simple kernel that changes the image data based on pixel coordinate
* Note that while we allocate the image with "Image" type, we pass it to the kernel as "ImageView"
* via as_view() method.
*/
__global__ void manipulate_image_kernel(ImageView<MemoryLocation::DEVICE> img) {
uint32_t row = threadIdx.y;
uint32_t col = threadIdx.x;

if (row < img.num_rows() && col < img.num_cols()) {
img.confidence[row][col] = static_cast<float>(row);
img.depth[row][col] = static_cast<float>(col);
}
}

} // namespace test_image_gpu

TEST(TestImage, ImageCreationHost) {
constexpr uint32_t height = 128;
constexpr uint32_t width = 256;

// Create an image in host memory
Image<MemoryLocation::HOST> img_buffer(height, width);
auto img = img_buffer.as_view();

// Check dimensions
EXPECT_EQ(img.num_rows(), height);
EXPECT_EQ(img.num_cols(), width);

// Check that confidence and depth are initialized to zero
for (uint32_t r = 0; r < height; ++r) {
for (uint32_t c = 0; c < width; ++c) {
EXPECT_FLOAT_EQ(img.confidence[r][c], 0.0f);
EXPECT_FLOAT_EQ(img.depth[r][c], 0.0f);
}
}
}

TEST(TestImage, ImageManipulationOnDevice) {
constexpr uint32_t height = 16;
constexpr uint32_t width = 16;

// Create an image in device memory
Image<MemoryLocation::DEVICE> img_device(height, width);

// Launch kernel to manipulate image data
dim3 threadsPerBlock(width, height);
test_image_gpu::manipulate_image_kernel<<<1, threadsPerBlock>>>(img_device.as_view());

// Copy image back to host for verification
Image<MemoryLocation::HOST> img_host = img_device;
cudaDeviceSynchronize();
auto img = img_host.as_view();

EXPECT_EQ(img.num_rows(), height);
EXPECT_EQ(img.num_cols(), width);

// Verify the manipulated data
for (uint32_t r = 0; r < height; ++r) {
for (uint32_t c = 0; c < width; ++c) {
EXPECT_FLOAT_EQ(img.confidence[r][c], static_cast<float>(r));
EXPECT_FLOAT_EQ(img.depth[r][c], static_cast<float>(c));
}
}
}
27 changes: 27 additions & 0 deletions tests/python_tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import jax.numpy as jnp
import pytest

from genmetaballs.core import make_image


@pytest.mark.parametrize("device", ["cpu", "gpu"])
def test_image_creation(device: str) -> None:
height, width = 480, 640
image = make_image(height, width, device=device)
assert image.num_rows == height
assert image.num_cols == width

# create views and check their types
image_view = image.as_view()
assert image_view.num_rows == height
assert image_view.num_cols == width

# check that the confidence and depth arrays can be converted to JAX arrays
confidence = image_view.confidence.as_jax()
depth = image_view.depth.as_jax()
assert isinstance(confidence, jnp.ndarray)
assert isinstance(depth, jnp.ndarray)

# make sure the arrays are initialized to zero
assert jnp.allclose(confidence, 0.0)
assert jnp.allclose(depth, 0.0)