From 52ba85af9341215ace6dd3880c726630fa60e6b7 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Mon, 2 Sep 2024 18:03:18 +1200 Subject: [PATCH] Bugs fixed for multi-GPU support JaggedTensor slicing and indexing moved to CUDA kernels Small updates to basic_concepts docs and README Signed-off-by: Jonathan Swartz --- fvdb/README.md | 4 +- fvdb/ci/Dockerfile.runner | 16 +- fvdb/docs/tutorials/basic_concepts.md | 32 +- fvdb/fvdb/nn/modules.py | 8 +- fvdb/setup.py | 2 + fvdb/src/GridBatch.cpp | 111 ++++- fvdb/src/GridBatch.h | 30 +- fvdb/src/JaggedTensor.cpp | 139 +----- fvdb/src/detail/GridBatchImpl.cu | 8 +- fvdb/src/detail/TorchDeviceBuffer.cpp | 12 +- fvdb/src/detail/io/SaveNanoVDB.cpp | 2 + fvdb/src/detail/ops/JCat0.cu | 2 + fvdb/src/detail/ops/JIdxForJOffsets.cu | 6 +- fvdb/src/detail/ops/JOffsetsFromJIdx.cu | 161 ++++--- fvdb/src/detail/ops/JaggedTensorIndex.cu | 588 ++++++++++++++++++++--- fvdb/src/detail/ops/Ops.h | 9 +- fvdb/src/detail/ops/VolumeRender.cu | 10 +- fvdb/src/detail/utils/Utils.h | 26 + fvdb/src/python/GridBatchBinding.cpp | 3 + fvdb/tests/unit/test_jagged_tensor.py | 271 +++++++++-- 20 files changed, 1062 insertions(+), 378 deletions(-) diff --git a/fvdb/README.md b/fvdb/README.md index 36d50052d8..241953d6e9 100644 --- a/fvdb/README.md +++ b/fvdb/README.md @@ -19,7 +19,7 @@ Lastly, our [documentation](docs) provides deeper details on the concepts as wel ## Installing *f*VDB -fVDB is provided as an installable python package from *[todo: insert package distributor]*. We provide pre-built packages of the latest *f*VDB version for the following dependent library configurations: +fVDB is provided as an installable python package from conda. We provide pre-built packages of the latest *f*VDB version for the following dependent library configurations: | PyTorch | Python | CUDA | | -------------- | ---------- | ------- | @@ -34,7 +34,7 @@ fVDB is provided as an installable python package from *[todo: insert package di Use the following command to install `fvdb` into your environment. ```bash -conda install -c jswartz fvdb +conda install [TBD] ``` If you intend to use our learning material such as the [notebooks](notebooks) or [examples](examples), we recommend you start from the `fvdb_learn` conda environment which contains all the dependencies needed to run the learning material as well as build *f*VDB from source. To create this environment, run the following commands from the root of this repository: diff --git a/fvdb/ci/Dockerfile.runner b/fvdb/ci/Dockerfile.runner index 76473e6be5..ae67a12c3a 100644 --- a/fvdb/ci/Dockerfile.runner +++ b/fvdb/ci/Dockerfile.runner @@ -3,12 +3,9 @@ ARG CUDNN_VERSION=8 FROM nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu20.04 -ENV PATH /usr/local/cuda/bin:$PATH -ENV LD_LIBRARY_PATH /usr/lib64:/usr/local/cuda/lib64:/usr/local/cuda/lib:${LD_LIBRARY_PATH} - # # nvidia-container-runtime -ENV NVIDIA_VISIBLE_DEVICES all -ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,graphics +ENV NVIDIA_VISIBLE_DEVICES=all +ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility,graphics RUN echo "Acquire { https::Verify-Peer false }" > /etc/apt/apt.conf.d/99verify-peer.conf \ && if [ -f /etc/apt/sources.list.d/cuda.list ]; then \ @@ -26,8 +23,6 @@ RUN echo "Acquire { https::Verify-Peer false }" > /etc/apt/apt.conf.d/99verify-p git \ unzip \ gfortran \ - libopenblas-dev \ - liblapack-dev \ ssh \ rsync \ iputils-ping \ @@ -37,15 +32,14 @@ RUN echo "Acquire { https::Verify-Peer false }" > /etc/apt/apt.conf.d/99verify-p WORKDIR /tmp RUN mkdir actions-runner && \ cd actions-runner && \ - curl -o actions-runner-linux-x64-2.316.0.tar.gz -L https://github.com/actions/runner/releases/download/v2.316.0/actions-runner-linux-x64-2.316.0.tar.gz && \ - tar xzf ./actions-runner-linux-x64-2.316.0.tar.gz && \ + curl -o actions-runner-linux-x64-2.319.1.tar.gz -L https://github.com/actions/runner/releases/download/v2.319.1/actions-runner-linux-x64-2.319.1.tar.gz && \ + tar xzf ./actions-runner-linux-x64-2.319.1.tar.gz && \ DEBIAN_FRONTEND=noninteractive ./bin/installdependencies.sh && \ - rm actions-runner-linux-x64-2.316.0.tar.gz + rm actions-runner-linux-x64-2.319.1.tar.gz # used for cross-compilation in docker build ENV FORCE_CUDA=1 ENV RUNNER_ALLOW_RUNASROOT=1 -ENV TORCH_CUDA_ARCH_LIST "6.1;7.0;7.5;8.0;8.6+PTX" # Install AWS CLI RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" && \ diff --git a/fvdb/docs/tutorials/basic_concepts.md b/fvdb/docs/tutorials/basic_concepts.md index 7fe7d3ae6b..f8523c5b22 100644 --- a/fvdb/docs/tutorials/basic_concepts.md +++ b/fvdb/docs/tutorials/basic_concepts.md @@ -22,11 +22,11 @@ Every operation in fVDB is built upon this kind of query (e.g. Sparse Convolutio Each grid in a `GridBatch` can have a different number of voxels (****e.g.**** in the mini batch of four cars above, each car has a different number of voxels). This means that unlike the dense case, fVDB needs to handle parallel operations over ***jagged batches***. I.e. batches containing different numbers of elements. -To handle jagged batches, fVDB provides a `JaggedTensor` class. Conceptually, a `JaggedTensor` is a list of tensors with shapes $[N_0, *], [N_1, *], \ldots, [N_B, *]$ where $B$ is the number of elements in the batch, $N_i$ is the number of elements in the $i^\text{th}$ batch item and $*$ is an arbitrary numer of additional dimensions that all match between the tensors. The figure below illustrates such a list of tensors pictorially. +To handle jagged batches, fVDB provides a `JaggedTensor` class. Conceptually, a `JaggedTensor` is a list of tensors with shapes $[N_0, *], [N_1, *], \ldots, [N_{B-1}, *]$ where $B$ is the number of elements in the batch, $N_i$ is the number of elements in the $i^\text{th}$ batch item and $*$ is an arbitrary numer of additional dimensions that all match between the tensors. The figure below illustrates such a list of tensors pictorially. ![jaggedtensor1.png](../imgs/fig/jaggedtensor1.png) -In practice, `JaggedTensor`s are represented in memory by concatenating each tensor in the list into a single `jdata` (for Jagged Data) tensor of shape $[N_0 + N_1 + \ldots + N_B, *]$. Additionally, each `JaggedTensor` stores an additional `jidx` tensor (for Jagged Indexes) of shape $[N_0 + N_1 + \ldots + N_B]$ containing one int per element in the jagged tensor. `jidx[i]` is the batch index of the $i^\text{th}$ element of `jdata`. Finally, a `JaggedTensor` contains a `joffsets` tensor (for Jagged Offsets) of shape $[B, 2]$ which indicates the start and end positions of the $i^\text{th}$ tensor in the batch. +In practice, `JaggedTensor`s are represented in memory by concatenating each tensor in the list into a single `jdata` (for Jagged Data) tensor of shape $[N_0 + N_1 + \ldots + N_{B-1}, *]$. Additionally, each `JaggedTensor` stores an additional `jidx` tensor (for Jagged Indexes) of shape $[N_0 + N_1 + \ldots + N_{B-1}]$ containing one int per element in the jagged tensor. `jidx[i]` is the batch index of the $i^\text{th}$ element of `jdata`. Finally, a `JaggedTensor` contains a `joffsets` tensor (for Jagged Offsets) of shape $[B, 2]$ which indicates the start and end positions of the $i^\text{th}$ tensor in the batch. ![jaggedtensor4.png](../imgs/fig/jaggedtensor4.png) @@ -36,6 +36,8 @@ Similarly, each `GridBatch` also has `jidx` and `joffsets` corresponding to the To illustrate the use of `GridBatch`and `JaggedTensor`, consider a simple example where we build a grid from a point cloud, splat some values onto the voxels of that grid, and then sample them again using a different set of points. +First, we construct a minibatch of grids using the input points. These input points have corresponding color attributes. + ```python import fvdb import torch @@ -48,7 +50,7 @@ pts2, clrs2 = pcu.load_mesh_vn("points2.ply") pts1, clrs1 = torch.from_numpy(pts1).cuda(), torch.from_numpy(clrs1).cuda() pts2, clrs2 = torch.from_numpy(pts2).cuda(), torch.from_numpy(clrs2).cuda() -# JaggedTensors of points and normals +# Creating JaggedTensors: one for points and one for colors points = fvdb.JaggedTensor([pts1, pts2]) colors = fvdb.JaggedTensor([clrs1, clrs2]) @@ -60,29 +62,27 @@ print(points[0].jdata.shape) print(points[1].jdata.shape) ``` -![We construct a minibatch of grids using the input points. These input points have corresponding color attributes](../imgs/fig/screenshot_000000.png.trim.png) +![Minibatch of grids constructed from the input points. These input points have corresponding color attributes.](../imgs/fig/screenshot_000000.png.trim.png) -We construct a minibatch of grids using the input points. These input points have corresponding color attributes +Next, we splat the colors at the points to the constructed grid, yielding per-voxel colors. ```python -# Splat the normals into the grid with trilinear interpolation -# vox_normals is a JaggedTensor of per-voxel normas +# Splat the colors into the grid with trilinear interpolation +# vox_colors is a JaggedTensor of per-voxel normas vox_colors = grid.splat_trilinear(points, colors) ``` -![We then splat the colors at the points to the constructed grid, yielding per-voxel colors.](../imgs/fig/screenshot_000006.png.trim.png) +![Colors splat at the input points to grid, yielding per-voxel colors.](../imgs/fig/screenshot_000006.png.trim.png) -We then splat the colors at the points to the constructed grid, yielding per-voxel colors. +Finally, we generate a new set of noisy points and sample the grid to recover colors at those new samples. ```python # Now let's generate some random points and sample the grid at those points -samples = fvdb.JaggedTensor([torch.rand(10_000, 3), torch.rand(11_000, 3)]).cuda() +sample_points = fvdb.JaggedTensor([torch.rand(10_000, 3), torch.rand(11_000, 3)]).cuda() -# sampled_normals is a JaggedTensor with the same shape as samples with -# one normal sampled from the grid at each point in samples -sampled_normals = grid.sample_trilinear(samples) +# sampled_colors is a JaggedTensor with the same shape as sample_points with +# one color sampled from the grid at each point +sampled_colors = grid.sample_trilinear(sample_points, vox_colors) ``` -![We now generate a new set of noisy points and sample the grid colors to recover colors at those new samples.](../imgs/fig/screenshot_000004.png.trim.png) - -We now generate a new set of noisy points and sample the grid colors to recover colors at those new samples. +![Colors resampled at random locations from the grid.](../imgs/fig/screenshot_000004.png.trim.png) diff --git a/fvdb/fvdb/nn/modules.py b/fvdb/fvdb/nn/modules.py index 7917bde6ac..26b448595e 100644 --- a/fvdb/fvdb/nn/modules.py +++ b/fvdb/fvdb/nn/modules.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: MPL-2.0 # import math -from typing import Optional, Union, List, Sequence +from typing import List, Optional, Sequence, Union import torch import torch.nn as nn @@ -10,6 +10,7 @@ import fvdb from fvdb import GridBatch, JaggedTensor + from .vdbtensor import VDBTensor @@ -267,6 +268,11 @@ def _dispatch_conv(self, in_feature, in_grid, in_kmap, out_grid): backend = self.backend + if self.allow_tf32 and self.weight.is_cuda: + assert ( + torch.cuda.get_device_capability()[0] >= 8 + ), "TF32 requires GPU with compute capability >= 8.0. Please set fvdb.nn.SparseConv3d.allow_tf32 = False." + if backend == "cutlass" and ( (not self.weight.is_cuda) or (self.in_channels, self.out_channels) not in self.CUTLASS_SUPPORTED_CHANNELS ): diff --git a/fvdb/setup.py b/fvdb/setup.py index 6de3f20d4b..906745111a 100644 --- a/fvdb/setup.py +++ b/fvdb/setup.py @@ -265,6 +265,8 @@ def download_and_install_cudnn(): "--extended-lambda", "--diag-suppress=186", "-diag-suppress=3189", + "-Xfatbin", + "-compress-all", ] user_nvcc_flags = os.getenv("NVCC_FLAGS", "").split() nvcc_flags += user_nvcc_flags diff --git a/fvdb/src/GridBatch.cpp b/fvdb/src/GridBatch.cpp index 2a2e0a90fd..d9e9ea517f 100644 --- a/fvdb/src/GridBatch.cpp +++ b/fvdb/src/GridBatch.cpp @@ -13,6 +13,7 @@ namespace fvdb { GridBatch::GridBatch(TorchDeviceOrString device, bool isMutable) { + detail::RAIIDeviceGuard guard(device.value()); mImpl = c10::make_intrusive(device.value(), isMutable); } @@ -25,6 +26,7 @@ GridBatch::GridBatch() { std::pair GridBatch::max_pool(Vec3iOrScalar pool_factor, const JaggedTensor &data, Vec3iOrScalar stride, torch::optional coarse_grid) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( data.ldim() == 1, "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -56,6 +58,7 @@ GridBatch::max_pool(Vec3iOrScalar pool_factor, const JaggedTensor &data, Vec3iOr std::pair GridBatch::avg_pool(Vec3iOrScalar pool_factor, const JaggedTensor &data, Vec3iOrScalar stride, torch::optional coarse_grid) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( data.ldim() == 1, "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -88,6 +91,7 @@ std::pair GridBatch::subdivide(Vec3iOrScalar subdiv_factor, const JaggedTensor &data, const torch::optional mask, torch::optional fine_grid) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( data.ldim() == 1, "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -115,7 +119,8 @@ GridBatch::subdivide(Vec3iOrScalar subdiv_factor, const JaggedTensor &data, JaggedTensor GridBatch::read_from_dense(const torch::Tensor &dense_data, const Vec3iBatch &dense_origins) const { - torch::Tensor retData = + detail::RAIIDeviceGuard guard(device()); + torch::Tensor retData = detail::autograd::ReadFromDense::apply(impl(), dense_data, dense_origins)[0]; return impl()->jaggedTensor(retData, false); } @@ -124,6 +129,7 @@ torch::Tensor GridBatch::read_into_dense(const JaggedTensor &sparse_data, const torch::optional &min_coord, const torch::optional &grid_size) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( sparse_data.ldim() == 1, "Expected sparse_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -135,6 +141,7 @@ GridBatch::read_into_dense(const JaggedTensor &sparse_data, JaggedTensor GridBatch::fill_to_grid(const JaggedTensor &features, const GridBatch &other_grid, float default_value) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( features.ldim() == 1, "Expected features to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -147,6 +154,7 @@ GridBatch::fill_to_grid(const JaggedTensor &features, const GridBatch &other_gri JaggedTensor GridBatch::grid_to_world(const JaggedTensor &ijk) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ijk.ldim() == 1, "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -159,6 +167,7 @@ GridBatch::grid_to_world(const JaggedTensor &ijk) const { JaggedTensor GridBatch::world_to_grid(const JaggedTensor &xyz) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( xyz.ldim() == 1, "Expected xyz to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -171,6 +180,7 @@ GridBatch::world_to_grid(const JaggedTensor &xyz) const { torch::Tensor GridBatch::grid_to_world_matrices(const torch::Dtype &dtype) const { + detail::RAIIDeviceGuard guard(device()); std::vector retTorch; for (int64_t bi = 0; bi < grid_count(); ++bi) { retTorch.emplace_back(impl()->gridToWorldMatrix(bi)); @@ -181,6 +191,7 @@ GridBatch::grid_to_world_matrices(const torch::Dtype &dtype) const { torch::Tensor GridBatch::world_to_grid_matrices(const torch::Dtype &dtype) const { + detail::RAIIDeviceGuard guard(device()); std::vector retTorch; for (int64_t bi = 0; bi < grid_count(); ++bi) { retTorch.emplace_back(impl()->worldToGridMatrix(bi)); @@ -191,6 +202,7 @@ GridBatch::world_to_grid_matrices(const torch::Dtype &dtype) const { JaggedTensor GridBatch::sample_trilinear(const JaggedTensor &points, const JaggedTensor &voxel_data) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( points.ldim() == 1, "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -207,6 +219,7 @@ GridBatch::sample_trilinear(const JaggedTensor &points, const JaggedTensor &voxe std::vector GridBatch::sample_trilinear_with_grad(const JaggedTensor &points, const JaggedTensor &voxel_data) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( points.ldim() == 1, "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -223,6 +236,7 @@ GridBatch::sample_trilinear_with_grad(const JaggedTensor &points, JaggedTensor GridBatch::sample_bezier(const JaggedTensor &points, const JaggedTensor &voxel_data) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( points.ldim() == 1, "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -239,6 +253,7 @@ GridBatch::sample_bezier(const JaggedTensor &points, const JaggedTensor &voxel_d std::vector GridBatch::sample_bezier_with_grad(const JaggedTensor &points, const JaggedTensor &voxel_data) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( points.ldim() == 1, "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -254,6 +269,7 @@ GridBatch::sample_bezier_with_grad(const JaggedTensor &points, JaggedTensor GridBatch::splat_trilinear(const JaggedTensor &points, const JaggedTensor &points_data) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( points.ldim() == 1, "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -273,6 +289,7 @@ GridBatch::splat_trilinear(const JaggedTensor &points, const JaggedTensor &point JaggedTensor GridBatch::splat_bezier(const JaggedTensor &points, const JaggedTensor &points_data) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( points.ldim() == 1, "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -292,7 +309,8 @@ GridBatch::splat_bezier(const JaggedTensor &points, const JaggedTensor &points_d torch::Tensor GridBatch::voxel_size_at(int64_t bi, const torch::Dtype &dtype) const { - torch::Tensor retTorch = + detail::RAIIDeviceGuard guard(device()); + torch::Tensor retTorch = torch::empty({ 3 }, torch::TensorOptions().device(this->device()).dtype(dtype)); const nanovdb::Vec3d &voxSize = impl()->voxelSize(bi); retTorch[0] = voxSize[0]; @@ -303,7 +321,8 @@ GridBatch::voxel_size_at(int64_t bi, const torch::Dtype &dtype) const { torch::Tensor GridBatch::voxel_sizes(const torch::Dtype &dtype) const { - torch::Tensor retTorch = torch::empty( + detail::RAIIDeviceGuard guard(device()); + torch::Tensor retTorch = torch::empty( { grid_count(), 3 }, torch::TensorOptions().device(this->device()).dtype(dtype)); for (int64_t bi = 0; bi < grid_count(); bi += 1) { const nanovdb::Vec3d voxSize = impl()->voxelSize(bi); @@ -316,8 +335,9 @@ GridBatch::voxel_sizes(const torch::Dtype &dtype) const { torch::Tensor GridBatch::origin_at(int64_t bi, const torch::Dtype &dtype) const { - const nanovdb::Vec3d &voxelOrigin = impl()->voxelOrigin(bi); - torch::Tensor retTorch = + detail::RAIIDeviceGuard guard(device()); + const nanovdb::Vec3d &voxelOrigin = impl()->voxelOrigin(bi); + torch::Tensor retTorch = torch::empty({ 3 }, torch::TensorOptions().device(this->device()).dtype(dtype)); retTorch[0] = voxelOrigin[0]; retTorch[1] = voxelOrigin[1]; @@ -327,7 +347,8 @@ GridBatch::origin_at(int64_t bi, const torch::Dtype &dtype) const { torch::Tensor GridBatch::origins(const torch::Dtype &dtype) const { - torch::Tensor retTorch = torch::empty( + detail::RAIIDeviceGuard guard(device()); + torch::Tensor retTorch = torch::empty( { grid_count(), 3 }, torch::TensorOptions().device(this->device()).dtype(dtype)); for (int64_t bi = 0; bi < grid_count(); bi += 1) { const nanovdb::Vec3d &voxOrigin = impl()->voxelOrigin(bi); @@ -340,7 +361,8 @@ GridBatch::origins(const torch::Dtype &dtype) const { torch::Tensor GridBatch::num_voxels() const { - torch::Tensor retTorch = torch::empty( + detail::RAIIDeviceGuard guard(device()); + torch::Tensor retTorch = torch::empty( { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); @@ -352,6 +374,7 @@ GridBatch::num_voxels() const { torch::Tensor GridBatch::num_enabled_voxels() const { + detail::RAIIDeviceGuard guard(device()); if (!is_mutable()) { return num_voxels(); } @@ -367,6 +390,7 @@ GridBatch::num_enabled_voxels() const { int64_t GridBatch::num_enabled_voxels_at(int64_t bi) const { + detail::RAIIDeviceGuard guard(device()); if (!is_mutable()) { return num_voxels_at(bi); } @@ -377,7 +401,8 @@ GridBatch::num_enabled_voxels_at(int64_t bi) const { torch::Tensor GridBatch::cum_voxels() const { - torch::Tensor retTorch = torch::empty( + detail::RAIIDeviceGuard guard(device()); + torch::Tensor retTorch = torch::empty( { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); @@ -389,6 +414,7 @@ GridBatch::cum_voxels() const { torch::Tensor GridBatch::cum_enabled_voxels() const { + detail::RAIIDeviceGuard guard(device()); if (!is_mutable()) { return cum_voxels(); } @@ -404,7 +430,8 @@ GridBatch::cum_enabled_voxels() const { int64_t GridBatch::cum_enabled_voxels_at(int64_t bi) const { - int64_t nCum = 0; + detail::RAIIDeviceGuard guard(device()); + int64_t nCum = 0; for (int64_t b = 0; b < bi; ++b) { nCum += num_enabled_voxels_at(b); } @@ -413,7 +440,8 @@ GridBatch::cum_enabled_voxels_at(int64_t bi) const { torch::Tensor GridBatch::num_bytes() const { - torch::Tensor retTorch = torch::empty( + detail::RAIIDeviceGuard guard(device()); + torch::Tensor retTorch = torch::empty( { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); @@ -425,7 +453,8 @@ GridBatch::num_bytes() const { torch::Tensor GridBatch::num_leaf_nodes() const { - torch::Tensor retTorch = torch::empty( + detail::RAIIDeviceGuard guard(device()); + torch::Tensor retTorch = torch::empty( { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); @@ -437,6 +466,7 @@ GridBatch::num_leaf_nodes() const { void GridBatch::disable_ijk(const JaggedTensor &ijk) { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ijk.ldim() == 1, "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -448,6 +478,7 @@ GridBatch::disable_ijk(const JaggedTensor &ijk) { void GridBatch::enable_ijk(const JaggedTensor &ijk) { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ijk.ldim() == 1, "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -460,6 +491,7 @@ GridBatch::enable_ijk(const JaggedTensor &ijk) { void GridBatch::set_from_mesh(const JaggedTensor &mesh_vertices, const JaggedTensor &mesh_faces, const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins) { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( mesh_vertices.ldim() == 1, "Expected mesh_vertices to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -518,6 +550,7 @@ GridBatch::set_from_mesh(const JaggedTensor &mesh_vertices, const JaggedTensor & void GridBatch::set_from_points(const JaggedTensor &points, const Vec3i &pad_min, const Vec3i &pad_max, const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins) { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( points.ldim() == 1, "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -567,6 +600,7 @@ void GridBatch::set_from_nearest_voxels_to_points(const JaggedTensor &points, const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins) { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( points.ldim() == 1, "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -610,6 +644,7 @@ GridBatch::set_from_nearest_voxels_to_points(const JaggedTensor &points, void GridBatch::set_from_ijk(const JaggedTensor &coords, const Vec3i &pad_min, const Vec3i &pad_max, const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins) { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( coords.ldim() == 1, "Expected coords to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -651,6 +686,7 @@ void GridBatch::set_from_dense_grid(const int64_t num_grids, const Vec3i &dense_dims, const Vec3i &ijk_min, const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins, torch::optional mask) { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE(num_grids >= 0, "num_grids must be non-negative"); const nanovdb::Coord &size = dense_dims.value(); @@ -689,7 +725,8 @@ GridBatch::set_from_dense_grid(const int64_t num_grids, const Vec3i &dense_dims, GridBatch GridBatch::dual_grid(bool exclude_border) const { - GridBatch ret = GridBatch(device(), is_mutable()); + detail::RAIIDeviceGuard guard(device()); + GridBatch ret = GridBatch(device(), is_mutable()); if (grid_count() == 0) { return ret; } @@ -699,7 +736,8 @@ GridBatch::dual_grid(bool exclude_border) const { GridBatch GridBatch::coarsened_grid(Vec3iOrScalar branch_factor) const { - nanovdb::Coord branchFactorCoord = branch_factor.value(); + detail::RAIIDeviceGuard guard(device()); + nanovdb::Coord branchFactorCoord = branch_factor.value(); for (int i = 0; i < 3; i += 1) { TORCH_CHECK_VALUE(branchFactorCoord[i] > 0, "branch_factor must be strictly positive. Got [" + @@ -718,6 +756,8 @@ GridBatch::coarsened_grid(Vec3iOrScalar branch_factor) const { GridBatch GridBatch::subdivided_grid(Vec3iOrScalar subdiv_factor, const torch::optional mask) const { + detail::RAIIDeviceGuard guard(device()); + if (mask.has_value()) { TORCH_CHECK_VALUE( mask.value().ldim() == 1, @@ -743,9 +783,10 @@ GridBatch::subdivided_grid(Vec3iOrScalar subdiv_factor, GridBatch GridBatch::clipped_grid(const Vec3iBatch &ijk_min, const Vec3iBatch &ijk_max) const { - JaggedTensor activeVoxelMask = FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { + detail::RAIIDeviceGuard guard(device()); + JaggedTensor activeVoxelMask = FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchActiveVoxelsInBoundsMask(*impl(), ijk_min, - ijk_max, false); + ijk_max, false); }); JaggedTensor activeVoxelCoords = FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { @@ -765,6 +806,7 @@ GridBatch::clipped_grid(const Vec3iBatch &ijk_min, const Vec3iBatch &ijk_max) co std::pair GridBatch::clip(const JaggedTensor &features, const Vec3iBatch &ijk_min, const Vec3iBatch &ijk_max) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( features.ldim() == 1, "Expected features to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -800,6 +842,7 @@ GridBatch::clip(const JaggedTensor &features, const Vec3iBatch &ijk_min, std::vector GridBatch::marching_cubes(const JaggedTensor &field, double level) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( field.ldim() == 1, "Expected field to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -826,6 +869,7 @@ GridBatch::marching_cubes(const JaggedTensor &field, double level) const { JaggedTensor GridBatch::sparse_conv_halo(const JaggedTensor &input, const torch::Tensor &weight, int variant) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( input.ldim() == 1, "Expected input to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -841,6 +885,7 @@ GridBatch::sparse_conv_halo(const JaggedTensor &input, const torch::Tensor &weig GridBatch GridBatch::conv_grid(Vec3iOrScalar kernel_size, Vec3iOrScalar stride) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE(Vec3iOrScalar(0).value() < kernel_size.value(), "kernel_size must be strictly positive. Got " + kernel_size.toString()); TORCH_CHECK_VALUE(Vec3iOrScalar(0).value() < stride.value(), @@ -862,6 +907,7 @@ GridBatch::conv_grid(Vec3iOrScalar kernel_size, Vec3iOrScalar stride) const { void GridBatch::buildCoarseFromFineGrid(const GridBatch &fineGrid, nanovdb::Coord branchFactor) { + detail::RAIIDeviceGuard guard(device()); std::vector voxS, voxO; fineGrid.impl()->gridVoxelSizesAndOrigins(voxS, voxO); mImpl = c10::make_intrusive( @@ -874,6 +920,7 @@ void GridBatch::buildFineFromCoarseGrid(const GridBatch &coarseGrid, const torch::optional &subdivMask, nanovdb::Coord subdivFactor) { + detail::RAIIDeviceGuard guard(device()); if (subdivMask.has_value()) { TORCH_CHECK_VALUE( subdivMask.value().ldim() == 1, @@ -899,6 +946,7 @@ GridBatch::buildFineFromCoarseGrid(const GridBatch &coarseGr void GridBatch::buildDualFromPrimalGrid(const GridBatch &primalGrid, bool excludeBorder) { + detail::RAIIDeviceGuard guard(device()); std::vector voxS, voxO; primalGrid.impl()->gridVoxelSizesAndOrigins(voxS, voxO); mImpl = c10::make_intrusive( @@ -911,6 +959,7 @@ GridBatch::buildDualFromPrimalGrid(const GridBatch &primalGrid, bool excludeBord std::vector GridBatch::voxels_along_rays(const JaggedTensor &ray_origins, const JaggedTensor &ray_directions, int64_t max_vox, double eps, bool return_ijk, bool cumulative) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ray_origins.ldim() == 1, "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -928,6 +977,7 @@ GridBatch::voxels_along_rays(const JaggedTensor &ray_origins, const JaggedTensor JaggedTensor GridBatch::segments_along_rays(const JaggedTensor &ray_origins, const JaggedTensor &ray_directions, int64_t max_segments, double eps, bool ignore_masked) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ray_origins.ldim() == 1, "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -946,6 +996,7 @@ JaggedTensor GridBatch::ray_implicit_intersection(const JaggedTensor &ray_origins, const JaggedTensor &ray_directions, const JaggedTensor &gridScalars, double eps) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ray_origins.ldim() == 1, "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -969,6 +1020,7 @@ GridBatch::uniform_ray_samples(const JaggedTensor &ray_origins, const JaggedTens const JaggedTensor &t_min, const JaggedTensor &t_max, double step_size, double cone_angle, bool include_end_segments, bool return_midpoint, double eps) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ray_origins.ldim() == 1, "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -994,6 +1046,7 @@ GridBatch::uniform_ray_samples(const JaggedTensor &ray_origins, const JaggedTens JaggedTensor GridBatch::neighbor_indexes(const JaggedTensor &ijk, int32_t extent, int32_t bitshift) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ijk.ldim() == 1, "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -1009,6 +1062,7 @@ GridBatch::neighbor_indexes(const JaggedTensor &ijk, int32_t extent, int32_t bit JaggedTensor GridBatch::points_in_active_voxel(const JaggedTensor &xyz, bool ignore_disabled) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( xyz.ldim() == 1, "Expected xyz to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -1021,6 +1075,7 @@ GridBatch::points_in_active_voxel(const JaggedTensor &xyz, bool ignore_disabled) JaggedTensor GridBatch::cubes_intersect_grid(const JaggedTensor &cube_centers, const Vec3dOrScalar &cube_min, const Vec3dOrScalar &cube_max, bool ignore_disabled) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( cube_centers.ldim() == 1, "Expected cube_centers to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -1034,6 +1089,7 @@ GridBatch::cubes_intersect_grid(const JaggedTensor &cube_centers, const Vec3dOrS JaggedTensor GridBatch::cubes_in_grid(const JaggedTensor &cube_centers, const Vec3dOrScalar &cube_min, const Vec3dOrScalar &cube_max, bool ignore_disabled) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( cube_centers.ldim() == 1, "Expected cube_centers to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -1046,6 +1102,7 @@ GridBatch::cubes_in_grid(const JaggedTensor &cube_centers, const Vec3dOrScalar & JaggedTensor GridBatch::enabled_mask() const { + detail::RAIIDeviceGuard guard(device()); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchEnabledMask(*impl(), false); }); @@ -1053,6 +1110,7 @@ GridBatch::enabled_mask() const { JaggedTensor GridBatch::disabled_mask() const { + detail::RAIIDeviceGuard guard(device()); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchEnabledMask(*impl(), true); }); @@ -1060,6 +1118,7 @@ GridBatch::disabled_mask() const { JaggedTensor GridBatch::coords_in_active_voxel(const JaggedTensor &ijk, bool ignore_disabled) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ijk.ldim() == 1, "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -1071,6 +1130,7 @@ GridBatch::coords_in_active_voxel(const JaggedTensor &ijk, bool ignore_disabled) JaggedTensor GridBatch::ijk_to_index(const JaggedTensor &ijk, bool cumulative) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ijk.ldim() == 1, "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -1082,6 +1142,7 @@ GridBatch::ijk_to_index(const JaggedTensor &ijk, bool cumulative) const { JaggedTensor GridBatch::ijk_to_inv_index(const JaggedTensor &ijk, bool cumulative) const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK_VALUE( ijk.ldim() == 1, "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", @@ -1093,6 +1154,7 @@ GridBatch::ijk_to_inv_index(const JaggedTensor &ijk, bool cumulative) const { JaggedTensor GridBatch::ijk() const { + detail::RAIIDeviceGuard guard(device()); return FVDB_DISPATCH_KERNEL_DEVICE(this->device(), [&]() { return fvdb::detail::ops::dispatchActiveGridCoords(*impl(), true); }); @@ -1100,6 +1162,7 @@ GridBatch::ijk() const { JaggedTensor GridBatch::ijk_enabled() const { + detail::RAIIDeviceGuard guard(device()); return FVDB_DISPATCH_KERNEL_DEVICE(this->device(), [&]() { return fvdb::detail::ops::dispatchActiveGridCoords(*impl(), false); }); @@ -1107,8 +1170,9 @@ GridBatch::ijk_enabled() const { const torch::Tensor GridBatch::bbox() const { - const int64_t bs = grid_count(); - torch::Tensor ret = + detail::RAIIDeviceGuard guard(device()); + const int64_t bs = grid_count(); + torch::Tensor ret = torch::zeros({ bs, 2, 3 }, torch::TensorOptions().device(device()).dtype(torch::kInt32)); for (int64_t i = 0; i < bs; ++i) { const nanovdb::CoordBBox &bbox = impl()->bbox(i); @@ -1124,7 +1188,8 @@ GridBatch::bbox() const { const torch::Tensor GridBatch::bbox_at(int64_t bi) const { - torch::Tensor ret = + detail::RAIIDeviceGuard guard(device()); + torch::Tensor ret = torch::zeros({ 2, 3 }, torch::TensorOptions().device(device()).dtype(torch::kInt32)); const nanovdb::CoordBBox &bbox = impl()->bbox(bi); ret[0][0] = bbox.min()[0]; @@ -1138,8 +1203,9 @@ GridBatch::bbox_at(int64_t bi) const { const torch::Tensor GridBatch::dual_bbox() const { - const int64_t bs = grid_count(); - torch::Tensor ret = + detail::RAIIDeviceGuard guard(device()); + const int64_t bs = grid_count(); + torch::Tensor ret = torch::zeros({ bs, 2, 3 }, torch::TensorOptions().device(device()).dtype(torch::kInt32)); for (int64_t i = 0; i < bs; ++i) { const nanovdb::CoordBBox &bbox = impl()->dualBbox(i); @@ -1155,7 +1221,8 @@ GridBatch::dual_bbox() const { const torch::Tensor GridBatch::dual_bbox_at(int64_t bi) const { - torch::Tensor ret = + detail::RAIIDeviceGuard guard(device()); + torch::Tensor ret = torch::zeros({ 2, 3 }, torch::TensorOptions().device(device()).dtype(torch::kInt32)); const nanovdb::CoordBBox &bbox = impl()->dualBbox(bi); ret[0][0] = bbox.min()[0]; @@ -1169,6 +1236,7 @@ GridBatch::dual_bbox_at(int64_t bi) const { const torch::Tensor GridBatch::total_bbox() const { + detail::RAIIDeviceGuard guard(device()); const nanovdb::CoordBBox &bbox = impl()->totalBBox(); return torch::tensor({ { bbox.min()[0], bbox.min()[1], bbox.min()[2] }, { bbox.max()[0], bbox.max()[1], bbox.max()[2] } }, @@ -1177,6 +1245,7 @@ GridBatch::total_bbox() const { std::vector GridBatch::viz_edge_network(bool returnVoxelCoordinates) const { + detail::RAIIDeviceGuard guard(device()); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchGridEdgeNetwork(*impl(), returnVoxelCoordinates); diff --git a/fvdb/src/GridBatch.h b/fvdb/src/GridBatch.h index 796f165192..e2d70663ad 100644 --- a/fvdb/src/GridBatch.h +++ b/fvdb/src/GridBatch.h @@ -39,6 +39,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return A contiguous copy of this grid batch GridBatch contiguous() const { + detail::RAIIDeviceGuard guard(device()); return GridBatch(detail::GridBatchImpl::contiguous(impl())); } @@ -72,6 +73,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The number of grids indexed by this batch int64_t grid_count() const { + detail::RAIIDeviceGuard guard(device()); TORCH_CHECK(impl()->batchSize() <= MAX_GRIDS_PER_BATCH, "Cannot have more than ", MAX_GRIDS_PER_BATCH, " grids in a batch"); return impl()->batchSize(); @@ -82,6 +84,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The total number of enabled voxels indexed by this batch of grids int64_t total_enabled_voxels() const { + detail::RAIIDeviceGuard guard(device()); return impl()->totalEnabledVoxels(false); } @@ -89,6 +92,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The total number of voxels indexed by this batch of grids int64_t total_voxels() const { + detail::RAIIDeviceGuard guard(device()); return impl()->totalVoxels(); } @@ -97,6 +101,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The number of voxels indexed by the bi^th grid in the batch int64_t num_voxels_at(int64_t bi) const { + detail::RAIIDeviceGuard guard(device()); return impl()->numVoxels(bi); } @@ -111,6 +116,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The cumulative number of voxels indexed by the first bi+1 grids int64_t cum_voxels_at(int64_t bi) const { + detail::RAIIDeviceGuard guard(device()); return impl()->cumVoxels(bi); } @@ -146,6 +152,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The total number of bytes required to store all grids indexed by this batch int64_t total_bytes() const { + detail::RAIIDeviceGuard guard(device()); return impl()->totalBytes(); } @@ -157,6 +164,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The total number of leaf nodes indexed by this batch of grids int64_t total_leaf_nodes() const { + detail::RAIIDeviceGuard guard(device()); return impl()->totalLeaves(); } @@ -171,6 +179,7 @@ struct GridBatch : torch::CustomClassHolder { /// grid in the batch torch::Tensor joffsets() const { + detail::RAIIDeviceGuard guard(device()); return impl()->voxelOffsets(true); } @@ -179,7 +188,8 @@ struct GridBatch : torch::CustomClassHolder { /// the i^th grid torch::Tensor jlidx() const { - const torch::Tensor ret = impl()->jlidx(true); + detail::RAIIDeviceGuard guard(device()); + const torch::Tensor ret = impl()->jlidx(true); if (ret.numel() == 0) { return torch::arange({ grid_count() }, torch::TensorOptions().device(device()).dtype(torch::kInt64)); @@ -193,7 +203,8 @@ struct GridBatch : torch::CustomClassHolder { /// of the i^th voxel torch::Tensor jidx() const { - const torch::Tensor ret = impl()->jidx(true); + detail::RAIIDeviceGuard guard(device()); + const torch::Tensor ret = impl()->jidx(true); if (grid_count() == 1 && ret.numel() == 0) { return torch::zeros({ total_voxels() }, torch::TensorOptions().device(device()).dtype(torch::kInt16)); @@ -206,6 +217,7 @@ struct GridBatch : torch::CustomClassHolder { /// @param voxel_size A 3D (shape [3,]) tensor specifying the voxel size to set for each grid inline void set_global_voxel_size(const Vec3dOrScalar &voxel_size) { + detail::RAIIDeviceGuard guard(device()); impl()->setGlobalVoxelSize(voxel_size.value()); } @@ -213,6 +225,7 @@ struct GridBatch : torch::CustomClassHolder { /// @param origin A 3D (shape [3,]) tensor specifying the voxel origin to set for each grid inline void set_global_origin(const Vec3d &origin) { + detail::RAIIDeviceGuard guard(device()); impl()->setGlobalVoxelOrigin(origin.value()); } @@ -220,6 +233,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return Whether the grid is mutable inline bool is_mutable() const { + detail::RAIIDeviceGuard guard(device()); return impl()->isMutable(); } @@ -236,6 +250,7 @@ struct GridBatch : torch::CustomClassHolder { /// this batch inline const std::vector primal_transforms() const { + detail::RAIIDeviceGuard guard(device()); std::vector transforms; transforms.reserve(grid_count()); for (int64_t bi = 0; bi < grid_count(); ++bi) { @@ -250,6 +265,7 @@ struct GridBatch : torch::CustomClassHolder { /// grids in this batch inline const std::vector dual_transforms() const { + detail::RAIIDeviceGuard guard(device()); std::vector transforms; transforms.reserve(grid_count()); for (int64_t bi = 0; bi < grid_count(); ++bi) { @@ -264,6 +280,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The primal transform of the bi^th grid in the batch inline const fvdb::detail::VoxelCoordTransform primal_transform_at(int64_t bi) const { + detail::RAIIDeviceGuard guard(device()); return impl()->primalTransform(bi); } @@ -273,6 +290,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The dual transform of the bi^th grid in the batch inline const fvdb::detail::VoxelCoordTransform dual_transform_at(int64_t bi) const { + detail::RAIIDeviceGuard guard(device()); return impl()->dualTransform(bi); } @@ -824,6 +842,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return A GridBatch representing the grid at the specified index GridBatch index(int64_t bi) const { + detail::RAIIDeviceGuard guard(device()); return GridBatch(impl()->index(bi)); } @@ -834,6 +853,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return A GridBatch representing the slice of this grid batch GridBatch index(size_t start, size_t stop, size_t step) const { + detail::RAIIDeviceGuard guard(device()); return GridBatch(impl()->index(start, stop, step)); } @@ -843,6 +863,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The grid batch vieweed at the specified indices GridBatch index(const std::vector &bi) const { + detail::RAIIDeviceGuard guard(device()); return GridBatch(impl()->index(bi)); } @@ -852,6 +873,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The grid batch vieweed at the specified indices GridBatch index(const std::vector &bi) const { + detail::RAIIDeviceGuard guard(device()); return GridBatch(impl()->index(bi)); } @@ -861,6 +883,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The grid batch vieweed at the specified indices GridBatch index(const torch::Tensor &bi) const { + detail::RAIIDeviceGuard guard(device()); return GridBatch(impl()->index(bi)); } @@ -872,6 +895,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return A JaggedTensor corresponding to the voxel grid of this grid batch JaggedTensor jagged_like(const torch::Tensor &data, bool ignore_disabled = true) const { + detail::RAIIDeviceGuard guard(device()); return impl()->jaggedTensor(data, ignore_disabled); } @@ -964,6 +988,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return A serialized grid batch encoded as a torch::Tensor of type int8 torch::Tensor serialize() const { + detail::RAIIDeviceGuard guard(device()); return impl()->serialize(); } @@ -972,6 +997,7 @@ struct GridBatch : torch::CustomClassHolder { /// @return The deserializes grid batch static GridBatch deserialize(const torch::Tensor &data) { + detail::RAIIDeviceGuard guard(data.device()); return GridBatch(detail::GridBatchImpl::deserialize(data)); } diff --git a/fvdb/src/JaggedTensor.cpp b/fvdb/src/JaggedTensor.cpp index 9f1531c5a1..3a01d7772d 100644 --- a/fvdb/src/JaggedTensor.cpp +++ b/fvdb/src/JaggedTensor.cpp @@ -573,14 +573,13 @@ JaggedTensor JaggedTensor::index(JaggedTensorIndex idx) const { if (idx.is_integer()) { return FVDB_DISPATCH_KERNEL_DEVICE(mData.device(), [&]() { - return detail::ops::dispatchJaggedTensorIndex(*this, idx.integer()); + return detail::ops::dispatchJaggedTensorIndexInt(*this, idx.integer()); }); } else if (idx.is_slice()) { int64_t start = idx.slice().start().as_int_unchecked(); int64_t end = idx.slice().stop().as_int_unchecked(); int64_t step = idx.slice().step().as_int_unchecked(); - TORCH_CHECK_INDEX(step == 1, - "step must be 1 for JaggedTensor. Only contiguous slicing is supported."); + TORCH_CHECK_VALUE(step >= 1, "step in slice must be >= 1"); // Deal with symbolic int if (start >= at::indexing::INDEX_MAX) { @@ -590,138 +589,18 @@ JaggedTensor::index(JaggedTensorIndex idx) const { end = 0; } - // Convert indexes to positive values - if (start < 0) { - start += mNumOuterLists; - } - if (end < 0) { - end += mNumOuterLists; - } - if (start >= end) { - start = end; - } - - start = std::max(start, (int64_t)0); - end = std::min(end, mNumOuterLists); - - if (mListIdx.size(0) == 0) { - TORCH_CHECK(ldim() == 1, "bad list indexes. this should never happen"); - const JOffsetsType startIdx = mOffsets[start].item(); - const JOffsetsType endIdx = mOffsets[end].item(); - const torch::Tensor retLidx = - mListIdx.numel() == 0 ? mListIdx - : mListIdx.index({ torch::indexing::Slice(start, end) }); - return JaggedTensor::from_data_offsets_and_list_ids( - mData.index({ torch::indexing::Slice(startIdx, endIdx) }), - mOffsets.index({ torch::indexing::Slice(start, end + 1) }) - startIdx, retLidx); - } else { - // Find all tensors that belong to the slice - const torch::Tensor outerLidx = mListIdx.index({ torch::indexing::Slice(), 0 }); - const torch::Tensor mask = outerLidx.ge(start).logical_and(outerLidx.lt(end)); - const torch::Tensor joffsetCat = - torch::stack({ mOffsets.index({ torch::indexing::Slice(0, num_tensors()) }), - mOffsets.index({ torch::indexing::Slice(1, num_tensors() + 1) }) }, - 1); - const torch::Tensor selectedOffsets = joffsetCat.index({ mask }); - - // Get the start and end offsets into the data tensor for the slice - JOffsetsType startIdx = - selectedOffsets.size(0) > 0 ? selectedOffsets[0][0].item() : 0; - JOffsetsType endIdx = - selectedOffsets.size(0) > 0 ? selectedOffsets[-1][1].item() : 0; - - // Slice the data tensor - const torch::Tensor retData = mData.index({ torch::indexing::Slice(startIdx, endIdx) }); - - // Subtract the start offset from the selected offsets to get the new offsets - // NOTE: This assumes offsets are always contiguous - const torch::Tensor retOffsets = - selectedOffsets.numel() > 0 - ? torch::cat({ selectedOffsets.index({ torch::indexing::Slice(), 0 }), - selectedOffsets.index({ -1, 1 }).unsqueeze(0) }) - - startIdx - : torch::zeros( - { 1 }, - torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); - - // Slice the list indices and subtract the start index - TORCH_CHECK(mListIdx.size(1) > 1, "bad list indexes. this should never happen"); - torch::Tensor retListIdx = mListIdx.index({ mask }); - retListIdx.index({ torch::indexing::Slice(), 0 }) -= start; - if (retListIdx.dim() == 0) { - retListIdx = retListIdx.unsqueeze(1); - } - const int64_t retNumOuterLists = end - start; - const torch::Tensor retJidx = jidx_from_joffsets(retOffsets, retData.size(0)); - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( - retData, retOffsets, retJidx, retListIdx, retNumOuterLists); - } + return FVDB_DISPATCH_KERNEL_DEVICE(mData.device(), [&]() { + return detail::ops::dispatchJaggedTensorIndexSlice(*this, start, end, step); + }); } else if (idx.is_ellipsis()) { return *this; } else if (idx.is_jagged_tensor()) { const JaggedTensor &jtIndices = idx.jagged_tensor(); - TORCH_CHECK_VALUE(jtIndices.device() == device(), - "indices must be on the same device as the JaggedTensor"); - - TORCH_CHECK_INDEX( - jtIndices.mListIdx.dim() == mListIdx.dim(), - "Indices must have the same list structure as JaggedTensor being indexed"); - for (int i = 0; i < mListIdx.dim(); ++i) { - TORCH_CHECK_INDEX( - jtIndices.mListIdx.size(i) == mListIdx.size(i), - "Indices must have the same list structure as JaggedTensor being indexed"); - } - if (Config::global().pendanticErrorCheckingEnabled()) { - // This is a slow check that we cap optionally do for correctness. - TORCH_CHECK_INDEX( - torch::all(jtIndices.mListIdx == mListIdx).item(), - "Indices must have the same list structure as JaggedTensor being indexed. ", - "This error was raised because config.pendatic_error_checking was enabled"); - } - - c10::ScalarType idxdt = jtIndices.scalar_type(); - const bool isIndexType = (idxdt == c10::ScalarType::Long || idxdt == c10::ScalarType::Int || - idxdt == c10::ScalarType::Byte || idxdt == c10::ScalarType::Bool); - TORCH_CHECK_INDEX( - isIndexType, - "JaggedTensors used as indices must be long, int, byte or bool JaggedTensors but got ", - idxdt); - - torch::Tensor selidx; - if (jtIndices.scalar_type() == torch::kBool) { - selidx = jtIndices.jdata(); - } else { - // FIXME (Francis): We're not checking out of range here and it's sketchy! Fix in a - // unified CUDA kernel - selidx = jtIndices.jdata().clone(); - for (int i = 0; i < jtIndices.joffsets().size(0) - 1; ++i) { - const JOffsetsType start = jtIndices.joffsets()[i].item(); - const JOffsetsType end = jtIndices.joffsets()[i + 1].item(); - const JOffsetsType add = mOffsets[i].item(); - selidx.index({ torch::indexing::Slice(start, end) }).add_(add); - } - } - - const torch::Tensor retJdata = mData.index({ selidx }); - torch::Tensor retJidx = mBatchIdx.index({ selidx }); - if (retJidx.dim() > 1) { - std::vector idx; - idx.reserve(retJidx.dim()); - idx.push_back(at::indexing::Slice()); - for (int i = 1; i < retJidx.dim(); ++i) { - idx.push_back(0); - } - retJidx = retJidx.index(idx); - } - retJidx = retJidx.contiguous(); - const torch::Tensor retJOffsets = - joffsets_from_jidx_and_jdata(retJidx, retJdata, num_tensors()); - const torch::Tensor retListIdx = mListIdx; - - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( - retJdata, retJOffsets, retJidx, retListIdx, mNumOuterLists); + return FVDB_DISPATCH_KERNEL_DEVICE(mData.device(), [&]() { + return detail::ops::dispatchJaggedTensorIndexJaggedTensor(*this, jtIndices); + }); } else { - TORCH_CHECK_INDEX(false, "Unsupported indexing operation"); + TORCH_CHECK_VALUE(false, "Unsupported indexing operation"); } } diff --git a/fvdb/src/detail/GridBatchImpl.cu b/fvdb/src/detail/GridBatchImpl.cu index 0159c991a8..c28279d8f9 100644 --- a/fvdb/src/detail/GridBatchImpl.cu +++ b/fvdb/src/detail/GridBatchImpl.cu @@ -87,8 +87,10 @@ GridBatchImpl::GridBatchImpl(nanovdb::GridHandle &&gridHdl, }; GridBatchImpl::~GridBatchImpl() { + torch::Device device = mGridHdl->buffer().device(); mHostGridMetadata.clear(); if (mDeviceGridMetadata != nullptr) { + c10::cuda::CUDAGuard deviceGuard(device); c10::cuda::CUDACachingAllocator::raw_delete(mDeviceGridMetadata); } }; @@ -172,10 +174,12 @@ void GridBatchImpl::syncMetadataToDeviceIfCUDA(bool blocking) { if (device().is_cuda()) { // There is something to sync and we're on a cuda device + // Global device guards as we operate on this. + c10::cuda::CUDAGuard deviceGuard(device()); + // We haven't allocated the cuda memory yet, so we need to do that now if (mDeviceGridMetadata == nullptr) { // We need to allocate the memory on the device - c10::cuda::CUDAGuard deviceGuard(device()); size_t metaDataByteSize = sizeof(GridMetadata) * mHostGridMetadata.size(); at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(device().index()); mDeviceGridMetadata = @@ -321,6 +325,7 @@ GridBatchImpl::setGrid(nanovdb::GridHandle &&gridHdl, // Clear out old grid metadata mHostGridMetadata.clear(); if (mDeviceGridMetadata != nullptr) { + c10::cuda::CUDAGuard deviceGuard(device); c10::cuda::CUDACachingAllocator::raw_delete(mDeviceGridMetadata); mDeviceGridMetadata = nullptr; } @@ -355,6 +360,7 @@ GridBatchImpl::setGrid(nanovdb::GridHandle &&gridHdl, // We don't need the device copy of the global batch metadata anymore (we only carry around // the host version and pass it by value to device kernels), so delete it if constexpr (DeviceTag == torch::kCUDA) { + c10::cuda::CUDAGuard deviceGuard(device); c10::cuda::CUDACachingAllocator::raw_delete(deviceBatchMetadataPtr); } }); diff --git a/fvdb/src/detail/TorchDeviceBuffer.cpp b/fvdb/src/detail/TorchDeviceBuffer.cpp index f498a66802..32392d7658 100644 --- a/fvdb/src/detail/TorchDeviceBuffer.cpp +++ b/fvdb/src/detail/TorchDeviceBuffer.cpp @@ -37,18 +37,23 @@ GridHandle::copy( if (iAmHost && guideIsHost) { std::memcpy(buffer.data(), mBuffer.data(), mBuffer.size()); // deep copy of buffer in CPU RAM + return GridHandle(std::move(buffer)); } else if (iAmHost && guideIsDevice) { + const at::cuda::CUDAGuard device_guard{ guide.device() }; at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(guide.device().index()); cudaCheck(cudaMemcpyAsync(buffer.deviceData(), mBuffer.data(), mBuffer.size(), cudaMemcpyHostToDevice, defaultStream.stream())); cudaCheck(cudaStreamSynchronize(defaultStream.stream())); + return GridHandle(std::move(buffer)); } else if (iAmDevice && guideIsHost) { at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(mBuffer.device().index()); cudaCheck(cudaMemcpyAsync(buffer.data(), mBuffer.deviceData(), mBuffer.size(), cudaMemcpyDeviceToHost, defaultStream.stream())); cudaCheck(cudaStreamSynchronize(defaultStream.stream())); + return GridHandle(std::move(buffer)); } else if (iAmDevice && guideIsDevice) { + const at::cuda::CUDAGuard device_guard{ guide.device() }; if (mBuffer.device() == guide.device()) { at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(mBuffer.device().index()); @@ -68,8 +73,10 @@ GridHandle::copy( cudaMemcpyHostToDevice, outBufferStream.stream())); cudaCheck(cudaStreamSynchronize(outBufferStream.stream())); } + return GridHandle(std::move(buffer)); + } else { + TORCH_CHECK(false, "All host/device combos exhausted. This should never happen."); } - return GridHandle(std::move(buffer)); } } // namespace nanovdb @@ -143,11 +150,13 @@ TorchDeviceBuffer::toCpu(bool blocking) { // If this is a cuda device, copy the data to the CPU if (mDevice.is_cuda()) { + c10::cuda::CUDAGuard deviceGuard(mDevice); at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(mDevice.index()); copyDeviceToHostAndFreeDevice(defaultStream.stream(), blocking); } if (mGpuData != nullptr) { + c10::cuda::CUDAGuard deviceGuard(mDevice); c10::cuda::CUDACachingAllocator::raw_delete(mGpuData); mGpuData = nullptr; } @@ -247,6 +256,7 @@ TorchDeviceBuffer::init(uint64_t size, void *data /* = nullptr */, bool host /* void TorchDeviceBuffer::clear() { if (mGpuData) { + c10::cuda::CUDAGuard deviceGuard(mDevice); c10::cuda::CUDACachingAllocator::raw_delete(mGpuData); } if (mCpuData) { diff --git a/fvdb/src/detail/io/SaveNanoVDB.cpp b/fvdb/src/detail/io/SaveNanoVDB.cpp index 0d5c5f8ca9..8e8e8517d0 100644 --- a/fvdb/src/detail/io/SaveNanoVDB.cpp +++ b/fvdb/src/detail/io/SaveNanoVDB.cpp @@ -282,6 +282,7 @@ getIndexGrid(const GridBatch &gridBatch, const std::vector names = // Write out the full grid to the buffer if (isCuda) { + c10::cuda::CUDAGuard deviceGuard(gridBatch.device()); at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(gridBatch.device().index()); cudaMemcpyAsync(writeHead, readHead, sourceGridByteSize, cudaMemcpyDeviceToHost, @@ -382,6 +383,7 @@ saveIndexGridWithBlindData(const std::string &path, const GridBatch &gridBatch, // Copy the full bi^th index grid to the buffer const size_t sourceGridByteSize = nanoGridHdl.gridSize(bi); if (isCuda) { + c10::cuda::CUDAGuard deviceGuard(gridBatch.device()); at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(gridBatch.device().index()); cudaMemcpyAsync((void *)writeHead, (void *)readHead, sourceGridByteSize, diff --git a/fvdb/src/detail/ops/JCat0.cu b/fvdb/src/detail/ops/JCat0.cu index f140543b2f..628edba988 100644 --- a/fvdb/src/detail/ops/JCat0.cu +++ b/fvdb/src/detail/ops/JCat0.cu @@ -80,6 +80,8 @@ computeIndexPutArg( template <> JaggedTensor dispatchJCat0(const std::vector &vec) { + c10::cuda::CUDAGuard deviceGuard(vec[0].device()); + int64_t totalElements = 0; int64_t maxElements = 0; thrust::host_vector offsets; diff --git a/fvdb/src/detail/ops/JIdxForJOffsets.cu b/fvdb/src/detail/ops/JIdxForJOffsets.cu index 4b97d599b6..827e6875e4 100644 --- a/fvdb/src/detail/ops/JIdxForJOffsets.cu +++ b/fvdb/src/detail/ops/JIdxForJOffsets.cu @@ -55,9 +55,9 @@ dispatchJIdxForJOffsets(torch::Tensor joffsets, int64_t numElement torch::empty({ numElements }, torch::TensorOptions().dtype(fvdb::JIdxScalarType).device(joffsets.device())); - const int blockSize = 1024; - const int gridSize = (numElements + blockSize - 1) / blockSize; - jIdxForJOffsets<<>>( + const int NUM_THREADS = 1024; + const int NUM_BLOCKS = GET_BLOCKS(numElements, NUM_THREADS); + jIdxForJOffsets<<>>( joffsets.packed_accessor32(), retJIdx.packed_accessor32()); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/fvdb/src/detail/ops/JOffsetsFromJIdx.cu b/fvdb/src/detail/ops/JOffsetsFromJIdx.cu index a06c1fd842..bfe3d61e79 100644 --- a/fvdb/src/detail/ops/JOffsetsFromJIdx.cu +++ b/fvdb/src/detail/ops/JOffsetsFromJIdx.cu @@ -18,78 +18,9 @@ setZero(T *thingToSet) { *thingToSet = 0; } -template <> -torch::Tensor -dispatchJOffsetsForJIdx(torch::Tensor jidx, torch::Tensor jdata, int64_t numTensors) { - TORCH_CHECK_VALUE(jdata.device().is_cuda(), "Invalid device for jdata"); - c10::cuda::CUDAGuard deviceGuard(jdata.device()); - - if (jidx.size(0) == 0 && numTensors == 1) { - torch::Tensor ret = torch::empty({ 2 }, JOffsetsScalarType); - auto acc = ret.accessor(); - acc[0] = 0; - acc[1] = jdata.size(0); - return ret.to(jdata.device()); - } - - TORCH_CHECK_VALUE(jidx.device().is_cuda(), "Invalid device for jidx"); - TORCH_CHECK_VALUE(jidx.scalar_type() == JIdxScalarType, "Invalid scalar type for jidx. Got ", - jidx.scalar_type(), " but expected ", JIdxScalarType); - TORCH_CHECK_VALUE(jidx.is_contiguous(), "jidx must be contiguous"); - TORCH_CHECK_VALUE(jidx.size(0) == jdata.size(0), - "jidx and jdata must have the same number of elments"); - - const size_t numItems = jidx.size(0); - - // FIXME: Francis -- write a dummy output iterator so we don't actually allocate here. - torch::Tensor dummyOut = torch::empty( - { numTensors }, torch::TensorOptions().dtype(JIdxScalarType).device(jdata.device())); - - torch::Tensor joffsetsOut = torch::empty( - { numTensors + 1 }, torch::TensorOptions().dtype(JIdxScalarType).device(jdata.device())); - torch::Tensor numRunsOut = - torch::empty({ 1 }, torch::TensorOptions().dtype(JIdxScalarType).device(jdata.device())); - - // Get current cuda stream for device - at::cuda::CUDAStream currentStream = at::cuda::getCurrentCUDAStream(jdata.device().index()); - - // Determine temporary device storage requirements - void *d_temp_storage = nullptr; - size_t temp_storage_bytes = 0; - cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, - jidx.data_ptr(), // keys in - dummyOut.data_ptr(), // unique out (dummy) - joffsetsOut.data_ptr() + 1, // counts out - numRunsOut.data_ptr(), // num runs out - numItems, currentStream.stream()); - - // Allocate temporary storage - d_temp_storage = c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(temp_storage_bytes, - currentStream.stream()); - - // Do the actual reduce by key operation - cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, - jidx.data_ptr(), // keys in - dummyOut.data_ptr(), // values out (dummy) - joffsetsOut.data_ptr() + 1, // unique out - numRunsOut.data_ptr(), // num runs out - numItems, currentStream.stream()); - - // Free up scratch memory - c10::cuda::CUDACachingAllocator::raw_delete(d_temp_storage); - - // Zero out the first element - setZero<<<1, 1>>>(joffsetsOut.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return joffsetsOut.cumsum(0, JOffsetsScalarType); -} - -template <> torch::Tensor -dispatchJOffsetsForJIdx(torch::Tensor jidx, torch::Tensor jdata, int64_t numTensors) { +joffsetsForJIdx(torch::Tensor jidx, torch::Tensor jdata, int64_t numTensors) { TORCH_CHECK_VALUE(jidx.dim() == 1, "jidx must be a 1D tensor"); - TORCH_CHECK_VALUE(jdata.device().is_cpu(), "Invalid device for jdata"); if (jidx.size(0) == 0 && numTensors == 1) { torch::Tensor ret = torch::empty({ 2 }, JOffsetsScalarType); @@ -99,8 +30,6 @@ dispatchJOffsetsForJIdx(torch::Tensor jidx, torch::Tensor jdata, in return ret.to(jdata.device()); } - TORCH_CHECK_VALUE(jidx.device().is_cpu(), "Invalid device for jidx"); - // Get the number of unique batch indices assuming jidx is always sorted // It should be of the form [0, ..., 0, 1, ..., 1, 3, ..., 3, ...] std::tuple uniqueRes = @@ -112,12 +41,98 @@ dispatchJOffsetsForJIdx(torch::Tensor jidx, torch::Tensor jdata, in torch::full({ numTensors + 1 }, 0, torch::TensorOptions().dtype(JOffsetsScalarType).device(jdata.device())); fullBatchCounts.index({ torch::indexing::Slice(1, torch::indexing::None, 1) }) - .index_put_({ uniqueBatchValues.to(torch::kLong) }, uniqueBatchCounts); + .index_put_({ uniqueBatchValues }, uniqueBatchCounts); torch::Tensor cumOffsets = torch::cumsum(fullBatchCounts, 0, JOffsetsScalarType); return cumOffsets; } +template <> +torch::Tensor +dispatchJOffsetsForJIdx(torch::Tensor jidx, torch::Tensor jdata, int64_t numTensors) { + TORCH_CHECK_VALUE(jidx.dim() == 1, "jidx must be a 1D tensor"); + TORCH_CHECK_VALUE(jdata.device().is_cuda(), "Invalid device for jdata"); + TORCH_CHECK_VALUE(jidx.size(0) == 0 || jidx.device().is_cuda(), "Invalid device for jidx"); + c10::cuda::CUDAGuard deviceGuard(jdata.device()); + return joffsetsForJIdx(jidx, jdata, numTensors); + + // FIXME: (Francis) Sadly this implementation doesn't work with empty tensors, but the above one + // is still pretty good + /* + TORCH_CHECK_VALUE(jdata.device().is_cuda(), "Invalid device for jdata"); + c10::cuda::CUDAGuard deviceGuard(jdata.device()); + + if (jidx.size(0) == 0 && numTensors == 1) { + torch::Tensor ret = torch::empty({ 2 }, JOffsetsScalarType); + auto acc = ret.accessor(); + acc[0] = 0; + acc[1] = jdata.size(0); + return ret.to(jdata.device()); + } + + TORCH_CHECK_VALUE(jidx.device().is_cuda(), "Invalid device for jidx"); + TORCH_CHECK_VALUE(jidx.scalar_type() == JIdxScalarType, "Invalid scalar type for jidx. Got + ", jidx.scalar_type(), " but expected ", JIdxScalarType); + TORCH_CHECK_VALUE(jidx.is_contiguous(), "jidx must be contiguous"); + TORCH_CHECK_VALUE(jidx.size(0) == jdata.size(0), + "jidx and jdata must have the same number of elments"); + + const size_t numItems = jidx.size(0); + + // FIXME: Francis -- write a dummy output iterator so we don't actually allocate here. + torch::Tensor dummyOut = torch::empty( + { numTensors }, torch::TensorOptions().dtype(JIdxScalarType).device(jdata.device())); + + torch::Tensor joffsetsOut = torch::empty( + { numTensors + 1 }, + torch::TensorOptions().dtype(JIdxScalarType).device(jdata.device())); torch::Tensor + numRunsOut = torch::empty({ 1 }, + torch::TensorOptions().dtype(JIdxScalarType).device(jdata.device())); + + // Get current cuda stream for device + at::cuda::CUDAStream currentStream = at::cuda::getCurrentCUDAStream(jdata.device().index()); + + // Determine temporary device storage requirements + void *d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, + jidx.data_ptr(), // keys in + dummyOut.data_ptr(), // unique out + (dummy) joffsetsOut.data_ptr() + 1, // counts + out numRunsOut.data_ptr(), // num runs out + numItems, currentStream.stream()); + + // Allocate temporary storage + d_temp_storage = c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(temp_storage_bytes, + currentStream.stream()); + + // Do the actual reduce by key operation + cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, + jidx.data_ptr(), // keys in + dummyOut.data_ptr(), // values out + (dummy) joffsetsOut.data_ptr() + 1, // unique + out numRunsOut.data_ptr(), // num runs out + numItems, currentStream.stream()); + + // Free up scratch memory + c10::cuda::CUDACachingAllocator::raw_delete(d_temp_storage); + + // Zero out the first element + setZero<<<1, 1>>>(joffsetsOut.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return joffsetsOut.cumsum(0, JOffsetsScalarType); + */ +} + +template <> +torch::Tensor +dispatchJOffsetsForJIdx(torch::Tensor jidx, torch::Tensor jdata, int64_t numTensors) { + TORCH_CHECK_VALUE(jidx.dim() == 1, "jidx must be a 1D tensor"); + TORCH_CHECK_VALUE(jdata.device().is_cpu(), "Invalid device for jdata"); + TORCH_CHECK_VALUE(jidx.size(0) == 0 || jidx.device().is_cpu(), "Invalid device for jidx"); + return joffsetsForJIdx(jidx, jdata, numTensors); +} + } // namespace ops } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/ops/JaggedTensorIndex.cu b/fvdb/src/detail/ops/JaggedTensorIndex.cu index 5262f596f3..cc9a83cfdc 100644 --- a/fvdb/src/detail/ops/JaggedTensorIndex.cu +++ b/fvdb/src/detail/ops/JaggedTensorIndex.cu @@ -1,8 +1,6 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include "Ops.h" - #include #include @@ -14,19 +12,11 @@ namespace fvdb { namespace detail { namespace ops { -// __global__ void makeJOffsetsForListJt(const TorchRAcc32 inJoffsets, int64_t -// idxVal, -// TorchRAcc32 outJoffsets) { -// JOffsetsType startIdx = inJoffsets[idxVal]; -// JOffsetsType endIdx = inJoffsets[idxVal + 1]; -// outJoffsets[0] = 0; -// outJoffsets[1] = endIdx - startIdx; -// } - +// This kernel computes the offsets for an integer indexing operation __global__ void -getJOffsetsMask(const int64_t idxVal, const TorchRAcc32 jlidx, - const TorchRAcc32 inJoffsets, - TorchRAcc32 offsetsAndRange) { +getJOffsetsIndexMask(const int64_t idxVal, const TorchRAcc32 jlidx, + const TorchRAcc32 inJoffsets, + TorchRAcc32 offsetsAndRange) { int32_t idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= jlidx.size(0)) { @@ -56,20 +46,470 @@ getJOffsetsMask(const int64_t idxVal, const TorchRAcc32 jlidx, } } -// __global__ void computeJLidx(const int64_t startIdx, const int64_t idxVal, -// const TorchRAcc32 inJLIdx, -// TorchRAcc32 outJLidx) { -// int32_t idx = threadIdx.x + blockIdx.x * blockDim.x; +// Computes a mask for the data tensor for a slice operation +__global__ void +makeDataSliceMask(const int64_t start, const int64_t end, const int64_t step, + const TorchRAcc32 inJIdx, const TorchRAcc32 inJLidx, + TorchRAcc32 outDataMask, bool isLdim1, bool oneTensor) { + int32_t idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx >= outDataMask.size(0)) { + return; + } + + if (isLdim1) { + const JIdxType jidx = oneTensor ? 0 : inJIdx[idx]; // Which tensor we're in + const bool elementIncluded = (jidx >= start && jidx < end && (jidx - start) % step == 0); + outDataMask[idx] = elementIncluded; + } else { + const JIdxType jidx = oneTensor ? 0 : inJIdx[idx]; // Which tensor this element belongs to + const JLIdxType lidx = inJLidx[jidx][0]; // Which list this tensor belongs to + const bool isIncluded = (lidx >= start && lidx < end && + (lidx - start) % step == 0); // Is the list included in the slice? + outDataMask[idx] = + isIncluded; // The element belongs to a tensor in a list that is included in the slice + } +} + +// Computes a the new joffsets and jlidx tensor for a slice operation. Note that the output joffsets +// and jlidx have redundant values that need to be masked out. We allocate them to be the size of +// the input so we don't need to track the size of the output tensors. +// This kernel also computes the appropriate masks +__global__ void +makeOffsetsSliceMask(const int64_t start, const int64_t end, const int64_t step, + const TorchRAcc32 inJoffsets, + const TorchRAcc32 inJLidx, TorchRAcc32 outOffsetsMask, + TorchRAcc32 outTensorSizes, + TorchRAcc32 outJLIdx, bool isLdim1) { + int32_t idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx == 0) { + outTensorSizes[0] = 0; + outOffsetsMask[0] = true; + } + + if (idx >= (inJoffsets.size(0) - 1)) { + return; + } + + const JOffsetsType inOrdinal = + isLdim1 ? idx : inJLidx[idx][0]; // either jidx when ldim1 or lidx when ldim2 + const JOffsetsType outOrdinal = + (inOrdinal - start + step - 1) / step; // which tensor or list this offset belongs to + + const bool offsetIncluded = + (inOrdinal >= start && inOrdinal < end && (inOrdinal - start) % step == 0); + outOffsetsMask[idx + 1] = offsetIncluded; + + if (offsetIncluded) { + outTensorSizes[idx + 1] = inJoffsets[idx + 1] - inJoffsets[idx]; + + if (!isLdim1) { + outJLIdx[idx][0] = outOrdinal; + outJLIdx[idx][1] = inJLidx[idx][1]; + } + } +} + +// When we're indexing with a jagged tensor, each indexing tensor i_AB has integers in the range +// [0, t_AB.size(dim))] where t_AB is the tensor being indexed. We need to convert these to global +// indices into the jdata tensor by adding the appropriate joffset to each index +template +__global__ void +calculateIndexShiftForEachElement(const TorchRAcc64 inJOffsets, + const TorchRAcc64 inJIdx, + TorchRAcc64 outAdd) { + int32_t idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx >= outAdd.size(0)) { + return; + } + + JIdxType jidx = inJIdx.size(0) > 0 ? inJIdx[idx] : 0; + outAdd[idx] = inJOffsets[jidx]; +} + +// This corresponds to indexing with a JaggedTensor. i.e. using each tensor in an indexing +// JaggedTensor to index the corresponding tensor in the JaggedTensor +// i.e. jt = JaggedTensor([[t_11, t_12], [t_21, t_22, t_23], ...]) +// indices = JaggedTensor([[i_11, i_12], [i_21, i_22, i_23], ...]) +// jt[indices] -> JaggedTensor([[t_11[i_11], t_12[i_12]], [t_21[i_21], t_22[i_22], t_23[i_23]], +// ...]) +// Here indices can be integers or a boolean mask +JaggedTensor +jaggedTensorIndexJaggedTensorImpl(const JaggedTensor &jt, const JaggedTensor &jtIndices) { + TORCH_CHECK_VALUE(jtIndices.device() == jt.device(), + "indices must be on the same device as the JaggedTensor"); + + TORCH_CHECK_INDEX(jtIndices.jlidx().dim() == jt.jlidx().dim(), + "Indices must have the same list structure as JaggedTensor being indexed"); + for (int i = 0; i < jt.jlidx().dim(); ++i) { + TORCH_CHECK_INDEX( + jtIndices.jlidx().size(i) == jt.jlidx().size(i), + "Indices must have the same list structure as JaggedTensor being indexed"); + } + if (Config::global().pendanticErrorCheckingEnabled()) { + // This is a slow check that we cap optionally do for correctness. + TORCH_CHECK_INDEX( + torch::all(jtIndices.jlidx() == jt.jlidx()).item(), + "Indices must have the same list structure as JaggedTensor being indexed. ", + "This error was raised because config.pendatic_error_checking was enabled"); + } + + c10::ScalarType idxDtype = jtIndices.scalar_type(); + const bool isIndexType = + (idxDtype == c10::ScalarType::Long || idxDtype == c10::ScalarType::Int || + idxDtype == c10::ScalarType::Byte || idxDtype == c10::ScalarType::Bool); + TORCH_CHECK_INDEX( + isIndexType, + "JaggedTensors used as indices must be long, int, byte or bool JaggedTensors but got ", + idxDtype); + + torch::Tensor selidx; + if (jtIndices.scalar_type() == torch::kBool) { + selidx = jtIndices.jdata(); + } else { + if (jt.device().is_cpu()) { + // FIXME (Francis): We're not checking out of range here and it's sketchy! Fix in a + // unified CUDA kernel + selidx = jtIndices.jdata().clone(); + for (int i = 0; i < jtIndices.joffsets().size(0) - 1; ++i) { + const JOffsetsType start = jtIndices.joffsets()[i].item(); + const JOffsetsType end = jtIndices.joffsets()[i + 1].item(); + const JOffsetsType add = jt.joffsets()[i].item(); + selidx.index({ torch::indexing::Slice(start, end) }).add_(add); + } + } else { + torch::Tensor selidxAdd = + torch::empty({ jtIndices.jdata().size(0) }, jtIndices.jdata().options()); + + AT_DISPATCH_INTEGRAL_TYPES( + jtIndices.scalar_type(), "calculateIndexShiftForEachElement", [&] { + const int64_t MAX_BLOCKS = 4194302; // floor((2^32 - 1) / 1024) + const int64_t numBlocks = GET_BLOCKS(jtIndices.jdata().size(0), 1024); + TORCH_INTERNAL_ASSERT(numBlocks < MAX_BLOCKS, "Too many blocks"); + calculateIndexShiftForEachElement<<>>( + jt.joffsets() + .packed_accessor64(), + jtIndices.jidx().packed_accessor64(), + selidxAdd.packed_accessor64()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + for (int i = 1; i < jtIndices.jdata().dim(); i += 1) { + selidxAdd = selidxAdd.unsqueeze(-1); + } + selidx = selidxAdd + jtIndices.jdata(); + } + } + + const torch::Tensor retJdata = jt.jdata().index({ selidx }); + torch::Tensor retJidx = jt.jidx().size(0) > 0 ? jt.jidx().index({ selidx }) : jt.jidx(); + if (retJidx.dim() > 1) { + std::vector idx; + idx.reserve(retJidx.dim()); + idx.push_back(at::indexing::Slice()); + for (int i = 1; i < retJidx.dim(); ++i) { + idx.push_back(0); + } + retJidx = retJidx.index(idx); + } + retJidx = retJidx.contiguous(); + const torch::Tensor retJOffsets = + JaggedTensor::joffsets_from_jidx_and_jdata(retJidx, retJdata, jt.num_tensors()); + const torch::Tensor retListIdx = jt.jlidx(); + + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retJdata, retJOffsets, retJidx, + retListIdx, jt.num_outer_lists()); +} + +// This corresponds to indexing with a slice +// i.e. jt = JaggedTensor([...]) +// jt[2:11:4] -> JaggedTensor([...]) where every fourth entry from the third to the tenth list +// (inclusive) is selected +JaggedTensor +jaggedTensorIndexSliceCuda(const JaggedTensor &jt, int64_t start, int64_t end, int64_t step) { + // Convert indexes to positive values in the range [0, num_outer_lists] + if (start < 0) { + start += jt.num_outer_lists(); + } + if (end < 0) { + end += jt.num_outer_lists(); + } + if (start >= end) { + start = end; + } + start = std::max(start, (int64_t)0); + end = std::min(end, jt.num_outer_lists()); + + // Single list case with step size 1 (ldim = 1) + if (jt.ldim() == 1 && step == 1) { + TORCH_CHECK(jt.ldim() == 1, "bad list indexes. this should never happen"); + const JOffsetsType startIdx = jt.joffsets()[start].item(); + const JOffsetsType endIdx = jt.joffsets()[end].item(); + const torch::Tensor retLidx = + jt.jlidx().numel() == 0 ? jt.jlidx() + : jt.jlidx().index({ torch::indexing::Slice(start, end) }); + return JaggedTensor::from_data_offsets_and_list_ids( + jt.jdata().index({ torch::indexing::Slice(startIdx, endIdx) }), + jt.joffsets().index({ torch::indexing::Slice(start, end + 1) }) - startIdx, retLidx); + } + + // Compute a boolean mask for the data tensor and offsets as well as the tensor sizes (which we + // cumsum) and list ids The offsets mask is used so we can just write the tensor sizes/lidx to + // the output tensor and then select only the active values. Otherwise, we'd need something like + // a binsearch + const torch::TensorOptions maskOpts = + torch::TensorOptions().device(jt.device()).dtype(torch::kBool); + torch::Tensor dataMask = torch::empty({ jt.jdata().size(0) }, maskOpts); + torch::Tensor offsetsMask = torch::empty({ jt.joffsets().size(0) }, maskOpts); + torch::Tensor outJLIdx = torch::empty_like(jt.jlidx()); + torch::Tensor outJOffsets = torch::empty_like(jt.joffsets()); + + auto joffsetsAcc = jt.joffsets().packed_accessor32(); + auto jidxAcc = jt.jidx().packed_accessor32(); + auto jlidxAcc = jt.jlidx().packed_accessor32(); + auto dataMaskAcc = dataMask.packed_accessor32(); + auto offsetsMaskAcc = offsetsMask.packed_accessor32(); + auto outJOffsetsAcc = + outJOffsets.packed_accessor32(); + auto outJLIdxAcc = outJLIdx.packed_accessor32(); + + auto callKernel = [=]() { + const int64_t MAX_BLOCKS = 4194302; // floor((2^32 - 1) / 1024) + const int64_t numBlocksData = GET_BLOCKS(jt.jdata().size(0), 1024); + TORCH_INTERNAL_ASSERT(numBlocksData < MAX_BLOCKS, "Too many blocks"); + makeDataSliceMask<<>>(start, end, step, jidxAcc, jlidxAcc, dataMaskAcc, + jt.ldim() == 1, jt.num_tensors() == 1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + const int numBlocksOffsets = GET_BLOCKS(jt.joffsets().size(0) - 1, 1024); + TORCH_INTERNAL_ASSERT(numBlocksOffsets < MAX_BLOCKS, "Too many blocks"); + makeOffsetsSliceMask<<>>(start, end, step, joffsetsAcc, jlidxAcc, + offsetsMaskAcc, outJOffsetsAcc, + outJLIdxAcc, jt.ldim() == 1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }; + callKernel(); + + const torch::Tensor outData = jt.jdata().index({ dataMask }); + + outJOffsets = outJOffsets.index({ offsetsMask }); + torch::cumsum_out(outJOffsets, outJOffsets, 0); + + torch::Tensor outJIdx = + outJOffsets.size(0) > 2 + ? fvdb::JaggedTensor::jidx_from_joffsets(outJOffsets, outData.size(0)) + : torch::empty( + { 0 }, torch::TensorOptions().dtype(JIdxScalarType).device(jt.jdata().device())); + + outJLIdx = + jt.ldim() > 1 + ? outJLIdx.index( + { offsetsMask.index({ torch::indexing::Slice(1, offsetsMask.size(0), 1) }) }) + : torch::empty( + { 0, 1 }, + torch::TensorOptions().dtype(JLIdxScalarType).device(jt.jdata().device())); + + const JOffsetsType totalItems = (end - start + step - 1) / step; + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(outData, outJOffsets, outJIdx, + outJLIdx, totalItems); +} + +// This corresponds to indexing with a slice +// i.e. jt = JaggedTensor([...]) +// jt[2:11:4] -> JaggedTensor([...]) where every fourth entry from the third to the tenth list +// (inclusive) is selected +JaggedTensor +jaggedTensorIndexSliceCpu(const JaggedTensor &jt, int64_t start, int64_t end, int64_t step) { + // Convert indexes to positive values + if (start < 0) { + start += jt.num_outer_lists(); + } + if (end < 0) { + end += jt.num_outer_lists(); + } + if (start >= end) { + start = end; + } + + start = std::max(start, static_cast(0)); + end = std::min(end, jt.num_outer_lists()); + + if (jt.ldim() == 1 && step == 1) { + TORCH_CHECK(jt.ldim() == 1, "bad list indexes. this should never happen"); + const JOffsetsType startIdx = jt.joffsets()[start].item(); + const JOffsetsType endIdx = jt.joffsets()[end].item(); + const torch::Tensor retLidx = + jt.jlidx().numel() == 0 ? jt.jlidx() + : jt.jlidx().index({ torch::indexing::Slice(start, end) }); + return JaggedTensor::from_data_offsets_and_list_ids( + jt.jdata().index({ torch::indexing::Slice(startIdx, endIdx) }), + jt.joffsets().index({ torch::indexing::Slice(start, end + 1) }) - startIdx, retLidx); + } else if (jt.ldim() > 1 && step == 1) { + // Find all tensors that belong to the slice + const torch::Tensor outerLidx = jt.jlidx().index({ torch::indexing::Slice(), 0 }); + const torch::Tensor lidxMask = outerLidx.ge(start).logical_and(outerLidx.lt(end)); + const torch::Tensor joffsetCat = torch::stack( + { jt.joffsets().index({ torch::indexing::Slice(0, jt.num_tensors()) }), + jt.joffsets().index({ torch::indexing::Slice(1, jt.num_tensors() + 1) }) }, + 1); + + // Start and end element index of each tensor in the slice + const torch::Tensor selectedOffsets = joffsetCat.index({ lidxMask }); + + // Get the start and end offsets into the data tensor for the slice + JOffsetsType startIdx = + selectedOffsets.size(0) > 0 ? selectedOffsets[0][0].item() : 0; + JOffsetsType endIdx = + selectedOffsets.size(0) > 0 ? selectedOffsets[-1][1].item() : 0; + + // Slice the data tensor + const torch::Tensor retData = + jt.jdata().index({ torch::indexing::Slice(startIdx, endIdx) }); + + // Subtract the start offset from the selected offsets to get the new offsets + // NOTE: This assumes offsets are always contiguous + const torch::Tensor retOffsets = + selectedOffsets.numel() > 0 + ? torch::cat({ selectedOffsets.index({ torch::indexing::Slice(), 0 }), + selectedOffsets.index({ -1, 1 }).unsqueeze(0) }) - + startIdx + : torch::zeros( + { 1 }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(jt.jdata().device())); + + // Slice the list indices and subtract the start index + TORCH_CHECK(jt.jlidx().size(1) > 1, "bad list indexes. this should never happen"); + torch::Tensor retListIdx = jt.jlidx().index({ lidxMask }); + retListIdx.index({ torch::indexing::Slice(), 0 }) -= start; + if (retListIdx.dim() == 0) { + retListIdx = retListIdx.unsqueeze(1); + } + const int64_t retNumOuterLists = end - start; + const torch::Tensor retJidx = + retOffsets.size(0) > 2 + ? JaggedTensor::jidx_from_joffsets(retOffsets, retData.size(0)) + : torch::empty( + { 0 }, + torch::TensorOptions().dtype(JIdxScalarType).device(jt.jdata().device())); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retOffsets, retJidx, + retListIdx, retNumOuterLists); + } else if (jt.ldim() == 1 && step > 1) { + const JOffsetsType totalItems = (end - start + step - 1) / step; + const torch::TensorOptions offsetsOpts = + torch::TensorOptions().dtype(JOffsetsScalarType).device(jt.jdata().device()); + torch::Tensor dataMask = torch::zeros( + { jt.jdata().size(0) }, torch::TensorOptions().dtype(torch::kBool).device(jt.device())); + torch::Tensor retOffsets = torch::empty({ totalItems + 1 }, offsetsOpts); + + auto retOffsetsAcc = retOffsets.accessor(); + auto joffsetsAcc = jt.joffsets().accessor(); + int64_t count = 0; + retOffsetsAcc[0] = 0; + for (int64_t i = start; i < end; i += step) { + JOffsetsType startIdx = joffsetsAcc[i]; + JOffsetsType endIdx = joffsetsAcc[i + 1]; + dataMask.index({ torch::indexing::Slice(startIdx, endIdx) }).fill_(true); + retOffsetsAcc[count + 1] = endIdx - startIdx; + count += 1; + } + torch::cumsum_out(retOffsets, retOffsets, 0); + const torch::Tensor retData = jt.jdata().index({ dataMask }); + const torch::Tensor retJIdx = + retOffsets.size(0) > 2 + ? JaggedTensor::jidx_from_joffsets(retOffsets, retData.size(0)) + : torch::empty( + { 0 }, + torch::TensorOptions().dtype(JIdxScalarType).device(jt.jdata().device())); + const torch::Tensor retJLidx = torch::zeros( + { 0, 1 }, torch::TensorOptions().dtype(JLIdxScalarType).device(jt.jdata().device())); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retOffsets, retJIdx, + retJLidx, totalItems); + } else { + // Find all tensors that belong to the slice + const torch::Tensor outerLidx = jt.jlidx().index({ torch::indexing::Slice(), 0 }); + const torch::Tensor lidxMask = outerLidx.ge(start) + .logical_and(outerLidx.lt(end)) + .logical_and((outerLidx - start) % step == 0); + const torch::Tensor selectedOffsets = + torch::stack( + { jt.joffsets().index({ torch::indexing::Slice(0, jt.num_tensors()) }), + jt.joffsets().index({ torch::indexing::Slice(1, jt.num_tensors() + 1) }) }, + 1) + .index({ lidxMask }); + + const torch::Tensor selectedLidx = jt.jlidx().index({ lidxMask }); + + const torch::TensorOptions offsetsOpts = + torch::TensorOptions().dtype(JOffsetsScalarType).device(jt.jdata().device()); + const torch::TensorOptions lidxOpts = + torch::TensorOptions().dtype(JLIdxScalarType).device(jt.jdata().device()); + + torch::Tensor dataMask = torch::zeros( + { jt.jdata().size(0) }, torch::TensorOptions().dtype(torch::kBool).device(jt.device())); + torch::Tensor retOffsets = torch::empty({ selectedOffsets.size(0) + 1 }, offsetsOpts); + torch::Tensor retJLidx = torch::empty({ selectedOffsets.size(0), jt.ldim() }, lidxOpts); + + auto retOffsetsAcc = retOffsets.accessor(); + auto retJLidxAcc = retJLidx.accessor(); + auto selOffsetsAcc = selectedOffsets.accessor(); + auto selLidxAcc = selectedLidx.accessor(); + retOffsetsAcc[0] = 0; + JLIdxType count = -1; + for (int i = 0; i < retOffsets.size(0) - 1; i += 1) { + if (i == 0 || selLidxAcc[i][0] != selLidxAcc[i - 1][0]) { + count += 1; + } + + JOffsetsType startIdx = selOffsetsAcc[i][0]; + JOffsetsType endIdx = selOffsetsAcc[i][1]; + + dataMask.index({ torch::indexing::Slice(startIdx, endIdx) }).fill_(true); + retOffsetsAcc[i + 1] = endIdx - startIdx; + retJLidxAcc[i][0] = count; + retJLidxAcc[i][1] = selLidxAcc[i][1]; + } + count += 1; + torch::cumsum_out(retOffsets, retOffsets, 0); + const torch::Tensor retData = jt.jdata().index({ dataMask }); + const torch::Tensor retJIdx = + retOffsets.size(0) > 2 + ? JaggedTensor::jidx_from_joffsets(retOffsets, retData.size(0)) + : torch::empty( + { 0 }, + torch::TensorOptions().dtype(JIdxScalarType).device(jt.jdata().device())); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retOffsets, retJIdx, + retJLidx, count); + } +} + +// Special case of integer indexing where the JaggedTensor is just a list of tensors and not a list +// of lists of tensors. We call this from the CPU and GPU implementations which is why it's factored +// out i.e. jt = JaggedTensor([t_0, t_1, t_2, ..., t_n]) +// jt[2] -> JaggedTensor([t_2]) where the 3rd list is selected +JaggedTensor +jaggedTensorIndexIntOneList(const JaggedTensor &jt, int64_t idxVal) { + torch::Tensor joffsets = jt.joffsets(); + torch::Tensor jdata = jt.jdata(); + torch::Tensor jlidx = jt.jlidx(); -// if (idx >= outJLidx.size(0)) { -// return; -// } -// outJLidx[idx][0] = inJLIdx[idx + startIdx][0] - idxVal; -// outJLidx[idx][1] = inJLIdx[idx + startIdx][1]; -// } + TORCH_CHECK(jt.ldim() == 1, "bad list indexes. this should never happen"); + const JOffsetsType startIdx = joffsets[idxVal].item(); + const JOffsetsType endIdx = joffsets[idxVal + 1].item(); + const torch::Tensor retJoffsets = + torch::tensor({ JOffsetsType(0), endIdx - startIdx }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(jdata.device())); + const torch::Tensor retData = jdata.index({ torch::indexing::Slice(startIdx, endIdx) }); + const torch::Tensor retJidx = torch::empty({ 0 }, torch::TensorOptions().dtype(JIdxScalarType)); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retJoffsets, retJidx, + jlidx, retJoffsets.size(0) - 1); +} +// This corresponds to indexing with an integer +// i.e. jt = JaggedTensor([...]) +// jt[2] -> JaggedTensor([...]) where the 3rd list is selected JaggedTensor -jaggedTensorIndexMultiListCuda(const JaggedTensor &jt, int64_t idxVal) { +jaggedTensorIndexIntCuda(const JaggedTensor &jt, int64_t idxVal) { if (idxVal < 0) { idxVal += jt.num_outer_lists(); } @@ -77,6 +517,10 @@ jaggedTensorIndexMultiListCuda(const JaggedTensor &jt, int64_t idxVal) { " is out of bounds for JaggedTensor with ", jt.num_outer_lists(), " elements"); + if (jt.jlidx().size(0) == 0) { + return jaggedTensorIndexIntOneList(jt, idxVal); + } + torch::Tensor joffsets = jt.joffsets(); torch::Tensor jdata = jt.jdata(); torch::Tensor jlidx = jt.jlidx(); @@ -93,8 +537,12 @@ jaggedTensorIndexMultiListCuda(const JaggedTensor &jt, int64_t idxVal) { auto inJOffsetsAcc = joffsets.packed_accessor32(); auto offsetsAndRangeAcc = offsetsAndRange.packed_accessor32(); - const int numBlocks = GET_BLOCKS(joffsets.size(0), 1024); - getJOffsetsMask<<>>(idxVal, inJLidxAcc, inJOffsetsAcc, offsetsAndRangeAcc); + + const int64_t MAX_BLOCKS = 4194302; // floor((2^32 - 1) / 1024) + const int64_t numBlocks = GET_BLOCKS(joffsets.size(0), 1024); + TORCH_INTERNAL_ASSERT(numBlocks < MAX_BLOCKS, "Too many blocks"); + getJOffsetsIndexMask<<>>(idxVal, inJLidxAcc, inJOffsetsAcc, + offsetsAndRangeAcc); C10_CUDA_KERNEL_LAUNCH_CHECK(); offsetsAndRange = offsetsAndRange.cpu(); @@ -129,8 +577,11 @@ jaggedTensorIndexMultiListCuda(const JaggedTensor &jt, int64_t idxVal) { retListIdx, retNumOuterLists); } +// This corresponds to indexing with an integer +// i.e. jt = JaggedTensor([...]) +// jt[2] -> JaggedTensor([...]) where the 3rd list is selected JaggedTensor -jaggedTensorIndexMultiListCpu(const JaggedTensor &jt, int64_t idxVal) { +jaggedTensorIndexIntCpu(const JaggedTensor &jt, int64_t idxVal) { if (idxVal < 0) { idxVal += jt.num_outer_lists(); } @@ -138,6 +589,10 @@ jaggedTensorIndexMultiListCpu(const JaggedTensor &jt, int64_t idxVal) { " is out of bounds for JaggedTensor with ", jt.num_outer_lists(), " elements"); + if (jt.jlidx().size(0) == 0) { + return jaggedTensorIndexIntOneList(jt, idxVal); + } + torch::Tensor joffsets = jt.joffsets(); torch::Tensor jdata = jt.jdata(); torch::Tensor jlidx = jt.jlidx(); @@ -181,49 +636,58 @@ jaggedTensorIndexMultiListCpu(const JaggedTensor &jt, int64_t idxVal) { retListIdx, retNumOuterLists); } +// This corresponds to indexing with an integer +// i.e. jt = JaggedTensor([...]) +// jt[2] -> JaggedTensor([...]) where the 3rd list is selected +template <> JaggedTensor -jaggedTensorIndexOneList(const JaggedTensor &jt, int64_t idxVal) { - if (idxVal < 0) { - idxVal += jt.num_outer_lists(); - } - TORCH_CHECK_INDEX(idxVal >= 0 && idxVal < jt.num_outer_lists(), "Index ", idxVal, - " is out of bounds for JaggedTensor with ", jt.num_outer_lists(), - " elements"); - - torch::Tensor joffsets = jt.joffsets(); - torch::Tensor jdata = jt.jdata(); - torch::Tensor jlidx = jt.jlidx(); - - TORCH_CHECK(jt.ldim() == 1, "bad list indexes. this should never happen"); - const JOffsetsType startIdx = joffsets[idxVal].item(); - const JOffsetsType endIdx = joffsets[idxVal + 1].item(); - const torch::Tensor retJoffsets = - torch::tensor({ JOffsetsType(0), endIdx - startIdx }, - torch::TensorOptions().dtype(JOffsetsScalarType).device(jdata.device())); - const torch::Tensor retData = jdata.index({ torch::indexing::Slice(startIdx, endIdx) }); - const torch::Tensor retJidx = torch::empty({ 0 }, torch::TensorOptions().dtype(JIdxScalarType)); - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retJoffsets, retJidx, - jlidx, retJoffsets.size(0) - 1); +dispatchJaggedTensorIndexInt(const JaggedTensor &jt, int64_t idxVal) { + return jaggedTensorIndexIntCpu(jt, idxVal); +} +template <> +JaggedTensor +dispatchJaggedTensorIndexInt(const JaggedTensor &jt, int64_t idxVal) { + c10::cuda::CUDAGuard deviceGuard(jt.device()); + return jaggedTensorIndexIntCuda(jt, idxVal); } +// This corresponds to indexing with a slice +// i.e. jt = JaggedTensor([...]) +// jt[2:11:4] -> JaggedTensor([...]) where every fourth entry from the third to the tenth list +// (inclusive) is selected template <> JaggedTensor -dispatchJaggedTensorIndex(const JaggedTensor &jt, int64_t idxVal) { - if (jt.jlidx().size(0) == 0) { - return jaggedTensorIndexOneList(jt, idxVal); - } else { - return jaggedTensorIndexMultiListCpu(jt, idxVal); - } +dispatchJaggedTensorIndexSlice(const JaggedTensor &jt, int64_t start, int64_t end, + int64_t step) { + return jaggedTensorIndexSliceCpu(jt, start, end, step); +} +template <> +JaggedTensor +dispatchJaggedTensorIndexSlice(const JaggedTensor &jt, int64_t start, int64_t end, + int64_t step) { + c10::cuda::CUDAGuard deviceGuard(jt.device()); + return jaggedTensorIndexSliceCuda(jt, start, end, step); } +// This corresponds to indexing with a JaggedTensor. i.e. using each tensor in an indexing +// JaggedTensor to index the corresponding tensor in the JaggedTensor +// i.e. jt = JaggedTensor([[t_11, t_12], [t_21, t_22, t_23], ...]) +// indices = JaggedTensor([[i_11, i_12], [i_21, i_22, i_23], ...]) +// jt[indices] -> JaggedTensor([[t_11[i_11], t_12[i_12]], [t_21[i_21], t_22[i_22], t_23[i_23]], +// ...]) +// Here indices can be integers or a boolean mask template <> JaggedTensor -dispatchJaggedTensorIndex(const JaggedTensor &jt, int64_t idxVal) { - if (jt.jlidx().size(0) == 0) { - return jaggedTensorIndexOneList(jt, idxVal); - } else { - return jaggedTensorIndexMultiListCuda(jt, idxVal); - } +dispatchJaggedTensorIndexJaggedTensor(const JaggedTensor &jt, + const JaggedTensor &idx) { + return jaggedTensorIndexJaggedTensorImpl(jt, idx); +} +template <> +JaggedTensor +dispatchJaggedTensorIndexJaggedTensor(const JaggedTensor &jt, + const JaggedTensor &idx) { + c10::cuda::CUDAGuard deviceGuard(jt.device()); + return jaggedTensorIndexJaggedTensorImpl(jt, idx); } } // namespace ops diff --git a/fvdb/src/detail/ops/Ops.h b/fvdb/src/detail/ops/Ops.h index ce5ed89ca3..d3623b701c 100644 --- a/fvdb/src/detail/ops/Ops.h +++ b/fvdb/src/detail/ops/Ops.h @@ -15,7 +15,14 @@ namespace detail { namespace ops { template -JaggedTensor dispatchJaggedTensorIndex(const JaggedTensor &jt, int64_t idxVal); +JaggedTensor dispatchJaggedTensorIndexInt(const JaggedTensor &jt, int64_t idxVal); + +template +JaggedTensor dispatchJaggedTensorIndexSlice(const JaggedTensor &jt, int64_t start, int64_t end, + int64_t step); + +template +JaggedTensor dispatchJaggedTensorIndexJaggedTensor(const JaggedTensor &jt, const JaggedTensor &idx); template JaggedTensor dispatchJCat0(const std::vector &tensors); diff --git a/fvdb/src/detail/ops/VolumeRender.cu b/fvdb/src/detail/ops/VolumeRender.cu index 217eebd4d8..8d88619ecd 100644 --- a/fvdb/src/detail/ops/VolumeRender.cu +++ b/fvdb/src/detail/ops/VolumeRender.cu @@ -260,11 +260,12 @@ dispatchVolumeRender( // auto total_samples = torch::zeros({numRays}, // torch::dtype(torch::kLong).device(sigmas.device())); - const int64_t threads = 1024, blocks = (numRays + threads - 1) / threads; + const int64_t NUM_THREADS = 1024; + const int64_t NUM_BLOCKS = GET_BLOCKS(numRays, NUM_THREADS); AT_DISPATCH_FLOATING_TYPES_AND_HALF( sigmas.scalar_type(), "volumeRender", ([&] { - volumeRender<<>>( + volumeRender<<>>( sigmas.packed_accessor32(), rgbs.packed_accessor32(), deltas.packed_accessor32(), @@ -324,11 +325,12 @@ dispatchVolumeRenderBackward(const torch::Tensor dLdOpacity, torch::Tensor dLdWs_times_ws = (dLdWs * ws); // auxiliary input - const int64_t threads = 1024, blocks = (numRays + threads - 1) / threads; + const int64_t NUM_THREADS = 1024; + const int64_t NUM_BLOCKS = GET_BLOCKS(numRays, NUM_THREADS); AT_DISPATCH_FLOATING_TYPES_AND_HALF( sigmas.scalar_type(), "volumeRenderBackward", ([&] { - volumeRenderBackward<<>>( + volumeRenderBackward<<>>( dLdOpacity.packed_accessor32(), dLdDepth.packed_accessor32(), // dLdDepthSq.packed_accessor32(), diff --git a/fvdb/src/detail/utils/Utils.h b/fvdb/src/detail/utils/Utils.h index 724ab81d12..9cca57d472 100644 --- a/fvdb/src/detail/utils/Utils.h +++ b/fvdb/src/detail/utils/Utils.h @@ -12,6 +12,7 @@ #include +#include #include #include @@ -295,6 +296,31 @@ StringToTorchScalarType(std::string dtypeStr) { TORCH_CHECK(false, "Invalid dtype string " + dtypeStr); } +struct RAIIDeviceGuard { + RAIIDeviceGuard(torch::Device device) { + if (device.is_cuda()) { + mGuard = new c10::cuda::CUDAGuard(device.index()); + } + } + + RAIIDeviceGuard(torch::Device device1, torch::Device device2) { + if (device1.is_cuda()) { + mGuard = new c10::cuda::CUDAGuard(device1.index()); + } else if (device2.is_cuda()) { + mGuard = new c10::cuda::CUDAGuard(device2.index()); + } + } + + RAIIDeviceGuard(const RAIIDeviceGuard &) = delete; + + RAIIDeviceGuard &operator=(const RAIIDeviceGuard &) = delete; + + ~RAIIDeviceGuard() { delete mGuard; } + + private: + c10::cuda::CUDAGuard *mGuard = nullptr; +}; + } // namespace detail } // namespace fvdb diff --git a/fvdb/src/python/GridBatchBinding.cpp b/fvdb/src/python/GridBatchBinding.cpp index 9333104ad3..93c87558db 100644 --- a/fvdb/src/python/GridBatchBinding.cpp +++ b/fvdb/src/python/GridBatchBinding.cpp @@ -624,6 +624,9 @@ bind_grid_batch(py::module &m) { .def("to", py::overload_cast(&fvdb::GridBatch::to, py::const_), py::arg("to_grid")) + .def("cpu", [](const fvdb::GridBatch &self) { return self.to(torch::kCPU); }) + .def("cuda", [](const fvdb::GridBatch &self) { return self.to(torch::kCUDA); }) + // .def("clone", &fvdb::GridBatch::clone) // TODO: We totally want this .def( diff --git a/fvdb/tests/unit/test_jagged_tensor.py b/fvdb/tests/unit/test_jagged_tensor.py index d15b26eb78..f599039d0a 100644 --- a/fvdb/tests/unit/test_jagged_tensor.py +++ b/fvdb/tests/unit/test_jagged_tensor.py @@ -457,6 +457,23 @@ def test_indexing(self, device, dtype): # self.check_lshape(gridbatch[9:8:1].ijk, ijk_list[9:8:1]) self.check_lshape(gridbatch.ijk[9:8:1], ijk_list[9:8:1]) + self.assertTrue(torch.equal(gridbatch[9:8:2].ijk.jdata, gridbatch.ijk[9:8:1].jdata)) + # An empty grid returns an ijk JaggedTensor with one thing in it so we can't quite compare! + # self.check_lshape(gridbatch[9:8:1].ijk, ijk_list[9:8:1]) + self.check_lshape(gridbatch.ijk[9:8:2], ijk_list[9:8:2]) + + self.assertTrue(torch.equal(gridbatch[-13:8:2].ijk.jdata, gridbatch.ijk[-13:8:2].jdata)) + self.check_lshape(gridbatch[-13:8:2].ijk, ijk_list[-13:8:2]) + self.check_lshape(gridbatch.ijk[-13:8:2], ijk_list[-13:8:2]) + + self.assertTrue(torch.equal(gridbatch[4:17:3].ijk.jdata, gridbatch.ijk[4:17:3].jdata)) + self.check_lshape(gridbatch[4:17:3].ijk, ijk_list[4:17:3]) + self.check_lshape(gridbatch.ijk[4:17:3], ijk_list[4:17:3]) + + self.assertTrue(torch.equal(gridbatch[4:15:4].ijk.jdata, gridbatch.ijk[4:15:4].jdata)) + self.check_lshape(gridbatch[4:15:4].ijk, ijk_list[4:15:4]) + self.check_lshape(gridbatch.ijk[4:15:4], ijk_list[4:15:4]) + self.assertTrue(torch.equal(gridbatch.ijk.jdata, gridbatch.ijk[...].jdata)) self.check_lshape(gridbatch.ijk, ijk_list) self.check_lshape(gridbatch.ijk[...], ijk_list) @@ -469,22 +486,22 @@ def test_indexing(self, device, dtype): self.check_lshape(gridbatch[::].ijk, ijk_list[::]) self.check_lshape(gridbatch.ijk[::], ijk_list[::]) - with self.assertRaises(IndexError): - print(gridbatch.ijk[9:8:2]) + with self.assertRaises(ValueError): + print(gridbatch.ijk[9:8:0]) - with self.assertRaises(IndexError): + with self.assertRaises(ValueError): print(gridbatch.ijk[9:8:-1]) - with self.assertRaises(IndexError): + with self.assertRaises(ValueError): print(gridbatch.ijk[None]) - with self.assertRaises(IndexError): + with self.assertRaises(ValueError): print(gridbatch.ijk[9:8:-1]) - with self.assertRaises(IndexError): + with self.assertRaises(ValueError): print(gridbatch.ijk[::-1]) - with self.assertRaises(IndexError): + with self.assertRaises(ValueError): print(gridbatch.ijk[::-3]) @parameterized.expand(all_device_dtype_combos) @@ -1531,14 +1548,21 @@ def test_list_of_lists_indexing(self): self.assertTrue(torch.all(lij == lt2[i][j]).item()) self.assertTrue(torch.all(jt[i][j].jdata == lt2[i][j]).item()) - def test_list_of_lists_slicing(self): + @parameterized.expand(["cuda", "cpu"]) + def test_list_of_lists_slicing(self, device): lt = [ - [torch.randn(np.random.randint(100, 200), 7) for _ in range(int(l.item()))] - for l in torch.randint(3, 17, (7,)) + [torch.randn(np.random.randint(100, 200), 7).to(device) for _ in range(int(l.item()))] + for l in torch.randint(3, 5, (10,)) ] jt = fvdb.JaggedTensor(lt) self.check_lshape(jt, lt) + def check_eq(jt_, lt_): + for i, li in enumerate(lt_): + self.assertEqual(len(li), len(jt_[i].unbind())) + for j, lij in enumerate(li): + self.assertTrue(torch.all(lij == jt_[i][j].jdata).item()) + lt2 = jt.unbind() self.check_lshape(jt, lt2) self.assertEqual(len(lt), len(lt2)) @@ -1551,42 +1575,27 @@ def test_list_of_lists_slicing(self): jt2 = jt[2:3] lt2 = lt[2:3] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[:4] lt2 = lt[:4] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[:] lt2 = lt[:] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[:-1] lt2 = lt[:-1] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[-5:] lt2 = lt[-5:] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[-1] lt2 = lt[-1] @@ -1599,42 +1608,204 @@ def test_list_of_lists_slicing(self): jt2 = jt[1:1] lt2 = lt[1:1] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[-1:1] lt2 = lt[-1:1] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[-5:-1] lt2 = lt[-5:-1] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[-5000:-1] lt2 = lt[-5000:-1] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) jt2 = jt[-5000:5000] lt2 = lt[-5000:5000] self.check_lshape(jt2, lt2) - for i, li in enumerate(lt2): - self.assertEqual(len(li), len(jt2[i].unbind())) - for j, lij in enumerate(li): - self.assertTrue(torch.all(lij == jt2[i][j].jdata).item()) + check_eq(jt2, lt2) + + jt2 = jt[2:8:2] + lt2 = lt[2:8:2] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[3:-1:3] + lt2 = lt[3:-1:3] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[3:11:4] + lt2 = lt[3:11:4] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[3:2:4] + lt2 = lt[3:2:4] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + @parameterized.expand(["cuda", "cpu"]) + def test_slicing_list_of_lists_small(self, device): + lt = [ + [torch.randn(0, 7, device=device), torch.randn(2, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(1, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(1, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(0, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(0, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(0, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(1, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(0, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(0, 7, device=device)], + [torch.randn(0, 7, device=device), torch.randn(1, 7, device=device), torch.randn(0, 7, device=device)], + ] + jt = fvdb.JaggedTensor(lt) + + def check_eq(jt_, lt_): + for i, li in enumerate(lt_): + self.assertEqual(len(li), len(jt_[i].unbind())) + for j, lij in enumerate(li): + self.assertTrue(torch.all(lij == jt_[i][j].jdata).item()) + + jt2 = jt[2:3] + lt2 = lt[2:3] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[2:4] + lt2 = lt[2:4] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[:4] + lt2 = lt[:4] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[:] + lt2 = lt[:] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[:-1] + lt2 = lt[:-1] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[-5:] + lt2 = lt[-5:] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[1:1] + lt2 = lt[1:1] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[-1:1] + lt2 = lt[-1:1] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[-5:-1] + lt2 = lt[-5:-1] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[-5000:-1] + lt2 = lt[-5000:-1] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[-5000:5000] + lt2 = lt[-5000:5000] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[2:8:2] + lt2 = lt[2:8:2] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[3:-1:3] + lt2 = lt[3:-1:3] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[3:11:4] + lt2 = lt[3:11:4] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + jt2 = jt[3:2:4] + lt2 = lt[3:2:4] + self.check_lshape(jt2, lt2) + check_eq(jt2, lt2) + + @parameterized.expand(["cuda", "cpu"]) + def test_jagged_tensor_jagged_tensor_indexing_single_tensor_list(self, device): + t1 = torch.randn(100, 3, device=device) + l1 = [t1] + jt1 = fvdb.JaggedTensor(l1) + pmt1 = torch.randperm(100, device=device) + lpmt1 = [pmt1] + jpmt1 = fvdb.JaggedTensor(lpmt1) + jt_permuted = jt1[jpmt1] + self.assertTrue(torch.all(jt_permuted.jdata == t1[pmt1]).item()) + + t1 = torch.randn(100, 3, device=device) + l1 = [[t1]] + jt1 = fvdb.JaggedTensor(l1) + pmt1 = torch.randperm(100, device=device) + lpmt1 = [[pmt1]] + jpmt1 = fvdb.JaggedTensor(lpmt1) + jt_permuted = jt1[jpmt1] + self.assertTrue(torch.all(jt_permuted.jdata == t1[pmt1]).item()) + + t1 = torch.randn(10, 3, device=device) + l1 = [torch.zeros(0, 3, device=device), t1, torch.zeros(0, 3, device=device)] + jt1 = fvdb.JaggedTensor(l1) + pmt1 = torch.randperm(10, device=device) + empty_idx = torch.zeros(0, dtype=pmt1.dtype, device=device) + lpmt1 = [empty_idx, pmt1, empty_idx] + jpmt1 = fvdb.JaggedTensor(lpmt1) + jt_permuted = jt1[jpmt1] + self.assertEqual(jt_permuted.lshape, [0, 10, 0]) + self.assertTrue(torch.all(jt_permuted.jdata == t1[pmt1]).item()) + + t1 = torch.randn(10, 3, device=device) + empty_data = torch.zeros(0, 3, device=device) + l1 = [ + [empty_data, t1, empty_data], + [t1, empty_data], + [empty_data], + [empty_data, empty_data, t1, empty_data], + [t1], + ] + jt1 = fvdb.JaggedTensor(l1) + pmt1 = torch.randperm(10, device=device) + empty_idx = torch.zeros(0, dtype=pmt1.dtype, device=device) + lpmt1 = [ + [empty_idx, pmt1, empty_idx], + [pmt1, empty_idx], + [empty_idx], + [empty_idx, empty_idx, pmt1, empty_idx], + [pmt1], + ] + jpmt1 = fvdb.JaggedTensor(lpmt1) + jt_permuted = jt1[jpmt1] + self.assertEqual(jt_permuted.lshape, [[0, 10, 0], [10, 0], [0], [0, 0, 10, 0], [10]]) + for i, jtpi in enumerate(jt_permuted): + for j, jtpij in enumerate(jtpi): + self.assertTrue(torch.all(jt1[i][j].jdata[jpmt1[i][j].jdata] == jtpij.jdata).item()) @parameterized.expand(["cuda", "cpu"]) def test_jagged_tensor_integer_indexing(self, device):