Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
add_library(genmetaballs_core
genmetaballs/src/cuda/core/camera.cu
genmetaballs/src/cuda/core/camera.cuh
genmetaballs/src/cuda/core/utils.cu
genmetaballs/src/cuda/core/utils.cuh
genmetaballs/src/cuda/core/fmb.cuh
genmetaballs/src/cuda/core/confidence.cuh
genmetaballs/src/cuda/core/fmb.cu
genmetaballs/src/cuda/core/geometry.cuh
genmetaballs/src/cuda/core/fmb.cuh
genmetaballs/src/cuda/core/forward.cu
genmetaballs/src/cuda/core/forward.cuh
genmetaballs/src/cuda/core/geometry.cu
genmetaballs/src/cuda/core/confidence.cuh
genmetaballs/src/cuda/core/geometry.cuh
genmetaballs/src/cuda/core/image.cuh
genmetaballs/src/cuda/core/intersector.cuh
genmetaballs/src/cuda/core/utils.cu
genmetaballs/src/cuda/core/utils.cuh
)

# Set include directories for the core library
Expand Down
4 changes: 3 additions & 1 deletion genmetaballs/src/cuda/core/camera.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ struct Intrinsics {
CUDA_CALLABLE Vec3D get_ray_direction(uint32_t px, uint32_t py) const;
};

using PixelCoord = cuda::std::pair<uint32_t, uint32_t>;

