diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index 4d4b9b78..cc749563 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -96,6 +96,21 @@ __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, cons return transformed; } +__forceinline__ __device__ float3 point_to_equirect( + float3 p_orig, + const float* viewmatrix) +{ + float3 direction_vector = transformPoint4x3(p_orig, viewmatrix); + float direction_vector_length = sqrtf(direction_vector.x * direction_vector.x + direction_vector.y * direction_vector.y + direction_vector.z * direction_vector.z); + float longitude = atan2f(direction_vector.x, direction_vector.z); + float latitude = atan2f(direction_vector.y , sqrtf(direction_vector.x * direction_vector.x + direction_vector.z * direction_vector.z)); + float normalized_latitude = latitude / (M_PI / 2.0f); + float normalized_longitude = longitude / M_PI; + float3 p_view = {normalized_longitude, normalized_latitude, direction_vector_length}; + return p_view; +} + + __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; @@ -163,6 +178,28 @@ __forceinline__ __device__ bool in_frustum(int idx, return true; } +__forceinline__ __device__ bool in_sphere(int idx, + const float* orig_points, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + bool prefiltered, + float3& p_view) +{ + float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; + p_view = point_to_equirect(p_orig, viewmatrix); + if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) + { + if (prefiltered) + { + printf("Point is filtered although prefiltered is set. This shouldn't happen!"); + __trap(); + } + return false; + } + return true; +} + #define CHECK_CUDA(A, debug) \ A; if(debug) { \ auto ret = cudaDeviceSynchronize(); \ diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 4aa41e1c..9d8b8a74 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -273,6 +273,135 @@ __global__ void computeCov2DCUDA(int P, dL_dmeans[idx] = dL_dmean; } +// Backward version of INVERSE 2D covariance matrix computation +// (due to length launched as separate kernel before other +// backward steps contained in preprocess) +__global__ void computesphericalCov2DCUDA(int P, + const float3* means, + const int* radii, + const float* cov3Ds, + const float h_x, float h_y, + const float* view_matrix, + const float* dL_dconics, + float3* dL_dmeans, + float* dL_dcov) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P || !(radii[idx] > 0)) + return; + + // Reading location of 3D covariance for this Gaussian + const float* cov3D = cov3Ds + 6 * idx; + + // Fetch gradients, recompute 2D covariance and relevant + // intermediate forward results needed in the backward. + float3 mean = means[idx]; + float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] }; + float3 t = transformPoint4x3(mean, view_matrix); + + float t_length = sqrtf(t.x * t.x + t.y * t.y + t.z * t.z); + + float3 t_unit_focal = {0.0f, 0.0f, t_length}; + + glm::mat3 J = glm::mat3( + h_x / t_unit_focal.z, 0.0f, -(h_x * t_unit_focal.x) / (t_unit_focal.z * t_unit_focal.z), + 0.0f, h_x / t_unit_focal.z, -(h_x * t_unit_focal.y) / (t_unit_focal.z * t_unit_focal.z), + 0, 0, 0); + + glm::mat3 W = glm::mat3( + view_matrix[0], view_matrix[4], view_matrix[8], + view_matrix[1], view_matrix[5], view_matrix[9], + view_matrix[2], view_matrix[6], view_matrix[10]); + + glm::mat3 Vrk = glm::mat3( + cov3D[0], cov3D[1], cov3D[2], + cov3D[1], cov3D[3], cov3D[4], + cov3D[2], cov3D[4], cov3D[5]); + + glm::mat3 T = W * J; + + glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T; + + // Use helper variables for 2D covariance entries. More compact. + float a = cov2D[0][0] += 0.3f; + float b = cov2D[0][1]; + float c = cov2D[1][1] += 0.3f; + + float denom = a * c - b * b; + float dL_da = 0, dL_db = 0, dL_dc = 0; + float denom2inv = 1.0f / ((denom * denom) + 0.0000001f); + + if (denom2inv != 0) + { + // Gradients of loss w.r.t. entries of 2D covariance matrix, + // given gradients of loss w.r.t. conic matrix (inverse covariance matrix). + // e.g., dL / da = dL / d_conic_a * d_conic_a / d_a + dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z); + dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x); + dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z); + + // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, + // given gradients w.r.t. 2D covariance matrix (diagonal). + // cov2D = transpose(T) * transpose(Vrk) * T; + dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc); + dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc); + dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc); + + // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, + // given gradients w.r.t. 2D covariance matrix (off-diagonal). + // Off-diagonal elements appear twice --> double the gradient. + // cov2D = transpose(T) * transpose(Vrk) * T; + dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc; + dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc; + dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc; + } + else + { + for (int i = 0; i < 6; i++) + dL_dcov[6 * idx + i] = 0; + } + + // Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T + // cov2D = transpose(T) * transpose(Vrk) * T; + float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da + + (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db; + float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da + + (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db; + float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da + + (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db; + float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc + + (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db; + float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc + + (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db; + float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc + + (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db; + + // Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix + // T = W * J + float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02; + float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02; + float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12; + float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12; + + float tz = 1.f / t.z; + float tz2 = tz * tz; + float tz3 = tz2 * tz; + + // Gradients of loss w.r.t. transformed Gaussian mean t + float dL_dtx = -h_x * tz2 * dL_dJ02; + float dL_dty = -h_y * tz2 * dL_dJ12; + float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12; + + // Account for transformation of mean to t + // t = transformPoint4x3(mean, view_matrix); + float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix); + + // Gradients of loss w.r.t. Gaussian means, but only the portion + // that is caused because the mean affects the covariance matrix. + // Additional mean gradient is accumulated in BACKWARD::preprocess. + dL_dmeans[idx] = dL_dmean; +} + // Backward pass for the conversion of scale and rotation to a // 3D covariance matrix for each Gaussian. __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots) @@ -395,6 +524,63 @@ __global__ void preprocessCUDA( computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot); } +// Backward pass of the preprocessing steps, except +// for the covariance computation and inversion +// (those are handled by a previous kernel call) +template +__global__ void preprocessspehricalCUDA( + int P, int D, int M, + const float3* means, + const int* radii, + const float* shs, + const bool* clamped, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* view_matrix, + const float* proj, + const glm::vec3* campos, + const float3* dL_dmean2D, + glm::vec3* dL_dmeans, + float* dL_dcolor, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P || !(radii[idx] > 0)) + return; + + float3 m = means[idx]; + + // Taking care of gradients from the screenspace points + float4 m_hom = transformPoint4x4(m, proj); + float m_w = 1.0f / (m_hom.w + 0.0000001f); + + // Compute loss gradient w.r.t. 3D means due to gradients of 2D means + // from rendering procedure + glm::vec3 dL_dmean; + float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w; + float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w; + dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y; + dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y; + dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y; + + // That's the second part of the mean gradient. Previous computation + // of cov2D and following SH conversion also affects it. + dL_dmeans[idx] += dL_dmean; + + // Compute gradient updates due to computing colors from SHs + if (shs) + computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh); + + // Compute gradient updates due to computing covariance from scale/rotation + if (scales) + computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot); +} + + // Backward version of the rendering procedure. template __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) @@ -621,6 +807,70 @@ void BACKWARD::preprocess( dL_drot); } +void BACKWARD::preprocessspherical( + int P, int D, int M, + const float3* means3D, + const int* radii, + const float* shs, + const bool* clamped, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* cov3Ds, + const float* viewmatrix, + const float* projmatrix, + const float focal_x, float focal_y, + const float tan_fovx, float tan_fovy, + const glm::vec3* campos, + const float3* dL_dmean2D, + const float* dL_dconic, + glm::vec3* dL_dmean3D, + float* dL_dcolor, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot) +{ + // Propagate gradients for the path of 2D conic matrix computation. + // Somewhat long, thus it is its own kernel rather than being part of + // "preprocess". When done, loss gradient w.r.t. 3D means has been + // modified and gradient w.r.t. 3D covariance matrix has been computed. + computesphericalCov2DCUDA << <(P + 255) / 256, 256 >> > ( + P, + means3D, + radii, + cov3Ds, + focal_x, + focal_y, + viewmatrix, + dL_dconic, + (float3*)dL_dmean3D, + dL_dcov3D); + + // Propagate gradients for remaining steps: finish 3D mean gradients, + // propagate color gradients to SH (if desireD), propagate 3D covariance + // matrix gradients to scale and rotation. + preprocessspehricalCUDA << < (P + 255) / 256, 256 >> > ( + P, D, M, + (float3*)means3D, + radii, + shs, + clamped, + (glm::vec3*)scales, + (glm::vec4*)rotations, + scale_modifier, + viewmatrix, + projmatrix, + campos, + (float3*)dL_dmean2D, + (glm::vec3*)dL_dmean3D, + dL_dcolor, + dL_dcov3D, + dL_dsh, + dL_dscale, + dL_drot); +} + void BACKWARD::render( const dim3 grid, const dim3 block, const uint2* ranges, diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index 93dd2e4b..6386954c 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -60,6 +60,30 @@ namespace BACKWARD float* dL_dsh, glm::vec3* dL_dscale, glm::vec4* dL_drot); + + void preprocessspherical( + int P, int D, int M, + const float3* means, + const int* radii, + const float* shs, + const bool* clamped, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* cov3Ds, + const float* view, + const float* proj, + const float focal_x, float focal_y, + const float tan_fovx, float tan_fovy, + const glm::vec3* campos, + const float3* dL_dmean2D, + const float* dL_dconics, + glm::vec3* dL_dmeans, + float* dL_dcolor, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot); } #endif \ No newline at end of file diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index c419a328..299dd3c0 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -8,7 +8,7 @@ * * For inquiries contact george.drettakis@inria.fr */ - +#include #include "forward.h" #include "auxiliary.h" #include @@ -112,6 +112,46 @@ __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) }; } + +// Forward version of 2D covariance matrix computation +__device__ float3 computesphericalCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix) +{ + // The following models the steps outlined by equations 29 + // and 31 in "EWA Splatting" (Zwicker et al., 2002). + // Additionally considers aspect / scaling of viewport. + // Transposes used to account for row-/column-major conventions. + + float3 t = transformPoint4x3(mean, viewmatrix); + + float t_length = sqrtf(t.x * t.x + t.y * t.y + t.z * t.z); + + float3 t_unit_focal = {0.0f, 0.0f, t_length}; + glm::mat3 J = glm::mat3( + focal_x / t_unit_focal.z, 0.0f, -(focal_x * t_unit_focal.x) / (t_unit_focal.z * t_unit_focal.z), + 0.0f, focal_x / t_unit_focal.z, -(focal_x * t_unit_focal.y) / (t_unit_focal.z * t_unit_focal.z), + 0, 0, 0); + + glm::mat3 W = glm::mat3( + viewmatrix[0], viewmatrix[4], viewmatrix[8], + viewmatrix[1], viewmatrix[5], viewmatrix[9], + viewmatrix[2], viewmatrix[6], viewmatrix[10]); + + glm::mat3 T = W * J; + + glm::mat3 Vrk = glm::mat3( + cov3D[0], cov3D[1], cov3D[2], + cov3D[1], cov3D[3], cov3D[4], + cov3D[2], cov3D[4], cov3D[5]); + + glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T; + + // Apply low-pass filter: every Gaussian should be at least + // one pixel wide/high. Discard 3rd row and column. + cov[0][0] += 0.3f; + cov[1][1] += 0.3f; + return { float(cov[0][0]), float(cov[0][1]) , float(cov[1][1])}; +} + // Forward method for converting scale and rotation properties of each // Gaussian to a 3D covariance matrix in world space. Also takes care // of quaternion normalization. @@ -230,6 +270,7 @@ __global__ void preprocessCUDA(int P, int D, int M, float lambda1 = mid + sqrt(max(0.1f, mid * mid - det)); float lambda2 = mid - sqrt(max(0.1f, mid * mid - det)); float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2))); + float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) }; uint2 rect_min, rect_max; getRect(point_image, my_radius, rect_min, rect_max, grid); @@ -255,6 +296,110 @@ __global__ void preprocessCUDA(int P, int D, int M, tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); } + +// Perform initial steps for each Gaussian prior to rasterization. +template +__global__ void preprocesssphericalCUDA(int P, int D, int M, + const float* orig_points, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, int H, + const float tan_fovx, float tan_fovy, + const float focal_x, float focal_y, + int* radii, + float2* points_xy_image, + float* depths, + float* cov3Ds, + float* rgb, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; + + // Initialize radius and touched tiles to 0. If this isn't changed, + // this Gaussian will not be processed further. + radii[idx] = 0; + tiles_touched[idx] = 0; + + // Perform near culling, quit if outside. + float3 p_view; + if (!in_sphere(idx, orig_points, viewmatrix, projmatrix, cam_pos, prefiltered, p_view)) + return; + + float3 p_proj = {p_view.x, p_view.y, 2.0 * (p_view.z - 0.2f) / (100.0f - 0.2f) - 1.0}; + + // If 3D covariance matrix is precomputed, use it, otherwise compute + // from scaling and rotation parameters. + const float* cov3D; + if (cov3D_precomp != nullptr) + { + cov3D = cov3D_precomp + idx * 6; + } + else + { + computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6); + cov3D = cov3Ds + idx * 6; + } + + // Compute 2D screen-space covariance matrix + float3 cov = computesphericalCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix); + // Invert covariance (EWA algorithm) + float det = (cov.x * cov.z - cov.y * cov.y); + if (det == 0.0f) + return; + float det_inv = 1.f / det; + float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv }; + + // Compute extent in screen space (by finding eigenvalues of + // 2D covariance matrix). Use extent to compute a bounding rectangle + // of screen-space tiles that this Gaussian overlaps with. Quit if + // rectangle covers 0 tiles. + float mid = 0.5f * (cov.x + cov.z); + float lambda1 = mid + sqrt(max(0.1f, mid * mid - det)); + float lambda2 = mid - sqrt(max(0.1f, mid * mid - det)); + + float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2))); + + float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) }; + uint2 rect_min, rect_max; + getRect(point_image, my_radius , rect_min, rect_max, grid); + if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) + return; + + + // If colors have been precomputed, use them, otherwise convert + // spherical harmonics coefficients to RGB color. + if (colors_precomp == nullptr) + { + glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); + rgb[idx * C + 0] = result.x; + rgb[idx * C + 1] = result.y; + rgb[idx * C + 2] = result.z; + } + + // Store some useful helper data for the next steps. + depths[idx] = p_view.z; + radii[idx] = my_radius; + points_xy_image[idx] = point_image; + // Inverse 2D covariance and opacity neatly pack into one float4 + conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] }; + tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); +} + // Main rasterization method. Collaboratively works on one tile per // block, each thread treats one pixel. Alternates between fetching // and rasterizing data. @@ -354,6 +499,7 @@ renderCUDA( for (int ch = 0; ch < CHANNELS; ch++) C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T; + T = test_T; // Keep track of last range entry to update this @@ -452,4 +598,60 @@ void FORWARD::preprocess(int P, int D, int M, tiles_touched, prefiltered ); +} + + +void FORWARD::preprocessspherical(int P, int D, int M, + const float* means3D, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, int H, + const float focal_x, float focal_y, + const float tan_fovx, float tan_fovy, + int* radii, + float2* means2D, + float* depths, + float* cov3Ds, + float* rgb, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered) +{ + preprocesssphericalCUDA << <(P + 255) / 256, 256 >> > ( + P, D, M, + means3D, + scales, + scale_modifier, + rotations, + opacities, + shs, + clamped, + cov3D_precomp, + colors_precomp, + viewmatrix, + projmatrix, + cam_pos, + W, H, + tan_fovx, tan_fovy, + focal_x, focal_y, + radii, + means2D, + depths, + cov3Ds, + rgb, + conic_opacity, + grid, + tiles_touched, + prefiltered + ); } \ No newline at end of file diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 3c11cb91..876c30ab 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -47,6 +47,32 @@ namespace FORWARD uint32_t* tiles_touched, bool prefiltered); + void preprocessspherical(int P, int D, int M, + const float* orig_points, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, int H, + const float focal_x, float focal_y, + const float tan_fovx, float tan_fovy, + int* radii, + float2* points_xy_image, + float* depths, + float* cov3Ds, + float* colors, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered); + // Main rasterization method. void render( const dim3 grid, dim3 block, diff --git a/cuda_rasterizer/rasterizer.h b/cuda_rasterizer/rasterizer.h index 81544ef6..b2b1550d 100644 --- a/cuda_rasterizer/rasterizer.h +++ b/cuda_rasterizer/rasterizer.h @@ -52,6 +52,30 @@ namespace CudaRasterizer int* radii = nullptr, bool debug = false); + static int forwardspherical( + std::function geometryBuffer, + std::function binningBuffer, + std::function imageBuffer, + const int P, int D, int M, + const float* background, + const int width, int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* opacities, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* cam_pos, + const float tan_fovx, float tan_fovy, + const bool prefiltered, + float* out_color, + int* radii = nullptr, + bool debug = false); + static void backward( const int P, int D, int M, int R, const float* background, @@ -82,6 +106,37 @@ namespace CudaRasterizer float* dL_dscale, float* dL_drot, bool debug); + + static void backwardspherical( + const int P, int D, int M, int R, + const float* background, + const int width, int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* campos, + const float tan_fovx, float tan_fovy, + const int* radii, + char* geom_buffer, + char* binning_buffer, + char* image_buffer, + const float* dL_dpix, + float* dL_dmean2D, + float* dL_dconic, + float* dL_dopacity, + float* dL_dcolor, + float* dL_dmean3D, + float* dL_dcov3D, + float* dL_dsh, + float* dL_dscale, + float* dL_drot, + bool debug); }; }; diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index f8782ac4..0d08aa54 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -335,6 +335,149 @@ int CudaRasterizer::Rasterizer::forward( return num_rendered; } + +// Forward rendering procedure for differentiable rasterization +// of Gaussians. +int CudaRasterizer::Rasterizer::forwardspherical( + std::function geometryBuffer, + std::function binningBuffer, + std::function imageBuffer, + const int P, int D, int M, + const float* background, + const int width, int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* opacities, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* cam_pos, + const float tan_fovx, float tan_fovy, + const bool prefiltered, + float* out_color, + int* radii, + bool debug) +{ + const float focal_y = height / (2.0f * tan_fovy) / 2; + const float focal_x = width / (2.0f * tan_fovx) / 4; + + size_t chunk_size = required(P); + char* chunkptr = geometryBuffer(chunk_size); + GeometryState geomState = GeometryState::fromChunk(chunkptr, P); + + if (radii == nullptr) + { + radii = geomState.internal_radii; + } + + dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); + dim3 block(BLOCK_X, BLOCK_Y, 1); + + // Dynamically resize image-based auxiliary buffers during training + size_t img_chunk_size = required(width * height); + char* img_chunkptr = imageBuffer(img_chunk_size); + ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height); + + if (NUM_CHANNELS != 3 && colors_precomp == nullptr) + { + throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!"); + } + + // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB) + CHECK_CUDA(FORWARD::preprocessspherical( + P, D, M, + means3D, + (glm::vec3*)scales, + scale_modifier, + (glm::vec4*)rotations, + opacities, + shs, + geomState.clamped, + cov3D_precomp, + colors_precomp, + viewmatrix, projmatrix, + (glm::vec3*)cam_pos, + width, height, + focal_x, focal_y, + tan_fovx, tan_fovy, + radii, + geomState.means2D, + geomState.depths, + geomState.cov3D, + geomState.rgb, + geomState.conic_opacity, + tile_grid, + geomState.tiles_touched, + prefiltered + ), debug) + + // Compute prefix sum over full list of touched tile counts by Gaussians + // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] + CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug) + + // Retrieve total number of Gaussian instances to launch and resize aux buffers + int num_rendered; + CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug); + + size_t binning_chunk_size = required(num_rendered); + char* binning_chunkptr = binningBuffer(binning_chunk_size); + BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered); + + // For each instance to be rendered, produce adequate [ tile | depth ] key + // and corresponding dublicated Gaussian indices to be sorted + duplicateWithKeys << <(P + 255) / 256, 256 >> > ( + P, + geomState.means2D, + geomState.depths, + geomState.point_offsets, + binningState.point_list_keys_unsorted, + binningState.point_list_unsorted, + radii, + tile_grid) + CHECK_CUDA(, debug) + + int bit = getHigherMsb(tile_grid.x * tile_grid.y); + + // Sort complete list of (duplicated) Gaussian indices by keys + CHECK_CUDA(cub::DeviceRadixSort::SortPairs( + binningState.list_sorting_space, + binningState.sorting_size, + binningState.point_list_keys_unsorted, binningState.point_list_keys, + binningState.point_list_unsorted, binningState.point_list, + num_rendered, 0, 32 + bit), debug) + + CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug); + + // Identify start and end of per-tile workloads in sorted list + if (num_rendered > 0) + identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > ( + num_rendered, + binningState.point_list_keys, + imgState.ranges); + CHECK_CUDA(, debug) + + // Let each tile blend its range of Gaussians independently in parallel + const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb; + CHECK_CUDA(FORWARD::render( + tile_grid, block, + imgState.ranges, + binningState.point_list, + width, height, + geomState.means2D, + feature_ptr, + geomState.conic_opacity, + imgState.accum_alpha, + imgState.n_contrib, + background, + out_color), debug) + + return num_rendered; +} + // Produce necessary gradients for optimization, corresponding // to forward render pass void CudaRasterizer::Rasterizer::backward( @@ -431,4 +574,103 @@ void CudaRasterizer::Rasterizer::backward( dL_dsh, (glm::vec3*)dL_dscale, (glm::vec4*)dL_drot), debug) +} + + +// Produce necessary gradients for optimization, corresponding +// to forward render pass +void CudaRasterizer::Rasterizer::backwardspherical( + const int P, int D, int M, int R, + const float* background, + const int width, int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* campos, + const float tan_fovx, float tan_fovy, + const int* radii, + char* geom_buffer, + char* binning_buffer, + char* img_buffer, + const float* dL_dpix, + float* dL_dmean2D, + float* dL_dconic, + float* dL_dopacity, + float* dL_dcolor, + float* dL_dmean3D, + float* dL_dcov3D, + float* dL_dsh, + float* dL_dscale, + float* dL_drot, + bool debug) +{ + GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); + BinningState binningState = BinningState::fromChunk(binning_buffer, R); + ImageState imgState = ImageState::fromChunk(img_buffer, width * height); + + if (radii == nullptr) + { + radii = geomState.internal_radii; + } + + const float focal_y = height / (2.0f * tan_fovy) / 4; + const float focal_x = width / (2.0f * tan_fovx) / 4; + + const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); + const dim3 block(BLOCK_X, BLOCK_Y, 1); + + // Compute loss gradients w.r.t. 2D mean position, conic matrix, + // opacity and RGB of Gaussians from per-pixel loss gradients. + // If we were given precomputed colors and not SHs, use them. + const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb; + CHECK_CUDA(BACKWARD::render( + tile_grid, + block, + imgState.ranges, + binningState.point_list, + width, height, + background, + geomState.means2D, + geomState.conic_opacity, + color_ptr, + imgState.accum_alpha, + imgState.n_contrib, + dL_dpix, + (float3*)dL_dmean2D, + (float4*)dL_dconic, + dL_dopacity, + dL_dcolor), debug) + + // Take care of the rest of preprocessing. Was the precomputed covariance + // given to us or a scales/rot pair? If precomputed, pass that. If not, + // use the one we computed ourselves. + const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D; + CHECK_CUDA(BACKWARD::preprocessspherical(P, D, M, + (float3*)means3D, + radii, + shs, + geomState.clamped, + (glm::vec3*)scales, + (glm::vec4*)rotations, + scale_modifier, + cov3D_ptr, + viewmatrix, + projmatrix, + focal_x, focal_y, + tan_fovx, tan_fovy, + (glm::vec3*)campos, + (float3*)dL_dmean2D, + dL_dconic, + (glm::vec3*)dL_dmean3D, + dL_dcolor, + dL_dcov3D, + dL_dsh, + (glm::vec3*)dL_dscale, + (glm::vec4*)dL_drot), debug) } \ No newline at end of file diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index bbef37d1..822ecde1 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -83,13 +83,19 @@ def forward( if raster_settings.debug: cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted try: - num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + if raster_settings.spherical: + num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_spherical_gaussians(*args) + else: + num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) except Exception as ex: torch.save(cpu_args, "snapshot_fw.dump") print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") raise ex else: - num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + if raster_settings.spherical: + num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_spherical_gaussians(*args) + else: + num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) # Keep relevant tensors for backward ctx.raster_settings = raster_settings @@ -132,13 +138,19 @@ def backward(ctx, grad_out_color, _): if raster_settings.debug: cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted try: - grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) + if raster_settings.spherical: + grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_spherical_gaussians_backward(*args) + else: + grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) except Exception as ex: torch.save(cpu_args, "snapshot_bw.dump") print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") raise ex else: - grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) + if raster_settings.spherical: + grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_spherical_gaussians_backward(*args) + else: + grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) grads = ( grad_means3D, @@ -166,6 +178,7 @@ class GaussianRasterizationSettings(NamedTuple): sh_degree : int campos : torch.Tensor prefiltered : bool + spherical : bool debug : bool class GaussianRasterizer(nn.Module): @@ -184,6 +197,24 @@ def markVisible(self, positions): return visible + def set_raster_viewproj(self, viewmatrix, projmatrix): + settings = GaussianRasterizationSettings( + image_height=self.raster_settings.image_height, + image_width=self.raster_settings.image_width, + tanfovx=self.raster_settings.tanfovx, + tanfovy=self.raster_settings.tanfovy, + bg=self.raster_settings.bg, + scale_modifier=self.raster_settings.scale_modifier, + viewmatrix=viewmatrix, + projmatrix=projmatrix, + sh_degree=self.raster_settings.sh_degree, + campos=self.raster_settings.campos, + prefiltered=self.raster_settings.prefiltered, + spherical=self.raster_settings.spherical, + debug=self.raster_settings.debug + ) + self.raster_settings = settings + def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): raster_settings = self.raster_settings diff --git a/ext.cpp b/ext.cpp index d7687795..b4e86e1c 100644 --- a/ext.cpp +++ b/ext.cpp @@ -15,5 +15,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); + m.def("rasterize_spherical_gaussians", &RasterizeGaussiansSphericalCUDA); + m.def("rasterize_spherical_gaussians_backward", &RasterizeGaussiansBackwardSphericalCUDA); m.def("mark_visible", &markVisible); } \ No newline at end of file diff --git a/rasterize_points.cu b/rasterize_points.cu index ddc5cf8b..dd53f064 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -114,6 +114,88 @@ RasterizeGaussiansCUDA( return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer); } +std::tuple +RasterizeGaussiansSphericalCUDA( + const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& colors, + const torch::Tensor& opacity, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const int image_height, + const int image_width, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const bool prefiltered, + const bool debug) +{ + if (means3D.ndimension() != 2 || means3D.size(1) != 3) { + AT_ERROR("means3D must have dimensions (num_points, 3)"); + } + + const int P = means3D.size(0); + const int H = image_height; + const int W = image_width; + + auto int_opts = means3D.options().dtype(torch::kInt32); + auto float_opts = means3D.options().dtype(torch::kFloat32); + + torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); + torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); + + torch::Device device(torch::kCUDA); + torch::TensorOptions options(torch::kByte); + torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); + torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); + torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); + std::function geomFunc = resizeFunctional(geomBuffer); + std::function binningFunc = resizeFunctional(binningBuffer); + std::function imgFunc = resizeFunctional(imgBuffer); + + int rendered = 0; + if(P != 0) + { + int M = 0; + if(sh.size(0) != 0) + { + M = sh.size(1); + } + + rendered = CudaRasterizer::Rasterizer::forwardspherical( + geomFunc, + binningFunc, + imgFunc, + P, degree, M, + background.contiguous().data(), + W, H, + means3D.contiguous().data(), + sh.contiguous().data_ptr(), + colors.contiguous().data(), + opacity.contiguous().data(), + scales.contiguous().data_ptr(), + scale_modifier, + rotations.contiguous().data_ptr(), + cov3D_precomp.contiguous().data(), + viewmatrix.contiguous().data(), + projmatrix.contiguous().data(), + campos.contiguous().data(), + tan_fovx, + tan_fovy, + prefiltered, + out_color.contiguous().data(), + radii.contiguous().data(), + debug); + } + return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer); +} + std::tuple RasterizeGaussiansBackwardCUDA( const torch::Tensor& background, @@ -195,6 +277,87 @@ std::tuple + RasterizeGaussiansBackwardSphericalCUDA( + const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& radii, + const torch::Tensor& colors, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const torch::Tensor& dL_dout_color, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const torch::Tensor& geomBuffer, + const int R, + const torch::Tensor& binningBuffer, + const torch::Tensor& imageBuffer, + const bool debug) +{ + const int P = means3D.size(0); + const int H = dL_dout_color.size(1); + const int W = dL_dout_color.size(2); + + int M = 0; + if(sh.size(0) != 0) + { + M = sh.size(1); + } + + torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); + torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); + torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); + torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); + torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); + torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); + + if(P != 0) + { + CudaRasterizer::Rasterizer::backwardspherical(P, degree, M, R, + background.contiguous().data(), + W, H, + means3D.contiguous().data(), + sh.contiguous().data(), + colors.contiguous().data(), + scales.data_ptr(), + scale_modifier, + rotations.data_ptr(), + cov3D_precomp.contiguous().data(), + viewmatrix.contiguous().data(), + projmatrix.contiguous().data(), + campos.contiguous().data(), + tan_fovx, + tan_fovy, + radii.contiguous().data(), + reinterpret_cast(geomBuffer.contiguous().data_ptr()), + reinterpret_cast(binningBuffer.contiguous().data_ptr()), + reinterpret_cast(imageBuffer.contiguous().data_ptr()), + dL_dout_color.contiguous().data(), + dL_dmeans2D.contiguous().data(), + dL_dconic.contiguous().data(), + dL_dopacity.contiguous().data(), + dL_dcolors.contiguous().data(), + dL_dmeans3D.contiguous().data(), + dL_dcov3D.contiguous().data(), + dL_dsh.contiguous().data(), + dL_dscales.contiguous().data(), + dL_drotations.contiguous().data(), + debug); + } + + return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations); +} + torch::Tensor markVisible( torch::Tensor& means3D, torch::Tensor& viewmatrix, diff --git a/rasterize_points.h b/rasterize_points.h index 9023d994..930bdc46 100644 --- a/rasterize_points.h +++ b/rasterize_points.h @@ -61,6 +61,53 @@ std::tuple +RasterizeGaussiansSphericalCUDA( + const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& colors, + const torch::Tensor& opacity, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const int image_height, + const int image_width, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const bool prefiltered, + const bool debug); + +std::tuple + RasterizeGaussiansBackwardSphericalCUDA( + const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& radii, + const torch::Tensor& colors, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const torch::Tensor& dL_dout_color, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const torch::Tensor& geomBuffer, + const int R, + const torch::Tensor& binningBuffer, + const torch::Tensor& imageBuffer, + const bool debug); + + torch::Tensor markVisible( torch::Tensor& means3D, torch::Tensor& viewmatrix,