struct PixelCoordRange {
uint32_t px_start;
uint32_t px_end;
Expand All @@ -38,7 +40,7 @@ struct PixelCoordRange {
uint32_t py;

// Returns the (px, py) coordinates of the current pixel
CUDA_CALLABLE cuda::std::pair<uint32_t, uint32_t> operator*() const;
CUDA_CALLABLE PixelCoord operator*() const;

// pre-increment operator that advances to the next pixel
CUDA_CALLABLE Iterator& operator++();
Expand Down
67 changes: 16 additions & 51 deletions genmetaballs/src/cuda/core/forward.cu
Original file line number Diff line number Diff line change
@@ -1,55 +1,20 @@
#include <cstdint>
#include <cuda_runtime.h>
#include <vector>

constexpr auto NUM_BLOCKS = dim3(10); // XXX madeup
constexpr auto THREADS_PER_BLOCK = dim3(10);

namespace FMB {

CUDA_CALLABLE std::vector<std::pair<PixelCoord, Ray>> get_pixel_coords_and_rays(
const dim3 thread_idx, const dim3 block_idx) {
std::vector<std::pair<PixelCoord, Ray>> res;

uint32_t i_beg = 0; // XXX TODO
uint32_t i_end = 0; // XXX TODO

for (int i = i_beg; i < i_end; i += blockDim.x) {
//...
}

return res;
#include "camera.cuh"
#include "forward.cuh"
#include "utils.cuh"

CUDA_CALLABLE PixelCoordRange get_pixel_coords(const dim3 thread_idx, const dim3 block_idx,
const dim3 block_dim, const dim3 grid_dim,
const Intrinsics& intr) {
// compute the number of pixels each thread should process
const auto num_pixels_x = int_ceil_div(intr.height, grid_dim.x * block_dim.x);
const auto num_pixels_y = int_ceil_div(intr.width, grid_dim.y * block_dim.y);
const auto start_x = (block_idx.x * block_dim.x + thread_idx.x) * num_pixels_x;
const auto start_y = (block_idx.y * block_dim.y + thread_idx.y) * num_pixels_y;
return PixelCoordRange{.px_start = start_x,
.px_end = min(start_x + num_pixels_x, intr.height),
.py_start = start_y,
.py_end = min(start_y + num_pixels_y, intr.width)};
}

template <class Getter, class Intersector, class Blender, class Confidence>
__global__ render_kernel(const Getter fmb_getter, const Blender blender,
Confidence const* confidence, Intrinsics const* intr, Pose const* extr,
Image* img) {
// TODO how to find the relevant chunk of computation from threadIdx,
// blockIdx, etc
auto pixel_coords_and_rays =
get_pixel_coords_and_rays(threadIdx, blockIdx, blockDim, gridDim, intr, extr);

for (const auto& [pixel_coords, ray] : pixel_coords_and_rays) {
float w0 = 0.0f, tf = 0.0f, sumexpd = 0.0f;
for (const auto& fmb : fmb_getter->get_metaballs(ray)) {
const auto& [t, d] = Intersector::intersect(fmb, ray, extr);
w = blender->blend(t, d, fmb, ray);
sumexpd += exp(d); // numerically unstable. use logsumexp
tf += t;
w0 += w;
}
img->confidence.at(pixel_coords) = confidence->get_confidence(sumexpd);
img->depth.at(pixel_coords) = tf / w0;
}
}

template <class Getter, class Intersector, class Blender, class Confidence>
void render_fmbs(const FMBs& fmbs, const Intrinsics& intr, const Pose& extr) {
// initialize the fmb_getter
typename Getter::Getter fmb_getter(fmbs, extr);
auto kernel = render_kernel<Getter, Intersector, Blender, Confidence>;
kernel<<<NUM_BLOCKS, THREADS_PER_BLOCK>>>(fmb_getter, fmbs, intr, extr);
}

}; // namespace FMB
48 changes: 48 additions & 0 deletions genmetaballs/src/cuda/core/forward.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once

#include <cstdint>
#include <cuda_runtime.h>

#include "camera.cuh"
#include "fmb.cuh"
#include "geometry.cuh"
#include "image.cuh"
#include "utils.cuh"

// TODO: tune this number
constexpr auto NUM_BLOCKS = dim3(4, 4);
constexpr auto THREADS_PER_BLOCK = dim3(16, 16);

CUDA_CALLABLE PixelCoordRange get_pixel_coords(const dim3 thread_idx, const dim3 block_idx,
const dim3 block_dim, const dim3 grid_dim,
const Intrinsics& intr);

template <typename Getter, typename Intersector, typename Blender, typename Confidence>
__global__ void render_kernel(const Getter fmb_getter, const Blender blender,
Confidence const* confidence, Intrinsics const intr, Pose const* extr,
ImageView<MemoryLocation::DEVICE> img) {
auto pixel_coords = get_pixel_coords(threadIdx, blockIdx, blockDim, gridDim, intr);

for (const auto& [px, py] : pixel_coords) {
float w0 = 0.0f, tf = 0.0f, sumexpd = 0.0f;
auto ray = intr.get_ray_direction(px, py);
for (const auto& fmb : fmb_getter->get_metaballs(ray)) {
const auto& [t, d] = Intersector::intersect(fmb, ray, extr);
auto w = blender->blend(t, d, fmb, ray);
sumexpd += exp(d); // numerically unstable. use logsumexp
tf += t;
w0 += w;
}
img.confidence[px][py] = confidence->get_confidence(sumexpd);
img.depth[px][py] = tf / w0;
}
}

template <typename Getter, typename Intersector, typename Blender, typename Confidence>
void render_fmbs(const FMBScene<MemoryLocation::DEVICE>& fmbs, const Intrinsics& intr,
const Pose& extr) {
// initialize the fmb_getter
auto fmb_getter = Getter(fmbs, extr);
auto& kernel = render_kernel<Getter, Intersector, Blender, Confidence>;
kernel<<<NUM_BLOCKS, THREADS_PER_BLOCK>>>(fmb_getter, fmbs, intr, extr);
}
6 changes: 6 additions & 0 deletions genmetaballs/src/cuda/core/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ CUDA_CALLABLE __forceinline__ float sigmoid(float x) {
return 1.0f / (1.0f + expf(-x));
}

// Integer ceiling division
template <typename T>
CUDA_CALLABLE constexpr T int_ceil_div(T a, T b) {
return (a + b - 1) / b;
}

enum class MemoryLocation { HOST, DEVICE };

// Non-owning 2D view into a contiguous array in either host or device memory
Expand Down
34 changes: 34 additions & 0 deletions tests/cpp_tests/test_forward.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include "core/forward.cuh"
#include "core/utils.cuh"
#include "thrust/device_vector.h"
#include "thrust/host_vector.h"

namespace get_pixel_coords_tests {
// A simple kernel that fills an Array2D with 1.0f in parallel
__global__ void fill_with_ones_kernel(Array2D<float, MemoryLocation::DEVICE> output,
const Intrinsics& intr) {
auto pixel_coords = get_pixel_coords(threadIdx, blockIdx, blockDim, gridDim, intr);
for (const auto [px, py] : pixel_coords) {
output[px][py] = 1.0f;
}
}
} // namespace get_pixel_coords_tests

// Test if fmb::get_pixel_coords correctly covers all image pixels
TEST(ForwardTest, GetPixelCoordsCoverage) {
const auto intrinsic =
Intrinsics{.height = 100, .width = 200, .fx = 1.0f, .fy = 1.0f, .cx = 50.0f, .cy = 100.0f};
auto buffer = thrust::device_vector<float>(intrinsic.height * intrinsic.width, 0.0f);
auto array2d =
Array2D<float, MemoryLocation::DEVICE>(buffer.data(), intrinsic.height, intrinsic.width);
constexpr dim3 block_dim(12, 8);
constexpr dim3 grid_dim(16, 24);
get_pixel_coords_tests::fill_with_ones_kernel<<<grid_dim, block_dim>>>(array2d, intrinsic);
auto host_buffer = thrust::host_vector<float>(buffer);
for (size_t i = 0; i < host_buffer.size(); ++i) {
EXPECT_EQ(host_buffer[i], 1.0f);
}
}
7 changes: 7 additions & 0 deletions tests/cpp_tests/test_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,10 @@ TYPED_TEST(Array2DTestFixture, MultipleViewsOfSameData) {
EXPECT_FLOAT_EQ(view1[1][2], 200.0f);
}
}

TEST(CeilDivTests, TestBasicCeillDivCorrectness) {
EXPECT_EQ(int_ceil_div(10, 3), 4);
EXPECT_EQ(int_ceil_div(9, 3), 3);
EXPECT_EQ(int_ceil_div(0, 5), 0);
EXPECT_EQ(int_ceil_div(1, 1), 1);
};