From c9aee5964a977e6ca268704119f589bfb0695b54 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 21 Aug 2024 10:47:20 +1200 Subject: [PATCH] Formatting changes to match OpenVDB guidelines Added examples Updated README and environment/docker updates Signed-off-by: Jonathan Swartz --- fvdb/Dockerfile | 57 +- fvdb/README.md | 74 +- fvdb/docs/conf.py | 32 +- fvdb/env/build_environment.yml | 15 +- fvdb/env/learn_environment.yml | 1 + fvdb/env/test_environment.yml | 17 +- fvdb/examples/common.py | 190 + fvdb/examples/compare_conv_speed.py | 75 + fvdb/examples/grid_building.py | 155 + fvdb/examples/grid_subdivide_coarsen.py | 56 + fvdb/examples/mutable_grids.py | 107 + fvdb/examples/overfit_sdf.py | 116 + fvdb/examples/ray_segment_marching.py | 102 + fvdb/examples/ray_voxel_marching.py | 100 + fvdb/examples/sample_trilinear.py | 67 + fvdb/examples/splat_trilinear.py | 54 + fvdb/examples/subdivide.py | 74 + fvdb/examples/uniform_ray_marching.py | 124 + fvdb/examples/voxel_neighborhood.py | 49 + fvdb/fvdb/_Cpp.pyi | 415 +- fvdb/fvdb/__init__.py | 24 +- fvdb/fvdb/nn/modules.py | 100 +- fvdb/fvdb/nn/vdbtensor.py | 24 +- fvdb/fvdb/utils/__init__.py | 2 +- fvdb/fvdb/utils/build_ext.py | 38 +- fvdb/scripts/rename_wheels.py | 6 +- fvdb/setup.py | 211 +- fvdb/src/Config.cpp | 15 +- fvdb/src/Config.h | 17 +- fvdb/src/FVDB.cpp | 184 +- fvdb/src/FVDB.h | 259 +- fvdb/src/GridBatch.cpp | 1173 +-- fvdb/src/GridBatch.h | 857 ++- fvdb/src/JaggedTensor.cpp | 980 ++- fvdb/src/JaggedTensor.h | 690 +- fvdb/src/SparseConvPackInfo.cpp | 262 +- fvdb/src/SparseConvPackInfo.h | 197 +- fvdb/src/Types.h | 128 +- fvdb/src/detail/GridBatchImpl.cu | 470 +- fvdb/src/detail/GridBatchImpl.h | 438 +- fvdb/src/detail/TorchDeviceBuffer.cpp | 173 +- fvdb/src/detail/TorchDeviceBuffer.h | 115 +- fvdb/src/detail/TypesImpl.h | 218 +- fvdb/src/detail/VoxelCoordTransform.h | 202 +- fvdb/src/detail/autograd/Attention.cpp | 27 +- fvdb/src/detail/autograd/Attention.h | 28 +- fvdb/src/detail/autograd/Autograd.h | 25 +- fvdb/src/detail/autograd/AvgPoolGrid.cpp | 78 +- fvdb/src/detail/autograd/AvgPoolGrid.h | 23 +- fvdb/src/detail/autograd/FillToGrid.h | 71 +- fvdb/src/detail/autograd/JaggedReduce.cpp | 110 +- fvdb/src/detail/autograd/JaggedReduce.h | 40 +- fvdb/src/detail/autograd/MaxPoolGrid.cpp | 78 +- fvdb/src/detail/autograd/MaxPoolGrid.h | 25 +- fvdb/src/detail/autograd/ReadFromDense.h | 81 +- fvdb/src/detail/autograd/ReadIntoDense.cpp | 96 +- fvdb/src/detail/autograd/ReadIntoDense.h | 23 +- fvdb/src/detail/autograd/SampleGrid.cpp | 141 +- fvdb/src/detail/autograd/SampleGrid.h | 42 +- .../detail/autograd/SparseConvolutionHalo.cpp | 76 +- .../detail/autograd/SparseConvolutionHalo.h | 24 +- .../autograd/SparseConvolutionImplicitGEMM.h | 206 +- .../autograd/SparseConvolutionKernelMap.h | 160 +- fvdb/src/detail/autograd/SplatIntoGrid.cpp | 106 +- fvdb/src/detail/autograd/SplatIntoGrid.h | 39 +- fvdb/src/detail/autograd/TransformPoints.cpp | 64 +- fvdb/src/detail/autograd/TransformPoints.h | 26 +- fvdb/src/detail/autograd/UpsampleGrid.cpp | 76 +- fvdb/src/detail/autograd/UpsampleGrid.h | 22 +- fvdb/src/detail/autograd/VolumeRender.cpp | 108 +- fvdb/src/detail/autograd/VolumeRender.h | 28 +- fvdb/src/detail/build/Build.h | 114 +- fvdb/src/detail/build/CoarseFromFine.cpp | 43 +- fvdb/src/detail/build/ConvGrid.cpp | 83 +- fvdb/src/detail/build/DenseGrid.cpp | 60 +- fvdb/src/detail/build/EmptyGrid.cpp | 17 +- fvdb/src/detail/build/FineFromCoarse.cpp | 49 +- fvdb/src/detail/build/FromMesh.cpp | 75 +- .../build/NearestNeighborGridFromPoints.cpp | 145 +- .../src/detail/build/PaddedGridFromCoords.cpp | 116 +- fvdb/src/detail/build/PaddedGridFromGrid.cpp | 65 +- .../src/detail/build/PaddedGridFromPoints.cpp | 131 +- fvdb/src/detail/io/IO.h | 48 +- fvdb/src/detail/io/LoadNanovdb.cpp | 646 +- fvdb/src/detail/io/SaveNanoVDB.cpp | 402 +- fvdb/src/detail/ops/ActiveGridGoords.cu | 113 +- .../detail/ops/ActiveVoxelsInBoundsMask.cu | 171 +- fvdb/src/detail/ops/BuildDeviceGrid.cu | 297 +- fvdb/src/detail/ops/CoordsInGrid.cu | 72 +- fvdb/src/detail/ops/CountEnabledVoxels.cu | 89 +- fvdb/src/detail/ops/CubesInGrid.cu | 151 +- fvdb/src/detail/ops/DownsampleGridAvgPool.cu | 316 +- fvdb/src/detail/ops/DownsampleGridMaxPool.cu | 299 +- fvdb/src/detail/ops/EnabledMask.cu | 55 +- fvdb/src/detail/ops/FillToGrid.cu | 98 +- fvdb/src/detail/ops/GridEdgeNetwork.cu | 138 +- fvdb/src/detail/ops/IjkToIndex.cu | 70 +- fvdb/src/detail/ops/IjkToInvIndex.cu | 74 +- fvdb/src/detail/ops/JCat0.cu | 168 +- fvdb/src/detail/ops/JIdxForGrid.cu | 65 +- fvdb/src/detail/ops/JIdxForJOffsets.cu | 50 +- fvdb/src/detail/ops/JOffsetsFromJIdx.cu | 98 +- fvdb/src/detail/ops/JaggedTensorIndex.cu | 185 +- fvdb/src/detail/ops/MarchingCubes.cu | 346 +- fvdb/src/detail/ops/Ops.h | 439 +- fvdb/src/detail/ops/PaddedIJKForMesh.cu | 235 +- fvdb/src/detail/ops/PointsInGrid.cu | 86 +- .../src/detail/ops/RayImplicitIntersection.cu | 179 +- fvdb/src/detail/ops/ReadFromDense.cu | 153 +- fvdb/src/detail/ops/ReadIntoDense.cu | 139 +- fvdb/src/detail/ops/SampleGridBezier.cu | 121 +- .../detail/ops/SampleGridBezierWithGrad.cu | 147 +- .../ops/SampleGridBezierWithGradBackward.cu | 141 +- fvdb/src/detail/ops/SampleGridTrilinear.cu | 117 +- .../detail/ops/SampleGridTrilinearWithGrad.cu | 135 +- .../SampleGridTrilinearWithGradBackward.cu | 141 +- fvdb/src/detail/ops/SampleRaysUniform.cu | 434 +- .../detail/ops/ScaledDotProductAttention.cu | 191 +- fvdb/src/detail/ops/SegmentsAlongRays.cu | 358 +- fvdb/src/detail/ops/SetMasked.cu | 70 +- fvdb/src/detail/ops/SplatIntoGridBezier.cu | 148 +- fvdb/src/detail/ops/SplatIntoGridTrilinear.cu | 147 +- fvdb/src/detail/ops/TransformPointToGrid.cu | 270 +- fvdb/src/detail/ops/UpsampleGridNearest.cu | 297 +- fvdb/src/detail/ops/VolumeRender.cu | 551 +- fvdb/src/detail/ops/VoxelNeighborhood.cu | 102 +- fvdb/src/detail/ops/VoxelsAlongRays.cu | 419 +- fvdb/src/detail/ops/VoxelsForGridBuilding.cu | 707 +- .../detail/ops/convolution/backend/ConvOps.h | 103 +- .../backend/MESparseConvolution.cu | 1322 ++-- .../backend/SparseConvolutionCutlass.cu | 692 +- .../backend/SparseConvolutionHalo.cu | 376 +- .../backend/SparseConvolutionHaloGrad.cu | 355 +- .../backend/SparseConvolutionImplicitGEMM.cu | 5782 ++++++++------- .../SparseConvolutionImplicitGEMMGrad.cu | 5965 ++++++++------- ...SparseConvolutionImplicitGEMMGradSorted.cu | 6229 ++++++++-------- .../SparseConvolutionImplicitGEMMSorted.cu | 6534 +++++++++-------- .../backend/SparseConvolutionKernelMap.cu | 505 +- .../backend/SparseConvolutionLggs.cu | 231 +- .../convolution/pack_info/BrickHaloBuffer.cu | 172 +- .../pack_info/ConvolutionKernelMap.cu | 132 +- .../pack_info/IGEMMBitOperations.cu | 125 +- .../ops/convolution/pack_info/PackInfoOps.h | 36 +- fvdb/src/detail/ops/jagged/JaggedOps.h | 25 +- fvdb/src/detail/ops/jagged/JaggedReduce.cu | 117 +- fvdb/src/detail/ops/jagged/JaggedSort.cu | 110 +- .../utils/BezierInterpolationIterator.h | 84 +- .../BezierInterpolationWithGradIterator.h | 86 +- fvdb/src/detail/utils/MarchingCubesData.h | 468 +- .../utils/TrilinearInterpolationIterator.h | 87 +- .../TrilinearInterpolationWithGradIterator.h | 129 +- fvdb/src/detail/utils/Utils.h | 271 +- fvdb/src/detail/utils/cuda/Atomics.cuh | 514 +- fvdb/src/detail/utils/cuda/Utils.cuh | 481 +- .../utils/nanovdb/ActiveVoxelIterator.h | 164 +- .../detail/utils/nanovdb/CustomAccessors.h | 254 +- fvdb/src/detail/utils/nanovdb/HDDAIterators.h | 191 +- fvdb/src/detail/utils/nanovdb/Printing.h | 17 +- .../utils/nanovdb/TorchNanoConversions.h | 81 +- fvdb/src/python/Bindings.cpp | 236 +- fvdb/src/python/GridBatchBinding.cpp | 417 +- fvdb/src/python/JaggedTensorBinding.cpp | 16 +- fvdb/src/python/TypeCasters.h | 197 +- fvdb/tests/benchmark/comparative_benchmark.py | 52 +- fvdb/tests/benchmark/conftest.py | 1 + .../tests/benchmark/fvdb_benchmark/configs.py | 69 +- .../tests/benchmark/fvdb_benchmark/dataset.py | 32 +- .../fvdb_benchmark/model/minkunet.py | 143 +- .../benchmark/fvdb_benchmark/model/updown.py | 10 +- .../benchmark/fvdb_benchmark/model/xcube.py | 123 +- fvdb/tests/benchmark/fvdb_benchmark/utils.py | 15 +- .../tests/benchmark/fvdb_benchmark/wrapper.py | 158 +- fvdb/tests/benchmark/test_conv.py | 16 +- fvdb/tests/unit/common.py | 32 +- fvdb/tests/unit/nkfw_api/backend/__init__.py | 7 +- fvdb/tests/unit/nkfw_api/backend/abc.py | 68 +- fvdb/tests/unit/nkfw_api/backend/fvdb.py | 161 +- .../tests/unit/nkfw_api/backend/hash_table.py | 239 +- fvdb/tests/unit/nkfw_api/ext/__init__.py | 16 +- fvdb/tests/unit/test_accessors.py | 15 +- fvdb/tests/unit/test_basic_ops.py | 543 +- fvdb/tests/unit/test_batching.py | 300 +- fvdb/tests/unit/test_conv.py | 442 +- fvdb/tests/unit/test_dense_interface.py | 143 +- fvdb/tests/unit/test_dual.py | 47 +- fvdb/tests/unit/test_empty_grids.py | 154 +- fvdb/tests/unit/test_io.py | 167 +- fvdb/tests/unit/test_jagged_tensor.py | 344 +- fvdb/tests/unit/test_mutable_grids.py | 30 +- fvdb/tests/unit/test_nkfw_api.py | 309 +- fvdb/tests/unit/test_nn.py | 179 +- fvdb/tests/unit/test_ray_marching.py | 197 +- fvdb/tests/unit/test_sample.py | 485 +- 193 files changed, 32114 insertions(+), 25949 deletions(-) create mode 100644 fvdb/examples/common.py create mode 100644 fvdb/examples/compare_conv_speed.py create mode 100644 fvdb/examples/grid_building.py create mode 100644 fvdb/examples/grid_subdivide_coarsen.py create mode 100644 fvdb/examples/mutable_grids.py create mode 100644 fvdb/examples/overfit_sdf.py create mode 100644 fvdb/examples/ray_segment_marching.py create mode 100644 fvdb/examples/ray_voxel_marching.py create mode 100644 fvdb/examples/sample_trilinear.py create mode 100644 fvdb/examples/splat_trilinear.py create mode 100644 fvdb/examples/subdivide.py create mode 100644 fvdb/examples/uniform_ray_marching.py create mode 100644 fvdb/examples/voxel_neighborhood.py diff --git a/fvdb/Dockerfile b/fvdb/Dockerfile index 4b3aa750dc..8485af1e77 100644 --- a/fvdb/Dockerfile +++ b/fvdb/Dockerfile @@ -1,52 +1,17 @@ -ARG CUDA_VERSION=12.1.1 -ARG CUDNN_VERSION=8 +FROM nvcr.io/nvidia/pytorch:24.04-py3 -FROM nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu20.04 - -ENV PATH /usr/local/cuda/bin:$PATH -ENV LD_LIBRARY_PATH /usr/lib64:/usr/local/cuda/lib64:/usr/local/cuda/lib:${LD_LIBRARY_PATH} - -# # nvidia-container-runtime -ENV NVIDIA_VISIBLE_DEVICES all -ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,graphics - -RUN echo "Acquire { https::Verify-Peer false }" > /etc/apt/apt.conf.d/99verify-peer.conf \ - && if [ -f /etc/apt/sources.list.d/cuda.list ]; then \ - rm /etc/apt/sources.list.d/cuda.list; \ - fi \ - && if [ -f /etc/apt/sources.list.d/nvidia-ml.list ]; then \ - rm /etc/apt/sources.list.d/nvidia-ml.list; \ - fi \ - && apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ca-certificates \ - && rm /etc/apt/apt.conf.d/99verify-peer.conf \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - wget \ - rsync \ - vim \ - git \ - curl \ - ninja-build \ - cmake \ - build-essential \ - xauth \ - openssh-server \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* - -RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - bash ~/miniconda.sh -b -p /opt/conda && \ - rm ~/miniconda.sh - -ENV PATH /opt/conda/bin:$PATH -ENV TORCH_CUDA_ARCH_LIST "6.1;7.0;7.5;8.0;8.6+PTX" +ARG MODE=production +RUN echo "Building fVDB container in $MODE mode" # used for cross-compilation in docker build ENV FORCE_CUDA=1 WORKDIR /fvdb -COPY env/test_environment.yml . - -RUN /opt/conda/bin/conda env create -f test_environment.yml \ - && /opt/conda/bin/conda clean -ya \ - && /opt/conda/bin/conda init bash +COPY . . +RUN pip install --no-cache-dir -r env/build_requirements.txt + +RUN if [ "$MODE" = "production" ]; then \ + MAX_JOBS=$(free -g | awk '/^Mem:/{jobs=int($4/2.5); if(jobs<1) jobs=1; print jobs}') \ + TORCH_CUDA_ARCH_LIST="6.1;7.0;7.5;8.0;8.6+PTX" \ + python setup.py install; \ + fi \ No newline at end of file diff --git a/fvdb/README.md b/fvdb/README.md index f9429145cc..36d50052d8 100644 --- a/fvdb/README.md +++ b/fvdb/README.md @@ -46,7 +46,62 @@ conda activate fvdb_learn ## Building *f*VDB from Source -*f*VDB is a Python library implemented as a C++ Pytorch extension. + +### Environment Management +ƒVDB is a Python library implemented as a C++ Pytorch extension. Of course you can build ƒVDB in whatever environment suits you, but we provide two paths to constructing reliable environments for building and running ƒVDB: using [docker](#setting-up-a-docker-container) and using [conda](#setting-up-a-conda-environment). + +`conda` tends to be more flexible since reconfiguring toolchains and modules to suit your larger project can be dynamic, but at the same time this can be a more brittle experience compared to using a virtualized `docker` container. Using `conda` is generally recommended for development and testing, while using `docker` is recommended for CI/CD and deployment. + +#### Setting up a Docker Container + +Running a docker container is a great way to ensure that you have a consistent environment for building and running ƒVDB. + +Our provided [`Dockerfile`](Dockerfile) has two modes for building the image: `dev` and `production`. `production` constructs an image capable of building ƒVDB, builds and installs the ƒVDB libraries and is read for you to start running python code that uses the `fvdb` module. `dev` mode constructs an image which is ready to build ƒVDB but does not build the ƒVDB libraries. + +Building the docker image in `production` mode is the default and is as simple as running the following command from the root of this repository: +```shell +# Build the docker image in production mode +docker build -t fvdb/prod . +``` + +Building the docker mage in `dev` mode is done by setting the `BUILD_MODE` argument to `dev`: +```shell +# Build the docker image in dev mode +docker build --build-arg MODE=dev -t fvdb/dev . +``` + +Running the docker container is done with the following command: +```shell +# Run an interactive bash shell (or replace with your command) +docker run -it --gpus all --rm \ + fvdb/dev:latest \ + /bin/bash +``` + + +#### Setting up a Conda Environment + +In order to get resolved package versions in your conda environment consistent with our testing, it is necessary to configure your `.condarc` since not all package resolving behaviour can be controlled with an `environment.yml` file. We recommend using `strict` channel priority in your conda configuration. This can be done by running the following command: + +```shell +conda config --set channel_priority strict +``` + +Further, it is recommend to not mix the `defaults` and `conda-forge` package channels when resolving environments. We have generally used `conda-forge` as the primary channel for our dependencies. You can remove the `defaults` channel and add `conda-forge` with the following command: + +```shell +conda config --remove channels defaults +conda config --add channels conda-forge +``` + +With these changes, it is recommended that your `.condarc` file looks like the following: + +```yaml +channel_priority: strict +channels: + - conda-forge +``` + **(Optional) Install libMamba for a huge quality of life improvement when using Conda** ``` @@ -55,7 +110,6 @@ conda install -n base conda-libmamba-solver conda config --set solver libmamba ``` -### Conda Environment Next, create the `fvdb` conda environment by running the following command from the root of this repository, and then grabbing a ☕: ```shell @@ -106,22 +160,6 @@ sphinx-build -E -a docs/ build/sphinx open build/sphinx/index.html ``` -### Docker Image - -To build and test *f*VDB, we have the dockerfile available: -```shell -# Build fvdb -docker build . -t fvdb-dev -# Run fvdb (or replace with your command) -docker run -it --gpus all --rm \ - --user $(id -u):$(id -g) \ - --mount type=bind,source="$HOME/.ssh",target=/root/.ssh \ - --mount type=bind,source="$(pwd)",target=/fvdb \ - fvdb-dev:latest \ - conda run -n fvdb_test --no-capture-output python setup.py develop -``` - - ## Usage Examples diff --git a/fvdb/docs/conf.py b/fvdb/docs/conf.py index b1d0ae3be3..9f15a297e0 100644 --- a/fvdb/docs/conf.py +++ b/fvdb/docs/conf.py @@ -9,14 +9,15 @@ import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- -project = 'fVDB' -copyright = '2023, NVIDIA Corporation' -author = 'NVIDIA Corporation' +project = "fVDB" +copyright = "2023, NVIDIA Corporation" +author = "NVIDIA Corporation" # -- General configuration --------------------------------------------------- @@ -24,12 +25,7 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'myst_parser' -] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode", "sphinx.ext.napoleon", "myst_parser"] myst_enable_extensions = [ "amsmath", @@ -49,28 +45,26 @@ ] # Fix return-type in google-style docstrings -napoleon_custom_sections = [('Returns', 'params_style')] +napoleon_custom_sections = [("Returns", "params_style")] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] -autodoc_default_options = { - 'undoc-members': 'forward, extra_repr' -} +autodoc_default_options = {"undoc-members": "forward, extra_repr"} # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -80,6 +74,7 @@ # -- Custom hooks ------------------------------------------------------------ + def process_signature(app, what, name, obj, options, signature, return_annotation): if signature is not None: signature = signature.replace("._Cpp", "") @@ -91,5 +86,6 @@ def process_signature(app, what, name, obj, options, signature, return_annotatio return signature, return_annotation + def setup(app): app.connect("autodoc-process-signature", process_signature) diff --git a/fvdb/env/build_environment.yml b/fvdb/env/build_environment.yml index f5fcff7320..19b8436c9f 100644 --- a/fvdb/env/build_environment.yml +++ b/fvdb/env/build_environment.yml @@ -1,7 +1,8 @@ name: fvdb_build channels: - - nvidia/label/cuda-12.1.0 - pytorch + - nvidia + - conda-forge dependencies: - python=3.10 - pytorch::pytorch=2.2 @@ -11,14 +12,14 @@ dependencies: - ca-certificates - certifi - openssl - - nvidia/label/cuda-12.1.0::cuda - - nvidia/label/cuda-12.1.0::cuda-tools - - nvidia/label/cuda-12.1.0::cuda-nvcc - - nvidia/label/cuda-12.1.0::cuda-cccl - - nvidia/label/cuda-12.1.0::cuda-libraries-static + - cuda-toolkit=12.1 + - cuda-compiler=12.1 + - cuda-nvcc=12.1 + - cuda-cccl=12.1 + - cuda-libraries-static=12.1 - gcc_linux-64=11 - gxx_linux-64=11 - - setuptools + - setuptools>=68.2.2 - cmake - make - ninja diff --git a/fvdb/env/learn_environment.yml b/fvdb/env/learn_environment.yml index 976bb5dc3c..b84607d13c 100644 --- a/fvdb/env/learn_environment.yml +++ b/fvdb/env/learn_environment.yml @@ -28,6 +28,7 @@ dependencies: - pytest-benchmark - polyscope - numpy<2 + - pyrender - pip: - point-cloud-utils - linkify-it-py diff --git a/fvdb/env/test_environment.yml b/fvdb/env/test_environment.yml index c3175190b8..d4e6f792d4 100644 --- a/fvdb/env/test_environment.yml +++ b/fvdb/env/test_environment.yml @@ -1,28 +1,29 @@ name: fvdb_test channels: - pyg - - nvidia/label/cuda-12.1.0 - pytorch + - nvidia + - conda-forge dependencies: - python=3.10 - pytorch::pytorch=2.2 - pytorch::pytorch-cuda=12.1 - tensorboard - - pip + - pip>=23.3.1 - git - gitpython - ca-certificates - certifi - openssl - - nvidia/label/cuda-12.1.0::cuda - - nvidia/label/cuda-12.1.0::cuda-tools - - nvidia/label/cuda-12.1.0::cuda-nvcc - - nvidia/label/cuda-12.1.0::cuda-cccl - - nvidia/label/cuda-12.1.0::cuda-libraries-static + - cuda-toolkit=12.1 + - cuda-compiler=12.1 + - cuda-nvcc=12.1 + - cuda-cccl=12.1 + - cuda-libraries-static=12.1 - parameterized - gcc_linux-64=11 - gxx_linux-64=11 - - setuptools + - setuptools>=68.2.2 - cmake - make - ninja diff --git a/fvdb/examples/common.py b/fvdb/examples/common.py new file mode 100644 index 0000000000..9eb5a22e07 --- /dev/null +++ b/fvdb/examples/common.py @@ -0,0 +1,190 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import hashlib +import logging +import timeit +from pathlib import Path +from typing import List, Tuple, Union + +import git +import git.repo +import numpy as np +import point_cloud_utils as pcu +import polyscope as ps +import torch +from git.exc import InvalidGitRepositoryError + +from fvdb import GridBatch + + +def _clone_fvdb_example_data(): + def is_git_repo(repo_path: str): + is_repo = False + try: + _ = git.repo.Repo(repo_path) + is_repo = True + except InvalidGitRepositoryError: + is_repo = False + + return is_repo + + git_tag = "main" + git_url = "git@github.com:voxel-foundation/fvdb-example-data.git" + repo_root = Path(__file__).resolve().parent.parent + external_path = repo_root / "external" + if not external_path.exists(): + external_path.mkdir() + elif not external_path.is_dir(): + raise RuntimeError(f"External path {external_path} exists but is not a directory") + + repo_path = external_path / "fvdb_example_data" + if repo_path.exists() and repo_path.is_dir(): + if is_git_repo(str(repo_path)): + repo = git.repo.Repo(repo_path) + repo.git.checkout(git_tag) + else: + raise ValueError(f"A path {repo_path} exists but is not a git repo") + else: + repo = git.repo.Repo.clone_from(git_url, repo_path) + repo.git.checkout(git_tag) + + return repo_path, repo + + +def get_fvdb_example_data_path(): + repo_path, _ = _clone_fvdb_example_data() + return repo_path + + +def get_md5_checksum(file_path: Path): + md5_hash = hashlib.md5() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + md5_hash.update(byte_block) + return md5_hash.hexdigest() + + +def make_grid_from_points(pts: torch.Tensor, padding, vox_size, vox_origin) -> GridBatch: + logging.info("Building GridBatch from points...") + start = timeit.default_timer() + grid = GridBatch(device=pts.device) + grid.set_from_points(pts, [-padding] * 3, [padding] * 3, voxel_sizes=vox_size, origins=vox_origin) + torch.cuda.synchronize() + logging.info(f"Done in {timeit.default_timer() - start}s") + logging.info(f"GridBatch has {grid.total_voxels} voxels") + + return grid + + +def make_ray_grid( + nrays: int, + origin: Union[torch.Tensor, Tuple, List], + minb=(-0.3, -0.3), + maxb=(0.3, 0.3), + device: Union[str, torch.device] = "cpu", + dtype=torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + ray_o = torch.tensor([origin] * nrays**2) + + ray_d = torch.from_numpy( + np.stack( + [a.ravel() for a in np.mgrid[minb[0] : maxb[0] : nrays * 1j, minb[1] : maxb[1] : nrays * 1j]] + + [np.ones(nrays**2)], + axis=-1, + ).astype(np.float32) + ) + ray_d /= torch.norm(ray_d, dim=-1, keepdim=True) + + ray_o, ray_d = ray_o.to(device).to(dtype), ray_d.to(device).to(dtype) + + return ray_o, ray_d + + +def load_pointcloud( + data_path, + skip_every=1, + shuffle=False, + device=torch.device("cuda"), + dtype=torch.float32, +) -> torch.Tensor: + logging.info(f"Loading pointlcoud {data_path}...") + start = timeit.default_timer() + pts = pcu.load_mesh_v(data_path) + if shuffle: + pts = pts[np.random.permutation(pts.shape[0])] + pts = pts[::skip_every] + logging.info(f"Done in {timeit.default_timer() - start}s") + return torch.from_numpy(pts).to(device).to(dtype) + + +def load_mesh( + data_path, skip_every=1, mode="vn", device=torch.device("cuda"), dtype=torch.float32 +) -> List[torch.Tensor]: + logging.info(f"Loading mesh {data_path}...") + start = timeit.default_timer() + if mode == "v": + attrs = [pcu.load_mesh_v(data_path)] + elif mode == "vf": + attrs = pcu.load_mesh_vf(data_path) + elif mode == "vn": + attrs = pcu.load_mesh_vn(data_path) + else: + raise ValueError(f"Unsupported mode {mode}") + + attrs = [torch.from_numpy(a[::skip_every]).to(device).to(dtype) for a in attrs] + logging.info(f"Done in {timeit.default_timer() - start}s") + return attrs + + +def load_dragon_mesh(skip_every=1, mode="vn", device=torch.device("cuda"), dtype=torch.float32) -> List[torch.Tensor]: + data_path = get_fvdb_example_data_path() / "meshes" / "dragon.ply" + if get_md5_checksum(data_path) != "0222e7d2147eebcb2eacdaf6263a9512": + raise ValueError(f"Checksum for {data_path} is incorrect") + return load_mesh(data_path, mode=mode, skip_every=skip_every, device=device, dtype=dtype) + + +def load_happy_mesh(skip_every=1, mode="vn", device=torch.device("cuda"), dtype=torch.float32) -> List[torch.Tensor]: + data_path = get_fvdb_example_data_path() / "meshes" / "happy.ply" + if get_md5_checksum(data_path) != "5cfe3c9c0b58bad9a77b47ae04454160": + raise ValueError(f"Checksum for {data_path} is incorrect") + return load_mesh(data_path, mode=mode, skip_every=skip_every, device=device, dtype=dtype) + + +def load_bunny_mesh(skip_every=1, mode="vn", device=torch.device("cuda"), dtype=torch.float32) -> List[torch.Tensor]: + data_path = get_fvdb_example_data_path() / "meshes" / "bunny.ply" + if get_md5_checksum(data_path) != "fe2f062a8e22b7dab895a1945c32cd58": + raise ValueError(f"Checksum for {data_path} is incorrect") + return load_mesh(data_path, mode=mode, skip_every=skip_every, device=device, dtype=dtype) + + +def load_car_1(skip_every=1, mode="vn", device=torch.device("cuda"), dtype=torch.float32) -> List[torch.Tensor]: + data_path = get_fvdb_example_data_path() / "meshes" / "car-mesh-1.ply" + if get_md5_checksum(data_path) != "e96d59a5ee392a40442ca510c0ab8f17": + raise ValueError(f"Checksum for {data_path} is incorrect") + return load_mesh(data_path, mode=mode, skip_every=skip_every, device=device, dtype=dtype) + + +def load_car_2(skip_every=1, mode="vn", device=torch.device("cuda"), dtype=torch.float32) -> List[torch.Tensor]: + data_path = get_fvdb_example_data_path() / "meshes" / "car-mesh-2.ply" + if get_md5_checksum(data_path) != "e7bcf0922518f6b43930e155a188a3a8": + raise ValueError(f"Checksum for {data_path} is incorrect") + return load_mesh(data_path, mode=mode, skip_every=skip_every, device=device, dtype=dtype) + + +def plot_ray_segments(ray_o, ray_d, times, plot_every=1): + for i in range(0, ray_o.shape[0], plot_every): + t0s = times[i].jdata[:, 0].unsqueeze(-1) + t1s = times[i].jdata[:, 1].unsqueeze(-1) + roi = ray_o[i].unsqueeze(0) + rdi = ray_d[i].unsqueeze(0) + rp = torch.cat([roi + t0s * rdi, roi + t1s * rdi]) + re = torch.stack( + [torch.arange(t0s.shape[0]), torch.arange(t0s.shape[0]) + t0s.shape[0]], + dim=-1, + ) + + ray_segs = ps.register_curve_network(f"ray segments {i}", rp, re, radius=0.001) + rv = torch.zeros(re.shape[0]) + rv[::2] = 1.0 + ray_segs.add_scalar_quantity(f"segment colors {i}", rv, defined_on="edges", enabled=True, cmap="jet") diff --git a/fvdb/examples/compare_conv_speed.py b/fvdb/examples/compare_conv_speed.py new file mode 100644 index 0000000000..cf2c9e44a6 --- /dev/null +++ b/fvdb/examples/compare_conv_speed.py @@ -0,0 +1,75 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import time + +import numpy as np +import torch +import tqdm +from common import load_dragon_mesh + +from fvdb import GridBatch + + +def benchmark_inplace_conv(grid: GridBatch, in_feature, in_kernel): + start_time = time.perf_counter() + out_feature = grid.sparse_conv_halo(in_feature, in_kernel) + torch.cuda.synchronize() + return time.perf_counter() - start_time + + +def benchmark_kmap_conv(grid: GridBatch, in_feature, in_kernel): + start_time = time.perf_counter() + kmap, _ = grid.sparse_conv_kernel_map(kernel_size=in_kernel.size(-1), stride=1) + kmap.build_gather_scatter() + torch.cuda.synchronize() + + kmap_time = time.perf_counter() + out_feature = kmap.sparse_conv_3d(in_feature, in_kernel) + torch.cuda.synchronize() + + return kmap_time - start_time, time.perf_counter() - kmap_time + + +def main(): + device = torch.device("cuda") + dtype = torch.float32 + kernel_size = 3 + in_channel, out_channel = 128, 64 + + vox_size = 0.005 + vox_origin = (0.0, 0.0, 0.0) + p, n = load_dragon_mesh(device=device, dtype=dtype) + + index0 = GridBatch(device=device) + index0.set_from_points(p, [-1, -1, -1], [1, 1, 1], voxel_sizes=vox_size, origins=vox_origin) + + grid_feats = torch.rand((index0.total_voxels, in_channel), device=device, dtype=dtype) * 0.5 + 0.5 + kernels = ( + torch.rand(out_channel, in_channel, kernel_size, kernel_size, kernel_size, dtype=dtype, device=device) * 0.5 + + 0.5 + ) + + torch.cuda.synchronize() + + inplace_time = [] + kmap_time = [] + conv_time = [] + + for iter in tqdm.trange(100): + inplace = benchmark_inplace_conv(index0, grid_feats, kernels) + kmap, conv = benchmark_kmap_conv(index0, grid_feats, kernels) + inplace_time.append(inplace) + kmap_time.append(kmap) + conv_time.append(conv) + + inplace_time, kmap_time, conv_time = inplace_time[5:], kmap_time[5:], conv_time[5:] + + print(f"Num voxels = {index0.num_voxels}, channel = {in_channel} -> {out_channel}, device = {device}") + print(f"Convolution Inplace {np.mean(inplace_time):.4f} +/- {np.std(inplace_time):.4f}") + print(f"Kmap {np.mean(kmap_time):.4f} +/- {np.std(kmap_time):.4f}") + print(f"Kmap Convolution {np.mean(conv_time):.4f} +/- {np.std(conv_time):.4f}") + + +if __name__ == "__main__": + main() diff --git a/fvdb/examples/grid_building.py b/fvdb/examples/grid_building.py new file mode 100644 index 0000000000..6213ec6462 --- /dev/null +++ b/fvdb/examples/grid_building.py @@ -0,0 +1,155 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +from pathlib import Path + +import numpy as np +import point_cloud_utils as pcu +import polyscope as ps +import torch +from common import load_car_1, load_car_2 + +import fvdb +from fvdb import JaggedTensor +from fvdb.nn import VDBTensor + +voxel_size_1 = 0.02 +voxel_size_2 = 0.03 + + +def build_from_pointcloud(pcd_1: np.ndarray, pcd_2: np.ndarray): + # Assemble point clouds into JaggedTensor + pcd_jagged = JaggedTensor([torch.from_numpy(pcd_1).float().cuda(), torch.from_numpy(pcd_2).float().cuda()]) + voxel_sizes = [[voxel_size_1, voxel_size_1, voxel_size_1], [voxel_size_2, voxel_size_2, voxel_size_2]] + + # Method 1: + grid_a1 = fvdb.sparse_grid_from_points(pcd_jagged, voxel_sizes=voxel_sizes, origins=[0.0] * 3) + + # Method 2: + grid_a2 = fvdb.GridBatch(device=pcd_jagged.device) + grid_a2.set_from_points(pcd_jagged, voxel_sizes=voxel_sizes, origins=[0.0] * 3) + + # Visualization + gv_a1, ge_a1 = grid_a1.viz_edge_network + ps.remove_all_structures() + ps.register_point_cloud("pcd_1", pcd_1, enabled=True, radius=0.01) + ps.register_curve_network( + "grid_a1", gv_a1[0].jdata.cpu().numpy(), ge_a1[0].jdata.cpu().numpy(), enabled=True, radius=0.004 + ) + ps.register_point_cloud("pcd_2", pcd_2, enabled=True, radius=0.01) + ps.register_curve_network( + "grid_a2", gv_a1[1].jdata.cpu().numpy(), ge_a1[1].jdata.cpu().numpy(), enabled=True, radius=0.004 + ) + ps.show() + + # Build grid from containing nearest voxels to the points + grid_b = fvdb.sparse_grid_from_nearest_voxels_to_points(pcd_jagged, voxel_sizes=voxel_sizes, origins=[0.0] * 3) + + # Visualization + gv_b, ge_b = grid_b.viz_edge_network + ps.remove_all_structures() + ps.register_point_cloud("pcd_1", pcd_1, enabled=True, radius=0.01) + ps.register_curve_network( + "grid_b1", gv_b[0].jdata.cpu().numpy(), ge_b[0].jdata.cpu().numpy(), enabled=True, radius=0.004 + ) + ps.register_point_cloud("pcd_2", pcd_2, enabled=True, radius=0.01) + ps.register_curve_network( + "grid_b2", gv_b[1].jdata.cpu().numpy(), ge_b[1].jdata.cpu().numpy(), enabled=True, radius=0.004 + ) + ps.show() + + +def build_from_coordinates(coords_1: np.ndarray, coords_2: np.ndarray): + coords_jagged = JaggedTensor([torch.from_numpy(coords_1).long().cuda(), torch.from_numpy(coords_2).long().cuda()]) + voxel_sizes = [[voxel_size_1, voxel_size_1, voxel_size_1], [voxel_size_2, voxel_size_2, voxel_size_2]] + + grid = fvdb.sparse_grid_from_ijk(coords_jagged, voxel_sizes=voxel_sizes, origins=[0.0] * 3) + + # Visualization + grid_mesh_1 = pcu.voxel_grid_geometry( + grid.ijk[0].jdata.cpu().numpy(), grid.voxel_sizes[0].cpu().numpy(), gap_fraction=0.1 + ) + grid_mesh_2 = pcu.voxel_grid_geometry( + grid.ijk[1].jdata.cpu().numpy(), grid.voxel_sizes[1].cpu().numpy(), gap_fraction=0.1 + ) + ps.remove_all_structures() + ps.register_surface_mesh("grid_1", grid_mesh_1[0], grid_mesh_1[1], enabled=True) + ps.register_surface_mesh("grid_2", grid_mesh_2[0], grid_mesh_2[1], enabled=True) + ps.show() + + +def build_from_mesh(mesh_1_vf, mesh_2_vf): + mesh_1_v, mesh_1_f = mesh_1_vf + mesh_2_v, mesh_2_f = mesh_2_vf + + mesh_v_jagged = JaggedTensor([torch.from_numpy(mesh_1_v).float().cuda(), torch.from_numpy(mesh_2_v).float().cuda()]) + mesh_f_jagged = JaggedTensor( + [ + torch.from_numpy(mesh_1_f.astype(np.int64)).long().cuda(), + torch.from_numpy(mesh_2_f.astype(np.int64)).long().cuda(), + ] + ) + + voxel_sizes = [[voxel_size_1, voxel_size_1, voxel_size_1], [voxel_size_2, voxel_size_2, voxel_size_2]] + grid = fvdb.sparse_grid_from_mesh(mesh_v_jagged, mesh_f_jagged, voxel_sizes=voxel_sizes, origins=[0.0] * 3) + + # Visualization + gv, ge = grid.viz_edge_network + ps.remove_all_structures() + ps.register_surface_mesh("mesh_1", mesh_1_v, mesh_1_f, enabled=True) + ps.register_curve_network( + "grid_1", gv[0].jdata.cpu().numpy(), ge[0].jdata.cpu().numpy(), enabled=True, radius=0.004 + ) + ps.register_surface_mesh("mesh_2", mesh_2_v, mesh_2_f, enabled=True) + ps.register_curve_network( + "grid_2", gv[1].jdata.cpu().numpy(), ge[1].jdata.cpu().numpy(), enabled=True, radius=0.004 + ) + ps.show() + + +def build_from_dense(): + grid = fvdb.sparse_grid_from_dense(num_grids=1, dense_dims=[32, 32, 32], device="cuda") + + # Easy way to initialize a VDBTensor from a torch 3D tensor [B, D, H, W, C] + dense_data = torch.ones(2, 32, 32, 32, 16).cuda() + sparse_data = VDBTensor.from_dense(dense_data, voxel_sizes=[0.1] * 3) + dense_data_back = sparse_data.to_dense() + assert torch.all(dense_data == dense_data_back) + + # Visualization + grid_mesh = pcu.voxel_grid_geometry( + grid.ijk[0].jdata.cpu().numpy(), grid.voxel_sizes[0].cpu().numpy(), gap_fraction=0.1 + ) + ps.remove_all_structures() + ps.register_surface_mesh("grid_1", grid_mesh[0], grid_mesh[1], enabled=True) + ps.show() + + +if __name__ == "__main__": + ps.init() + ps.set_ground_plane_mode("shadow_only") + ps.set_navigation_style("free") + + base_path = Path(__file__).parent.parent + + mesh_1_v, mesh_1_f = load_car_1(mode="vf", device=torch.device("cpu")) + mesh_2_v, mesh_2_f = load_car_2(mode="vf", device=torch.device("cpu")) + + mesh_1_v, mesh_1_f = mesh_1_v.numpy(), mesh_1_f.numpy().astype(np.int64) + mesh_2_v, mesh_2_f = mesh_2_v.numpy(), mesh_2_f.numpy().astype(np.int64) + + mesh_2_v[:, 2] += 0.8 + + fi1, bc1 = pcu.sample_mesh_random(mesh_1_v, mesh_1_f, 10000) + fi2, bc2 = pcu.sample_mesh_random(mesh_2_v, mesh_2_f, 10000) + + pcd_1 = pcu.interpolate_barycentric_coords(mesh_1_f, fi1, bc1, mesh_1_v) + pcd_2 = pcu.interpolate_barycentric_coords(mesh_2_f, fi2, bc2, mesh_2_v) + + ijk_1 = np.unique(np.floor(pcd_1 / voxel_size_1).astype(np.int64), axis=0) + ijk_2 = np.unique(np.floor(pcd_2 / voxel_size_2).astype(np.int64), axis=0) + + build_from_pointcloud(pcd_1, pcd_2) + build_from_mesh((mesh_1_v, mesh_1_f), (mesh_2_v, mesh_2_f)) + build_from_coordinates(ijk_1, ijk_2) + build_from_dense() diff --git a/fvdb/examples/grid_subdivide_coarsen.py b/fvdb/examples/grid_subdivide_coarsen.py new file mode 100644 index 0000000000..a4dec20498 --- /dev/null +++ b/fvdb/examples/grid_subdivide_coarsen.py @@ -0,0 +1,56 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import uuid + +import numpy as np +import point_cloud_utils as pcu +import polyscope as ps +import torch +from common import load_dragon_mesh + +import fvdb + + +def visualize_grid(a: fvdb.GridBatch, offset: float): + assert a.grid_count == 1 + mesh_a = pcu.voxel_grid_geometry(a.ijk[0].jdata.cpu().numpy(), a.voxel_sizes[0].cpu().numpy(), gap_fraction=0.1) + ps.register_surface_mesh( + str(uuid.uuid4()), + mesh_a[0] + np.array([0.0, 0.0, offset]) - a.voxel_sizes[0].cpu().numpy()[None, :] / 2.0, + mesh_a[1], + enabled=True, + ) + + +if __name__ == "__main__": + ps.init() + ps.set_ground_plane_mode("shadow_only") + ps.set_navigation_style("free") + + [p] = load_dragon_mesh(mode="v", device=torch.device("cuda")) + + grid_origin = fvdb.sparse_grid_from_points(p, voxel_sizes=[0.005] * 3, origins=[0.0] * 3) + visualize_grid(grid_origin, 0.0) + + grid_subdivided = grid_origin.subdivided_grid(2) + visualize_grid(grid_subdivided, 0.15) + + grid_coarsened = grid_origin.coarsened_grid(2) + visualize_grid(grid_coarsened, 0.3) + + ps.show() + + grid_dual = grid_origin.dual_grid() + + grid_dual_gv, grid_dual_ge = grid_dual.viz_edge_network + ps.remove_all_structures() + visualize_grid(grid_origin, 0.0) + ps.register_curve_network( + str(uuid.uuid4()), + grid_dual_gv[0].jdata.cpu().numpy(), + grid_dual_ge[0].jdata.cpu().numpy(), + enabled=True, + radius=0.004, + ) + ps.show() diff --git a/fvdb/examples/mutable_grids.py b/fvdb/examples/mutable_grids.py new file mode 100644 index 0000000000..6af4eeafaa --- /dev/null +++ b/fvdb/examples/mutable_grids.py @@ -0,0 +1,107 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +from pathlib import Path + +import point_cloud_utils as pcu +import polyscope as ps +import torch +from common import load_car_1, load_car_2 + +import fvdb +from fvdb import GridBatch, JaggedTensor + + +def visualize_grid_color(grid: GridBatch, rgb: JaggedTensor, ignore_disabled: bool = False): + for b in range(grid.grid_count): + grid_mask = grid.enabled_mask[b].jdata.cpu().numpy() + if ignore_disabled: + grid_mask.fill(True) + + grid_mesh = pcu.voxel_grid_geometry( + grid.ijk[b].jdata.cpu().numpy()[grid_mask], grid.voxel_sizes[b].cpu().numpy(), gap_fraction=0.1 + ) + grid_color = rgb[b].jdata.cpu().numpy()[grid_mask].repeat(8, axis=0).reshape(-1, 3) + + ps.register_surface_mesh(f"grid_{b}", grid_mesh[0], grid_mesh[1], enabled=True).add_color_quantity( + "color", grid_color, enabled=True + ) + + +if __name__ == "__main__": + ps.init() + ps.set_ground_plane_mode("shadow_only") + ps.set_navigation_style("free") + + base_path = Path(__file__).parent.parent + + mesh_1_v, mesh_1_f = load_car_1(mode="vf") + mesh_2_v, mesh_2_f = load_car_2(mode="vf") + + mesh_1_f, mesh_2_f = mesh_1_f.long(), mesh_2_f.long() + mesh_2_v[:, 2] += 0.8 + + mesh_v_jagged = JaggedTensor([mesh_1_v, mesh_2_v]) + mesh_f_jagged = JaggedTensor([mesh_1_f, mesh_2_f]) + + fi1, bc1 = pcu.sample_mesh_random(mesh_1_v.cpu().numpy(), mesh_1_f.cpu().numpy(), 10000) + fi2, bc2 = pcu.sample_mesh_random(mesh_2_v.cpu().numpy(), mesh_2_f.cpu().numpy(), 10000) + + pcd_1 = pcu.interpolate_barycentric_coords(mesh_1_f.cpu().numpy(), fi1, bc1, mesh_1_v.cpu().numpy()) + pcd_2 = pcu.interpolate_barycentric_coords(mesh_2_f.cpu().numpy(), fi2, bc2, mesh_2_v.cpu().numpy()) + pcd_jagged = JaggedTensor([torch.from_numpy(pcd_1).float().cuda(), torch.from_numpy(pcd_2).float().cuda()]) + + # Grid creation + grid = fvdb.sparse_grid_from_mesh( + mesh_v_jagged, mesh_f_jagged, voxel_sizes=[0.01] * 3, origins=[0.0] * 3, mutable=True + ) + feature = grid.grid_to_world(grid.ijk.float()) + feature.jdata = (feature.jdata - feature.jdata.min(dim=0).values) / ( + feature.jdata.max(dim=0).values - feature.jdata.min(dim=0).values + ) + + # Visualization + ps.remove_all_structures() + visualize_grid_color(grid, feature) + ps.show() + + # Get the IJK coordinates to be disabled + disable_ijk = grid.ijk.rmask(feature.jdata[:, 0] > 0.5) + grid.disable_ijk(disable_ijk) + + # Visualize disable mask + enabled_mask = grid.enabled_mask + ps.remove_all_structures() + visualize_grid_color( + grid, feature.jagged_like(enabled_mask.jdata.unsqueeze(1).repeat(1, 3).float()), ignore_disabled=True + ) + ps.show() + + # Sample features onto points + pts_feature = grid.sample_trilinear(pcd_jagged, feature) + + # Visualize (disabled grid will no longer function) + ps.remove_all_structures() + ps.register_point_cloud("pcd_1", pcd_1, enabled=True).add_color_quantity( + "feature", pts_feature[0].jdata.cpu().numpy(), enabled=True + ) + ps.register_point_cloud("pcd_2", pcd_2, enabled=True).add_color_quantity( + "feature", pts_feature[1].jdata.cpu().numpy(), enabled=True + ) + ps.show() + + # We could enable those IJK back + grid.enable_ijk(disable_ijk) + + # Sample features onto points + pts_feature = grid.sample_trilinear(pcd_jagged, feature) + + # Visualize (this time we got the original features back) + ps.remove_all_structures() + ps.register_point_cloud("pcd_1", pcd_1, enabled=True).add_color_quantity( + "feature", pts_feature[0].jdata.cpu().numpy(), enabled=True + ) + ps.register_point_cloud("pcd_2", pcd_2, enabled=True).add_color_quantity( + "feature", pts_feature[1].jdata.cpu().numpy(), enabled=True + ) + ps.show() diff --git a/fvdb/examples/overfit_sdf.py b/fvdb/examples/overfit_sdf.py new file mode 100644 index 0000000000..69be3eb49d --- /dev/null +++ b/fvdb/examples/overfit_sdf.py @@ -0,0 +1,116 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import os +import logging + +import numpy as np +import point_cloud_utils as pcu +import polyscope as ps +import torch +import tqdm +from common import load_happy_mesh + +from fvdb import GridBatch + + +def prepare_sdf(npts, ng): + logging.info("Loading data...") + v, f = load_happy_mesh(mode="vf", device=torch.device("cpu")) + v -= v.amin(0) + v /= v.amax() + v -= 0.5 * v.amax() + v = v.numpy() + f = f.type(torch.int32).numpy() + + n = pcu.estimate_mesh_vertex_normals(v, f) + fid, bc = pcu.sample_mesh_poisson_disk(v, f, npts) + pts = pcu.interpolate_barycentric_coords(f, fid, bc, v) + nms = pcu.interpolate_barycentric_coords(f, fid, bc, n) + logging.info("Done") + + logging.info("Generating grid samples") + gpts = np.stack( + [ + a.ravel() + for a in np.mgrid[ + v.min(0)[0] * 1.05 : v.max(0)[0] * 1.05 : ng * 1j, + v.min(0)[1] * 1.05 : v.max(0)[1] * 1.05 : ng * 1j, + v.min(0)[2] * 1.05 : v.max(0)[2] * 1.05 : ng * 1j, + ] + ], + axis=-1, + ).astype(pts.dtype) + logging.info("Done") + + logging.info("Computing SDF") + sdf, _, _ = pcu.signed_distance_to_mesh(gpts, v, f) + logging.info("Done") + + return pts, nms, gpts, sdf + + +def main(): + torch.random.manual_seed(5) + logging.basicConfig(level=logging.INFO) + logging.addLevelName(logging.INFO, "\033[1;32m%s\033[1;0m" % logging.getLevelName(logging.INFO)) + device = torch.device("cuda") + dtype = torch.float32 + vox_size = 0.005 + vox_origin = np.zeros(3) + vox_pad = 1 + ng = 256 + npts = 10_000 + num_iters = 750 + + # Cache the sdf data so we don't have to recompute it every single time + if not os.path.exists("sdf.npz"): + pts, nms, gpts, sdf = prepare_sdf(npts, ng) + np.savez("sdf.npz", pts=pts, nms=nms, gpts=gpts, sdf=sdf) + else: + dat = np.load("sdf.npz") + pts, nms, gpts, sdf = dat["pts"], dat["nms"], dat["gpts"], dat["sdf"] + + p, n = torch.from_numpy(pts).to(device).to(dtype), torch.from_numpy(nms).to(device).to(dtype) + + grid = GridBatch(device=device) + + grid.set_from_points(p, [-vox_pad] * 3, [vox_pad] * 3, vox_size, vox_origin) + dual_index = grid.dual_grid() + + mask = grid.points_in_active_voxel(torch.from_numpy(gpts).to(dtype).to(device)).jdata.cpu().numpy() + vol_pts = torch.from_numpy(gpts[mask]).to(device=device, dtype=dtype) + vol_sdf = torch.from_numpy(sdf[mask]).to(device=device, dtype=dtype).unsqueeze(-1) + + features = torch.randn(dual_index.total_voxels, 1).to(device).to(dtype) + features.requires_grad = True + + optimizer = torch.optim.Adam([features], lr=1e-2) + + # This should converge to around 2e-8 loss + pbar = tqdm.tqdm(range(num_iters)) + for _ in pbar: + optimizer.zero_grad() + vp_idx = torch.randperm(vol_pts.shape[0]) + vpts = vol_pts[vp_idx] + vsdf = vol_sdf[vp_idx] + + samp_sdf = dual_index.sample_trilinear(vpts, features).jdata + + loss = torch.nn.functional.mse_loss(samp_sdf, vsdf) + loss.backward() + pbar.set_postfix({"Loss": loss.item()}) + optimizer.step() + + ps.init() + pred_sdf = dual_index.sample_trilinear(vol_pts, features).jdata + assert isinstance(pred_sdf, torch.Tensor) + vol_pc = ps.register_point_cloud("pts", vol_pts.cpu().numpy()) + vol_pc.add_scalar_quantity("sdf_pred", pred_sdf.squeeze().detach().cpu().numpy()) + vol_pc.add_scalar_quantity("sdf_gt", vol_sdf.squeeze().detach().cpu().numpy()) + vol_pc.add_scalar_quantity("delta", (vol_sdf - pred_sdf).squeeze().abs().detach().cpu().numpy()) + ps.show() + + +if __name__ == "__main__": + main() diff --git a/fvdb/examples/ray_segment_marching.py b/fvdb/examples/ray_segment_marching.py new file mode 100644 index 0000000000..bfc9f5ae5c --- /dev/null +++ b/fvdb/examples/ray_segment_marching.py @@ -0,0 +1,102 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import logging +import timeit + +from fvdb import GridBatch, JaggedTensor +import torch +import polyscope as ps + +from common import load_dragon_mesh, make_ray_grid + + +def main(): + torch.random.manual_seed(5) + logging.basicConfig(level=logging.INFO) + logging.addLevelName(logging.INFO, "\033[1;32m%s\033[1;0m" % logging.getLevelName(logging.INFO)) + + device = torch.device("cuda") + dtype = torch.float32 + vox_size = 0.04 + vox_origin = torch.zeros(3).to(device) + + N = 10 # Maximum number of segments to intersect along ray + nrays = 100 + plot_every = 20 + batch_size = 2 + + p, n = load_dragon_mesh(device=device, dtype=dtype) + + p, n = load_dragon_mesh(device=device, dtype=dtype) + p -= p.mean(0) + p /= 10.0 + p = torch.concatenate( + [ + p, + p + 2 * torch.tensor([0, 0, 0.48], device=p.device), + p + 2 * torch.tensor([0, 0, 0.96], device=p.device), + # p + 1 * torch.tensor([0, 0, 1.44], device=p.device), + ] + ) + n = torch.concatenate([n, n, n]) + + ray_o, ray_d = make_ray_grid(nrays, [0.0, 0.1, -0.1], device=device, dtype=dtype) + pmt = torch.randperm(ray_o.shape[0]).to(device) + ray_o, ray_d = ray_o[pmt], ray_d[pmt] + + p, n = JaggedTensor([p] * batch_size), JaggedTensor([n] * batch_size) + ray_o, ray_d = JaggedTensor([ray_o] * batch_size), JaggedTensor([ray_d] * batch_size) + + grid = GridBatch(device=device) + grid.set_from_points(p, [-1] * 3, [1] * 3, voxel_sizes=vox_size, origins=vox_origin) + + gc, ge = grid.viz_edge_network + + logging.info(f"Tracing {nrays ** 2} Ray Segments...") + start = timeit.default_timer() + segments = grid.segments_along_rays(ray_o, ray_d, N, eps=1e-5) + if p.is_cuda: + torch.cuda.synchronize() + logging.info(f"Done in {timeit.default_timer() - start}s!") + + ps.init() + ps.set_ground_plane_mode("shadow_only") + + for b_i in range(batch_size): + ps.register_point_cloud("points", p[b_i].jdata.cpu(), radius=0.00025, point_render_mode="quad") + for i in range(0, len(ray_o[b_i].jdata), plot_every): + roi = ray_o[b_i].jdata[i].unsqueeze(0) # [1, 3] + rdi = ray_d[b_i].jdata[i].unsqueeze(0) # [1, 3] + segsi = segments[b_i][i].jdata # [N, 2] + + if segsi.numel() == 0: + continue + + rp = torch.cat( + [ + roi + segsi[:, 0].unsqueeze(-1) * rdi, + roi + segsi[:, 1].unsqueeze(-1) * rdi, + ] + ) + re = torch.stack([torch.arange(segsi.shape[0]), torch.arange(segsi.shape[0]) + segsi.shape[0]], dim=-1) + + ray_segs = ps.register_curve_network(f"ray segments {i}", rp.cpu(), re.cpu(), radius=0.00175) + rv = torch.zeros(re.shape[0]) + rv[::2] = 1.0 + ray_segs.add_scalar_quantity(f"segment colors {i}", rv.cpu(), defined_on="edges", enabled=True, cmap="jet") + + ps.register_point_cloud("grid corners", gc.jdata.cpu(), enabled=True, radius=0.00025, point_render_mode="quad") + ps.register_curve_network( + "grid edges", gc.jdata.cpu(), ge.jdata.cpu(), enabled=True, radius=0.00025, transparency=0.7 + ) + + # ray_dir_points = torch.cat([ray_o, ray_o + 0.5 * ray_d]) + # ray_dir_edges = torch.stack([torch.arange(ray_o.shape[0]), torch.arange(ray_o.shape[0]) + ray_o.shape[0]], dim=-1) + # ps.register_curve_network("ray directions", ray_dir_points, ray_dir_edges, radius=0.0005) + # ps.register_point_cloud("ray origins", ray_o, radius=0.01) + ps.show() + + +if __name__ == "__main__": + main() diff --git a/fvdb/examples/ray_voxel_marching.py b/fvdb/examples/ray_voxel_marching.py new file mode 100644 index 0000000000..a5ed3ee908 --- /dev/null +++ b/fvdb/examples/ray_voxel_marching.py @@ -0,0 +1,100 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import timeit +import logging + +import polyscope as ps +import torch + +import fvdb +from fvdb import JaggedTensor, GridBatch + +from common import load_dragon_mesh, make_ray_grid, plot_ray_segments + + +def main(): + logging.basicConfig(level=logging.INFO) + logging.addLevelName(logging.INFO, "\033[1;32m%s\033[1;0m" % logging.getLevelName(logging.INFO)) + device = torch.device("cuda") + dtype = torch.float32 + vox_size = 0.04 + vox_origin = [0.0, 0.0, 0.0] + + nrays = 1024 # 100 x 100 rays + plot_every = 512 # only plot every n rays + max_voxels = 20 # maximum number of voxels to intersect along ray + + p, n = load_dragon_mesh(device=device, dtype=dtype) + p -= p.mean(0) + p /= 10.0 + p = torch.concatenate( + [ + p, + p + 2 * torch.tensor([0, 0, 0.48], device=p.device), + p + 2 * torch.tensor([0, 0, 0.96], device=p.device), + # p + 1 * torch.tensor([0, 0, 1.44], device=p.device), + ] + ) + n = torch.concatenate([n, n, n]) + + batch_size = 2 + + p = fvdb.JaggedTensor([p] * batch_size) + n = fvdb.JaggedTensor([n] * batch_size) + + grid = GridBatch(device=device) + grid.set_from_points(p, [-1] * 3, [1] * 3, voxel_sizes=vox_size, origins=vox_origin) + + logging.info(f"Created {len(grid)} grids with {grid.total_voxels} total voxels") + gc, ge = grid.viz_edge_network + + ray_o, ray_d = make_ray_grid(nrays, [0.0, 0.0, -0.1], device=device, dtype=dtype) + pmt = torch.randperm(ray_o.shape[0]).to(device) + ray_o, ray_d = ray_o[pmt], ray_d[pmt] + + ray_o, ray_d = fvdb.JaggedTensor([ray_o] * batch_size), fvdb.JaggedTensor([ray_d] * batch_size) + + logging.info(f"Tracing {nrays ** 2} Rays Per Grid...") + start = timeit.default_timer() + vox, times = grid.voxels_along_rays(ray_o, ray_d, max_voxels, 1e-4) + if p.jdata.is_cuda: + torch.cuda.synchronize() + logging.info(f"Done in {timeit.default_timer() - start}s") + + logging.info(f"There are {len(vox)} sets of intersections in the batch") + for i, visect in enumerate(vox): + logging.info(f"There are {len(visect)} rays in the {i}th set of intersections") + + logging.info("Plotting") + ps.init() + for i in range(batch_size): + p_i = p[i].jdata.cpu() + ray_o_i, ray_d_i = ray_o[i].jdata.cpu(), ray_d[i].jdata.cpu() + times_i = times[i].cpu() + gc_i, ge_i = gc[i].cpu(), ge[i].cpu() + + ps.set_ground_plane_mode("shadow_only") + + ps.register_point_cloud("points", p_i, radius=0.00025) + logging.info("About to plot ray segments") + plot_ray_segments(ray_o_i, ray_d_i, times_i, plot_every) + logging.info("Plotted Ray Segments") + + logging.info(f"Creating a new grid of only the voxels intersected by this ray") + isected_grid = fvdb.sparse_grid_from_ijk(vox[i].jflatten(), voxel_sizes=vox_size, origins=vox_origin) + logging.info(f"Created {len(isected_grid)} grids with {isected_grid.total_voxels} total voxels") + iv, ie = isected_grid.viz_edge_network + ps.register_curve_network("intersected voxels", iv.jdata.cpu(), ie.jdata.cpu(), enabled=True, radius=0.0009) + ps.register_point_cloud("grid corners", gc_i.jdata, enabled=True, radius=0.001) + ps.register_curve_network("grid edges", gc_i.jdata, ge_i.jdata, enabled=True, radius=0.00015, transparency=0.7) + + # ray_dir_points = torch.cat([ray_o_i, ray_o_i + times_i.jdata.max() * ray_d_i]) + # ray_dir_edges = torch.stack([torch.arange(ray_o_i.shape[0]), torch.arange(ray_o_i.shape[0]) + ray_o_i.shape[0]], dim=-1) + # ps.register_curve_network("ray directions", ray_dir_points, ray_dir_edges, radius=0.0005) + # ps.register_point_cloud("ray origins", ray_o, radius=0.01) + ps.show() + + +if __name__ == "__main__": + main() diff --git a/fvdb/examples/sample_trilinear.py b/fvdb/examples/sample_trilinear.py new file mode 100644 index 0000000000..8711691c30 --- /dev/null +++ b/fvdb/examples/sample_trilinear.py @@ -0,0 +1,67 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import timeit +import logging + +import polyscope as ps +import torch +from fvdb import GridBatch + +from common import load_dragon_mesh + + +def main(): + logging.basicConfig(level=logging.INFO) + logging.addLevelName(logging.INFO, "\033[1;32m%s\033[1;0m" % logging.getLevelName(logging.INFO)) + device = torch.device("cuda") + dtype = torch.float32 + vox_size = 0.0025 + vox_origin = torch.zeros(3) + + p, n = load_dragon_mesh(skip_every=1, device=device, dtype=dtype) + + index = GridBatch(device=device) + index.set_from_points(p, voxel_sizes=vox_size, origins=vox_origin) + index_dual = index.dual_grid() + + nsplat = index.splat_trilinear(p, n) + gp = index.ijk + gd = index_dual.ijk + gp = index.grid_to_world(gp.type(dtype)) + gd = index_dual.grid_to_world(gd.type(dtype)) + + features = torch.ones(index_dual.total_voxels, 32).to(device).to(dtype) * torch.norm( + gd.jdata.type(dtype), dim=-1, keepdim=True + ) + features.requires_grad = True + + logging.info("Sampling features....") + start = timeit.default_timer() + features_trilerp = index_dual.sample_trilinear(p, features) + if features.is_cuda: + torch.cuda.synchronize() + logging.info(f"Done in {timeit.default_timer() - start}s!") + loss = features_trilerp.jdata.sum() + loss.backward() + + p, n = p.cpu(), n.cpu() + nsplat = nsplat.cpu() + gp, gd = gp.cpu(), gd.cpu() + features = features.detach().cpu() + features_trilerp = features_trilerp.detach().cpu() + + ps.init() + dual_grid_pts = ps.register_point_cloud("dual grid corners", gd.jdata, radius=0.001) + dual_grid_pts.add_scalar_quantity("feature norms", torch.norm(features, dim=-1), enabled=True) + + primal_grid_pts = ps.register_point_cloud("primal grid corners", gp.jdata, radius=0.0005) + primal_grid_pts.add_vector_quantity("splatted normals", nsplat.jdata, enabled=True, length=0.05, radius=0.001) + + surf_pts = ps.register_point_cloud("points", p, radius=0.0035) + surf_pts.add_scalar_quantity("sampled feature norms", torch.norm(features_trilerp.jdata, dim=-1), enabled=True) + ps.show() + + +if __name__ == "__main__": + main() diff --git a/fvdb/examples/splat_trilinear.py b/fvdb/examples/splat_trilinear.py new file mode 100644 index 0000000000..415fe338c4 --- /dev/null +++ b/fvdb/examples/splat_trilinear.py @@ -0,0 +1,54 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import timeit + +import logging +import polyscope as ps +import torch +from fvdb import GridBatch + +from common import load_dragon_mesh + + +def main(): + logging.basicConfig(level=logging.INFO) + logging.addLevelName(logging.INFO, "\033[1;32m%s\033[1;0m" % logging.getLevelName(logging.INFO)) + device = torch.device("cuda") + dtype = torch.float32 + + vox_size = 0.0025 + vox_origin = (0, 0, 0) + + p, n = load_dragon_mesh(skip_every=1, device=device, dtype=dtype) + + index = GridBatch(device=device) + index.set_from_points(p, voxel_sizes=vox_size, origins=vox_origin) + index_dual = index.dual_grid() + + logging.info("Splatting into grid...") + start = timeit.default_timer() + nsplat = index.splat_trilinear(p, n) + if p.is_cuda: + torch.cuda.synchronize() + logging.info(f"Done in {timeit.default_timer() - start}s!") + + gp = index.ijk + gd = index_dual.ijk + gp = index.grid_to_world(gp.type(dtype)) + gd = index_dual.grid_to_world(gd.type(dtype)) + + p, n = p.cpu(), n.cpu() + nsplat = nsplat.cpu() + gp, gd = gp.cpu(), gd.cpu() + + ps.init() + ps.register_point_cloud("points", p, radius=0.00075) + grid_pts = ps.register_point_cloud("vox coords", gp.jdata, radius=0.0005) + + grid_pts.add_vector_quantity("splatted normals", nsplat.jdata, enabled=True, length=0.05, radius=0.001) + ps.show() + + +if __name__ == "__main__": + main() diff --git a/fvdb/examples/subdivide.py b/fvdb/examples/subdivide.py new file mode 100644 index 0000000000..541a768f46 --- /dev/null +++ b/fvdb/examples/subdivide.py @@ -0,0 +1,74 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import time +import logging + +import polyscope as ps +import torch +from fvdb import GridBatch + +from common import load_dragon_mesh + + +def main(): + logging.basicConfig(level=logging.INFO) + logging.addLevelName(logging.INFO, "\033[1;32m%s\033[1;0m" % logging.getLevelName(logging.INFO)) + device = torch.device("cuda") + dtype = torch.float32 + + vox_size = 0.01 + vox_origin = (0.0, 0.0, 0.0) + p, n = load_dragon_mesh(device=device, dtype=dtype) + + index0 = GridBatch(device) + index0.set_from_points(p, [-1, -1, -1], [1, 1, 1], vox_size, vox_origin) + grids = [index0] + + logging.info("Splatting into grid...") + start = time.time() + nsplat = index0.splat_trilinear(p, n) + if device == "cuda": + torch.cuda.synchronize() + logging.info(f"Done in {time.time() - start}s!") + + logging.info("Building subdivided grids") + start = time.time() + for i in range(2): + subdiv_factor = i + 1 + mask = torch.rand(grids[i].total_voxels, device=device) > 0.5 + grids.append(grids[-1].subdivided_grid(subdiv_factor, mask)) + assert mask.sum().item() * subdiv_factor**3 == grids[-1].total_voxels + if device == "cuda": + torch.cuda.synchronize() + logging.info(f"Done in {time.time() - start}s!") + + p, n = p.cpu(), n.cpu() + + ps.init() + ps.register_point_cloud("points", p, radius=0.00075) + + for i, index in enumerate(grids): + dual_index = index.dual_grid() + gp = index.ijk.jdata + gd = dual_index.ijk.jdata + dual_v, dual_e = index.viz_edge_network + + dual_v = dual_v.jdata.cpu() + dual_e = dual_e.jdata.cpu() + gp = index.grid_to_world(gp.to(dtype)).cpu() + gd = dual_index.grid_to_world(gd.to(dtype)).cpu() + gp, gd = gp.cpu().jdata, gd.cpu().jdata + + ps.register_curve_network(f"grid edges {i}", dual_v.cpu(), dual_e.cpu(), enabled=True, radius=0.0005) + ps.register_point_cloud(f"vox corners {i}", gd, radius=0.0005 * (i + 1)) + if i == 0: + grid_pts = ps.register_point_cloud("vox centers", gp, radius=0.0005) + grid_pts.add_vector_quantity( + "splatted normals", nsplat.jdata.cpu(), enabled=True, length=0.05, radius=0.001 + ) + ps.show() + + +if __name__ == "__main__": + main() diff --git a/fvdb/examples/uniform_ray_marching.py b/fvdb/examples/uniform_ray_marching.py new file mode 100644 index 0000000000..d6176d7a46 --- /dev/null +++ b/fvdb/examples/uniform_ray_marching.py @@ -0,0 +1,124 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import os +import time +import logging +import timeit + +import numpy as np +import point_cloud_utils as pcu +import polyscope as ps +import torch +from fvdb import GridBatch, JaggedTensor +import fvdb + +from common import load_dragon_mesh, make_ray_grid, plot_ray_segments + + +def main(): + torch.random.manual_seed(5) + logging.basicConfig(level=logging.INFO) + logging.addLevelName(logging.INFO, "\033[1;32m%s\033[1;0m" % logging.getLevelName(logging.INFO)) + + device = torch.device("cuda") + dtype = torch.float32 + vox_size = 0.04 + vox_origin = torch.zeros(3).to(device) + step_size = 0.5 * vox_size + + N = 10 # Maximum number of segments to intersect along ray + nrays = 100 + plot_every = 20 + batch_size = 2 + + p, n = load_dragon_mesh(device=device, dtype=dtype) + + p, n = load_dragon_mesh(device=device, dtype=dtype) + p -= p.mean(0) + p /= 10.0 + p = torch.concatenate( + [ + p, + p + 2 * torch.tensor([0, 0, 0.48], device=p.device), + p + 2 * torch.tensor([0, 0, 0.96], device=p.device), + # p + 1 * torch.tensor([0, 0, 1.44], device=p.device), + ] + ) + n = torch.concatenate([n, n, n]) + + ray_o, ray_d = make_ray_grid(nrays, [0.0, 0.1, -0.1], device=device, dtype=dtype) + pmt = torch.randperm(ray_o.shape[0]).to(device) + ray_o, ray_d = ray_o[pmt], ray_d[pmt] + + p, n = JaggedTensor([p] * batch_size), JaggedTensor([n] * batch_size) + ray_o, ray_d = JaggedTensor([ray_o] * batch_size), JaggedTensor([ray_d] * batch_size) + + grid = GridBatch(device=device) + grid.set_from_points(p, [-1] * 3, [1] * 3, voxel_sizes=vox_size, origins=vox_origin) + + gc, ge = grid.viz_edge_network + + logging.info(f"Tracing {nrays ** 2} Ray Segments...") + start = timeit.default_timer() + segments = grid.segments_along_rays(ray_o, ray_d, N, eps=1e-5) + if p.is_cuda: + torch.cuda.synchronize() + logging.info(f"Done in {timeit.default_timer() - start}s!") + + tmin = fvdb.jzeros(ray_o.lshape, device=device, dtype=dtype) + tmax = fvdb.jones(ray_o.lshape, device=device, dtype=dtype) * 1e10 + + logging.info(f"Generating samples for {ray_o.rshape[0]} Ray Segments...") + start = timeit.default_timer() + ray_ts = grid.uniform_ray_samples(ray_o, ray_d, tmin, tmax, step_size, eps=1e-4) + if p.is_cuda: + torch.cuda.synchronize() + logging.info(f"Done in {timeit.default_timer() - start}s!") + + print(ray_ts.eshape) + ps.init() + ps.set_ground_plane_mode("shadow_only") + + for b_i in range(batch_size): + ps.register_point_cloud("points", p[b_i].jdata.cpu(), radius=0.00025, point_render_mode="quad") + for i in range(0, len(ray_o[b_i].jdata), plot_every): + roi = ray_o[b_i].jdata[i].unsqueeze(0) # [1, 3] + rdi = ray_d[b_i].jdata[i].unsqueeze(0) # [1, 3] + segsi = segments[b_i][i].jdata # [N, 2] + + if segsi.numel() == 0: + continue + + rp = torch.cat( + [ + roi + segsi[:, 0].unsqueeze(-1) * rdi, + roi + segsi[:, 1].unsqueeze(-1) * rdi, + ] + ) + re = torch.stack([torch.arange(segsi.shape[0]), torch.arange(segsi.shape[0]) + segsi.shape[0]], dim=-1) + + # ray_segs = ps.register_curve_network(f"ray segments {i}", rp.cpu(), re.cpu(), radius=0.00075) + + ray_ts_i = ray_ts[b_i][i].jdata + ray_ts_i = 0.5 * (ray_ts_i[:, 0] + ray_ts_i[:, 1]) + ray_samples = roi + ray_ts_i.unsqueeze(-1) * rdi + ps.register_point_cloud(f"ray samples {i}", ray_samples.cpu(), radius=0.0015) + # rv = torch.zeros(re.shape[0]) + # rv[::2] = 1.0 + # ray_segs.add_scalar_quantity(f"segment colors {i}", rv.cpu(), defined_on="edges", enabled=True, cmap="jet") + + ps.register_point_cloud("grid corners", gc.jdata.cpu(), enabled=True, radius=0.00025, point_render_mode="quad") + ps.register_curve_network( + "grid edges", gc.jdata.cpu(), ge.jdata.cpu(), enabled=True, radius=0.00025, transparency=0.7 + ) + + # ray_dir_points = torch.cat([ray_o, ray_o + 0.5 * ray_d]) + # ray_dir_edges = torch.stack([torch.arange(ray_o.shape[0]), torch.arange(ray_o.shape[0]) + ray_o.shape[0]], dim=-1) + # ps.register_curve_network("ray directions", ray_dir_points, ray_dir_edges, radius=0.0005) + # ps.register_point_cloud("ray origins", ray_o, radius=0.01) + ps.show() + + +if __name__ == "__main__": + main() diff --git a/fvdb/examples/voxel_neighborhood.py b/fvdb/examples/voxel_neighborhood.py new file mode 100644 index 0000000000..c69576768f --- /dev/null +++ b/fvdb/examples/voxel_neighborhood.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: MPL-2.0 +# +import numpy as np +import polyscope as ps +import torch +from common import load_dragon_mesh + +from fvdb import GridBatch, sparse_grid_from_ijk + + +def main(): + device = "cuda" + + vox_size = 0.0075 + vox_origin = (0, 0, 0) + N = 1 + + [p] = load_dragon_mesh(mode="v", skip_every=N, device=torch.device(device)) + + index = GridBatch(device=device) + index.set_from_points(p, [-1, -1, -1], [1, 1, 1], vox_size, vox_origin) + + primal_voxels = index.ijk.jdata + + nhood = index.neighbor_indexes(primal_voxels, 1, 0).jdata + + ps.init() + for _ in range(10): + randvox = np.random.randint(nhood.shape[0]) + + voxijk = primal_voxels[randvox] + nbrs = primal_voxels[nhood[randvox][nhood[randvox] >= 0]] + print(nhood[randvox]) + nhood_ijk = torch.cat([voxijk.unsqueeze(0), nbrs], dim=0) + + vp, ve = index.viz_edge_network + vp, ve = vp.jdata, ve.jdata + + vi, vei = sparse_grid_from_ijk(nhood_ijk, voxel_sizes=vox_size, origins=vox_origin).viz_edge_network + vi, vei = vi.jdata, vei.jdata + + ps.register_curve_network("vox", vp.cpu().numpy(), ve.cpu().numpy(), radius=0.0025) + ps.register_curve_network("nhd", vi.cpu().numpy(), vei.cpu().numpy(), radius=0.005) + ps.show() + + +if __name__ == "__main__": + main() diff --git a/fvdb/fvdb/_Cpp.pyi b/fvdb/fvdb/_Cpp.pyi index 30d2d9fea5..c262a47c5b 100644 --- a/fvdb/fvdb/_Cpp.pyi +++ b/fvdb/fvdb/_Cpp.pyi @@ -6,16 +6,31 @@ import numpy import torch from enum import Enum - Numeric = Union[int, float] TorchDeviceOrString = Union[torch.device, str] -Vec3iBatch = Union[torch.Tensor, numpy.ndarray, List[int], List[List[int]], - Tuple[int, int, int], List[Tuple[int, int, int]]] -Vec3dBatch = Union[torch.Tensor, numpy.ndarray, List[float], List[List[float]], - Tuple[float, float, float], List[Tuple[float, float, float]], Vec3iBatch] -Vec3dBatchOrScalar = Union[torch.Tensor, numpy.ndarray, List[float], List[List[float]], - Tuple[float, float, float], List[Tuple[float, float, float]], - float, Vec3iBatch, int] +Vec3iBatch = Union[ + torch.Tensor, numpy.ndarray, List[int], List[List[int]], Tuple[int, int, int], List[Tuple[int, int, int]] +] +Vec3dBatch = Union[ + torch.Tensor, + numpy.ndarray, + List[float], + List[List[float]], + Tuple[float, float, float], + List[Tuple[float, float, float]], + Vec3iBatch, +] +Vec3dBatchOrScalar = Union[ + torch.Tensor, + numpy.ndarray, + List[float], + List[List[float]], + Tuple[float, float, float], + List[Tuple[float, float, float]], + float, + Vec3iBatch, + int, +] Vec3i = Union[torch.Tensor, numpy.ndarray, List[int], Tuple[int, int, int]] Vec3d = Union[torch.Tensor, numpy.ndarray, List[float], Tuple[float, float, float]] @@ -46,7 +61,6 @@ class JaggedTensor: def type(self, arg0: torch.dtype) -> JaggedTensor: ... def to(self, device: TorchDeviceOrString | torch.dtype) -> JaggedTensor: ... def rmask(self, mask: torch.Tensor) -> JaggedTensor: ... - def __add__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __sub__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __mul__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... @@ -54,7 +68,6 @@ class JaggedTensor: def __truediv__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __floordiv__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __mod__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... - def __iadd__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __isub__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __imul__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... @@ -62,32 +75,26 @@ class JaggedTensor: def __itruediv__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __ifloordiv__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __imod__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... - def __gt__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __ge__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __lt__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __le__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __eq__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... def __ne__(self, other: Union[int, float, JaggedTensor]) -> JaggedTensor: ... - def __getitem__(self, idx: Index | JaggedTensor) -> JaggedTensor: ... def __iter__(self) -> Iterator[JaggedTensor]: ... def __len__(self) -> int: ... - def __getstate__(self) -> tuple: ... def __setstate__(self, arg0: tuple) -> None: ... - def detach(self) -> JaggedTensor: ... def requires_grad_(self, requires_grad: bool) -> JaggedTensor: ... def jagged_like(self, data: torch.Tensor) -> JaggedTensor: ... def clone(self) -> JaggedTensor: ... - def sqrt(self) -> JaggedTensor: ... def abs(self) -> JaggedTensor: ... def round(self, decimals: int = ...) -> JaggedTensor: ... def floor(self) -> JaggedTensor: ... def ceil(self) -> JaggedTensor: ... - def sqrt_(self) -> JaggedTensor: ... def abs_(self) -> JaggedTensor: ... def round_(self, decimals: int = ...) -> JaggedTensor: ... @@ -95,18 +102,13 @@ class JaggedTensor: def ceil_(self) -> JaggedTensor: ... # def jagged_argsort(self) -> JaggedTensor: ... - def jsum(self, dim : int = 0, keepdim : bool = False) -> JaggedTensor: ... - def jmin(self, dim : int = 0, keepdim : bool = False) -> Tuple[JaggedTensor, JaggedTensor]: ... - def jmax(self, dim : int = 0, keepdim : bool = False) -> Tuple[JaggedTensor, JaggedTensor]: ... - + def jsum(self, dim: int = 0, keepdim: bool = False) -> JaggedTensor: ... + def jmin(self, dim: int = 0, keepdim: bool = False) -> Tuple[JaggedTensor, JaggedTensor]: ... + def jmax(self, dim: int = 0, keepdim: bool = False) -> Tuple[JaggedTensor, JaggedTensor]: ... def jreshape(self, lshape: LShapeSpec) -> JaggedTensor: ... - def jreshape_as(self, other: JaggedTensor) -> JaggedTensor: ... - def jflatten(self, dim: int = 0) -> JaggedTensor: ... - def unbind(self) -> Union[List[torch.Tensor], List[List[torch.Tensor]]]: ... - @property def num_tensors(self) -> int: ... @property @@ -135,19 +137,18 @@ class JaggedTensor: def edim(self) -> int: ... @property def requires_grad(self) -> bool: ... - @staticmethod def from_data_and_indices(data: torch.Tensor, indices: torch.Tensor, num_tensors: int) -> JaggedTensor: ... - @staticmethod - def from_data_indices_and_list_ids(data: torch.Tensor, indices: torch.Tensor, list_ids: torch.Tensor, num_tensors: int) -> JaggedTensor: ... - + def from_data_indices_and_list_ids( + data: torch.Tensor, indices: torch.Tensor, list_ids: torch.Tensor, num_tensors: int + ) -> JaggedTensor: ... @staticmethod def from_data_and_offsets(data: torch.Tensor, offsets: torch.Tensor) -> JaggedTensor: ... - @staticmethod - def from_data_offsets_and_list_ids(data: torch.Tensor, offsets: torch.Tensor, list_ids: torch.Tensor) -> JaggedTensor: ... - + def from_data_offsets_and_list_ids( + data: torch.Tensor, offsets: torch.Tensor, list_ids: torch.Tensor + ) -> JaggedTensor: ... JaggedTensorOrTensor = Union[torch.Tensor, JaggedTensor] @@ -212,7 +213,6 @@ class GridBatch: def total_bbox(self) -> torch.IntTensor: ... @property def address(self) -> int: ... - def voxel_size_at(self, bi: int) -> torch.FloatTensor: ... def origin_at(self, bi: int) -> torch.FloatTensor: ... def num_voxels_at(self, bi: int) -> int: ... @@ -221,71 +221,157 @@ class GridBatch: def cum_enabled_voxels_at(self, bi: int) -> int: ... def bbox_at(self, bi: int) -> torch.IntTensor: ... def dual_bbox_at(self, bi: int) -> torch.IntTensor: ... - def jagged_like(self, data: torch.Tensor, ignore_disabled: bool = ...) -> JaggedTensor: ... - def set_global_origin(self, origin: Vec3d) -> None: ... def set_global_voxel_size(self, voxel_size: Vec3dOrScalar) -> None: ... - - def set_from_dense_grid(self, num_grids: int, dense_dims: Vec3i, ijk_min: Vec3i = ..., voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ..., mask: Optional[torch.Tensor] = ...) -> None: ... - def set_from_ijk(self, ijk: JaggedTensorOrTensor, pad_min: Vec3i = ..., pad_max: Vec3i = ..., voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ...) -> None: ... - def set_from_nearest_voxels_to_points(self, points: JaggedTensorOrTensor, voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ...) -> None: ... - def set_from_points(self, points: JaggedTensorOrTensor, pad_min: Vec3i = ..., pad_max: Vec3i = ..., voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ...) -> None: ... - def set_from_mesh(self, mesh_vertices: JaggedTensorOrTensor, mesh_faces: JaggedTensorOrTensor, voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ...) -> None: ... - + def set_from_dense_grid( + self, + num_grids: int, + dense_dims: Vec3i, + ijk_min: Vec3i = ..., + voxel_sizes: Vec3dBatchOrScalar = ..., + origins: Vec3dBatch = ..., + mask: Optional[torch.Tensor] = ..., + ) -> None: ... + def set_from_ijk( + self, + ijk: JaggedTensorOrTensor, + pad_min: Vec3i = ..., + pad_max: Vec3i = ..., + voxel_sizes: Vec3dBatchOrScalar = ..., + origins: Vec3dBatch = ..., + ) -> None: ... + def set_from_nearest_voxels_to_points( + self, points: JaggedTensorOrTensor, voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ... + ) -> None: ... + def set_from_points( + self, + points: JaggedTensorOrTensor, + pad_min: Vec3i = ..., + pad_max: Vec3i = ..., + voxel_sizes: Vec3dBatchOrScalar = ..., + origins: Vec3dBatch = ..., + ) -> None: ... + def set_from_mesh( + self, + mesh_vertices: JaggedTensorOrTensor, + mesh_faces: JaggedTensorOrTensor, + voxel_sizes: Vec3dBatchOrScalar = ..., + origins: Vec3dBatch = ..., + ) -> None: ... def read_from_dense(self, dense_data: torch.Tensor, dense_origins: Vec3iBatch = ...) -> JaggedTensor: ... - def read_into_dense(self, sparse_data: JaggedTensorOrTensor, min_coord: Optional[Vec3iBatch] = ..., grid_size: Optional[Vec3i] = ...) -> torch.Tensor: ... - - def clip(self, features: JaggedTensorOrTensor, ijk_min: Vec3iBatch, ijk_max: Vec3iBatch) -> Tuple[JaggedTensor, GridBatch]: ... + def read_into_dense( + self, sparse_data: JaggedTensorOrTensor, min_coord: Optional[Vec3iBatch] = ..., grid_size: Optional[Vec3i] = ... + ) -> torch.Tensor: ... + def clip( + self, features: JaggedTensorOrTensor, ijk_min: Vec3iBatch, ijk_max: Vec3iBatch + ) -> Tuple[JaggedTensor, GridBatch]: ... def clipped_grid(self, ijk_min: Vec3iBatch, ijk_max: Vec3iBatch) -> GridBatch: ... - def dual_grid(self, exclude_border: bool = False) -> GridBatch: ... - - def fill_to_grid(self, features: JaggedTensor, other_grid: GridBatch, default_value: float = ...) -> JaggedTensor: ... - + def fill_to_grid( + self, features: JaggedTensor, other_grid: GridBatch, default_value: float = ... + ) -> JaggedTensor: ... def coarsened_grid(self, coarsening_factor: Vec3iOrScalar) -> GridBatch: ... def subdivided_grid(self, subdiv_factor: Vec3iOrScalar, mask: JaggedTensorOrTensor = ...) -> GridBatch: ... - - def max_pool(self, pool_factor: Vec3iOrScalar, data: JaggedTensorOrTensor, stride: Vec3iOrScalar = 0, coarse_grid: Optional[GridBatch] = None) -> Tuple[JaggedTensor, GridBatch]: ... - def avg_pool(self, pool_factor: Vec3iOrScalar, data: JaggedTensorOrTensor, stride: Vec3iOrScalar = 0, coarse_grid: Optional[GridBatch] = None) -> Tuple[JaggedTensor, GridBatch]: ... - def subdivide(self, subdiv_factor: Vec3iOrScalar, data: JaggedTensorOrTensor, mask: Optional[JaggedTensorOrTensor] = None, fine_grid: Optional[GridBatch] = None) -> Tuple[JaggedTensor, GridBatch]: ... - + def max_pool( + self, + pool_factor: Vec3iOrScalar, + data: JaggedTensorOrTensor, + stride: Vec3iOrScalar = 0, + coarse_grid: Optional[GridBatch] = None, + ) -> Tuple[JaggedTensor, GridBatch]: ... + def avg_pool( + self, + pool_factor: Vec3iOrScalar, + data: JaggedTensorOrTensor, + stride: Vec3iOrScalar = 0, + coarse_grid: Optional[GridBatch] = None, + ) -> Tuple[JaggedTensor, GridBatch]: ... + def subdivide( + self, + subdiv_factor: Vec3iOrScalar, + data: JaggedTensorOrTensor, + mask: Optional[JaggedTensorOrTensor] = None, + fine_grid: Optional[GridBatch] = None, + ) -> Tuple[JaggedTensor, GridBatch]: ... def disable_ijk(self, ijk: JaggedTensorOrTensor) -> None: ... def enable_ijk(self, ijk: JaggedTensorOrTensor) -> None: ... - def points_in_active_voxel(self, xyz: JaggedTensorOrTensor, ignore_disabled: bool = False) -> JaggedTensor: ... def coords_in_active_voxel(self, ijk: JaggedTensorOrTensor, ignore_disabled: bool = False) -> JaggedTensor: ... - def cubes_in_grid(self, cube_centers: JaggedTensorOrTensor, cube_min: Vec3dOrScalar = 0.0, cube_max: Vec3dOrScalar = 0.0, ignore_disabled: bool = False) -> JaggedTensor: ... - def cubes_intersect_grid(self, cube_centers: JaggedTensorOrTensor, cube_min: Vec3dOrScalar = 0.0, cube_max: Vec3dOrScalar = 0.0, ignore_disabled: bool = False) -> JaggedTensor: ... - + def cubes_in_grid( + self, + cube_centers: JaggedTensorOrTensor, + cube_min: Vec3dOrScalar = 0.0, + cube_max: Vec3dOrScalar = 0.0, + ignore_disabled: bool = False, + ) -> JaggedTensor: ... + def cubes_intersect_grid( + self, + cube_centers: JaggedTensorOrTensor, + cube_min: Vec3dOrScalar = 0.0, + cube_max: Vec3dOrScalar = 0.0, + ignore_disabled: bool = False, + ) -> JaggedTensor: ... def ijk_to_index(self, ijk: JaggedTensorOrTensor, cumulative: bool = False) -> JaggedTensor: ... def ijk_to_inv_index(self, ijk: JaggedTensorOrTensor, cumulative: bool = False) -> JaggedTensor: ... def neighbor_indexes(self, ijk: JaggedTensorOrTensor, extent: int, bitshift: int = 0) -> JaggedTensor: ... - - def splat_bezier(self, points: JaggedTensorOrTensor, points_data: JaggedTensorOrTensor) -> JaggedTensor: ... + def splat_bezier(self, points: JaggedTensorOrTensor, points_data: JaggedTensorOrTensor) -> JaggedTensor: ... def splat_trilinear(self, points: JaggedTensorOrTensor, points_data: JaggedTensorOrTensor) -> JaggedTensor: ... def sample_bezier(self, points: JaggedTensorOrTensor, voxel_data: JaggedTensorOrTensor) -> JaggedTensor: ... - def sample_bezier_with_grad(self, points: JaggedTensorOrTensor, voxel_data: JaggedTensorOrTensor) -> Tuple[JaggedTensor, JaggedTensor]: ... + def sample_bezier_with_grad( + self, points: JaggedTensorOrTensor, voxel_data: JaggedTensorOrTensor + ) -> Tuple[JaggedTensor, JaggedTensor]: ... def sample_trilinear(self, points: JaggedTensorOrTensor, voxel_data: JaggedTensorOrTensor) -> JaggedTensor: ... - def sample_trilinear_with_grad(self, points: JaggedTensorOrTensor, voxel_data: JaggedTensorOrTensor) -> Tuple[JaggedTensor, JaggedTensor]: ... - - - def segments_along_rays(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, max_segments: int, eps: float = 0.0, ignore_masked: bool = False) -> JaggedTensor: ... - def voxels_along_rays(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, max_voxels: int, eps: float = 0.0, return_ijk: bool = True, cumulative: bool = False) -> Tuple[JaggedTensor, JaggedTensor]: ... - def uniform_ray_samples(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, t_min: JaggedTensorOrTensor, t_max: JaggedTensorOrTensor, step_size: float, cone_angle: float = 0.0, include_end_segments : bool = True, return_midpoints: bool = False, eps: float = 0.0) -> JaggedTensor: ... - def ray_implicit_intersection(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, grid_scalars: JaggedTensorOrTensor, eps: float = 0.0) -> JaggedTensor: ... - + def sample_trilinear_with_grad( + self, points: JaggedTensorOrTensor, voxel_data: JaggedTensorOrTensor + ) -> Tuple[JaggedTensor, JaggedTensor]: ... + def segments_along_rays( + self, + ray_origins: JaggedTensorOrTensor, + ray_directions: JaggedTensorOrTensor, + max_segments: int, + eps: float = 0.0, + ignore_masked: bool = False, + ) -> JaggedTensor: ... + def voxels_along_rays( + self, + ray_origins: JaggedTensorOrTensor, + ray_directions: JaggedTensorOrTensor, + max_voxels: int, + eps: float = 0.0, + return_ijk: bool = True, + cumulative: bool = False, + ) -> Tuple[JaggedTensor, JaggedTensor]: ... + def uniform_ray_samples( + self, + ray_origins: JaggedTensorOrTensor, + ray_directions: JaggedTensorOrTensor, + t_min: JaggedTensorOrTensor, + t_max: JaggedTensorOrTensor, + step_size: float, + cone_angle: float = 0.0, + include_end_segments: bool = True, + return_midpoints: bool = False, + eps: float = 0.0, + ) -> JaggedTensor: ... + def ray_implicit_intersection( + self, + ray_origins: JaggedTensorOrTensor, + ray_directions: JaggedTensorOrTensor, + grid_scalars: JaggedTensorOrTensor, + eps: float = 0.0, + ) -> JaggedTensor: ... def grid_to_world(self, ijk: JaggedTensorOrTensor) -> JaggedTensor: ... def world_to_grid(self, ijk: JaggedTensorOrTensor) -> JaggedTensor: ... - - def marching_cubes(self, field: JaggedTensorOrTensor, level: float = 0.0) -> Tuple[JaggedTensor, JaggedTensor, JaggedTensor]: ... - - def sparse_conv_kernel_map(self, kernel_size: Union[int, Sequence], stride: Union[int, Sequence], target_grid: Optional[GridBatch] = None) -> Tuple[SparseConvPackInfo, GridBatch]: ... + def marching_cubes( + self, field: JaggedTensorOrTensor, level: float = 0.0 + ) -> Tuple[JaggedTensor, JaggedTensor, JaggedTensor]: ... + def sparse_conv_kernel_map( + self, kernel_size: Union[int, Sequence], stride: Union[int, Sequence], target_grid: Optional[GridBatch] = None + ) -> Tuple[SparseConvPackInfo, GridBatch]: ... def sparse_conv_halo(self, input: JaggedTensorOrTensor, weight: torch.Tensor, variant: int = 8) -> JaggedTensor: ... - def is_contiguous(self) -> bool: ... def contiguous(self) -> GridBatch: ... - @overload def to(self, device: TorchDeviceOrString) -> GridBatch: ... @overload @@ -294,7 +380,6 @@ class GridBatch: def to(self, to_jtensor: JaggedTensor) -> GridBatch: ... @overload def to(self, to_grid: GridBatch) -> GridBatch: ... - @overload def __getitem__(self, arg0: int) -> GridBatch: ... @overload @@ -307,26 +392,37 @@ class GridBatch: def __getitem__(self, arg0: torch.Tensor) -> GridBatch: ... @overload def __getitem__(self, arg0: numpy.ndarray) -> GridBatch: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[GridBatch]: ... - def __getstate__(self) -> tuple: ... def __setstate__(self, arg0: tuple) -> None: ... - class ConvPackBackend(Enum): GATHER_SCATTER = 0 IGEMM = 1 CUTLASS = 2 LGGS = 3 - class SparseConvPackInfo: - def __init__(self, kernel_size: Vec3iOrScalar, stride: Vec3iOrScalar, source_grid: GridBatch, target_grid: Optional[GridBatch]) -> None: ... - def sparse_conv_3d(self, input: JaggedTensorOrTensor, weights: torch.Tensor, backend: ConvPackBackend = ConvPackBackend.GATHER_SCATTER) -> JaggedTensor: ... - def sparse_transpose_conv_3d(self, input: JaggedTensorOrTensor, weights: torch.Tensor, backend: ConvPackBackend = ConvPackBackend.GATHER_SCATTER) -> JaggedTensor: ... + def __init__( + self, + kernel_size: Vec3iOrScalar, + stride: Vec3iOrScalar, + source_grid: GridBatch, + target_grid: Optional[GridBatch], + ) -> None: ... + def sparse_conv_3d( + self, + input: JaggedTensorOrTensor, + weights: torch.Tensor, + backend: ConvPackBackend = ConvPackBackend.GATHER_SCATTER, + ) -> JaggedTensor: ... + def sparse_transpose_conv_3d( + self, + input: JaggedTensorOrTensor, + weights: torch.Tensor, + backend: ConvPackBackend = ConvPackBackend.GATHER_SCATTER, + ) -> JaggedTensor: ... @property def kernel_size(self) -> Tuple: ... @property @@ -357,73 +453,126 @@ class SparseConvPackInfo: def halo_index_buffer(self) -> torch.Tensor: ... @property def output_index_buffer(self) -> torch.Tensor: ... - @property def block_kernel_ranges(self) -> torch.Tensor: ... @property def block_kernel_rel_out_idx(self) -> torch.Tensor: ... @property def block_kernel_in_idx(self) -> torch.Tensor: ... - @property def source_grid(self) -> GridBatch: ... @property def stride(self) -> Tuple: ... @property def target_grid(self) -> GridBatch: ... - def build_gather_scatter(self, use_me: bool = False) -> None: ... - def build_implicit_gemm(self, sorted: bool = False, split_mask_num: int = 1, - training: bool = False, split_mask_num_bwd: int = 1, - use_tf32: bool = False) -> None: ... + def build_implicit_gemm( + self, + sorted: bool = False, + split_mask_num: int = 1, + training: bool = False, + split_mask_num_bwd: int = 1, + use_tf32: bool = False, + ) -> None: ... def build_cutlass(self, benchmark: bool = False) -> None: ... def build_lggs(self) -> None: ... @overload def jcat(grid_batches: List[GridBatch]) -> GridBatch: ... - @overload def jcat(jagged_tensors: List[JaggedTensorOrTensor], dim: int | None = ...) -> JaggedTensor: ... - -def sparse_grid_from_ijk(ijk: JaggedTensorOrTensor, pad_min: Vec3i = ..., pad_max: Vec3i = ..., voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ..., mutable: bool = ...) -> GridBatch: ... -def sparse_grid_from_nearest_voxels_to_points(points: JaggedTensorOrTensor, voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ..., mutable: bool = ...) -> GridBatch: ... -def sparse_grid_from_points(points: JaggedTensorOrTensor, pad_min: Vec3i = ..., pad_max: Vec3i = ..., voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ..., mutable: bool = ...) -> GridBatch: ... -def sparse_grid_from_dense(num_grids: int, dense_dims: Vec3i, ijk_min: Vec3i = ..., voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ..., device: TorchDeviceOrString = ..., mutable: bool = ...) -> GridBatch: ... -def sparse_grid_from_mesh(vertices: JaggedTensorOrTensor, faces: JaggedTensorOrTensor, voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ..., device: TorchDeviceOrString = ..., mutable: bool = ...) -> GridBatch: ... - -def volume_render(sigmas: torch.Tensor, rgbs: torch.Tensor, deltaTs: torch.Tensor, ts: torch.Tensor, packInfo: torch.Tensor, transmittanceThresh: float) -> List[torch.Tensor]: ... - -def load(path: str, grid_id: Optional[GridIdentifier] = None, device: TorchDeviceOrString = 'cpu', verbose: bool = False) -> Tuple[GridBatch, JaggedTensor, list[str]]: ... -def save(path: str, grid: GridBatch, data: Optional[JaggedTensorOrTensor] = None, names: Optional[Union[str , List[str]]] = None, compressed: bool = False, verbose: bool = False): ... - - -def jrand(lsizes: LShapeSpec, - rsizes: RShapeSpec | None = None, - dtype: torch.dtype | None = None, - device: TorchDeviceOrString | None = None, - requires_grad: bool = False, - pin_memory: bool = False) -> JaggedTensor: ... -def jrandn(lsizes: LShapeSpec, - rsizes: RShapeSpec | None = None, - dtype: torch.dtype | None = None, - device: TorchDeviceOrString | None = None, - requires_grad: bool = False, - pin_memory: bool = False) -> JaggedTensor: ... -def jones(lsizes: LShapeSpec, - rsizes: RShapeSpec | None = None, - dtype: torch.dtype | None = None, - device: TorchDeviceOrString | None = None, - requires_grad: bool = False, - pin_memory: bool = False) -> JaggedTensor: ... -def jzeros(lsizes: LShapeSpec, - rsizes: RShapeSpec | None = None, - dtype: torch.dtype | None = None, - device: TorchDeviceOrString | None = None, - requires_grad: bool = False, - pin_memory: bool = False) -> JaggedTensor: ... -def jempty(lsizes: LShapeSpec, - rsizes: RShapeSpec | None = None, - dtype: torch.dtype | None = None, - device: TorchDeviceOrString | None = None, - requires_grad: bool = False, - pin_memory: bool = False) -> JaggedTensor: ... \ No newline at end of file +def sparse_grid_from_ijk( + ijk: JaggedTensorOrTensor, + pad_min: Vec3i = ..., + pad_max: Vec3i = ..., + voxel_sizes: Vec3dBatchOrScalar = ..., + origins: Vec3dBatch = ..., + mutable: bool = ..., +) -> GridBatch: ... +def sparse_grid_from_nearest_voxels_to_points( + points: JaggedTensorOrTensor, voxel_sizes: Vec3dBatchOrScalar = ..., origins: Vec3dBatch = ..., mutable: bool = ... +) -> GridBatch: ... +def sparse_grid_from_points( + points: JaggedTensorOrTensor, + pad_min: Vec3i = ..., + pad_max: Vec3i = ..., + voxel_sizes: Vec3dBatchOrScalar = ..., + origins: Vec3dBatch = ..., + mutable: bool = ..., +) -> GridBatch: ... +def sparse_grid_from_dense( + num_grids: int, + dense_dims: Vec3i, + ijk_min: Vec3i = ..., + voxel_sizes: Vec3dBatchOrScalar = ..., + origins: Vec3dBatch = ..., + device: TorchDeviceOrString = ..., + mutable: bool = ..., +) -> GridBatch: ... +def sparse_grid_from_mesh( + vertices: JaggedTensorOrTensor, + faces: JaggedTensorOrTensor, + voxel_sizes: Vec3dBatchOrScalar = ..., + origins: Vec3dBatch = ..., + device: TorchDeviceOrString = ..., + mutable: bool = ..., +) -> GridBatch: ... +def volume_render( + sigmas: torch.Tensor, + rgbs: torch.Tensor, + deltaTs: torch.Tensor, + ts: torch.Tensor, + packInfo: torch.Tensor, + transmittanceThresh: float, +) -> List[torch.Tensor]: ... +def load( + path: str, grid_id: Optional[GridIdentifier] = None, device: TorchDeviceOrString = "cpu", verbose: bool = False +) -> Tuple[GridBatch, JaggedTensor, list[str]]: ... +def save( + path: str, + grid: GridBatch, + data: Optional[JaggedTensorOrTensor] = None, + names: Optional[Union[str, List[str]]] = None, + compressed: bool = False, + verbose: bool = False, +): ... +def jrand( + lsizes: LShapeSpec, + rsizes: RShapeSpec | None = None, + dtype: torch.dtype | None = None, + device: TorchDeviceOrString | None = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> JaggedTensor: ... +def jrandn( + lsizes: LShapeSpec, + rsizes: RShapeSpec | None = None, + dtype: torch.dtype | None = None, + device: TorchDeviceOrString | None = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> JaggedTensor: ... +def jones( + lsizes: LShapeSpec, + rsizes: RShapeSpec | None = None, + dtype: torch.dtype | None = None, + device: TorchDeviceOrString | None = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> JaggedTensor: ... +def jzeros( + lsizes: LShapeSpec, + rsizes: RShapeSpec | None = None, + dtype: torch.dtype | None = None, + device: TorchDeviceOrString | None = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> JaggedTensor: ... +def jempty( + lsizes: LShapeSpec, + rsizes: RShapeSpec | None = None, + dtype: torch.dtype | None = None, + device: TorchDeviceOrString | None = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> JaggedTensor: ... diff --git a/fvdb/fvdb/__init__.py b/fvdb/fvdb/__init__.py index 0f84cbb3d9..10c1961447 100644 --- a/fvdb/fvdb/__init__.py +++ b/fvdb/fvdb/__init__.py @@ -11,18 +11,30 @@ from . import utils from ._Cpp import JaggedTensor, GridBatch, SparseConvPackInfo, ConvPackBackend -from ._Cpp import (jcat, sparse_grid_from_ijk, sparse_grid_from_points, - sparse_grid_from_nearest_voxels_to_points, - sparse_grid_from_dense, sparse_grid_from_mesh, load, save, - scaled_dot_product_attention, config, - jrand, jrandn, jones, jzeros, jempty) +from ._Cpp import ( + jcat, + sparse_grid_from_ijk, + sparse_grid_from_points, + sparse_grid_from_nearest_voxels_to_points, + sparse_grid_from_dense, + sparse_grid_from_mesh, + load, + save, + scaled_dot_product_attention, + config, + jrand, + jrandn, + jones, + jzeros, + jempty, +) # The following import needs to come after the GridBatch and JaggedTensor imports # immediately above in order to avoid a circular dependency error. from . import nn -__version__ = '0.0.1' +__version__ = "0.0.1" __version_info__ = (0, 0, 1) __all__ = [ diff --git a/fvdb/fvdb/nn/modules.py b/fvdb/fvdb/nn/modules.py index 830be512ec..7917bde6ac 100644 --- a/fvdb/fvdb/nn/modules.py +++ b/fvdb/fvdb/nn/modules.py @@ -16,9 +16,11 @@ def fvnn_module(module): # Register class as a module in fvdb.nn old_forward = module.forward + def _forward(self, *args, **kwargs): with record_function(repr(self)): return old_forward(self, *args, **kwargs) + module.forward = _forward return module @@ -26,6 +28,7 @@ def _forward(self, *args, **kwargs): GridOrVDBTensor = Union[fvdb.GridBatch, VDBTensor] ListOrInt = Union[int, List[int]] + @fvnn_module class MaxPool(nn.Module): r"""Applies a 3D max pooling over an input signal. @@ -45,8 +48,7 @@ def __init__(self, kernel_size: ListOrInt, stride: Optional[ListOrInt] = None): self.kernel_size = kernel_size self.stride = stride or self.kernel_size - def forward(self, input: VDBTensor, - ref_coarse_data: Optional[GridOrVDBTensor] = None) -> VDBTensor: + def forward(self, input: VDBTensor, ref_coarse_data: Optional[GridOrVDBTensor] = None) -> VDBTensor: if isinstance(ref_coarse_data, VDBTensor): coarse_grid, coarse_kmap = ref_coarse_data.grid, ref_coarse_data.kmap elif isinstance(ref_coarse_data, fvdb.GridBatch): @@ -55,16 +57,13 @@ def forward(self, input: VDBTensor, coarse_grid, coarse_kmap = None, None new_feature, new_grid = input.grid.max_pool( - self.kernel_size, input.feature, stride=self.stride, - coarse_grid=coarse_grid + self.kernel_size, input.feature, stride=self.stride, coarse_grid=coarse_grid ) new_feature.jdata[torch.isinf(new_feature.jdata)] = 0.0 return VDBTensor(new_grid, new_feature, kmap=coarse_kmap) def extra_repr(self) -> str: - return "kernel_size={kernel_size}, stride={stride}".format( - kernel_size=self.kernel_size, stride=self.stride - ) + return "kernel_size={kernel_size}, stride={stride}".format(kernel_size=self.kernel_size, stride=self.stride) @fvnn_module @@ -76,13 +75,13 @@ class AvgPool(nn.Module): stride: the stride of the window. Default value is :attr:`kernel_size` """ + def __init__(self, kernel_size: ListOrInt, stride: Optional[ListOrInt] = None): super().__init__() self.kernel_size = kernel_size self.stride = stride or self.kernel_size - def forward(self, input: VDBTensor, - ref_coarse_data: Optional[GridOrVDBTensor] = None) -> VDBTensor: + def forward(self, input: VDBTensor, ref_coarse_data: Optional[GridOrVDBTensor] = None) -> VDBTensor: if isinstance(ref_coarse_data, VDBTensor): coarse_grid, coarse_kmap = ref_coarse_data.grid, ref_coarse_data.kmap elif isinstance(ref_coarse_data, fvdb.GridBatch): @@ -91,15 +90,12 @@ def forward(self, input: VDBTensor, coarse_grid, coarse_kmap = None, None new_feature, new_grid = input.grid.avg_pool( - self.kernel_size, input.feature, stride=self.stride, - coarse_grid=coarse_grid + self.kernel_size, input.feature, stride=self.stride, coarse_grid=coarse_grid ) return VDBTensor(new_grid, new_feature, kmap=coarse_kmap) def extra_repr(self) -> str: - return "kernel_size={kernel_size}, stride={stride}".format( - kernel_size=self.kernel_size, stride=self.stride - ) + return "kernel_size={kernel_size}, stride={stride}".format(kernel_size=self.kernel_size, stride=self.stride) @fvnn_module @@ -109,13 +105,13 @@ class UpsamplingNearest(nn.Module): Args: scale_factor: the upsampling factor """ + def __init__(self, scale_factor: ListOrInt): super().__init__() self.scale_factor = scale_factor def forward( - self, input: VDBTensor, mask: Optional[JaggedTensor] = None, - ref_fine_data: Optional[GridOrVDBTensor] = None + self, input: VDBTensor, mask: Optional[JaggedTensor] = None, ref_fine_data: Optional[GridOrVDBTensor] = None ) -> VDBTensor: if isinstance(ref_fine_data, VDBTensor): fine_grid, fine_kmap = ref_fine_data.grid, ref_fine_data.kmap @@ -124,9 +120,7 @@ def forward( else: fine_grid, fine_kmap = None, None - new_feature, new_grid = input.grid.subdivide( - self.scale_factor, input.feature, mask, fine_grid=fine_grid - ) + new_feature, new_grid = input.grid.subdivide(self.scale_factor, input.feature, mask, fine_grid=fine_grid) return VDBTensor(new_grid, new_feature, kmap=fine_kmap) def extra_repr(self) -> str: @@ -141,6 +135,7 @@ class FillToGrid(nn.Module): Args: default_value: the default value to fill in the new grid. """ + def __init__(self, default_value: float = 0.0) -> None: super().__init__() self.default_value = default_value @@ -172,9 +167,21 @@ class SparseConv3d(nn.Module): """ CUTLASS_SUPPORTED_CHANNELS = [ - (32, 64), (64, 128), (128, 256), (32, 32), (64, 64), (128, 128), - (256, 256), (128, 64), (64, 32), (256, 128), (384, 256), (192, 128), - (256, 512), (512, 256), (512, 512) + (32, 64), + (64, 128), + (128, 256), + (32, 32), + (64, 64), + (128, 128), + (256, 256), + (128, 64), + (64, 32), + (256, 128), + (384, 256), + (192, 128), + (256, 512), + (512, 256), + (512, 512), ] """ @@ -200,7 +207,7 @@ def __init__( kernel_size: Union[int, Sequence] = 3, stride: Union[int, Sequence] = 1, bias: bool = True, - transposed: bool = False + transposed: bool = False, ) -> None: super().__init__() @@ -251,10 +258,7 @@ def extra_repr(self) -> str: return s.format(**self.__dict__) def reset_parameters(self) -> None: - std = 1 / math.sqrt( - (self.out_channels if self.transposed else self.in_channels) - * self.kernel_volume - ) + std = 1 / math.sqrt((self.out_channels if self.transposed else self.in_channels) * self.kernel_volume) self.weight.data.uniform_(-std, std) if self.bias is not None: self.bias.data.uniform_(-std, std) @@ -263,9 +267,12 @@ def _dispatch_conv(self, in_feature, in_grid, in_kmap, out_grid): backend = self.backend - if backend == "cutlass" and ((not self.weight.is_cuda) or - (self.in_channels, self.out_channels) not in self.CUTLASS_SUPPORTED_CHANNELS): - print(f"Cutlass backend does not support {self.in_channels} -> {self.out_channels} convolutions, falling back to default") + if backend == "cutlass" and ( + (not self.weight.is_cuda) or (self.in_channels, self.out_channels) not in self.CUTLASS_SUPPORTED_CHANNELS + ): + print( + f"Cutlass backend does not support {self.in_channels} -> {self.out_channels} convolutions, falling back to default" + ) backend = "default" if backend == "lggs" and ((self.in_channels, self.out_channels) not in [(128, 128)]): @@ -287,9 +294,7 @@ def _dispatch_conv(self, in_feature, in_grid, in_kmap, out_grid): min_coord = in_grid.ijk.jdata.min(axis=0).values # BWHDC -> BCDHW dense_feature = in_grid.read_into_dense(in_feature, min_coord=min_coord).permute(0, 4, 3, 2, 1) - dense_feature = torch.nn.functional.conv3d( - dense_feature, self.weight, padding=1, stride=1 - ) + dense_feature = torch.nn.functional.conv3d(dense_feature, self.weight, padding=1, stride=1) # BCDHW -> BWHDC dense_feature = dense_feature.permute(0, 4, 3, 2, 1).contiguous() dense_feature = in_grid.read_from_dense(dense_feature, dense_origins=min_coord) @@ -305,13 +310,9 @@ def _dispatch_conv(self, in_feature, in_grid, in_kmap, out_grid): else: if self.transposed: assert out_grid is not None - kmap, _ = out_grid.sparse_conv_kernel_map( - self.kernel_size, self.stride, in_grid - ) + kmap, _ = out_grid.sparse_conv_kernel_map(self.kernel_size, self.stride, in_grid) else: - kmap, out_grid = in_grid.sparse_conv_kernel_map( - self.kernel_size, self.stride, out_grid - ) + kmap, out_grid = in_grid.sparse_conv_kernel_map(self.kernel_size, self.stride, out_grid) out_kmap = kmap if can_cache else None @@ -335,17 +336,20 @@ def _build_kmap_and_convert_backend(self, kmap: fvdb.SparseConvPackInfo, backend elif backend == "igemm_mode0": kmap.build_implicit_gemm( - sorted=False, split_mask_num=1, training=self.training, split_mask_num_bwd=3, use_tf32=self.allow_tf32) + sorted=False, split_mask_num=1, training=self.training, split_mask_num_bwd=3, use_tf32=self.allow_tf32 + ) return fvdb.ConvPackBackend.IGEMM elif backend == "igemm_mode1": kmap.build_implicit_gemm( - sorted=True, split_mask_num=1, training=self.training, split_mask_num_bwd=3, use_tf32=self.allow_tf32) + sorted=True, split_mask_num=1, training=self.training, split_mask_num_bwd=3, use_tf32=self.allow_tf32 + ) return fvdb.ConvPackBackend.IGEMM elif backend == "igemm_mode2": kmap.build_implicit_gemm( - sorted=True, split_mask_num=3, training=self.training, split_mask_num_bwd=3, use_tf32=self.allow_tf32) + sorted=True, split_mask_num=3, training=self.training, split_mask_num_bwd=3, use_tf32=self.allow_tf32 + ) return fvdb.ConvPackBackend.IGEMM elif backend == "lggs": @@ -368,9 +372,7 @@ def forward( out_grid, out_kmap = in_grid, in_kmap else: - out_grid, out_feature, out_kmap = self._dispatch_conv( - in_feature, in_grid, in_kmap, out_grid - ) + out_grid, out_feature, out_kmap = self._dispatch_conv(in_feature, in_grid, in_kmap, out_grid) if self.bias is not None: out_feature.jdata = out_feature.jdata + self.bias @@ -383,6 +385,7 @@ class GroupNorm(nn.GroupNorm): r"""Applies Group Normalization over a VDBTensor. See :class:`~torch.nn.GroupNorm` for detailed information. """ + def forward(self, input: VDBTensor) -> VDBTensor: num_channels = input.feature.jdata.size(1) assert num_channels == self.num_channels, "Input feature should have the same number of channels as GroupNorm" @@ -393,13 +396,13 @@ def forward(self, input: VDBTensor) -> VDBTensor: result_data = torch.empty_like(flat_data) for b in range(num_batches): - feat = flat_data[flat_offsets[b]:flat_offsets[b+1]] + feat = flat_data[flat_offsets[b] : flat_offsets[b + 1]] if feat.size(0) != 0: feat = feat.transpose(0, 1).reshape(1, num_channels, -1) feat = super().forward(feat) feat = feat.reshape(num_channels, -1).transpose(0, 1) - result_data[flat_offsets[b]:flat_offsets[b+1]] = feat + result_data[flat_offsets[b] : flat_offsets[b + 1]] = feat return VDBTensor(input.grid, input.grid.jagged_like(result_data), input.kmap) @@ -409,6 +412,7 @@ class BatchNorm(nn.BatchNorm1d): r"""Applies Batch Normalization over a VDBTensor. See :class:`~torch.nn.BatchNorm1d` for detailed information. """ + def forward(self, input: VDBTensor) -> VDBTensor: num_channels = input.feature.jdata.size(1) assert num_channels == self.num_features, "Input feature should have the same number of channels as BatchNorm" @@ -420,7 +424,7 @@ def forward(self, input: VDBTensor) -> VDBTensor: class ElementwiseMixin: def forward(self, input: VDBTensor) -> VDBTensor: assert isinstance(input, VDBTensor), "Input should have type VDBTensor" - res = super().forward(input.feature.jdata) # type: ignore + res = super().forward(input.feature.jdata) # type: ignore return VDBTensor(input.grid, input.feature.jagged_like(res), input.kmap) diff --git a/fvdb/fvdb/nn/vdbtensor.py b/fvdb/fvdb/nn/vdbtensor.py index 4c0fd42ecf..9352a069ee 100644 --- a/fvdb/fvdb/nn/vdbtensor.py +++ b/fvdb/fvdb/nn/vdbtensor.py @@ -37,9 +37,11 @@ def __post_init__(self): if self.grid.total_voxels != self.feature.jdata.size(0): raise ValueError("grid and feature should have the same total voxel count") if self.kmap is not None: - if not (self.same_grid(self.kmap.source_grid, self.grid) and - self.same_grid(self.kmap.target_grid, self.grid) and - self.kmap.stride == (1, 1, 1)): + if not ( + self.same_grid(self.kmap.source_grid, self.grid) + and self.same_grid(self.kmap.target_grid, self.grid) + and self.kmap.stride == (1, 1, 1) + ): raise ValueError("kmap should operate on the same grid as this tensor") @staticmethod @@ -50,10 +52,10 @@ def type(self, arg0: torch.dtype): return VDBTensor(self.grid, self.feature.type(arg0)) def cpu(self): - return VDBTensor(self.grid.to('cpu'), self.feature.cpu()) + return VDBTensor(self.grid.to("cpu"), self.feature.cpu()) def cuda(self): - return VDBTensor(self.grid.to('cuda'), self.feature.cuda()) + return VDBTensor(self.grid.to("cuda"), self.feature.cuda()) def to(self, device: Any): return VDBTensor(self.grid.to(device), self.feature.to(device)) @@ -107,8 +109,8 @@ def cat(tensors: List[Union["VDBTensor", JaggedTensor, torch.Tensor]], dim: int assert len(tensors) > 0, "At least one tensor should be provided" if dim == 0: assert all(isinstance(t, VDBTensor) for t in tensors), "All tensors should be of type VDBTensor" - new_grid = fvdb.jcat([t.grid for t in tensors]) # type: ignore - new_feature = new_grid.jagged_like(torch.cat([t.feature.jdata for t in tensors])) # type: ignore + new_grid = fvdb.jcat([t.grid for t in tensors]) # type: ignore + new_feature = new_grid.jagged_like(torch.cat([t.feature.jdata for t in tensors])) # type: ignore return VDBTensor(new_grid, new_feature) else: return VDBTensor._feature_ops(lambda *t: torch.cat(t, dim=dim), tensors) @@ -122,8 +124,12 @@ def from_dense(dense_feature: torch.Tensor, ijk_min=None, origins=None, voxel_si if ijk_min is None: ijk_min = [0, 0, 0] grid = fvdb.sparse_grid_from_dense( - dense_feature.size(0), dense_feature.size()[1:4], ijk_min=ijk_min, - voxel_sizes=voxel_sizes, origins=origins, device=dense_feature.device + dense_feature.size(0), + dense_feature.size()[1:4], + ijk_min=ijk_min, + voxel_sizes=voxel_sizes, + origins=origins, + device=dense_feature.device, ) # Note: this would map dense_feature[0, 0, 0] to grid[ijk_min] feature = grid.read_from_dense(dense_feature.contiguous(), dense_origins=ijk_min) diff --git a/fvdb/fvdb/utils/__init__.py b/fvdb/fvdb/utils/__init__.py index 7a7f44504a..7197162af3 100644 --- a/fvdb/fvdb/utils/__init__.py +++ b/fvdb/fvdb/utils/__init__.py @@ -1,4 +1,4 @@ # Copyright Contributors to the OpenVDB Project # SPDX-License-Identifier: MPL-2.0 # -from .._Cpp import volume_render \ No newline at end of file +from .._Cpp import volume_render diff --git a/fvdb/fvdb/utils/build_ext.py b/fvdb/fvdb/utils/build_ext.py index 70421eab54..44c5dc506c 100644 --- a/fvdb/fvdb/utils/build_ext.py +++ b/fvdb/fvdb/utils/build_ext.py @@ -30,29 +30,29 @@ def FVDBExtension(name, sources, *args, **kwargs): :return: A :class:`torch.utils.cpp_extension.CppExtension` object. """ - libraries = kwargs.get('libraries', []) - libraries.append('fvdb') - kwargs['libraries'] = libraries + libraries = kwargs.get("libraries", []) + libraries.append("fvdb") + kwargs["libraries"] = libraries - library_dirs = kwargs.get('library_dirs', []) + library_dirs = kwargs.get("library_dirs", []) library_dirs.append(os.path.dirname(fvdb.__file__)) - kwargs['library_dirs'] = library_dirs + kwargs["library_dirs"] = library_dirs - include_dirs = kwargs.get('include_dirs', []) - include_dirs.append(os.path.join(os.path.dirname(fvdb.__file__), 'include')) + include_dirs = kwargs.get("include_dirs", []) + include_dirs.append(os.path.join(os.path.dirname(fvdb.__file__), "include")) # We also need to add this because fvdb internally will refer to their headers without the fvdb/ prefix. - include_dirs.append(os.path.join(os.path.dirname(fvdb.__file__), 'include/fvdb')) - kwargs['include_dirs'] = include_dirs - - extra_link_args = kwargs.get('extra_link_args', []) - extra_link_args.append(f'-Wl,-rpath={os.path.dirname(fvdb.__file__)}') - kwargs['extra_link_args'] = extra_link_args - - extra_compile_args = kwargs.get('extra_compile_args', {}) - extra_compile_args['nvcc'] = extra_compile_args.get('nvcc', []) - if '--extended-lambda' not in extra_compile_args['nvcc']: - extra_compile_args['nvcc'].append('--extended-lambda') - kwargs['extra_compile_args'] = extra_compile_args + include_dirs.append(os.path.join(os.path.dirname(fvdb.__file__), "include/fvdb")) + kwargs["include_dirs"] = include_dirs + + extra_link_args = kwargs.get("extra_link_args", []) + extra_link_args.append(f"-Wl,-rpath={os.path.dirname(fvdb.__file__)}") + kwargs["extra_link_args"] = extra_link_args + + extra_compile_args = kwargs.get("extra_compile_args", {}) + extra_compile_args["nvcc"] = extra_compile_args.get("nvcc", []) + if "--extended-lambda" not in extra_compile_args["nvcc"]: + extra_compile_args["nvcc"].append("--extended-lambda") + kwargs["extra_compile_args"] = extra_compile_args return cpp_extension.CUDAExtension(name, sources, *args, **kwargs) diff --git a/fvdb/scripts/rename_wheels.py b/fvdb/scripts/rename_wheels.py index d868324a23..c0d2f9da46 100644 --- a/fvdb/scripts/rename_wheels.py +++ b/fvdb/scripts/rename_wheels.py @@ -15,11 +15,7 @@ wheel = os.path.basename(wheel) filename, ext = os.path.splitext(wheel) tags = filename.split("-") - new_filename = "-".join( - tags[:-4] - + [tags[-4] + "+" + "torch" + torch_version + "+" + cuda_version] - + tags[-3:] - ) + new_filename = "-".join(tags[:-4] + [tags[-4] + "+" + "torch" + torch_version + "+" + cuda_version] + tags[-3:]) new_filename += ext print(f"Renaming {wheel} -> {new_filename}") os.rename(os.path.join("dist", wheel), os.path.join("dist", new_filename)) diff --git a/fvdb/setup.py b/fvdb/setup.py index a82b8c71c0..6de3f20d4b 100644 --- a/fvdb/setup.py +++ b/fvdb/setup.py @@ -3,25 +3,32 @@ # import os import re -import subprocess import shutil -import requests -from tqdm import tqdm +import subprocess +import tarfile from pathlib import Path import git import git.repo -from git.exc import InvalidGitRepositoryError, GitCommandError -import tarfile +import requests +from git.exc import GitCommandError, InvalidGitRepositoryError from setuptools import setup from torch.utils import cpp_extension +from tqdm import tqdm + +is_conda_env = "CONDA_PREFIX" in os.environ +if is_conda_env: + os.environ["CXX"] = "x86_64-conda-linux-gnu-g++" + os.environ["NVCC_CCBIN"] = "x86_64-conda-linux-gnu-gcc" + def get_nanovdb_source_dir(): - nanovdb_source_dir = '../nanovdb' + nanovdb_source_dir = "../nanovdb" if not os.path.exists(nanovdb_source_dir): - nanovdb_source_dir = 'external/openvdb/nanovdb' + nanovdb_source_dir = "external/openvdb/nanovdb" return nanovdb_source_dir + class FVDBBuildCommand(cpp_extension.BuildExtension): @staticmethod @@ -38,7 +45,7 @@ def is_git_repo(repo_path: str): @staticmethod def download_external_dep(name: str, git_url: str, git_tag: str, recursive: bool = False): based = os.path.dirname(os.path.abspath(__file__)) - external_path = os.path.join(based, 'external') + external_path = os.path.join(based, "external") if not os.path.exists(external_path): os.makedirs(external_path, exist_ok=True) elif not os.path.isdir(external_path): @@ -53,7 +60,7 @@ def download_external_dep(name: str, git_url: str, git_tag: str, recursive: bool raise ValueError(f"A path {repo_path} exists but is not a git repo") else: if recursive: - repo = git.repo.Repo.clone_from(git_url, repo_path, multi_options=['--recursive']) + repo = git.repo.Repo.clone_from(git_url, repo_path, multi_options=["--recursive"]) else: repo = git.repo.Repo.clone_from(git_url, repo_path) repo.git.checkout(git_tag) @@ -62,89 +69,87 @@ def download_external_dep(name: str, git_url: str, git_tag: str, recursive: bool @staticmethod def build_cmake_project(base_path, cmake_args): - cmake_build_dir = os.path.join(base_path, 'build') - cmake_install_dir = os.path.join(base_path, 'install') + cmake_build_dir = os.path.join(base_path, "build") + cmake_install_dir = os.path.join(base_path, "install") os.makedirs(cmake_build_dir, exist_ok=True) os.makedirs(cmake_install_dir, exist_ok=True) - subprocess.check_call(['cmake', base_path, f'-DCMAKE_INSTALL_PREFIX={cmake_install_dir}'] + cmake_args, - cwd=cmake_build_dir) - subprocess.check_call(['cmake', '--build', '.', '--target', 'install'], - cwd=cmake_build_dir) + subprocess.check_call( + ["cmake", base_path, f"-DCMAKE_INSTALL_PREFIX={cmake_install_dir}"] + cmake_args, cwd=cmake_build_dir + ) + subprocess.check_call(["cmake", "--build", ".", "--target", "install"], cwd=cmake_build_dir) return cmake_install_dir def build_extension(self, _ext): - path = os.path.join(self.build_lib, 'fvdb') + path = os.path.join(self.build_lib, "fvdb") - if _ext.name == 'fvdb._Cpp': + if _ext.name == "fvdb._Cpp": _ext.library_dirs.append(path) super().build_extension(_ext) - if _ext.name == 'fvdb.fvdblib': - if os.path.exists(os.path.join(path, 'libfvdb.so')): - os.remove(os.path.join(path, 'libfvdb.so')) + if _ext.name == "fvdb.fvdblib": + if os.path.exists(os.path.join(path, "libfvdb.so")): + os.remove(os.path.join(path, "libfvdb.so")) # Find the .so file in the fvdb subdirectory of self.build_lib # assert that there is only a single one. - so_files = [os.path.join(path, t) for t in os.listdir(path) if t.endswith('.so') and t.startswith('fvdblib')] + so_files = [ + os.path.join(path, t) for t in os.listdir(path) if t.endswith(".so") and t.startswith("fvdblib") + ] assert len(so_files) == 1 # Copy the file in so_files[0] to lib/libfvdb.so - shutil.copy(so_files[0], os.path.join(path, 'libfvdb.so')) + shutil.copy(so_files[0], os.path.join(path, "libfvdb.so")) # Also copy the file to the appropriate directory if installing inplace if self.old_inplace: - build_py = self.get_finalized_command('build_py') - inplace_file, regular_file = self._get_inplace_equivalent(build_py, _ext) # type: ignore - inplace_file = os.path.join(os.path.dirname(inplace_file), 'libfvdb.so') - regular_file = os.path.join(os.path.dirname(regular_file), 'libfvdb.so') - self.copy_file(regular_file, inplace_file, level=self.verbose) # type: ignore + build_py = self.get_finalized_command("build_py") + inplace_file, regular_file = self._get_inplace_equivalent(build_py, _ext) # type: ignore + inplace_file = os.path.join(os.path.dirname(inplace_file), "libfvdb.so") + regular_file = os.path.join(os.path.dirname(regular_file), "libfvdb.so") + self.copy_file(regular_file, inplace_file, level=self.verbose) # type: ignore def run(self) -> None: # A sibling nanovdb source directory will exist if fvdb is being built as part of OpenVDB - sibling_nanovdb_dir = Path('../nanovdb') + sibling_nanovdb_dir = Path("../nanovdb") if not sibling_nanovdb_dir.exists(): openvdb_url = "https://github.com/kmuseth/openvdb.git" - self.download_external_dep( - name='openvdb', - git_url=openvdb_url, - git_tag='feature/nanovdb_v32.7' - ) + self.download_external_dep(name="openvdb", git_url=openvdb_url, git_tag="feature/nanovdb_v32.7") _, cutlass_repo = self.download_external_dep( - name='cutlass', - git_url='https://github.com/NVIDIA/cutlass.git', - git_tag='v3.4.0' + name="cutlass", git_url="https://github.com/NVIDIA/cutlass.git", git_tag="v3.4.0" ) try: # NOTE: In python <=3.8, __file__ will be a relative path and >3.8 it is an absolute path - cutlass_repo.git.apply(Path(__file__).resolve().parent / 'env' / 'cutlass.patch') + cutlass_repo.git.apply(Path(__file__).resolve().parent / "env" / "cutlass.patch") except GitCommandError as e: print(f"Failed to apply cutlass patch: {str(e)}, continuing without patching") self.download_external_dep( - name='cudnn_fe', - git_url='https://github.com/NVIDIA/cudnn-frontend', - git_tag='v1.3.0' + name="cudnn_fe", git_url="https://github.com/NVIDIA/cudnn-frontend", git_tag="v1.3.0" ) blosc_source_dir, _ = self.download_external_dep( - name='c-blosc', - git_url='https://github.com/Blosc/c-blosc.git', - git_tag='v1.21.4' + name="c-blosc", git_url="https://github.com/Blosc/c-blosc.git", git_tag="v1.21.4" + ) + self.build_cmake_project( + blosc_source_dir, + [ + "-DBUILD_SHARED=OFF", + "-DBUILD_TESTS=OFF", + "-DBUILD_FUZZERS=OFF", + "-DBUILD_BENCHMARKS=OFF", + "-DCMAKE_POSITION_INDEPENDENT_CODE=ON", + ], ) - self.build_cmake_project(blosc_source_dir, [ - "-DBUILD_SHARED=OFF", "-DBUILD_TESTS=OFF", "-DBUILD_FUZZERS=OFF", "-DBUILD_BENCHMARKS=OFF", - "-DCMAKE_POSITION_INDEPENDENT_CODE=ON" - ]) self.old_inplace = self.inplace super().run() # Find all the headers and copy them into the build directory. # This way extension modules of FVDB can include them. - fvdb_headers = get_header_files_recursive('src', 'fvdb') - nanovdb_headers = get_header_files_recursive(get_nanovdb_source_dir(), 'nanovdb') + fvdb_headers = get_header_files_recursive("src", "fvdb") + nanovdb_headers = get_header_files_recursive(get_nanovdb_source_dir(), "nanovdb") for header_folder, header_files in fvdb_headers + nanovdb_headers: os.makedirs(os.path.join(self.build_lib, header_folder), exist_ok=True) @@ -156,7 +161,7 @@ def run(self) -> None: def get_source_files_recursive(base_path, include_bindings=True): source_files = [] for dir_name, _, dir_files in os.walk(base_path): - if not include_bindings and os.path.basename(dir_name) == 'python': + if not include_bindings and os.path.basename(dir_name) == "python": continue cpp_files = [os.path.join(dir_name, t) for t in dir_files if t.endswith(".cpp")] cu_files = [os.path.join(dir_name, t) for t in dir_files if t.endswith(".cu")] @@ -165,11 +170,11 @@ def get_source_files_recursive(base_path, include_bindings=True): def get_header_files_recursive(base_path, new_path): - base_len = len(base_path.split('/')) + base_len = len(base_path.split("/")) source_files = [] for dir_name, _, dir_files in os.walk(base_path): header_files = [os.path.join(dir_name, t) for t in dir_files if t.endswith(".h") or t.endswith(".cuh")] - header_folder = [os.path.join('fvdb/include', new_path, *(h.split('/')[base_len:-1])) for h in header_files] + header_folder = [os.path.join("fvdb/include", new_path, *(h.split("/")[base_len:-1])) for h in header_files] # All items of header_folder should be the same if len(header_folder) > 0: @@ -179,8 +184,10 @@ def get_header_files_recursive(base_path, new_path): def download_and_install_cudnn(): - url = "https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/" + \ - "cudnn-linux-x86_64-9.1.0.70_cuda12-archive.tar.xz" + url = ( + "https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/" + + "cudnn-linux-x86_64-9.1.0.70_cuda12-archive.tar.xz" + ) cwd = os.path.dirname(os.path.abspath(__file__)) tar_filepath = os.path.join(cwd, "external/cudnn.tar.xz") folder_filepath = os.path.join(cwd, "external/cudnn") @@ -224,21 +231,21 @@ def download_and_install_cudnn(): if __name__ == "__main__": - if not os.path.exists('external'): - os.makedirs('external') + if not os.path.exists("external"): + os.makedirs("external") else: - assert os.path.isdir('external'), "external exists but is not a directory" + assert os.path.isdir("external"), "external exists but is not a directory" # Use new C++ standard for newer NVCC versions cuda_home = cpp_extension.CUDA_HOME cuda_version = None if cuda_home is not None: - cuda_version_str = subprocess.check_output([cuda_home + "/bin/nvcc", '--version']).strip().decode() - cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str) + cuda_version_str = subprocess.check_output([cuda_home + "/bin/nvcc", "--version"]).strip().decode() + cuda_version = re.search(r"release (\d+[.]\d+)", cuda_version_str) if cuda_version is not None: cuda_version = cuda_version.group(1) - if cuda_version is not None and int(cuda_version.split('.')[0]) >= 12: + if cuda_version is not None and int(cuda_version.split(".")[0]) >= 12: cpp_std = "c++20" cudnn_include_dirs, cudnn_static_libs = download_and_install_cudnn() else: @@ -247,54 +254,66 @@ def download_and_install_cudnn(): cwd = os.path.dirname(os.path.abspath(__file__)) cpp_flags = [ - f'-std={cpp_std}', - '-Wno-unknown-pragmas', - '-Wno-class-memaccess', - '-fdiagnostics-color=always', - '-DNANOVDB_USE_BLOSC', + f"-std={cpp_std}", + "-Wno-unknown-pragmas", + "-Wno-class-memaccess", + "-fdiagnostics-color=always", + "-DNANOVDB_USE_BLOSC", + ] + nvcc_flags = [ + f"-std={cpp_std}", + "--extended-lambda", + "--diag-suppress=186", + "-diag-suppress=3189", ] - nvcc_flags = [f'-std={cpp_std}', '--extended-lambda', '--diag-suppress=186', '-diag-suppress=3189'] user_nvcc_flags = os.getenv("NVCC_FLAGS", "").split() nvcc_flags += user_nvcc_flags lib_ext = cpp_extension.CUDAExtension( - name='fvdb.fvdblib', - sources=get_source_files_recursive('src', include_bindings=False), - include_dirs=[os.path.join(cwd, 'src'), - os.path.join(cwd, get_nanovdb_source_dir()), - os.path.join(cwd, 'external/cutlass/include'), - os.path.join(cwd, 'external/c-blosc/install/include'), - os.path.join(cwd, 'external/cudnn_fe/include')] + cudnn_include_dirs, - extra_objects=['external/c-blosc/install/lib/libblosc.a'] + cudnn_static_libs, - extra_compile_args={'cxx': cpp_flags + ['-fvisibility=default'], - 'nvcc': nvcc_flags}, - language='c++') + name="fvdb.fvdblib", + sources=get_source_files_recursive("src", include_bindings=False), + include_dirs=[ + os.path.join(cwd, "src"), + os.path.join(cwd, get_nanovdb_source_dir()), + os.path.join(cwd, "external/cutlass/include"), + os.path.join(cwd, "external/c-blosc/install/include"), + os.path.join(cwd, "external/cudnn_fe/include"), + ] + + cudnn_include_dirs, + extra_objects=["external/c-blosc/install/lib/libblosc.a"] + cudnn_static_libs, + extra_compile_args={"cxx": cpp_flags + ["-fvisibility=default"], "nvcc": nvcc_flags}, + language="c++", + ) bind_ext = cpp_extension.CUDAExtension( - name='fvdb._Cpp', - sources=get_source_files_recursive('src/python/'), - include_dirs=[os.path.join(cwd, 'src'), - os.path.join(cwd, get_nanovdb_source_dir()), - os.path.join(cwd, 'external/cutlass/include'), - os.path.join(cwd, 'external/c-blosc/install/include')], - library_dirs=[os.path.join(cwd, 'fvdb')], - libraries=['fvdb'], - extra_link_args=['-Wl,-rpath,$ORIGIN'], - extra_compile_args={'cxx': cpp_flags + ['-fvisibility=hidden'], - 'nvcc': nvcc_flags}, - language='c++') - - def retrieve_version(file_path = "fvdb/__init__.py"): + name="fvdb._Cpp", + sources=get_source_files_recursive("src/python/"), + include_dirs=[ + os.path.join(cwd, "src"), + os.path.join(cwd, get_nanovdb_source_dir()), + os.path.join(cwd, "external/cutlass/include"), + os.path.join(cwd, "external/c-blosc/install/include"), + ], + library_dirs=[os.path.join(cwd, "fvdb")], + libraries=["fvdb"], + extra_link_args=["-Wl,-rpath,$ORIGIN"], + extra_compile_args={"cxx": cpp_flags + ["-fvisibility=hidden"], "nvcc": nvcc_flags}, + language="c++", + ) + + def retrieve_version(file_path="fvdb/__init__.py"): with open(file_path, "r") as f: for line in f: if line.startswith("__version__"): return line.split("=")[1].strip().strip("'").strip('"') return "0.0.0" - setup(name='fvdb', - version = retrieve_version(), + setup( + name="fvdb", + version=retrieve_version(), ext_modules=[lib_ext, bind_ext], - packages=['fvdb', 'fvdb.nn', 'fvdb.utils'], + packages=["fvdb", "fvdb.nn", "fvdb.utils"], include_package_data=True, - package_data={'fvdb': ['_Cpp.pyi', 'py.typed']}, - cmdclass={'build_ext': FVDBBuildCommand}) + package_data={"fvdb": ["_Cpp.pyi", "py.typed"]}, + cmdclass={"build_ext": FVDBBuildCommand}, + ) diff --git a/fvdb/src/Config.cpp b/fvdb/src/Config.cpp index c0027f498e..8554a909c3 100644 --- a/fvdb/src/Config.cpp +++ b/fvdb/src/Config.cpp @@ -7,23 +7,28 @@ namespace fvdb { Config::Config() = default; -Config& Config::global() { +Config & +Config::global() { static Config _config; return _config; } -void Config::setUltraSparseAcceleration(bool enabled) { +void +Config::setUltraSparseAcceleration(bool enabled) { mUltraSparseAcceleration = enabled; } -bool Config::ultraSparseAccelerationEnabled() const { +bool +Config::ultraSparseAccelerationEnabled() const { return mUltraSparseAcceleration; } -void Config::setPendanticErrorChecking(bool enabled) { +void +Config::setPendanticErrorChecking(bool enabled) { mPendanticErrorChecking = enabled; } -bool Config::pendanticErrorCheckingEnabled() const { +bool +Config::pendanticErrorCheckingEnabled() const { return mPendanticErrorChecking; } diff --git a/fvdb/src/Config.h b/fvdb/src/Config.h index 056cd5fc41..6c0677a5d8 100644 --- a/fvdb/src/Config.h +++ b/fvdb/src/Config.h @@ -1,14 +1,13 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once - +#ifndef FVDB_CONFIG_H +#define FVDB_CONFIG_H namespace fvdb { class Config { - -public: + public: Config(); void setUltraSparseAcceleration(bool enabled); @@ -17,11 +16,13 @@ class Config { void setPendanticErrorChecking(bool enabled); bool pendanticErrorCheckingEnabled() const; - static Config& global(); + static Config &global(); -private: + private: bool mUltraSparseAcceleration = false; - bool mPendanticErrorChecking = false; + bool mPendanticErrorChecking = false; }; -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_CONFIG_H \ No newline at end of file diff --git a/fvdb/src/FVDB.cpp b/fvdb/src/FVDB.cpp index c77318ac83..955aea8844 100644 --- a/fvdb/src/FVDB.cpp +++ b/fvdb/src/FVDB.cpp @@ -13,19 +13,18 @@ namespace fvdb { -std::vector volumeRender(const torch::Tensor& sigmas, const torch::Tensor& rgbs, - const torch::Tensor& deltaTs, const torch::Tensor& ts, - const torch::Tensor& jOffsets, double transmittanceThresh) { - return detail::autograd::VolumeRender::apply(sigmas, rgbs, deltaTs, ts, jOffsets, transmittanceThresh); +std::vector +volumeRender(const torch::Tensor &sigmas, const torch::Tensor &rgbs, const torch::Tensor &deltaTs, + const torch::Tensor &ts, const torch::Tensor &jOffsets, double transmittanceThresh) { + return detail::autograd::VolumeRender::apply(sigmas, rgbs, deltaTs, ts, jOffsets, + transmittanceThresh); } -JaggedTensor scaledDotProductAttention(const JaggedTensor& query, - const JaggedTensor& key, - const JaggedTensor& value, - float scale) { - - cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); - const int computeCapability = p->major * 10 + p->minor; +JaggedTensor +scaledDotProductAttention(const JaggedTensor &query, const JaggedTensor &key, + const JaggedTensor &value, float scale) { + cudaDeviceProp *p = at::cuda::getDeviceProperties(query.device().index()); + const int computeCapability = p->major * 10 + p->minor; if (computeCapability < 90) { // https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -33,134 +32,126 @@ JaggedTensor scaledDotProductAttention(const JaggedTensor& query, // - key: (N, ..., S, E) // - value: (N, ..., S, V) std::vector outList; - torch::Tensor qOffsets = query.joffsets().cpu(); - torch::Tensor kvOffsets = key.joffsets().cpu(); + torch::Tensor qOffsets = query.joffsets().cpu(); + torch::Tensor kvOffsets = key.joffsets().cpu(); for (int64_t b = 0; b < query.num_tensors(); ++b) { - int64_t qStart = qOffsets[b].item(); - int64_t qEnd = qOffsets[b+1].item(); + int64_t qStart = qOffsets[b].item(); + int64_t qEnd = qOffsets[b + 1].item(); int64_t kvStart = kvOffsets[b].item(); - int64_t kvEnd = kvOffsets[b+1].item(); - - torch::Tensor q = query.jdata().index({torch::indexing::Slice(qStart, qEnd)}).permute({1, 0, 2}); - torch::Tensor k = key.jdata().index({torch::indexing::Slice(kvStart, kvEnd)}).permute({1, 0, 2}); - torch::Tensor v = value.jdata().index({torch::indexing::Slice(kvStart, kvEnd)}).permute({1, 0, 2}); - - torch::Tensor out = at::native::scaled_dot_product_attention(q, k, v, {}, 0.0, false, scale); - outList.push_back(out.permute({1, 0, 2})); + int64_t kvEnd = kvOffsets[b + 1].item(); + + torch::Tensor q = + query.jdata().index({ torch::indexing::Slice(qStart, qEnd) }).permute({ 1, 0, 2 }); + torch::Tensor k = + key.jdata().index({ torch::indexing::Slice(kvStart, kvEnd) }).permute({ 1, 0, 2 }); + torch::Tensor v = value.jdata() + .index({ torch::indexing::Slice(kvStart, kvEnd) }) + .permute({ 1, 0, 2 }); + + torch::Tensor out = + at::native::scaled_dot_product_attention(q, k, v, {}, 0.0, false, scale); + outList.push_back(out.permute({ 1, 0, 2 })); } return JaggedTensor(outList); } // Custom implementation with CUDNN is only available for Hopper. - torch::Tensor qLengths = query.joffsets().index({torch::indexing::Slice(1, query.num_tensors())}); - torch::Tensor kvLengths = key.joffsets().index({torch::indexing::Slice(1, query.num_tensors())}); + torch::Tensor qLengths = + query.joffsets().index({ torch::indexing::Slice(1, query.num_tensors()) }); + torch::Tensor kvLengths = + key.joffsets().index({ torch::indexing::Slice(1, query.num_tensors()) }); torch::Tensor res = detail::autograd::Attention::apply( query.jdata(), key.jdata(), value.jdata(), qLengths, kvLengths, scale)[0]; return query.jagged_like(res); } std::tuple> -from_nanovdb(nanovdb::GridHandle& handle){ +from_nanovdb(nanovdb::GridHandle &handle) { return detail::io::fromNVDB(handle); } nanovdb::GridHandle -to_nanovdb(const GridBatch& gridBatch, - const torch::optional maybeData, - const torch::optional maybeNames){ +to_nanovdb(const GridBatch &gridBatch, const torch::optional maybeData, + const torch::optional maybeNames) { return detail::io::toNVDB(gridBatch, maybeData, maybeNames); } - -GridBatch jcat(const std::vector& vec) { - std::vector> vecHdls; - std::transform(vec.begin(), vec.end(), std::back_inserter(vecHdls), - [](const GridBatch& grid) { return grid.impl(); }); - return GridBatch(detail::GridBatchImpl::concatenate(vecHdls)); +GridBatch +jcat(const std::vector &vec) { + std::vector> vecHdls; + std::transform(vec.begin(), vec.end(), std::back_inserter(vecHdls), + [](const GridBatch &grid) { return grid.impl(); }); + return GridBatch(detail::GridBatchImpl::concatenate(vecHdls)); } -JaggedTensor jcat(const std::vector& vec, torch::optional dim) { +JaggedTensor +jcat(const std::vector &vec, torch::optional dim) { return JaggedTensor::jcat(vec, dim); } -void save(const std::string& path, - const GridBatch& gridBatch, - const torch::optional maybeData, - const torch::optional maybeNames, - bool compressed, - bool verbose) { +void +save(const std::string &path, const GridBatch &gridBatch, + const torch::optional maybeData, + const torch::optional maybeNames, bool compressed, bool verbose) { detail::io::saveNVDB(path, gridBatch, maybeData, maybeNames, compressed, verbose); } - std::tuple> -load(const std::string& path, - NanoVDBFileGridIdentifier gridIdentifier, - TorchDeviceOrString device, +load(const std::string &path, NanoVDBFileGridIdentifier gridIdentifier, TorchDeviceOrString device, bool verbose) { return detail::io::loadNVDB(path, gridIdentifier, device, verbose); } -GridBatch sparse_grid_from_points(const JaggedTensor& points, - const Vec3i& pad_min, - const Vec3i& pad_max, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins, - bool is_mutable) { +GridBatch +sparse_grid_from_points(const JaggedTensor &points, const Vec3i &pad_min, const Vec3i &pad_max, + const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins, + bool is_mutable) { auto ret = GridBatch(points.device(), is_mutable); ret.set_from_points(points, pad_min, pad_max, voxel_sizes, origins); return ret; } - -GridBatch sparse_grid_from_ijk(const JaggedTensor& ijk, - const Vec3i& pad_min, - const Vec3i& pad_max, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins, - bool is_mutable) { +GridBatch +sparse_grid_from_ijk(const JaggedTensor &ijk, const Vec3i &pad_min, const Vec3i &pad_max, + const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins, + bool is_mutable) { auto ret = GridBatch(ijk.device(), is_mutable); ret.set_from_ijk(ijk, pad_min, pad_max, voxel_sizes, origins); return ret; } - -GridBatch sparse_grid_from_nearest_voxels_to_points(const JaggedTensor& points, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins, - bool is_mutable) { +GridBatch +sparse_grid_from_nearest_voxels_to_points(const JaggedTensor &points, + const Vec3dBatchOrScalar &voxel_sizes, + const Vec3dBatch &origins, bool is_mutable) { auto ret = GridBatch(points.device(), is_mutable); ret.set_from_nearest_voxels_to_points(points, voxel_sizes, origins); return ret; } - -GridBatch sparse_grid_from_dense(const int64_t numGrids, - const Vec3i& denseDims, - const Vec3i& ijkMin, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins, - torch::optional mask, - TorchDeviceOrString device, bool is_mutable) { +GridBatch +sparse_grid_from_dense(const int64_t numGrids, const Vec3i &denseDims, const Vec3i &ijkMin, + const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins, + torch::optional mask, TorchDeviceOrString device, + bool is_mutable) { auto ret = GridBatch(device, is_mutable); ret.set_from_dense_grid(numGrids, denseDims, ijkMin, voxel_sizes, origins, mask); return ret; } -GridBatch sparse_grid_from_mesh(const JaggedTensor& vertices, - const JaggedTensor& faces, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins, - bool is_mutable) { +GridBatch +sparse_grid_from_mesh(const JaggedTensor &vertices, const JaggedTensor &faces, + const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins, + bool is_mutable) { auto ret = GridBatch(vertices.device(), is_mutable); ret.set_from_mesh(vertices, faces, voxel_sizes, origins); return ret; } - -std::vector jdataShape1(const std::vector& lsizes, const std::vector& rsizes) { +std::vector +jdataShape1(const std::vector &lsizes, const std::vector &rsizes) { const int64_t totalElements = std::reduce(std::execution::par, lsizes.begin(), lsizes.end()); std::vector shape; shape.reserve(rsizes.size() + 1); @@ -169,17 +160,20 @@ std::vector jdataShape1(const std::vector& lsizes, const std:: return shape; } -std::tuple> jdataShape2(const std::vector>& lsizes, const std::vector& rsizes) { +std::tuple> +jdataShape2(const std::vector> &lsizes, const std::vector &rsizes) { std::vector elementCountsPerList; std::vector tensorCountsPerList; elementCountsPerList.reserve(lsizes.size()); tensorCountsPerList.reserve(lsizes.size()); - for (const auto& l : lsizes) { + for (const auto &l: lsizes) { elementCountsPerList.push_back(std::reduce(std::execution::par, l.begin(), l.end())); tensorCountsPerList.push_back(l.size()); } - const int64_t totalSize = std::reduce(std::execution::par, elementCountsPerList.begin(), elementCountsPerList.end()); - const int64_t totalTensors = std::reduce(std::execution::par, tensorCountsPerList.begin(), tensorCountsPerList.end()); + const int64_t totalSize = + std::reduce(std::execution::par, elementCountsPerList.begin(), elementCountsPerList.end()); + const int64_t totalTensors = + std::reduce(std::execution::par, tensorCountsPerList.begin(), tensorCountsPerList.end()); std::vector shape; shape.reserve(rsizes.size() + 1); shape.push_back(totalSize); @@ -188,19 +182,17 @@ std::tuple> jdataShape2(const std::vector& lsizes, \ - const std::vector rsizes, \ - at::TensorOptions options) { \ - auto shape = jdataShape1(lsizes, rsizes); \ - return JaggedTensor(lsizes, FNAME(shape, options)); \ - } \ - \ - JaggedTensor JFNAME(const std::vector>& lsizes, \ - const std::vector rsizes, \ - at::TensorOptions options) { \ - auto shape = jdataShape2(lsizes, rsizes); \ - return JaggedTensor(lsizes, std::get<0>(shape), FNAME(std::get<1>(shape), options)); \ +#define __FVDB__BUILDER(FNAME, JFNAME) \ + JaggedTensor JFNAME(const std::vector &lsizes, const std::vector rsizes, \ + at::TensorOptions options) { \ + auto shape = jdataShape1(lsizes, rsizes); \ + return JaggedTensor(lsizes, FNAME(shape, options)); \ + } \ + \ + JaggedTensor JFNAME(const std::vector> &lsizes, \ + const std::vector rsizes, at::TensorOptions options) { \ + auto shape = jdataShape2(lsizes, rsizes); \ + return JaggedTensor(lsizes, std::get<0>(shape), FNAME(std::get<1>(shape), options)); \ } __FVDB__BUILDER(torch::rand, jrand) diff --git a/fvdb/src/FVDB.h b/fvdb/src/FVDB.h index a0052da2c1..e41c27cc07 100644 --- a/fvdb/src/FVDB.h +++ b/fvdb/src/FVDB.h @@ -1,37 +1,37 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_FVDB_H +#define FVDB_FVDB_H -#include - -#include "JaggedTensor.h" #include "GridBatch.h" +#include "JaggedTensor.h" #include "SparseConvPackInfo.h" #include "Types.h" +#include + namespace fvdb { -std::vector volumeRender(const torch::Tensor& sigmas, const torch::Tensor& rgbs, - const torch::Tensor& deltaTs, const torch::Tensor& ts, - const torch::Tensor& packInfo, double transmittanceThresh); +std::vector volumeRender(const torch::Tensor &sigmas, const torch::Tensor &rgbs, + const torch::Tensor &deltaTs, const torch::Tensor &ts, + const torch::Tensor &packInfo, double transmittanceThresh); -JaggedTensor scaledDotProductAttention(const JaggedTensor& query, - const JaggedTensor& key, - const JaggedTensor& value, - float scale); +JaggedTensor scaledDotProductAttention(const JaggedTensor &query, const JaggedTensor &key, + const JaggedTensor &value, float scale); /// @brief Concatenate a list of grid batches into a single grid batch /// @param vec A list of grid batches to concatenate /// @return A GridBatch representing the concatenated grid batch -GridBatch jcat(const std::vector& vec); +GridBatch jcat(const std::vector &vec); /// @brief Concatenate a list of JaggedTensor into a single JaggedTensor /// @param vec A list of JaggedTensor to concatenate -/// @param dim The dimension to concatenate along or nullptr to concatenate the outermost tensor lists +/// @param dim The dimension to concatenate along or nullptr to concatenate the outermost tensor +/// lists /// @return A JaggedTensor representing the concatenated JaggedTensor -JaggedTensor jcat(const std::vector& vec, torch::optional dim = torch::nullopt); - +JaggedTensor jcat(const std::vector &vec, + torch::optional dim = torch::nullopt); /// @brief Create a JaggedTensor filled with random numbers from a uniform distribution /// on the interval [0, 1) with the specified lshape an rshape @@ -39,12 +39,10 @@ JaggedTensor jcat(const std::vector& vec, torch::optional /// @param rsizes The rshape of the JaggedTensor (feature dimension of each tensor) /// @param options The options to use for the created tensor /// @return A JaggedTensor filled with random numbers from the uniform distribution on [0, 1). -JaggedTensor jrand(const std::vector& lsizes, - const std::vector rsizes = {}, - at::TensorOptions options = {}); -JaggedTensor jrand(const std::vector>& lsizes, - const std::vector rsizes = {}, +JaggedTensor jrand(const std::vector &lsizes, const std::vector rsizes = {}, at::TensorOptions options = {}); +JaggedTensor jrand(const std::vector> &lsizes, + const std::vector rsizes = {}, at::TensorOptions options = {}); /// @brief Create a JaggedTensor filled with random numbers from a normal distribution /// with mean 0 and variance 1 (also called the standard normal distribution). @@ -52,183 +50,192 @@ JaggedTensor jrand(const std::vector>& lsizes, /// @param rsizes The rshape of the JaggedTensor (feature dimension of each tensor) /// @param options The options to use for the created tensor /// @return A JaggedTensor filled with random numbers from the standard normal distribution. -JaggedTensor jrandn(const std::vector& lsizes, - const std::vector rsizes = {}, - at::TensorOptions options = {}); -JaggedTensor jrandn(const std::vector>& lsizes, - const std::vector rsizes = {}, +JaggedTensor jrandn(const std::vector &lsizes, const std::vector rsizes = {}, at::TensorOptions options = {}); +JaggedTensor jrandn(const std::vector> &lsizes, + const std::vector rsizes = {}, at::TensorOptions options = {}); /// @brief Create a JaggedTensor filled with zeros. /// @param lsizes The lshape of the JaggedTensor (number of elements per tensor) /// @param rsizes The rshape of the JaggedTensor (feature dimension of each tensor) /// @param options The options to use for the created tensor /// @return A JaggedTensor filled with zeros. -JaggedTensor jzeros(const std::vector& lsizes, - const std::vector rsizes = {}, - at::TensorOptions options = {}); -JaggedTensor jzeros(const std::vector>& lsizes, - const std::vector rsizes = {}, +JaggedTensor jzeros(const std::vector &lsizes, const std::vector rsizes = {}, at::TensorOptions options = {}); +JaggedTensor jzeros(const std::vector> &lsizes, + const std::vector rsizes = {}, at::TensorOptions options = {}); /// @brief Create a JaggedTensor filled with ones. /// @param lsizes The lshape of the JaggedTensor (number of elements per tensor) /// @param rsizes The rshape of the JaggedTensor (feature dimension of each tensor) /// @param options The options to use for the created tensor /// @return A JaggedTensor filled with ones. -JaggedTensor jones(const std::vector& lsizes, - const std::vector rsizes = {}, - at::TensorOptions options = {}); -JaggedTensor jones(const std::vector>& lsizes, - const std::vector rsizes = {}, +JaggedTensor jones(const std::vector &lsizes, const std::vector rsizes = {}, at::TensorOptions options = {}); +JaggedTensor jones(const std::vector> &lsizes, + const std::vector rsizes = {}, at::TensorOptions options = {}); /// @brief Create an empty JaggedTensor with uninitialized values. /// @param lsizes The lshape of the JaggedTensor (number of elements per tensor) /// @param rsizes The rshape of the JaggedTensor (feature dimension of each tensor) /// @param options The options to use for the created tensor /// @return A JaggedTensor filled with uninitialized values. -JaggedTensor jempty(const std::vector& lsizes, - const std::vector rsizes = {}, - at::TensorOptions options = {}); -JaggedTensor jempty(const std::vector>& lsizes, - const std::vector rsizes = {}, +JaggedTensor jempty(const std::vector &lsizes, const std::vector rsizes = {}, at::TensorOptions options = {}); +JaggedTensor jempty(const std::vector> &lsizes, + const std::vector rsizes = {}, at::TensorOptions options = {}); /// @brief Return a grid batch with voxels which contain a point in an input set of point clouds /// (possibly padding each voxel containing a point) /// @param points A JaggedTensor with shape [B, -1, 3] containing one point set per grid to create -/// @param pad_min A tensor of shape [3,] containing the number of voxels to pad each inserted voxel with to the left/back/bottom -/// @param pad_max A tensor of shape [3,] containing the number of voxels to pad each inserted voxel with to the right/front/top -/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids -/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel +/// @param pad_min A tensor of shape [3,] containing the number of voxels to pad each inserted voxel +/// with to the left/back/bottom +/// @param pad_max A tensor of shape [3,] containing the number of voxels to pad each inserted voxel +/// with to the right/front/top +/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in +/// the batch or one voxel size for all grids +/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, +/// 0, 0] voxel /// for each grid in the batch, or one origin for all grids /// @param is_mutable Whether the grid should be mutable or not /// @return A GridBatch containing the created grid batch -GridBatch sparse_grid_from_points(const JaggedTensor& points, - const Vec3i& pad_min = torch::zeros({3}, torch::kInt32), - const Vec3i& pad_max = torch::zeros({3}, torch::kInt32), - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros({3}), - bool is_mutable = false); - - -/// @brief Return a grid batch with the eight nearest voxels to each point in an input set of point clouds +GridBatch sparse_grid_from_points(const JaggedTensor &points, + const Vec3i &pad_min = torch::zeros({ 3 }, torch::kInt32), + const Vec3i &pad_max = torch::zeros({ 3 }, torch::kInt32), + const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros({ 3 }), + bool is_mutable = false); + +/// @brief Return a grid batch with the eight nearest voxels to each point in an input set of point +/// clouds /// @param points A JaggedTensor with shape [B, -1, 3] containing one point set per grid to create -/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids -/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel +/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in +/// the batch or one voxel size for all grids +/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, +/// 0, 0] voxel /// for each grid in the batch, or one origin for all grids /// @param is_mutable Whether the grid should be mutable or not /// @return A GridBatch containing the created grid batch -GridBatch sparse_grid_from_nearest_voxels_to_points(const JaggedTensor& points, - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros({3}), - bool is_mutable = false); - +GridBatch sparse_grid_from_nearest_voxels_to_points(const JaggedTensor &points, + const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros({ 3 }), + bool is_mutable = false); /// @brief REturn a grid batch with the specified voxel coordinates (possibly with padding) -/// @param coords A JaggedTensor of shape [B, -1, 3] specifying the coordinates of each voxel to insert -/// @param pad_min A tensor of shape [3,] containing the number of voxels to pad each inserted voxel with to the left/back/bottom -/// @param pad_max A tensor of shape [3,] containing the number of voxels to pad each inserted voxel with to the right/front/top -/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids -/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel +/// @param coords A JaggedTensor of shape [B, -1, 3] specifying the coordinates of each voxel to +/// insert +/// @param pad_min A tensor of shape [3,] containing the number of voxels to pad each inserted voxel +/// with to the left/back/bottom +/// @param pad_max A tensor of shape [3,] containing the number of voxels to pad each inserted voxel +/// with to the right/front/top +/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in +/// the batch or one voxel size for all grids +/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, +/// 0, 0] voxel /// for each grid in the batch, or one origin for all grids /// @return A GridBatch containing the created grid batch -GridBatch sparse_grid_from_ijk(const JaggedTensor& ijk, - const Vec3i& pad_min = torch::zeros({3}, torch::kInt32), - const Vec3i& pad_max = torch::zeros({3}, torch::kInt32), - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros({3}), - bool is_mutable = false); - +GridBatch sparse_grid_from_ijk(const JaggedTensor &ijk, + const Vec3i &pad_min = torch::zeros({ 3 }, torch::kInt32), + const Vec3i &pad_max = torch::zeros({ 3 }, torch::kInt32), + const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros({ 3 }), + bool is_mutable = false); /// @brief Return a grid batch densely from ijkMin to ijkMin + size /// @param numGrids The number of grids to create in the batch /// @param denseDims The size of each dense grid (shape [3,] = [W, H, D]) /// @param ijkMin The minimum ijk coordinate of each dense grid in the batch (shape [3,]) -/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids -/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel +/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in +/// the batch or one voxel size for all grids +/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, +/// 0, 0] voxel /// for each grid in the batch, or one origin for all grids -/// @param mask Optional mask of shape [W, H, D] to specify voxels which are included in the dense grid. +/// @param mask Optional mask of shape [W, H, D] to specify voxels which are included in the dense +/// grid. /// Note that the same mask will be re-used for all the grids in the batch. /// @param device Which device to build the grid batch on /// @param mutable If the returned grid batch should be mutable /// @return A GridBatch containing a batch of dense grids -GridBatch sparse_grid_from_dense(const int64_t numGrids, - const Vec3i& denseDims, - const Vec3i& ijkMin, - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros({3}), - torch::optional mask = torch::nullopt, - TorchDeviceOrString device = torch::kCPU, - bool is_mutable = false); - - -/// @brief Return a grid batch from a jagged batch of triangle meshes (i.e. each voxel intersects the mesh) -/// @param vertices A JaggedTensor of shape [B, -1, 3] containing the vertices of each mesh in the batch +GridBatch sparse_grid_from_dense(const int64_t numGrids, const Vec3i &denseDims, + const Vec3i &ijkMin, const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros({ 3 }), + torch::optional mask = torch::nullopt, + TorchDeviceOrString device = torch::kCPU, bool is_mutable = false); + +/// @brief Return a grid batch from a jagged batch of triangle meshes (i.e. each voxel intersects +/// the mesh) +/// @param vertices A JaggedTensor of shape [B, -1, 3] containing the vertices of each mesh in the +/// batch /// @param faces A JaggedTensor of shape [B, -1, 3] containing the faces of each mesh in the batch -/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids -/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel +/// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in +/// the batch or one voxel size for all grids +/// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, +/// 0, 0] voxel /// for each grid in the batch, or one origin for all grids /// @param is_mutable Whether the grid should be mutable or not /// @return A GridBatch containing the created grid batch -GridBatch sparse_grid_from_mesh(const JaggedTensor& vertices, - const JaggedTensor& faces, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins, +GridBatch sparse_grid_from_mesh(const JaggedTensor &vertices, const JaggedTensor &faces, + const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins, bool is_mutable); /// @brief Return a grid batch, tensors of data, and names from a nanovdb grid handle /// @param handle nanovdb grid handle -/// @return A triple (gridbatch, data, names) where gridbatch is a GridBatch containing the converted grids, -/// data is a JaggedTensor containing the data of the grids, and names is a list of strings containing -/// the name of each grid +/// @return A triple (gridbatch, data, names) where gridbatch is a GridBatch containing the +/// converted grids, +/// data is a JaggedTensor containing the data of the grids, and names is a list of strings +/// containing the name of each grid std::tuple> -from_nanovdb(nanovdb::GridHandle& handle); - +from_nanovdb(nanovdb::GridHandle &handle); -/// @brief Return a nanovdb grid handle created from a grid batch, optional jagged tensor of data, and optional +/// @brief Return a nanovdb grid handle created from a grid batch, optional jagged tensor of data, +/// and optional /// list of names /// @param gridBatch The gridbatch to convert -/// @param maybeData Optional JaggedTensor of data to save with the grid batch (one element per voxel) -/// @param maybeNames Optional list of names for each grid in the batch (or a single name to use for every grid) -/// @return A nanovdb grid handle, whose type is inferred from the data, containing the converted grids +/// @param maybeData Optional JaggedTensor of data to save with the grid batch (one element per +/// voxel) +/// @param maybeNames Optional list of names for each grid in the batch (or a single name to use +/// for every grid) +/// @return A nanovdb grid handle, whose type is inferred from the data, containing the converted +/// grids nanovdb::GridHandle -to_nanovdb(const GridBatch& gridBatch, - const torch::optional maybeData = torch::optional(), - const torch::optional maybeNames = torch::optional()); - +to_nanovdb(const GridBatch &gridBatch, + const torch::optional maybeData = torch::optional(), + const torch::optional maybeNames = + torch::optional()); -/// @brief Save a grid batch and optional jagged tensor to a .nvdb file. Will overwrite existing files. +/// @brief Save a grid batch and optional jagged tensor to a .nvdb file. Will overwrite existing +/// files. /// @param path The path to save the file to. /// @param gridBatch The gridbatch to save -/// @param maybeData Optional JaggedTensor of data to save with the grid batch (one element per voxel) -/// @param maybeNames Optional list of names for each grid in the batch (or a single name to use for every grid) +/// @param maybeData Optional JaggedTensor of data to save with the grid batch (one element per +/// voxel) +/// @param maybeNames Optional list of names for each grid in the batch (or a single name to use for +/// every grid) /// @param compressed Whether to compress the stored grid using Blosc (https://www.blosc.org/) /// @param verbose Whether to print information about the saved grids -void save(const std::string& path, - const GridBatch& gridBatch, - const torch::optional maybeData = torch::optional(), - const torch::optional maybeNames = torch::optional(), - bool compressed = false, - bool verbose = false); - - -/// @brief Load a grid batch from a .nvdb file. This function loads each nanovdb grid into the batch as well +void save(const std::string &path, const GridBatch &gridBatch, + const torch::optional maybeData = torch::optional(), + const torch::optional maybeNames = + torch::optional(), + bool compressed = false, bool verbose = false); + +/// @brief Load a grid batch from a .nvdb file. This function loads each nanovdb grid into the batch +/// as well /// as a list of tensors containing the data at each grid in the batch /// (e.g. a Vec3d grid will load a [num_voxels, 3] float64 tensor) /// @param path The path to the .nvdb file to load -/// @param gridIdentifier The identifier (index, list of indices, name, list of names) to load from the file +/// @param gridIdentifier The identifier (index, list of indices, name, list of names) to load from +/// the file /// @param device Which device to load the grid batch on /// @param verbose If set to true, print information about the loaded grids -/// @return A triple (gridbatch, data, names) where gridbatch is a GridBatch containing the loaded grids, -/// data is a JaggedTensor containing the data of the grids, and names is a list of strings containing -/// the name of each grid +/// @return A triple (gridbatch, data, names) where gridbatch is a GridBatch containing the loaded +/// grids, +/// data is a JaggedTensor containing the data of the grids, and names is a list of strings +/// containing the name of each grid std::tuple> -load(const std::string& path, - NanoVDBFileGridIdentifier gridIdentifier, - TorchDeviceOrString device, +load(const std::string &path, NanoVDBFileGridIdentifier gridIdentifier, TorchDeviceOrString device, bool verbose = false); +} // namespace fvdb -} // namespace fvdb \ No newline at end of file +#endif // FVDB_FVDB_H \ No newline at end of file diff --git a/fvdb/src/GridBatch.cpp b/fvdb/src/GridBatch.cpp index f40edb19ce..2a2e0a90fd 100644 --- a/fvdb/src/GridBatch.cpp +++ b/fvdb/src/GridBatch.cpp @@ -5,12 +5,10 @@ #include "FVDB.h" #include "detail/GridBatchImpl.h" -#include "detail/build/Build.h" -#include "detail/ops/Ops.h" #include "detail/autograd/Autograd.h" +#include "detail/build/Build.h" #include "detail/io/IO.h" - - +#include "detail/ops/Ops.h" namespace fvdb { @@ -18,22 +16,22 @@ GridBatch::GridBatch(TorchDeviceOrString device, bool isMutable) { mImpl = c10::make_intrusive(device.value(), isMutable); } - GridBatch::GridBatch() { - mImpl = c10::make_intrusive(detail::build::buildEmptyGrid(torch::kCPU, false), nanovdb::Vec3d(1.0), nanovdb::Vec3d(0.0)); + mImpl = c10::make_intrusive( + detail::build::buildEmptyGrid(torch::kCPU, false), nanovdb::Vec3d(1.0), + nanovdb::Vec3d(0.0)); } - -std::pair GridBatch::max_pool(Vec3iOrScalar pool_factor, - const JaggedTensor& data, - Vec3iOrScalar stride, - torch::optional coarse_grid) const { - TORCH_CHECK_VALUE(data.ldim() == 1, - "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", data.ldim(), "list dimensions" - ); +std::pair +GridBatch::max_pool(Vec3iOrScalar pool_factor, const JaggedTensor &data, Vec3iOrScalar stride, + torch::optional coarse_grid) const { + TORCH_CHECK_VALUE( + data.ldim() == 1, + "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + data.ldim(), "list dimensions"); nanovdb::Coord pool_factor_coord = pool_factor.value(); - nanovdb::Coord stride_coord = stride.value(); + nanovdb::Coord stride_coord = stride.value(); for (int i = 0; i < 3; i += 1) { if (stride_coord[i] == 0) { @@ -51,23 +49,20 @@ std::pair GridBatch::max_pool(Vec3iOrScalar pool_factor torch::Tensor pool_data = detail::autograd::MaxPoolGrid::apply( impl(), coarse_grid_impl, pool_factor_coord, stride_coord, data.jdata())[0]; - return std::make_pair( - coarse_grid_impl->jaggedTensor(pool_data, false), - GridBatch(coarse_grid_impl) - ); + return std::make_pair(coarse_grid_impl->jaggedTensor(pool_data, false), + GridBatch(coarse_grid_impl)); } - -std::pair GridBatch::avg_pool(Vec3iOrScalar pool_factor, - const JaggedTensor& data, - Vec3iOrScalar stride, - torch::optional coarse_grid) const { - TORCH_CHECK_VALUE(data.ldim() == 1, - "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", data.ldim(), "list dimensions" - ); +std::pair +GridBatch::avg_pool(Vec3iOrScalar pool_factor, const JaggedTensor &data, Vec3iOrScalar stride, + torch::optional coarse_grid) const { + TORCH_CHECK_VALUE( + data.ldim() == 1, + "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + data.ldim(), "list dimensions"); nanovdb::Coord pool_factor_coord = pool_factor.value(); - nanovdb::Coord stride_coord = stride.value(); + nanovdb::Coord stride_coord = stride.value(); for (int i = 0; i < 3; i += 1) { if (stride_coord[i] == 0) { @@ -85,25 +80,23 @@ std::pair GridBatch::avg_pool(Vec3iOrScalar pool_factor torch::Tensor pool_data = detail::autograd::AvgPoolGrid::apply( impl(), coarse_grid_impl, pool_factor_coord, stride_coord, data.jdata())[0]; - return std::make_pair( - coarse_grid_impl->jaggedTensor(pool_data, false), - GridBatch(coarse_grid_impl) - ); + return std::make_pair(coarse_grid_impl->jaggedTensor(pool_data, false), + GridBatch(coarse_grid_impl)); } - -std::pair GridBatch::subdivide(Vec3iOrScalar subdiv_factor, - const JaggedTensor& data, - const torch::optional mask, - torch::optional fine_grid) const { - - TORCH_CHECK_VALUE(data.ldim() == 1, - "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", data.ldim(), "list dimensions" - ); +std::pair +GridBatch::subdivide(Vec3iOrScalar subdiv_factor, const JaggedTensor &data, + const torch::optional mask, + torch::optional fine_grid) const { + TORCH_CHECK_VALUE( + data.ldim() == 1, + "Expected data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + data.ldim(), "list dimensions"); if (mask.has_value()) { - TORCH_CHECK_VALUE(mask.value().ldim() == 1, - "Expected mask to have 1 list dimension, i.e. be a single list of coordinate values, but got", mask.value().ldim(), "list dimensions" - ); + TORCH_CHECK_VALUE( + mask.value().ldim() == 1, + "Expected mask to have 1 list dimension, i.e. be a single list of coordinate values, but got", + mask.value().ldim(), "list dimensions"); } const nanovdb::Coord upsampleFactorCoord = subdiv_factor.value(); @@ -114,66 +107,70 @@ std::pair GridBatch::subdivide(Vec3iOrScalar subdiv_fac fineGrid = subdivided_grid(subdiv_factor, mask).impl(); } - torch::Tensor subdivData = detail::autograd::UpsampleGrid::apply(impl(), fineGrid, upsampleFactorCoord, data.jdata())[0]; + torch::Tensor subdivData = detail::autograd::UpsampleGrid::apply( + impl(), fineGrid, upsampleFactorCoord, data.jdata())[0]; - return std::make_pair( - fineGrid->jaggedTensor(subdivData, false), - GridBatch(fineGrid) - ); + return std::make_pair(fineGrid->jaggedTensor(subdivData, false), GridBatch(fineGrid)); } - -JaggedTensor GridBatch::read_from_dense(const torch::Tensor& dense_data, - const Vec3iBatch& dense_origins) const { - torch::Tensor retData = detail::autograd::ReadFromDense::apply(impl(), dense_data, dense_origins)[0]; +JaggedTensor +GridBatch::read_from_dense(const torch::Tensor &dense_data, const Vec3iBatch &dense_origins) const { + torch::Tensor retData = + detail::autograd::ReadFromDense::apply(impl(), dense_data, dense_origins)[0]; return impl()->jaggedTensor(retData, false); } - -torch::Tensor GridBatch::read_into_dense(const JaggedTensor& sparse_data, - const torch::optional& min_coord, - const torch::optional& grid_size) const { - TORCH_CHECK_VALUE(sparse_data.ldim() == 1, - "Expected sparse_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", sparse_data.ldim(), "list dimensions" - ); - return detail::autograd::ReadIntoDense::apply(impl(), sparse_data.jdata(), min_coord, grid_size)[0]; -} - -JaggedTensor GridBatch::fill_to_grid(const JaggedTensor& features, - const GridBatch& other_grid, - float default_value) const { - TORCH_CHECK_VALUE(features.ldim() == 1, - "Expected features to have 1 list dimension, i.e. be a single list of coordinate values, but got", features.ldim(), "list dimensions" - ); +torch::Tensor +GridBatch::read_into_dense(const JaggedTensor &sparse_data, + const torch::optional &min_coord, + const torch::optional &grid_size) const { + TORCH_CHECK_VALUE( + sparse_data.ldim() == 1, + "Expected sparse_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + sparse_data.ldim(), "list dimensions"); + return detail::autograd::ReadIntoDense::apply(impl(), sparse_data.jdata(), min_coord, + grid_size)[0]; +} + +JaggedTensor +GridBatch::fill_to_grid(const JaggedTensor &features, const GridBatch &other_grid, + float default_value) const { + TORCH_CHECK_VALUE( + features.ldim() == 1, + "Expected features to have 1 list dimension, i.e. be a single list of coordinate values, but got", + features.ldim(), "list dimensions"); torch::Tensor retData = detail::autograd::FillToGrid::apply(other_grid.impl(), impl(), features.jdata(), default_value)[0]; return impl()->jaggedTensor(retData, false); } - -JaggedTensor GridBatch::grid_to_world(const JaggedTensor& ijk) const { - TORCH_CHECK_VALUE(ijk.ldim() == 1, - "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", ijk.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::grid_to_world(const JaggedTensor &ijk) const { + TORCH_CHECK_VALUE( + ijk.ldim() == 1, + "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ijk.ldim(), "list dimensions"); torch::Tensor ret = detail::autograd::TransformPoints::apply( impl(), ijk, ijk.jdata(), true /*isInverse*/, false /*isDual*/)[0]; return ijk.jagged_like(ret); } - -JaggedTensor GridBatch::world_to_grid(const JaggedTensor& xyz) const { - TORCH_CHECK_VALUE(xyz.ldim() == 1, - "Expected xyz to have 1 list dimension, i.e. be a single list of coordinate values, but got", xyz.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::world_to_grid(const JaggedTensor &xyz) const { + TORCH_CHECK_VALUE( + xyz.ldim() == 1, + "Expected xyz to have 1 list dimension, i.e. be a single list of coordinate values, but got", + xyz.ldim(), "list dimensions"); torch::Tensor ret = detail::autograd::TransformPoints::apply( impl(), xyz, xyz.jdata(), false /* isInverse*/, false /*isDual*/)[0]; return xyz.jagged_like(ret); } -torch::Tensor GridBatch::grid_to_world_matrices(const torch::Dtype& dtype) const { +torch::Tensor +GridBatch::grid_to_world_matrices(const torch::Dtype &dtype) const { std::vector retTorch; for (int64_t bi = 0; bi < grid_count(); ++bi) { retTorch.emplace_back(impl()->gridToWorldMatrix(bi)); @@ -182,7 +179,8 @@ torch::Tensor GridBatch::grid_to_world_matrices(const torch::Dtype& dtype) const return torch::stack(retTorch, 0).toType(dtype); } -torch::Tensor GridBatch::world_to_grid_matrices(const torch::Dtype& dtype) const { +torch::Tensor +GridBatch::world_to_grid_matrices(const torch::Dtype &dtype) const { std::vector retTorch; for (int64_t bi = 0; bi < grid_count(); ++bi) { retTorch.emplace_back(impl()->worldToGridMatrix(bi)); @@ -191,68 +189,81 @@ torch::Tensor GridBatch::world_to_grid_matrices(const torch::Dtype& dtype) const return torch::stack(retTorch, 0).toType(dtype); } -JaggedTensor GridBatch::sample_trilinear(const JaggedTensor& points, - const JaggedTensor& voxel_data) const { - TORCH_CHECK_VALUE(points.ldim() == 1, - "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", points.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(voxel_data.ldim() == 1, - "Expected voxel_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", voxel_data.ldim(), "list dimensions" - ); - torch::Tensor ret = detail::autograd::SampleGridTrilinear::apply(impl(), points, voxel_data.jdata(), false /*returnGrad*/)[0]; +JaggedTensor +GridBatch::sample_trilinear(const JaggedTensor &points, const JaggedTensor &voxel_data) const { + TORCH_CHECK_VALUE( + points.ldim() == 1, + "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + voxel_data.ldim() == 1, + "Expected voxel_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + voxel_data.ldim(), "list dimensions"); + torch::Tensor ret = detail::autograd::SampleGridTrilinear::apply( + impl(), points, voxel_data.jdata(), false /*returnGrad*/)[0]; return points.jagged_like(ret); } - -std::vector GridBatch::sample_trilinear_with_grad(const JaggedTensor& points, - const JaggedTensor& voxel_data) const { - TORCH_CHECK_VALUE(points.ldim() == 1, - "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", points.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(voxel_data.ldim() == 1, - "Expected voxel_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", voxel_data.ldim(), "list dimensions" - ); - std::vector ret = detail::autograd::SampleGridTrilinear::apply(impl(), points, voxel_data.jdata(), true /*returnGrad*/); - - return {points.jagged_like(ret[0]), points.jagged_like(ret[1])}; -} - - -JaggedTensor GridBatch::sample_bezier(const JaggedTensor& points, - const JaggedTensor& voxel_data) const { - TORCH_CHECK_VALUE(points.ldim() == 1, - "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", points.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(voxel_data.ldim() == 1, - "Expected voxel_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", voxel_data.ldim(), "list dimensions" - ); - torch::Tensor ret = detail::autograd::SampleGridBezier::apply(impl(), points, voxel_data.jdata(), false /*returnGrad*/)[0]; +std::vector +GridBatch::sample_trilinear_with_grad(const JaggedTensor &points, + const JaggedTensor &voxel_data) const { + TORCH_CHECK_VALUE( + points.ldim() == 1, + "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + voxel_data.ldim() == 1, + "Expected voxel_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + voxel_data.ldim(), "list dimensions"); + std::vector ret = detail::autograd::SampleGridTrilinear::apply( + impl(), points, voxel_data.jdata(), true /*returnGrad*/); + + return { points.jagged_like(ret[0]), points.jagged_like(ret[1]) }; +} + +JaggedTensor +GridBatch::sample_bezier(const JaggedTensor &points, const JaggedTensor &voxel_data) const { + TORCH_CHECK_VALUE( + points.ldim() == 1, + "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + voxel_data.ldim() == 1, + "Expected voxel_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + voxel_data.ldim(), "list dimensions"); + torch::Tensor ret = detail::autograd::SampleGridBezier::apply( + impl(), points, voxel_data.jdata(), false /*returnGrad*/)[0]; return points.jagged_like(ret); } - -std::vector GridBatch::sample_bezier_with_grad(const JaggedTensor& points, - const JaggedTensor& voxel_data) const { - TORCH_CHECK_VALUE(points.ldim() == 1, - "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", points.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(voxel_data.ldim() == 1, - "Expected voxel_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", voxel_data.ldim(), "list dimensions" - ); - auto ret = detail::autograd::SampleGridBezier::apply(impl(), points, voxel_data.jdata(), true /*returnGrad*/); - return {points.jagged_like(ret[0]), points.jagged_like(ret[1])}; -} - - -JaggedTensor GridBatch::splat_trilinear(const JaggedTensor& points, - const JaggedTensor& points_data) const { - TORCH_CHECK_VALUE(points.ldim() == 1, - "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", points.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(points_data.ldim() == 1, - "Expected points_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", points_data.ldim(), "list dimensions" - ); - torch::Tensor ret = detail::autograd::SplatIntoGridTrilinear::apply(impl(), points, points_data.jdata())[0]; +std::vector +GridBatch::sample_bezier_with_grad(const JaggedTensor &points, + const JaggedTensor &voxel_data) const { + TORCH_CHECK_VALUE( + points.ldim() == 1, + "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + voxel_data.ldim() == 1, + "Expected voxel_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + voxel_data.ldim(), "list dimensions"); + auto ret = detail::autograd::SampleGridBezier::apply(impl(), points, voxel_data.jdata(), + true /*returnGrad*/); + return { points.jagged_like(ret[0]), points.jagged_like(ret[1]) }; +} + +JaggedTensor +GridBatch::splat_trilinear(const JaggedTensor &points, const JaggedTensor &points_data) const { + TORCH_CHECK_VALUE( + points.ldim() == 1, + "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + points_data.ldim() == 1, + "Expected points_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points_data.ldim(), "list dimensions"); + torch::Tensor ret = + detail::autograd::SplatIntoGridTrilinear::apply(impl(), points, points_data.jdata())[0]; if (grid_count() == 1) { return JaggedTensor(ret); } else { @@ -260,16 +271,18 @@ JaggedTensor GridBatch::splat_trilinear(const JaggedTensor& points, } } - -JaggedTensor GridBatch::splat_bezier(const JaggedTensor& points, - const JaggedTensor& points_data) const { - TORCH_CHECK_VALUE(points.ldim() == 1, - "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", points.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(points_data.ldim() == 1, - "Expected points_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", points_data.ldim(), "list dimensions" - ); - torch::Tensor ret = detail::autograd::SplatIntoGridBezier::apply(impl(), points, points_data.jdata())[0]; +JaggedTensor +GridBatch::splat_bezier(const JaggedTensor &points, const JaggedTensor &points_data) const { + TORCH_CHECK_VALUE( + points.ldim() == 1, + "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + points_data.ldim() == 1, + "Expected points_data to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points_data.ldim(), "list dimensions"); + torch::Tensor ret = + detail::autograd::SplatIntoGridBezier::apply(impl(), points, points_data.jdata())[0]; if (grid_count() == 1) { return JaggedTensor(ret); } else { @@ -277,51 +290,58 @@ JaggedTensor GridBatch::splat_bezier(const JaggedTensor& points, } } - -torch::Tensor GridBatch::voxel_size_at(int64_t bi, const torch::Dtype& dtype) const { - torch::Tensor retTorch = torch::empty({3}, torch::TensorOptions().device(this->device()).dtype(dtype)); - const nanovdb::Vec3d& voxSize = impl()->voxelSize(bi); - retTorch[0] = voxSize[0]; - retTorch[1] = voxSize[1]; - retTorch[2] = voxSize[2]; +torch::Tensor +GridBatch::voxel_size_at(int64_t bi, const torch::Dtype &dtype) const { + torch::Tensor retTorch = + torch::empty({ 3 }, torch::TensorOptions().device(this->device()).dtype(dtype)); + const nanovdb::Vec3d &voxSize = impl()->voxelSize(bi); + retTorch[0] = voxSize[0]; + retTorch[1] = voxSize[1]; + retTorch[2] = voxSize[2]; return retTorch; } -torch::Tensor GridBatch::voxel_sizes(const torch::Dtype& dtype) const { - torch::Tensor retTorch = torch::empty({grid_count(), 3}, torch::TensorOptions().device(this->device()).dtype(dtype)); +torch::Tensor +GridBatch::voxel_sizes(const torch::Dtype &dtype) const { + torch::Tensor retTorch = torch::empty( + { grid_count(), 3 }, torch::TensorOptions().device(this->device()).dtype(dtype)); for (int64_t bi = 0; bi < grid_count(); bi += 1) { const nanovdb::Vec3d voxSize = impl()->voxelSize(bi); - retTorch[bi][0] = voxSize[0]; - retTorch[bi][1] = voxSize[1]; - retTorch[bi][2] = voxSize[2]; + retTorch[bi][0] = voxSize[0]; + retTorch[bi][1] = voxSize[1]; + retTorch[bi][2] = voxSize[2]; } return retTorch; } -torch::Tensor GridBatch::origin_at(int64_t bi, const torch::Dtype& dtype) const { - const nanovdb::Vec3d& voxelOrigin = impl()->voxelOrigin(bi); - torch::Tensor retTorch = torch::empty({3}, torch::TensorOptions().device(this->device()).dtype(dtype)); +torch::Tensor +GridBatch::origin_at(int64_t bi, const torch::Dtype &dtype) const { + const nanovdb::Vec3d &voxelOrigin = impl()->voxelOrigin(bi); + torch::Tensor retTorch = + torch::empty({ 3 }, torch::TensorOptions().device(this->device()).dtype(dtype)); retTorch[0] = voxelOrigin[0]; retTorch[1] = voxelOrigin[1]; retTorch[2] = voxelOrigin[2]; return retTorch; } - -torch::Tensor GridBatch::origins(const torch::Dtype& dtype) const { - torch::Tensor retTorch = torch::empty({grid_count(), 3}, torch::TensorOptions().device(this->device()).dtype(dtype)); +torch::Tensor +GridBatch::origins(const torch::Dtype &dtype) const { + torch::Tensor retTorch = torch::empty( + { grid_count(), 3 }, torch::TensorOptions().device(this->device()).dtype(dtype)); for (int64_t bi = 0; bi < grid_count(); bi += 1) { - const nanovdb::Vec3d& voxOrigin = impl()->voxelOrigin(bi); - retTorch[bi][0] = voxOrigin[0]; - retTorch[bi][1] = voxOrigin[1]; - retTorch[bi][2] = voxOrigin[2]; + const nanovdb::Vec3d &voxOrigin = impl()->voxelOrigin(bi); + retTorch[bi][0] = voxOrigin[0]; + retTorch[bi][1] = voxOrigin[1]; + retTorch[bi][2] = voxOrigin[2]; } return retTorch; } - -torch::Tensor GridBatch::num_voxels() const { - torch::Tensor retTorch = torch::empty({grid_count()}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); +torch::Tensor +GridBatch::num_voxels() const { + torch::Tensor retTorch = torch::empty( + { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); for (int64_t bi = 0; bi < grid_count(); bi += 1) { @@ -330,11 +350,13 @@ torch::Tensor GridBatch::num_voxels() const { return retTorch.to(device()); } -torch::Tensor GridBatch::num_enabled_voxels() const { +torch::Tensor +GridBatch::num_enabled_voxels() const { if (!is_mutable()) { return num_voxels(); } - torch::Tensor retTorch = torch::empty({grid_count()}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); + torch::Tensor retTorch = torch::empty( + { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); for (int64_t bi = 0; bi < grid_count(); bi += 1) { @@ -343,7 +365,8 @@ torch::Tensor GridBatch::num_enabled_voxels() const { return retTorch.to(device()); } -int64_t GridBatch::num_enabled_voxels_at(int64_t bi) const { +int64_t +GridBatch::num_enabled_voxels_at(int64_t bi) const { if (!is_mutable()) { return num_voxels_at(bi); } @@ -352,8 +375,10 @@ int64_t GridBatch::num_enabled_voxels_at(int64_t bi) const { }); } -torch::Tensor GridBatch::cum_voxels() const { - torch::Tensor retTorch = torch::empty({grid_count()}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); +torch::Tensor +GridBatch::cum_voxels() const { + torch::Tensor retTorch = torch::empty( + { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); for (int64_t bi = 0; bi < grid_count(); bi += 1) { @@ -362,11 +387,13 @@ torch::Tensor GridBatch::cum_voxels() const { return retTorch.to(device()); } -torch::Tensor GridBatch::cum_enabled_voxels() const { +torch::Tensor +GridBatch::cum_enabled_voxels() const { if (!is_mutable()) { return cum_voxels(); } - torch::Tensor retTorch = torch::empty({grid_count()}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); + torch::Tensor retTorch = torch::empty( + { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); for (int64_t bi = 0; bi < grid_count(); bi += 1) { @@ -375,7 +402,8 @@ torch::Tensor GridBatch::cum_enabled_voxels() const { return retTorch.to(device()); } -int64_t GridBatch::cum_enabled_voxels_at(int64_t bi) const { +int64_t +GridBatch::cum_enabled_voxels_at(int64_t bi) const { int64_t nCum = 0; for (int64_t b = 0; b < bi; ++b) { nCum += num_enabled_voxels_at(b); @@ -383,8 +411,10 @@ int64_t GridBatch::cum_enabled_voxels_at(int64_t bi) const { return nCum; } -torch::Tensor GridBatch::num_bytes() const { - torch::Tensor retTorch = torch::empty({grid_count()}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); +torch::Tensor +GridBatch::num_bytes() const { + torch::Tensor retTorch = torch::empty( + { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); for (int64_t bi = 0; bi < grid_count(); bi += 1) { @@ -393,9 +423,10 @@ torch::Tensor GridBatch::num_bytes() const { return retTorch.to(device()); } - -torch::Tensor GridBatch::num_leaf_nodes() const { - torch::Tensor retTorch = torch::empty({grid_count()}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); +torch::Tensor +GridBatch::num_leaf_nodes() const { + torch::Tensor retTorch = torch::empty( + { grid_count() }, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64)); auto acc = retTorch.accessor(); for (int64_t bi = 0; bi < grid_count(); bi += 1) { @@ -404,202 +435,232 @@ torch::Tensor GridBatch::num_leaf_nodes() const { return retTorch.to(device()); } - -void GridBatch::disable_ijk(const JaggedTensor& ijk) { - TORCH_CHECK_VALUE(ijk.ldim() == 1, - "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", ijk.ldim(), "list dimensions" - ); +void +GridBatch::disable_ijk(const JaggedTensor &ijk) { + TORCH_CHECK_VALUE( + ijk.ldim() == 1, + "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ijk.ldim(), "list dimensions"); FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { fvdb::detail::ops::dispatchSetMaskedIjk(*impl(), ijk, false); }); } - -void GridBatch::enable_ijk(const JaggedTensor& ijk) { - TORCH_CHECK_VALUE(ijk.ldim() == 1, - "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", ijk.ldim(), "list dimensions" - ); +void +GridBatch::enable_ijk(const JaggedTensor &ijk) { + TORCH_CHECK_VALUE( + ijk.ldim() == 1, + "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ijk.ldim(), "list dimensions"); FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { fvdb::detail::ops::dispatchSetMaskedIjk(*impl(), ijk, true); }); } -void GridBatch::set_from_mesh(const JaggedTensor& mesh_vertices, - const JaggedTensor& mesh_faces, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins) { - TORCH_CHECK_VALUE(mesh_vertices.ldim() == 1, - "Expected mesh_vertices to have 1 list dimension, i.e. be a single list of coordinate values, but got", mesh_vertices.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(mesh_faces.ldim() == 1, - "Expected mesh_faces to have 1 list dimension, i.e. be a single list of coordinate values, but got", mesh_faces.ldim(), "list dimensions" - ); - TORCH_CHECK_TYPE(mesh_vertices.is_floating_point(), "mesh_vertices must have a floating point type"); - TORCH_CHECK_VALUE(mesh_vertices.rdim() == 2, std::string("Expected mesh_vertices to have 2 dimensions (shape (n, 3)) but got ") + - std::to_string(mesh_vertices.rdim()) + " dimensions"); +void +GridBatch::set_from_mesh(const JaggedTensor &mesh_vertices, const JaggedTensor &mesh_faces, + const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins) { + TORCH_CHECK_VALUE( + mesh_vertices.ldim() == 1, + "Expected mesh_vertices to have 1 list dimension, i.e. be a single list of coordinate values, but got", + mesh_vertices.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + mesh_faces.ldim() == 1, + "Expected mesh_faces to have 1 list dimension, i.e. be a single list of coordinate values, but got", + mesh_faces.ldim(), "list dimensions"); + TORCH_CHECK_TYPE(mesh_vertices.is_floating_point(), + "mesh_vertices must have a floating point type"); + TORCH_CHECK_VALUE( + mesh_vertices.rdim() == 2, + std::string("Expected mesh_vertices to have 2 dimensions (shape (n, 3)) but got ") + + std::to_string(mesh_vertices.rdim()) + " dimensions"); TORCH_CHECK_VALUE(mesh_vertices.rsize(1) == 3, "Expected 3 dimensional mesh_vertices but got mesh_vertices.rshape[1] = " + - std::to_string(mesh_vertices.rsize(1))); + std::to_string(mesh_vertices.rsize(1))); TORCH_CHECK_TYPE(!mesh_faces.is_floating_point(), "mesh_faces must have an integer type"); - TORCH_CHECK_VALUE(mesh_faces.rdim() == 2, std::string("Expected mesh_faces to have 2 dimensions (shape (n, 3)) but got ") + - std::to_string(mesh_faces.rdim()) + " dimensions"); + TORCH_CHECK_VALUE( + mesh_faces.rdim() == 2, + std::string("Expected mesh_faces to have 2 dimensions (shape (n, 3)) but got ") + + std::to_string(mesh_faces.rdim()) + " dimensions"); TORCH_CHECK_VALUE(mesh_faces.rsize(1) == 3, "Expected 3 dimensional mesh_faces but got mesh_faces.rshape[1] = " + - std::to_string(mesh_faces.rsize(1))); + std::to_string(mesh_faces.rsize(1))); TORCH_CHECK_VALUE(mesh_vertices.num_outer_lists() == mesh_faces.num_outer_lists(), "Expected same number of vertex and face sets got len(mesh_vertices) = ", - mesh_vertices.num_outer_lists(), " and len(mesh_faces) = ", mesh_faces.num_outer_lists()); + mesh_vertices.num_outer_lists(), + " and len(mesh_faces) = ", mesh_faces.num_outer_lists()); const int64_t numGrids = mesh_vertices.joffsets().size(0) - 1; - TORCH_CHECK(numGrids == mesh_vertices.num_outer_lists(), "If this happens, Francis' paranoia was justified. File a bug"); - TORCH_CHECK_VALUE(numGrids <= MAX_GRIDS_PER_BATCH, - "Cannot create a grid with more than ", MAX_GRIDS_PER_BATCH, " grids in a batch. ", - "You passed in ", numGrids, " mesh sets."); + TORCH_CHECK(numGrids == mesh_vertices.num_outer_lists(), + "If this happens, Francis' paranoia was justified. File a bug"); + TORCH_CHECK_VALUE(numGrids <= MAX_GRIDS_PER_BATCH, "Cannot create a grid with more than ", + MAX_GRIDS_PER_BATCH, " grids in a batch. ", "You passed in ", numGrids, + " mesh sets."); - const std::vector voxSizesVec = voxel_sizes.value(numGrids, true /* onlyPositive */, "voxel_sizes"); - const std::vector voxOriginsVec = origins.value(numGrids, false /* onlyPositive */, "voxel_origins"); + const std::vector voxSizesVec = + voxel_sizes.value(numGrids, true /* onlyPositive */, "voxel_sizes"); + const std::vector voxOriginsVec = + origins.value(numGrids, false /* onlyPositive */, "voxel_origins"); std::vector transforms; transforms.reserve(numGrids); for (int64_t i = 0; i < numGrids; i += 1) { - transforms.push_back(detail::primalVoxelTransformForSizeAndOrigin(voxSizesVec[i], voxOriginsVec[i])); + transforms.push_back( + detail::primalVoxelTransformForSizeAndOrigin(voxSizesVec[i], voxOriginsVec[i])); } mImpl = c10::make_intrusive( - detail::build::buildGridFromMesh(is_mutable(), mesh_vertices, mesh_faces, transforms), - voxSizesVec, voxOriginsVec); + detail::build::buildGridFromMesh(is_mutable(), mesh_vertices, mesh_faces, transforms), + voxSizesVec, voxOriginsVec); } - -void GridBatch::set_from_points(const JaggedTensor& points, - const Vec3i& pad_min, - const Vec3i& pad_max, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins) { - TORCH_CHECK_VALUE(points.ldim() == 1, - "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", points.ldim(), "list dimensions" - ); +void +GridBatch::set_from_points(const JaggedTensor &points, const Vec3i &pad_min, const Vec3i &pad_max, + const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins) { + TORCH_CHECK_VALUE( + points.ldim() == 1, + "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points.ldim(), "list dimensions"); TORCH_CHECK_TYPE(points.is_floating_point(), "points must have a floating point type"); - TORCH_CHECK_VALUE(points.rdim() == 2, std::string("Expected points to have 2 dimensions (shape (n, 3)) but got ") + - std::to_string(points.rdim()) + " dimensions"); + TORCH_CHECK_VALUE(points.rdim() == 2, + std::string("Expected points to have 2 dimensions (shape (n, 3)) but got ") + + std::to_string(points.rdim()) + " dimensions"); TORCH_CHECK_VALUE(points.rsize(1) == 3, "Expected 3 dimensional points but got points.rshape[1] = " + - std::to_string(points.rsize(1))); + std::to_string(points.rsize(1))); impl()->checkDevice(points); - TORCH_CHECK(points.num_tensors() == points.num_outer_lists(), "If this happens, Francis' paranoia about tensors and points was justified. File a bug"); + TORCH_CHECK( + points.num_tensors() == points.num_outer_lists(), + "If this happens, Francis' paranoia about tensors and points was justified. File a bug"); TORCH_CHECK_VALUE(points.num_outer_lists() <= MAX_GRIDS_PER_BATCH, - "Cannot create a grid with more than ", MAX_GRIDS_PER_BATCH, " grids in a batch. ", - "You passed in ", points.num_outer_lists(), " points sets."); + "Cannot create a grid with more than ", MAX_GRIDS_PER_BATCH, + " grids in a batch. ", "You passed in ", points.num_outer_lists(), + " points sets."); const nanovdb::Coord padMin = pad_min.value(); const nanovdb::Coord padMax = pad_max.value(); const int64_t numGrids = points.joffsets().size(0) - 1; - TORCH_CHECK(numGrids == points.num_outer_lists(), "If this happens, Francis' paranoia about grids and points was justified. File a bug"); + TORCH_CHECK( + numGrids == points.num_outer_lists(), + "If this happens, Francis' paranoia about grids and points was justified. File a bug"); - const std::vector voxSizesVec = voxel_sizes.value(numGrids, true /* onlyPositive */, "voxel_sizes"); - const std::vector voxOriginsVec = origins.value(numGrids, false /* onlyPositive */, "voxel_origins"); + const std::vector voxSizesVec = + voxel_sizes.value(numGrids, true /* onlyPositive */, "voxel_sizes"); + const std::vector voxOriginsVec = + origins.value(numGrids, false /* onlyPositive */, "voxel_origins"); std::vector transforms; transforms.reserve(numGrids); for (int64_t i = 0; i < numGrids; i += 1) { - transforms.push_back(detail::primalVoxelTransformForSizeAndOrigin(voxSizesVec[i], voxOriginsVec[i])); + transforms.push_back( + detail::primalVoxelTransformForSizeAndOrigin(voxSizesVec[i], voxOriginsVec[i])); } mImpl = c10::make_intrusive( - detail::build::buildPaddedGridFromPoints(is_mutable(), points, transforms, padMin, padMax), - voxSizesVec, voxOriginsVec); -} - - -void GridBatch::set_from_nearest_voxels_to_points(const JaggedTensor& points, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins) { - TORCH_CHECK_VALUE(points.ldim() == 1, - "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", points.ldim(), "list dimensions" - ); + detail::build::buildPaddedGridFromPoints(is_mutable(), points, transforms, padMin, padMax), + voxSizesVec, voxOriginsVec); +} + +void +GridBatch::set_from_nearest_voxels_to_points(const JaggedTensor &points, + const Vec3dBatchOrScalar &voxel_sizes, + const Vec3dBatch &origins) { + TORCH_CHECK_VALUE( + points.ldim() == 1, + "Expected points to have 1 list dimension, i.e. be a single list of coordinate values, but got", + points.ldim(), "list dimensions"); TORCH_CHECK_TYPE(points.is_floating_point(), "points must have a floating point type"); - TORCH_CHECK_VALUE(points.rdim() == 2, std::string("Expected points to have 2 dimensions (shape (n, 3)) but got ") + - std::to_string(points.rdim()) + " dimensions"); + TORCH_CHECK_VALUE(points.rdim() == 2, + std::string("Expected points to have 2 dimensions (shape (n, 3)) but got ") + + std::to_string(points.rdim()) + " dimensions"); TORCH_CHECK_VALUE(points.rsize(1) == 3, "Expected 3 dimensional points but got points.shape[1] = " + - std::to_string(points.rsize(1))); + std::to_string(points.rsize(1))); impl()->checkDevice(points); - TORCH_CHECK(points.num_tensors() == points.num_outer_lists(), "If this happens, Francis' paranoia was justified. File a bug"); + TORCH_CHECK(points.num_tensors() == points.num_outer_lists(), + "If this happens, Francis' paranoia was justified. File a bug"); TORCH_CHECK_VALUE(points.num_outer_lists() <= MAX_GRIDS_PER_BATCH, - "Cannot create a grid with more than ", MAX_GRIDS_PER_BATCH, " grids in a batch. ", - "You passed in ", points.num_outer_lists(), " point sets."); + "Cannot create a grid with more than ", MAX_GRIDS_PER_BATCH, + " grids in a batch. ", "You passed in ", points.num_outer_lists(), + " point sets."); const int64_t numGrids = points.joffsets().size(0) - 1; - TORCH_CHECK(numGrids == points.num_outer_lists(), "If this happens, Francis' paranoia was justified. File a bug"); + TORCH_CHECK(numGrids == points.num_outer_lists(), + "If this happens, Francis' paranoia was justified. File a bug"); - const std::vector voxSizesVec = voxel_sizes.value(numGrids, true /* onlyPositive */, "voxel_sizes"); - const std::vector voxOriginsVec = origins.value(numGrids, false /* onlyPositive */, "voxel_origins"); + const std::vector voxSizesVec = + voxel_sizes.value(numGrids, true /* onlyPositive */, "voxel_sizes"); + const std::vector voxOriginsVec = + origins.value(numGrids, false /* onlyPositive */, "voxel_origins"); std::vector transforms; transforms.reserve(numGrids); for (int64_t i = 0; i < numGrids; i += 1) { - transforms.push_back(detail::primalVoxelTransformForSizeAndOrigin(voxSizesVec[i], voxOriginsVec[i])); + transforms.push_back( + detail::primalVoxelTransformForSizeAndOrigin(voxSizesVec[i], voxOriginsVec[i])); } mImpl = c10::make_intrusive( - detail::build::buildNearestNeighborGridFromPoints(is_mutable(), points, transforms), - voxSizesVec, voxOriginsVec); -} - - -void GridBatch::set_from_ijk(const JaggedTensor& coords, - const Vec3i& pad_min, - const Vec3i& pad_max, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins) { - TORCH_CHECK_VALUE(coords.ldim() == 1, - "Expected coords to have 1 list dimension, i.e. be a single list of coordinate values, but got", coords.ldim(), "list dimensions" - ); - TORCH_CHECK_TYPE(at::isIntegralType(coords.scalar_type(), false), "coords must have an integer type"); - TORCH_CHECK_VALUE(coords.rdim() == 2, std::string("Expected points to have 2 dimensions (shape (n, 3)) but got ") + - std::to_string(coords.rdim()) + " dimensions"); + detail::build::buildNearestNeighborGridFromPoints(is_mutable(), points, transforms), + voxSizesVec, voxOriginsVec); +} + +void +GridBatch::set_from_ijk(const JaggedTensor &coords, const Vec3i &pad_min, const Vec3i &pad_max, + const Vec3dBatchOrScalar &voxel_sizes, const Vec3dBatch &origins) { + TORCH_CHECK_VALUE( + coords.ldim() == 1, + "Expected coords to have 1 list dimension, i.e. be a single list of coordinate values, but got", + coords.ldim(), "list dimensions"); + TORCH_CHECK_TYPE(at::isIntegralType(coords.scalar_type(), false), + "coords must have an integer type"); + TORCH_CHECK_VALUE(coords.rdim() == 2, + std::string("Expected points to have 2 dimensions (shape (n, 3)) but got ") + + std::to_string(coords.rdim()) + " dimensions"); TORCH_CHECK_VALUE(coords.rsize(1) == 3, "Expected 3 dimensional coords but got points.rshape[1] = " + - std::to_string(coords.rsize(1))); + std::to_string(coords.rsize(1))); impl()->checkDevice(coords); - TORCH_CHECK(coords.num_tensors() == coords.num_outer_lists(), "If this happens, Francis' paranoia was justified. File a bug"); + TORCH_CHECK(coords.num_tensors() == coords.num_outer_lists(), + "If this happens, Francis' paranoia was justified. File a bug"); TORCH_CHECK_VALUE(coords.num_outer_lists() <= MAX_GRIDS_PER_BATCH, - "Cannot create a grid with more than ", MAX_GRIDS_PER_BATCH, " grids in a batch. ", - "You passed in ", coords.num_outer_lists(), " coordinate sets."); + "Cannot create a grid with more than ", MAX_GRIDS_PER_BATCH, + " grids in a batch. ", "You passed in ", coords.num_outer_lists(), + " coordinate sets."); - const nanovdb::Coord& padMin = pad_min.value(); - const nanovdb::Coord& padMax = pad_max.value(); + const nanovdb::Coord &padMin = pad_min.value(); + const nanovdb::Coord &padMax = pad_max.value(); const int64_t numGrids = coords.joffsets().size(0) - 1; - TORCH_CHECK(numGrids == coords.num_outer_lists(), "If this happens, Francis' paranoia was justified. File a bug"); + TORCH_CHECK(numGrids == coords.num_outer_lists(), + "If this happens, Francis' paranoia was justified. File a bug"); - const std::vector voxSizesVec = voxel_sizes.value(numGrids, true /* onlyPositive */, "voxel_sizes"); - const std::vector voxOriginsVec = origins.value(numGrids, false /* onlyPositive */, "voxel_origins"); + const std::vector voxSizesVec = + voxel_sizes.value(numGrids, true /* onlyPositive */, "voxel_sizes"); + const std::vector voxOriginsVec = + origins.value(numGrids, false /* onlyPositive */, "voxel_origins"); mImpl = c10::make_intrusive( - detail::build::buildPaddedGridFromCoords(is_mutable(), coords, padMin, padMax), - voxSizesVec, voxOriginsVec); + detail::build::buildPaddedGridFromCoords(is_mutable(), coords, padMin, padMax), voxSizesVec, + voxOriginsVec); } - -void GridBatch::set_from_dense_grid(const int64_t num_grids, - const Vec3i& dense_dims, - const Vec3i& ijk_min, - const Vec3dBatchOrScalar& voxel_sizes, - const Vec3dBatch& origins, - torch::optional mask) { - +void +GridBatch::set_from_dense_grid(const int64_t num_grids, const Vec3i &dense_dims, + const Vec3i &ijk_min, const Vec3dBatchOrScalar &voxel_sizes, + const Vec3dBatch &origins, torch::optional mask) { TORCH_CHECK_VALUE(num_grids >= 0, "num_grids must be non-negative"); - const nanovdb::Coord& size = dense_dims.value(); + const nanovdb::Coord &size = dense_dims.value(); - const nanovdb::Coord& ijk_min_value = ijk_min.value(); + const nanovdb::Coord &ijk_min_value = ijk_min.value(); if (mask.has_value()) { impl()->checkDevice(mask.value()); - TORCH_CHECK_VALUE(mask.value().dtype() == torch::kBool, "mask must be a boolean type or None"); + TORCH_CHECK_VALUE(mask.value().dtype() == torch::kBool, + "mask must be a boolean type or None"); TORCH_CHECK_VALUE(mask.value().dim() == 3, "mask must be 3 dimensional"); TORCH_CHECK_VALUE(mask.value().size(0) == size[0], "mask must have shape (w, h, d) = size"); TORCH_CHECK_VALUE(mask.value().size(1) == size[1], "mask must have shape (w, h, d) = size"); @@ -608,23 +669,26 @@ void GridBatch::set_from_dense_grid(const int64_t num_grids, TORCH_CHECK_VALUE(size[0] >= 0 && size[1] >= 0 && size[2] >= 0, "size must be non-negative"); - std::vector voxSizesVec = voxel_sizes.value(num_grids, true /* onlyPositive */, "voxel_sizes"); - std::vector voxOriginsVec = origins.value(num_grids, false /* onlyPositive */, "voxel_origins"); + std::vector voxSizesVec = + voxel_sizes.value(num_grids, true /* onlyPositive */, "voxel_sizes"); + std::vector voxOriginsVec = + origins.value(num_grids, false /* onlyPositive */, "voxel_origins"); - TORCH_CHECK_VALUE(num_grids <= MAX_GRIDS_PER_BATCH, - "Cannot create a grid with more than ", MAX_GRIDS_PER_BATCH, " grids in a batch. ", - "You requested ", num_grids, " grids."); - TORCH_CHECK((size_t) num_grids == voxSizesVec.size(), "If this happens, Francis' paranoia was justified. File a bug"); - TORCH_CHECK((size_t) num_grids == voxOriginsVec.size(), "If this happens, Francis' paranoia was justified. File a bug"); + TORCH_CHECK_VALUE(num_grids <= MAX_GRIDS_PER_BATCH, "Cannot create a grid with more than ", + MAX_GRIDS_PER_BATCH, " grids in a batch. ", "You requested ", num_grids, + " grids."); + TORCH_CHECK((size_t)num_grids == voxSizesVec.size(), + "If this happens, Francis' paranoia was justified. File a bug"); + TORCH_CHECK((size_t)num_grids == voxOriginsVec.size(), + "If this happens, Francis' paranoia was justified. File a bug"); mImpl = c10::make_intrusive( - detail::build::buildDenseGrid(device(), is_mutable(), num_grids, size, ijk_min_value, mask), - voxSizesVec, voxOriginsVec); - + detail::build::buildDenseGrid(device(), is_mutable(), num_grids, size, ijk_min_value, mask), + voxSizesVec, voxOriginsVec); } - -GridBatch GridBatch::dual_grid(bool exclude_border) const { +GridBatch +GridBatch::dual_grid(bool exclude_border) const { GridBatch ret = GridBatch(device(), is_mutable()); if (grid_count() == 0) { return ret; @@ -633,14 +697,15 @@ GridBatch GridBatch::dual_grid(bool exclude_border) const { return ret; } - -GridBatch GridBatch::coarsened_grid(Vec3iOrScalar branch_factor) const { +GridBatch +GridBatch::coarsened_grid(Vec3iOrScalar branch_factor) const { nanovdb::Coord branchFactorCoord = branch_factor.value(); for (int i = 0; i < 3; i += 1) { - TORCH_CHECK_VALUE(branchFactorCoord[i] > 0, "branch_factor must be strictly positive. Got [" + - std::to_string(branchFactorCoord[0]) + ", " + - std::to_string(branchFactorCoord[1]) + ", " + - std::to_string(branchFactorCoord[2]) + "]"); + TORCH_CHECK_VALUE(branchFactorCoord[i] > 0, + "branch_factor must be strictly positive. Got [" + + std::to_string(branchFactorCoord[0]) + ", " + + std::to_string(branchFactorCoord[1]) + ", " + + std::to_string(branchFactorCoord[2]) + "]"); } GridBatch ret(device(), is_mutable()); if (grid_count() == 0) { @@ -650,20 +715,22 @@ GridBatch GridBatch::coarsened_grid(Vec3iOrScalar branch_factor) const { return ret; } - -GridBatch GridBatch::subdivided_grid(Vec3iOrScalar subdiv_factor, const torch::optional mask) const { - +GridBatch +GridBatch::subdivided_grid(Vec3iOrScalar subdiv_factor, + const torch::optional mask) const { if (mask.has_value()) { - TORCH_CHECK_VALUE(mask.value().ldim() == 1, - "Expected mask to have 1 list dimension, i.e. be a single list of coordinate values, but got", mask.value().ldim(), "list dimensions" - ); + TORCH_CHECK_VALUE( + mask.value().ldim() == 1, + "Expected mask to have 1 list dimension, i.e. be a single list of coordinate values, but got", + mask.value().ldim(), "list dimensions"); } const nanovdb::Coord subdivFactorCoord = subdiv_factor.value(); for (int i = 0; i < 3; i += 1) { - TORCH_CHECK_VALUE(subdivFactorCoord[i] > 0, "subdiv_factor must be strictly positive. Got [" + - std::to_string(subdivFactorCoord[0]) + ", " + - std::to_string(subdivFactorCoord[1]) + ", " + - std::to_string(subdivFactorCoord[2]) + "]"); + TORCH_CHECK_VALUE(subdivFactorCoord[i] > 0, + "subdiv_factor must be strictly positive. Got [" + + std::to_string(subdivFactorCoord[0]) + ", " + + std::to_string(subdivFactorCoord[1]) + ", " + + std::to_string(subdivFactorCoord[2]) + "]"); } GridBatch ret = GridBatch(device(), is_mutable()); @@ -674,11 +741,11 @@ GridBatch GridBatch::subdivided_grid(Vec3iOrScalar subdiv_factor, const torch::o return ret; } -GridBatch GridBatch::clipped_grid(const Vec3iBatch& ijk_min, - const Vec3iBatch& ijk_max) const { - +GridBatch +GridBatch::clipped_grid(const Vec3iBatch &ijk_min, const Vec3iBatch &ijk_max) const { JaggedTensor activeVoxelMask = FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchActiveVoxelsInBoundsMask(*impl(), ijk_min, ijk_max, false); + return fvdb::detail::ops::dispatchActiveVoxelsInBoundsMask(*impl(), ijk_min, + ijk_max, false); }); JaggedTensor activeVoxelCoords = FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { @@ -690,26 +757,29 @@ GridBatch GridBatch::clipped_grid(const Vec3iBatch& ijk_min, // construct grid from ijk's clipped from original grid GridBatch clippedGrid = sparse_grid_from_ijk(activeVoxelMaskCoords, Vec3i(), Vec3i(), - voxel_sizes(), origins(), is_mutable()); + voxel_sizes(), origins(), is_mutable()); return clippedGrid; } -std::pair GridBatch::clip(const JaggedTensor& features, - const Vec3iBatch& ijk_min, - const Vec3iBatch& ijk_max) const { - - TORCH_CHECK_VALUE(features.ldim() == 1, - "Expected features to have 1 list dimension, i.e. be a single list of coordinate values, but got", features.ldim(), "list dimensions" - ); +std::pair +GridBatch::clip(const JaggedTensor &features, const Vec3iBatch &ijk_min, + const Vec3iBatch &ijk_max) const { + TORCH_CHECK_VALUE( + features.ldim() == 1, + "Expected features to have 1 list dimension, i.e. be a single list of coordinate values, but got", + features.ldim(), "list dimensions"); impl()->checkDevice(features); TORCH_CHECK(features.rsize(0) == total_voxels(), "Value count of features does not match grid"); - TORCH_CHECK(features.num_outer_lists() == grid_count(), "Batch size of features does not match grid."); - TORCH_CHECK(torch::equal(features.joffsets(), impl()->voxelOffsets(false)), "Offsets of features does not match grid."); + TORCH_CHECK(features.num_outer_lists() == grid_count(), + "Batch size of features does not match grid."); + TORCH_CHECK(torch::equal(features.joffsets(), impl()->voxelOffsets(false)), + "Offsets of features does not match grid."); JaggedTensor activeVoxelMask = FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchActiveVoxelsInBoundsMask(*impl(), ijk_min, ijk_max, false); + return fvdb::detail::ops::dispatchActiveVoxelsInBoundsMask(*impl(), ijk_min, + ijk_max, false); }); JaggedTensor activeVoxelCoords = FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { @@ -728,10 +798,12 @@ std::pair GridBatch::clip(const JaggedTensor& features, return std::make_pair(clippedFeatures, clippedGrid); } -std::vector GridBatch::marching_cubes(const JaggedTensor& field, double level) const { - TORCH_CHECK_VALUE(field.ldim() == 1, - "Expected field to have 1 list dimension, i.e. be a single list of coordinate values, but got", field.ldim(), "list dimensions" - ); +std::vector +GridBatch::marching_cubes(const JaggedTensor &field, double level) const { + TORCH_CHECK_VALUE( + field.ldim() == 1, + "Expected field to have 1 list dimension, i.e. be a single list of coordinate values, but got", + field.ldim(), "list dimensions"); TORCH_CHECK_TYPE(field.is_floating_point(), "field must have a floating point type"); TORCH_CHECK_VALUE(field.numel() == total_voxels(), "Value count not match!"); TORCH_CHECK_VALUE(field.num_outer_lists() == grid_count(), "Batch size not match!"); @@ -742,30 +814,37 @@ std::vector GridBatch::marching_cubes(const JaggedTensor& field, d if (fieldJdata.dim() != 1) { fieldJdata = fieldJdata.squeeze(); } - TORCH_CHECK(fieldJdata.dim() == 1, std::string("Expected field to have 1 effective dimension but got ") + - std::to_string(field.rdim()) + " dimensions"); + TORCH_CHECK(fieldJdata.dim() == 1, + std::string("Expected field to have 1 effective dimension but got ") + + std::to_string(field.rdim()) + " dimensions"); impl()->checkDevice(field); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchMarchingCubes(*impl(), fieldJdata, level); }); } -JaggedTensor GridBatch::sparse_conv_halo(const JaggedTensor& input, const torch::Tensor& weight, int variant) const { - TORCH_CHECK_VALUE(input.ldim() == 1, - "Expected input to have 1 list dimension, i.e. be a single list of coordinate values, but got", input.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::sparse_conv_halo(const JaggedTensor &input, const torch::Tensor &weight, + int variant) const { + TORCH_CHECK_VALUE( + input.ldim() == 1, + "Expected input to have 1 list dimension, i.e. be a single list of coordinate values, but got", + input.ldim(), "list dimensions"); TORCH_CHECK_TYPE(input.is_floating_point(), "input must have a floating point type"); TORCH_CHECK_VALUE(input.rsize(0) == total_voxels(), "Value count not match!"); TORCH_CHECK_VALUE(input.num_outer_lists() == grid_count(), "Batch size not match!"); impl()->checkDevice(input); - torch::Tensor ret = detail::autograd::SparseConvolutionHalo::apply(impl(), input.jdata(), weight, variant)[0]; + torch::Tensor ret = + detail::autograd::SparseConvolutionHalo::apply(impl(), input.jdata(), weight, variant)[0]; return input.jagged_like(ret); } - -GridBatch GridBatch::conv_grid(Vec3iOrScalar kernel_size, Vec3iOrScalar stride) const { - TORCH_CHECK_VALUE(Vec3iOrScalar(0).value() < kernel_size.value(), "kernel_size must be strictly positive. Got " + kernel_size.toString()); - TORCH_CHECK_VALUE(Vec3iOrScalar(0).value() < stride.value(), "stride must be strictly positive. Got " + stride.toString()); +GridBatch +GridBatch::conv_grid(Vec3iOrScalar kernel_size, Vec3iOrScalar stride) const { + TORCH_CHECK_VALUE(Vec3iOrScalar(0).value() < kernel_size.value(), + "kernel_size must be strictly positive. Got " + kernel_size.toString()); + TORCH_CHECK_VALUE(Vec3iOrScalar(0).value() < stride.value(), + "stride must be strictly positive. Got " + stride.toString()); GridBatch ret = GridBatch(device(), is_mutable()); if (grid_count() == 0) { return ret; @@ -773,28 +852,36 @@ GridBatch GridBatch::conv_grid(Vec3iOrScalar kernel_size, Vec3iOrScalar stride) std::vector voxS, voxO; impl()->gridVoxelSizesAndOrigins(voxS, voxO); ret.mImpl = c10::make_intrusive( - detail::build::buildConvGridFromGrid(ret.is_mutable(), *impl(), kernel_size.value(), stride.value()), voxS, voxO); - ret.impl()->setCoarseTransformFromFineGrid(*impl(), nanovdb::Coord(stride.value().x(), stride.value().y(), stride.value().z())); + detail::build::buildConvGridFromGrid(ret.is_mutable(), *impl(), kernel_size.value(), + stride.value()), + voxS, voxO); + ret.impl()->setCoarseTransformFromFineGrid( + *impl(), nanovdb::Coord(stride.value().x(), stride.value().y(), stride.value().z())); return ret; } -void GridBatch::buildCoarseFromFineGrid(const GridBatch& fineGrid, nanovdb::Coord branchFactor) { +void +GridBatch::buildCoarseFromFineGrid(const GridBatch &fineGrid, nanovdb::Coord branchFactor) { std::vector voxS, voxO; fineGrid.impl()->gridVoxelSizesAndOrigins(voxS, voxO); mImpl = c10::make_intrusive( - detail::build::buildCoarseGridFromFineGrid(is_mutable(), *fineGrid.impl(), branchFactor), - voxS, voxO); + detail::build::buildCoarseGridFromFineGrid(is_mutable(), *fineGrid.impl(), branchFactor), + voxS, voxO); impl()->setCoarseTransformFromFineGrid(*fineGrid.impl(), branchFactor); } - -void GridBatch::buildFineFromCoarseGrid(const GridBatch& coarseGrid, const torch::optional& subdivMask, nanovdb::Coord subdivFactor) { +void +GridBatch::buildFineFromCoarseGrid(const GridBatch &coarseGrid, + const torch::optional &subdivMask, + nanovdb::Coord subdivFactor) { if (subdivMask.has_value()) { - TORCH_CHECK_VALUE(subdivMask.value().ldim() == 1, - "Expected subdiv_mask to have 1 list dimension, i.e. be a single list of coordinate values, but got", subdivMask.value().ldim(), "list dimensions" - ); + TORCH_CHECK_VALUE( + subdivMask.value().ldim() == 1, + "Expected subdiv_mask to have 1 list dimension, i.e. be a single list of coordinate values, but got", + subdivMask.value().ldim(), "list dimensions"); impl()->checkDevice(subdivMask.value()); - TORCH_CHECK(subdivMask.value().jdata().sizes().size() == 1, "subdivision mask must have 1 dimension"); + TORCH_CHECK(subdivMask.value().jdata().sizes().size() == 1, + "subdivision mask must have 1 dimension"); TORCH_CHECK(subdivMask.value().jdata().size(0) == coarseGrid.total_voxels(), "subdivision mask must be either empty tensor or have one entry per voxel"); TORCH_CHECK(subdivMask.value().scalar_type() == torch::kBool, @@ -804,269 +891,295 @@ void GridBatch::buildFineFromCoarseGrid(const GridBatch& coarseGrid, const torch std::vector voxS, voxO; coarseGrid.impl()->gridVoxelSizesAndOrigins(voxS, voxO); mImpl = c10::make_intrusive( - detail::build::buildFineGridFromCoarseGrid(is_mutable(), *coarseGrid.impl(), subdivMask, subdivFactor), - voxS, voxO); + detail::build::buildFineGridFromCoarseGrid(is_mutable(), *coarseGrid.impl(), subdivMask, + subdivFactor), + voxS, voxO); impl()->setFineTransformFromCoarseGrid(*coarseGrid.impl(), subdivFactor); } - -void GridBatch::buildDualFromPrimalGrid(const GridBatch& primalGrid, bool excludeBorder) { +void +GridBatch::buildDualFromPrimalGrid(const GridBatch &primalGrid, bool excludeBorder) { std::vector voxS, voxO; primalGrid.impl()->gridVoxelSizesAndOrigins(voxS, voxO); mImpl = c10::make_intrusive( - detail::build::buildPaddedGridFromGrid(is_mutable(), *primalGrid.impl(), 0, 1, excludeBorder), + detail::build::buildPaddedGridFromGrid(is_mutable(), *primalGrid.impl(), 0, 1, + excludeBorder), voxS, voxO); impl()->setPrimalTransformFromDualGrid(*primalGrid.impl()); } - -std::vector GridBatch::voxels_along_rays(const JaggedTensor& ray_origins, - const JaggedTensor& ray_directions, - int64_t max_vox, double eps, - bool return_ijk, bool cumulative) const { - TORCH_CHECK_VALUE(ray_origins.ldim() == 1, - "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", ray_origins.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(ray_directions.ldim() == 1, - "Expected ray_directions to have 1 list dimension, i.e. be a single list of coordinate values, but got", ray_directions.ldim(), "list dimensions" - ); +std::vector +GridBatch::voxels_along_rays(const JaggedTensor &ray_origins, const JaggedTensor &ray_directions, + int64_t max_vox, double eps, bool return_ijk, bool cumulative) const { + TORCH_CHECK_VALUE( + ray_origins.ldim() == 1, + "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ray_origins.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + ray_directions.ldim() == 1, + "Expected ray_directions to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ray_directions.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchVoxelsAlongRays(*impl(), ray_origins, ray_directions, max_vox, eps, return_ijk, cumulative); + return fvdb::detail::ops::dispatchVoxelsAlongRays( + *impl(), ray_origins, ray_directions, max_vox, eps, return_ijk, cumulative); }); } - -JaggedTensor GridBatch::segments_along_rays(const JaggedTensor& ray_origins, - const JaggedTensor& ray_directions, - int64_t max_segments, double eps, bool ignore_masked) const { - TORCH_CHECK_VALUE(ray_origins.ldim() == 1, - "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", ray_origins.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(ray_directions.ldim() == 1, - "Expected ray_directions to have 1 list dimension, i.e. be a single list of coordinate values, but got", ray_directions.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::segments_along_rays(const JaggedTensor &ray_origins, const JaggedTensor &ray_directions, + int64_t max_segments, double eps, bool ignore_masked) const { + TORCH_CHECK_VALUE( + ray_origins.ldim() == 1, + "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ray_origins.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + ray_directions.ldim() == 1, + "Expected ray_directions to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ray_directions.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchSegmentsAlongRays(*impl(), ray_origins, ray_directions, max_segments, eps, ignore_masked); + return fvdb::detail::ops::dispatchSegmentsAlongRays( + *impl(), ray_origins, ray_directions, max_segments, eps, ignore_masked); }); } - -JaggedTensor GridBatch::ray_implicit_intersection(const JaggedTensor& ray_origins, - const JaggedTensor& ray_directions, - const JaggedTensor& gridScalars, - double eps) const { - TORCH_CHECK_VALUE(ray_origins.ldim() == 1, - "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", ray_origins.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(ray_directions.ldim() == 1, - "Expected ray_directions to have 1 list dimension, i.e. be a single list of coordinate values, but got", ray_directions.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(gridScalars.ldim() == 1, - "Expected grid_scalars to have 1 list dimension, i.e. be a single list of coordinate values, but got", gridScalars.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::ray_implicit_intersection(const JaggedTensor &ray_origins, + const JaggedTensor &ray_directions, + const JaggedTensor &gridScalars, double eps) const { + TORCH_CHECK_VALUE( + ray_origins.ldim() == 1, + "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ray_origins.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + ray_directions.ldim() == 1, + "Expected ray_directions to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ray_directions.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + gridScalars.ldim() == 1, + "Expected grid_scalars to have 1 list dimension, i.e. be a single list of coordinate values, but got", + gridScalars.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchRayImplicitIntersection(*impl(), ray_origins, ray_directions, gridScalars, eps); + return fvdb::detail::ops::dispatchRayImplicitIntersection( + *impl(), ray_origins, ray_directions, gridScalars, eps); }); } - -JaggedTensor GridBatch::uniform_ray_samples(const JaggedTensor& ray_origins, - const JaggedTensor& ray_directions, - const JaggedTensor& t_min, - const JaggedTensor& t_max, - double step_size, - double cone_angle, - bool include_end_segments, - bool return_midpoint, - double eps) const { - TORCH_CHECK_VALUE(ray_origins.ldim() == 1, - "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", ray_origins.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(ray_directions.ldim() == 1, - "Expected ray_directions to have 1 list dimension, i.e. be a single list of coordinate values, but got", ray_directions.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(t_min.ldim() == 1, - "Expected t_min to have 1 list dimension, i.e. be a single list of coordinate values, but got", t_min.ldim(), "list dimensions" - ); - TORCH_CHECK_VALUE(t_max.ldim() == 1, - "Expected t_max to have 1 list dimension, i.e. be a single list of coordinate values, but got", t_max.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::uniform_ray_samples(const JaggedTensor &ray_origins, const JaggedTensor &ray_directions, + const JaggedTensor &t_min, const JaggedTensor &t_max, + double step_size, double cone_angle, bool include_end_segments, + bool return_midpoint, double eps) const { + TORCH_CHECK_VALUE( + ray_origins.ldim() == 1, + "Expected ray_origins to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ray_origins.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + ray_directions.ldim() == 1, + "Expected ray_directions to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ray_directions.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + t_min.ldim() == 1, + "Expected t_min to have 1 list dimension, i.e. be a single list of coordinate values, but got", + t_min.ldim(), "list dimensions"); + TORCH_CHECK_VALUE( + t_max.ldim() == 1, + "Expected t_max to have 1 list dimension, i.e. be a single list of coordinate values, but got", + t_max.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchUniformRaySamples(*impl(), ray_origins, ray_directions, t_min, t_max, step_size, cone_angle, include_end_segments, return_midpoint, eps); + return fvdb::detail::ops::dispatchUniformRaySamples( + *impl(), ray_origins, ray_directions, t_min, t_max, step_size, cone_angle, + include_end_segments, return_midpoint, eps); }); } - -JaggedTensor GridBatch::neighbor_indexes(const JaggedTensor& ijk, int32_t extent, int32_t bitshift) const { - TORCH_CHECK_VALUE(ijk.ldim() == 1, - "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", ijk.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::neighbor_indexes(const JaggedTensor &ijk, int32_t extent, int32_t bitshift) const { + TORCH_CHECK_VALUE( + ijk.ldim() == 1, + "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ijk.ldim(), "list dimensions"); TORCH_CHECK_VALUE(extent >= 0, "extent must be >= 0"); nanovdb::Coord extentMin(-extent, -extent, -extent); nanovdb::Coord extentMax(extent, extent, extent); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchVoxelNeighborhood(*impl(), ijk, extentMin, extentMax, bitshift); + return fvdb::detail::ops::dispatchVoxelNeighborhood(*impl(), ijk, extentMin, + extentMax, bitshift); }); } - -JaggedTensor GridBatch::points_in_active_voxel(const JaggedTensor& xyz, bool ignore_disabled) const { - TORCH_CHECK_VALUE(xyz.ldim() == 1, - "Expected xyz to have 1 list dimension, i.e. be a single list of coordinate values, but got", xyz.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::points_in_active_voxel(const JaggedTensor &xyz, bool ignore_disabled) const { + TORCH_CHECK_VALUE( + xyz.ldim() == 1, + "Expected xyz to have 1 list dimension, i.e. be a single list of coordinate values, but got", + xyz.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchPointsInGrid(*impl(), xyz, ignore_disabled); }); } - -JaggedTensor GridBatch::cubes_intersect_grid(const JaggedTensor& cube_centers, - const Vec3dOrScalar& cube_min, - const Vec3dOrScalar& cube_max, - bool ignore_disabled) const { - TORCH_CHECK_VALUE(cube_centers.ldim() == 1, - "Expected cube_centers to have 1 list dimension, i.e. be a single list of coordinate values, but got", cube_centers.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::cubes_intersect_grid(const JaggedTensor &cube_centers, const Vec3dOrScalar &cube_min, + const Vec3dOrScalar &cube_max, bool ignore_disabled) const { + TORCH_CHECK_VALUE( + cube_centers.ldim() == 1, + "Expected cube_centers to have 1 list dimension, i.e. be a single list of coordinate values, but got", + cube_centers.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchCubesIntersectGrid(*impl(), cube_centers, cube_min, cube_max, ignore_disabled); + return fvdb::detail::ops::dispatchCubesIntersectGrid( + *impl(), cube_centers, cube_min, cube_max, ignore_disabled); }); } - -JaggedTensor GridBatch::cubes_in_grid(const JaggedTensor& cube_centers, - const Vec3dOrScalar& cube_min, - const Vec3dOrScalar& cube_max, - bool ignore_disabled) const { - TORCH_CHECK_VALUE(cube_centers.ldim() == 1, - "Expected cube_centers to have 1 list dimension, i.e. be a single list of coordinate values, but got", cube_centers.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::cubes_in_grid(const JaggedTensor &cube_centers, const Vec3dOrScalar &cube_min, + const Vec3dOrScalar &cube_max, bool ignore_disabled) const { + TORCH_CHECK_VALUE( + cube_centers.ldim() == 1, + "Expected cube_centers to have 1 list dimension, i.e. be a single list of coordinate values, but got", + cube_centers.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchCubesInGrid(*impl(), cube_centers, cube_min, cube_max, ignore_disabled); + return fvdb::detail::ops::dispatchCubesInGrid(*impl(), cube_centers, cube_min, + cube_max, ignore_disabled); }); } - -JaggedTensor GridBatch::enabled_mask() const { +JaggedTensor +GridBatch::enabled_mask() const { return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchEnabledMask(*impl(), false); }); } -JaggedTensor GridBatch::disabled_mask() const { +JaggedTensor +GridBatch::disabled_mask() const { return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchEnabledMask(*impl(), true); }); } - -JaggedTensor GridBatch::coords_in_active_voxel(const JaggedTensor& ijk, bool ignore_disabled) const { - TORCH_CHECK_VALUE(ijk.ldim() == 1, - "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", ijk.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::coords_in_active_voxel(const JaggedTensor &ijk, bool ignore_disabled) const { + TORCH_CHECK_VALUE( + ijk.ldim() == 1, + "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ijk.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchCoordsInGrid(*impl(), ijk, ignore_disabled); }); } - -JaggedTensor GridBatch::ijk_to_index(const JaggedTensor& ijk, bool cumulative) const { - TORCH_CHECK_VALUE(ijk.ldim() == 1, - "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", ijk.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::ijk_to_index(const JaggedTensor &ijk, bool cumulative) const { + TORCH_CHECK_VALUE( + ijk.ldim() == 1, + "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ijk.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchIjkToIndex(*impl(), ijk, cumulative); }); } - -JaggedTensor GridBatch::ijk_to_inv_index(const JaggedTensor& ijk, bool cumulative) const { - TORCH_CHECK_VALUE(ijk.ldim() == 1, - "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", ijk.ldim(), "list dimensions" - ); +JaggedTensor +GridBatch::ijk_to_inv_index(const JaggedTensor &ijk, bool cumulative) const { + TORCH_CHECK_VALUE( + ijk.ldim() == 1, + "Expected ijk to have 1 list dimension, i.e. be a single list of coordinate values, but got", + ijk.ldim(), "list dimensions"); return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { return fvdb::detail::ops::dispatchIjkToInvIndex(*impl(), ijk, cumulative); }); } - -JaggedTensor GridBatch::ijk() const { +JaggedTensor +GridBatch::ijk() const { return FVDB_DISPATCH_KERNEL_DEVICE(this->device(), [&]() { return fvdb::detail::ops::dispatchActiveGridCoords(*impl(), true); }); } -JaggedTensor GridBatch::ijk_enabled() const { +JaggedTensor +GridBatch::ijk_enabled() const { return FVDB_DISPATCH_KERNEL_DEVICE(this->device(), [&]() { return fvdb::detail::ops::dispatchActiveGridCoords(*impl(), false); }); } - -const torch::Tensor GridBatch::bbox() const { +const torch::Tensor +GridBatch::bbox() const { const int64_t bs = grid_count(); - torch::Tensor ret = torch::zeros({bs, 2, 3}, torch::TensorOptions().device(device()).dtype(torch::kInt32)); + torch::Tensor ret = + torch::zeros({ bs, 2, 3 }, torch::TensorOptions().device(device()).dtype(torch::kInt32)); for (int64_t i = 0; i < bs; ++i) { - const nanovdb::CoordBBox& bbox = impl()->bbox(i); - ret[i][0][0] = bbox.min()[0]; - ret[i][0][1] = bbox.min()[1]; - ret[i][0][2] = bbox.min()[2]; - ret[i][1][0] = bbox.max()[0]; - ret[i][1][1] = bbox.max()[1]; - ret[i][1][2] = bbox.max()[2]; + const nanovdb::CoordBBox &bbox = impl()->bbox(i); + ret[i][0][0] = bbox.min()[0]; + ret[i][0][1] = bbox.min()[1]; + ret[i][0][2] = bbox.min()[2]; + ret[i][1][0] = bbox.max()[0]; + ret[i][1][1] = bbox.max()[1]; + ret[i][1][2] = bbox.max()[2]; } return ret; } -const torch::Tensor GridBatch::bbox_at(int64_t bi) const { - torch::Tensor ret = torch::zeros({2, 3}, torch::TensorOptions().device(device()).dtype(torch::kInt32)); - const nanovdb::CoordBBox& bbox = impl()->bbox(bi); - ret[0][0] = bbox.min()[0]; - ret[0][1] = bbox.min()[1]; - ret[0][2] = bbox.min()[2]; - ret[1][0] = bbox.max()[0]; - ret[1][1] = bbox.max()[1]; - ret[1][2] = bbox.max()[2]; +const torch::Tensor +GridBatch::bbox_at(int64_t bi) const { + torch::Tensor ret = + torch::zeros({ 2, 3 }, torch::TensorOptions().device(device()).dtype(torch::kInt32)); + const nanovdb::CoordBBox &bbox = impl()->bbox(bi); + ret[0][0] = bbox.min()[0]; + ret[0][1] = bbox.min()[1]; + ret[0][2] = bbox.min()[2]; + ret[1][0] = bbox.max()[0]; + ret[1][1] = bbox.max()[1]; + ret[1][2] = bbox.max()[2]; return ret; } -const torch::Tensor GridBatch::dual_bbox() const { +const torch::Tensor +GridBatch::dual_bbox() const { const int64_t bs = grid_count(); - torch::Tensor ret = torch::zeros({bs, 2, 3}, torch::TensorOptions().device(device()).dtype(torch::kInt32)); + torch::Tensor ret = + torch::zeros({ bs, 2, 3 }, torch::TensorOptions().device(device()).dtype(torch::kInt32)); for (int64_t i = 0; i < bs; ++i) { - const nanovdb::CoordBBox& bbox = impl()->dualBbox(i); - ret[i][0][0] = bbox.min()[0]; - ret[i][0][1] = bbox.min()[1]; - ret[i][0][2] = bbox.min()[2]; - ret[i][1][0] = bbox.max()[0]; - ret[i][1][1] = bbox.max()[1]; - ret[i][1][2] = bbox.max()[2]; + const nanovdb::CoordBBox &bbox = impl()->dualBbox(i); + ret[i][0][0] = bbox.min()[0]; + ret[i][0][1] = bbox.min()[1]; + ret[i][0][2] = bbox.min()[2]; + ret[i][1][0] = bbox.max()[0]; + ret[i][1][1] = bbox.max()[1]; + ret[i][1][2] = bbox.max()[2]; } return ret; } -const torch::Tensor GridBatch::dual_bbox_at(int64_t bi) const { - torch::Tensor ret = torch::zeros({2, 3}, torch::TensorOptions().device(device()).dtype(torch::kInt32)); - const nanovdb::CoordBBox& bbox = impl()->dualBbox(bi); - ret[0][0] = bbox.min()[0]; - ret[0][1] = bbox.min()[1]; - ret[0][2] = bbox.min()[2]; - ret[1][0] = bbox.max()[0]; - ret[1][1] = bbox.max()[1]; - ret[1][2] = bbox.max()[2]; +const torch::Tensor +GridBatch::dual_bbox_at(int64_t bi) const { + torch::Tensor ret = + torch::zeros({ 2, 3 }, torch::TensorOptions().device(device()).dtype(torch::kInt32)); + const nanovdb::CoordBBox &bbox = impl()->dualBbox(bi); + ret[0][0] = bbox.min()[0]; + ret[0][1] = bbox.min()[1]; + ret[0][2] = bbox.min()[2]; + ret[1][0] = bbox.max()[0]; + ret[1][1] = bbox.max()[1]; + ret[1][2] = bbox.max()[2]; return ret; } -const torch::Tensor GridBatch::total_bbox() const { - const nanovdb::CoordBBox& bbox = impl()->totalBBox(); - return torch::tensor({{bbox.min()[0], bbox.min()[1], bbox.min()[2]}, - {bbox.max()[0], bbox.max()[1], bbox.max()[2]}}, - torch::TensorOptions().device(device()).dtype(torch::kInt32)); +const torch::Tensor +GridBatch::total_bbox() const { + const nanovdb::CoordBBox &bbox = impl()->totalBBox(); + return torch::tensor({ { bbox.min()[0], bbox.min()[1], bbox.min()[2] }, + { bbox.max()[0], bbox.max()[1], bbox.max()[2] } }, + torch::TensorOptions().device(device()).dtype(torch::kInt32)); } - -std::vector GridBatch::viz_edge_network(bool returnVoxelCoordinates) const { +std::vector +GridBatch::viz_edge_network(bool returnVoxelCoordinates) const { return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return fvdb::detail::ops::dispatchGridEdgeNetwork(*impl(), returnVoxelCoordinates); + return fvdb::detail::ops::dispatchGridEdgeNetwork(*impl(), + returnVoxelCoordinates); }); } diff --git a/fvdb/src/GridBatch.h b/fvdb/src/GridBatch.h index 5743cabd4e..796f165192 100644 --- a/fvdb/src/GridBatch.h +++ b/fvdb/src/GridBatch.h @@ -1,25 +1,24 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once -#include -#include - -#include -#include +#ifndef FVDB_GRIDBATCH_H +#define FVDB_GRIDBATCH_H -#include "detail/utils/Utils.h" #include "detail/GridBatchImpl.h" +#include "detail/utils/Utils.h" #include "JaggedTensor.h" #include "Types.h" -namespace fvdb { +#include +#include +#include +#include +namespace fvdb { struct GridBatch : torch::CustomClassHolder { - // Set some speed limits so you don't shoot yourself in the foot constexpr static int64_t MAX_GRIDS_PER_BATCH = 1024; @@ -30,63 +29,74 @@ struct GridBatch : torch::CustomClassHolder { /// @brief Return true if this is a contiguous view of the grid batch /// @return true if this is a contiguous view of the grid batch - bool is_contiguous() const { + bool + is_contiguous() const { return impl()->isContiguous(); } /// @brief Return a contiguous copy of this grid batch. If the grid batch is already contiguous, /// then return a reference to this /// @return A contiguous copy of this grid batch - GridBatch contiguous() const { + GridBatch + contiguous() const { return GridBatch(detail::GridBatchImpl::contiguous(impl())); } - /// @brief Get the voxel size of the bi^th grid in the batch and return is a tensor of type dtype + /// @brief Get the voxel size of the bi^th grid in the batch and return is a tensor of type + /// dtype /// @param bi The batch index of the grid for which to get the voxel size /// @param dtype The dtype of the returned tensor /// @return A tensor of shape [3,] containing the voxel size of the bi^th grid in the batch - torch::Tensor voxel_size_at(int64_t bi, const torch::Dtype& dtype = torch::kFloat32) const; + torch::Tensor voxel_size_at(int64_t bi, const torch::Dtype &dtype = torch::kFloat32) const; - /// @brief Get the voxel origin of the bi^th grid in the batch and return is a tensor of type dtype + /// @brief Get the voxel origin of the bi^th grid in the batch and return is a tensor of type + /// dtype /// @param bi The batch index of the grid for which to get the voxel origin /// @param dtype The dtype of the returned tensor /// @return A tensor of shape [3,] containing the voxel origin of the bi^th grid in the batch - torch::Tensor origin_at(int64_t bi, const torch::Dtype& dtype = torch::kFloat32) const; + torch::Tensor origin_at(int64_t bi, const torch::Dtype &dtype = torch::kFloat32) const; /// @brief Get the voxel size of all grids in this batch and return is a tensor of type dtype /// @param dtype The dtype of the returned tensor - /// @return A tensor of shape [grid_count(), 3] containing the voxel size of all grids indexed by this batch - torch::Tensor voxel_sizes(const torch::Dtype& dtype = torch::kFloat32) const; + /// @return A tensor of shape [grid_count(), 3] containing the voxel size of all grids indexed + /// by this batch + torch::Tensor voxel_sizes(const torch::Dtype &dtype = torch::kFloat32) const; /// @brief Get the voxel origins of all grids in this batch and return is a tensor of type dtype /// @param dtype The dtype of the returned tensor - /// @return A tensor of shape [grid_count(), 3] containing the voxel origins of all grids indexed by this batch - torch::Tensor origins(const torch::Dtype& dtype = torch::kFloat32) const; + /// @return A tensor of shape [grid_count(), 3] containing the voxel origins of all grids + /// indexed by this batch + torch::Tensor origins(const torch::Dtype &dtype = torch::kFloat32) const; /// @brief Get the number of grids indexed by this batch /// @return The number of grids indexed by this batch - int64_t grid_count() const { - TORCH_CHECK(impl()->batchSize() <= MAX_GRIDS_PER_BATCH, "Cannot have more than ", MAX_GRIDS_PER_BATCH, " grids in a batch"); + int64_t + grid_count() const { + TORCH_CHECK(impl()->batchSize() <= MAX_GRIDS_PER_BATCH, "Cannot have more than ", + MAX_GRIDS_PER_BATCH, " grids in a batch"); return impl()->batchSize(); } /// @brief The total number of enabled voxels indexed by this batch of grids /// For immutable grids, this returns the same value as total_voxels() /// @return The total number of enabled voxels indexed by this batch of grids - int64_t total_enabled_voxels() const { + int64_t + total_enabled_voxels() const { return impl()->totalEnabledVoxels(false); } /// @brief Get the total number of voxels indexed by this batch of grids /// @return The total number of voxels indexed by this batch of grids - int64_t total_voxels() const { + int64_t + total_voxels() const { return impl()->totalVoxels(); } /// @brief Get the number of voxels indexed by the bi^th grid in the batch /// @param bi The batch index of the grid for which to get the number of voxels /// @return The number of voxels indexed by the bi^th grid in the batch - int64_t num_voxels_at(int64_t bi) const { + int64_t + num_voxels_at(int64_t bi) const { return impl()->numVoxels(bi); } @@ -99,7 +109,8 @@ struct GridBatch : torch::CustomClassHolder { /// @brief Get the cumulative number of voxels indexed by the first bi+1 grids /// @param bi The batch index /// @return The cumulative number of voxels indexed by the first bi+1 grids - int64_t cum_voxels_at(int64_t bi) const { + int64_t + cum_voxels_at(int64_t bi) const { return impl()->cumVoxels(bi); } @@ -115,22 +126,26 @@ struct GridBatch : torch::CustomClassHolder { /// @brief Get the number of enabled voxels indexed by this batch of grids /// For immutable grids, this returns the same value as num_voxels() - /// @return An integer tensor containing the number of enabled voxels per grid indexed by this batch + /// @return An integer tensor containing the number of enabled voxels per grid indexed by this + /// batch torch::Tensor num_enabled_voxels() const; /// @brief Get the cumulative number of voxels indexed by the grids in this batch /// i.e. [nvox_0, nvox_0+nvox_1, nvox_0+nvox_1+nvox_2, ...] - /// @return An integer tensor containing the cumulative number of voxels indexed by the grids in this batch + /// @return An integer tensor containing the cumulative number of voxels indexed by the grids in + /// this batch torch::Tensor cum_voxels() const; /// @brief Get the cumulative number of voxels indexed by the grids in this batch /// i.e. [nvox_0, nvox_0+nvox_1, nvox_0+nvox_1+nvox_2, ...] - /// @return An integer tensor containing the cumulative number of voxels indexed by the grids in this batch + /// @return An integer tensor containing the cumulative number of voxels indexed by the grids in + /// this batch torch::Tensor cum_enabled_voxels() const; /// @brief Get the total number of bytes required to store all grids indexed by this batch /// @return The total number of bytes required to store all grids indexed by this batch - int64_t total_bytes() const { + int64_t + total_bytes() const { return impl()->totalBytes(); } @@ -140,7 +155,8 @@ struct GridBatch : torch::CustomClassHolder { /// @brief Get the total number of leaf nodes indexed by this batch of grids /// @return The total number of leaf nodes indexed by this batch of grids - int64_t total_leaf_nodes() const { + int64_t + total_leaf_nodes() const { return impl()->totalLeaves(); } @@ -148,64 +164,78 @@ struct GridBatch : torch::CustomClassHolder { /// @return An integer tensor containing the number of leaf nodes in each grid torch::Tensor num_leaf_nodes() const; - /// @brief Get the offsets of the voxels indexed by this batch of grid - /// @return A tensor of shape [batch_size, 2] where the [bi, 0]^th entry is the offset of the first voxel - /// and the [bi, 1]^th entry is the offset one past the last voxel indexed by the bi^th grid in the batch - torch::Tensor joffsets() const { + /// @return A tensor of shape [batch_size, 2] where the [bi, 0]^th entry is the offset of the + /// first voxel + /// and the [bi, 1]^th entry is the offset one past the last voxel indexed by the bi^th + /// grid in the batch + torch::Tensor + joffsets() const { return impl()->voxelOffsets(true); } /// @brief Get the list indices for theis batch of grids - /// @return A tensor of shape [total_grids, ldim] where the [i]^th entry is the list index of the i^th grid - torch::Tensor jlidx() const { + /// @return A tensor of shape [total_grids, ldim] where the [i]^th entry is the list index of + /// the i^th grid + torch::Tensor + jlidx() const { const torch::Tensor ret = impl()->jlidx(true); if (ret.numel() == 0) { - return torch::arange({grid_count()}, torch::TensorOptions().device(device()).dtype(torch::kInt64)); + return torch::arange({ grid_count() }, + torch::TensorOptions().device(device()).dtype(torch::kInt64)); } else { return ret; } } /// @brief Get the batch index for each voxel indexed by this batch of grids - /// @return An integer tensor of shape [total_voxels,] where the [i]^th entry is the batch index of the i^th voxel - torch::Tensor jidx() const { + /// @return An integer tensor of shape [total_voxels,] where the [i]^th entry is the batch index + /// of the i^th voxel + torch::Tensor + jidx() const { const torch::Tensor ret = impl()->jidx(true); if (grid_count() == 1 && ret.numel() == 0) { - return torch::zeros({total_voxels()}, torch::TensorOptions().device(device()).dtype(torch::kInt16)); + return torch::zeros({ total_voxels() }, + torch::TensorOptions().device(device()).dtype(torch::kInt16)); } else { return ret; } - } /// @brief Set the voxel size of all grids indexed by this batch to the specified value /// @param voxel_size A 3D (shape [3,]) tensor specifying the voxel size to set for each grid - inline void set_global_voxel_size(const Vec3dOrScalar& voxel_size) { + inline void + set_global_voxel_size(const Vec3dOrScalar &voxel_size) { impl()->setGlobalVoxelSize(voxel_size.value()); } /// @brief Set the voxel origin of all grids indexed by this batch to the specified value /// @param origin A 3D (shape [3,]) tensor specifying the voxel origin to set for each grid - inline void set_global_origin(const Vec3d& origin) { + inline void + set_global_origin(const Vec3d &origin) { impl()->setGlobalVoxelOrigin(origin.value()); } /// @brief Return true if this grid is mutable /// @return Whether the grid is mutable - inline bool is_mutable() const { + inline bool + is_mutable() const { return impl()->isMutable(); } /// @brief Get the device on which this grid is stored /// @return The device on which this grid is stored - inline c10::Device device() const { + inline c10::Device + device() const { return impl()->device(); } - /// @brief Get the primal transforms of the grids in this batch (i.e. world to primal grid coordinates) - /// @return A std::vector containing the primal transforms of the grids in this batch - inline const std::vector primal_transforms() const { + /// @brief Get the primal transforms of the grids in this batch (i.e. world to primal grid + /// coordinates) + /// @return A std::vector containing the primal transforms of the grids in + /// this batch + inline const std::vector + primal_transforms() const { std::vector transforms; transforms.reserve(grid_count()); for (int64_t bi = 0; bi < grid_count(); ++bi) { @@ -214,9 +244,12 @@ struct GridBatch : torch::CustomClassHolder { return transforms; } - /// @brief Get the dual transforms of the grids in this batch (i.e. world to dual grid coordinates) - /// @return A std::vector containing the dual transforms of the grids in this batch - inline const std::vector dual_transforms() const { + /// @brief Get the dual transforms of the grids in this batch (i.e. world to dual grid + /// coordinates) + /// @return A std::vector containing the dual transforms of the + /// grids in this batch + inline const std::vector + dual_transforms() const { std::vector transforms; transforms.reserve(grid_count()); for (int64_t bi = 0; bi < grid_count(); ++bi) { @@ -225,267 +258,330 @@ struct GridBatch : torch::CustomClassHolder { return transforms; } - /// @brief Get the primal transform of the bi^th grid in the batch (i.e. world to primal grid coordinates) + /// @brief Get the primal transform of the bi^th grid in the batch (i.e. world to primal grid + /// coordinates) /// @param bi The index of the grid in the batch for which to get the primal transform /// @return The primal transform of the bi^th grid in the batch - inline const fvdb::detail::VoxelCoordTransform primal_transform_at(int64_t bi) const { + inline const fvdb::detail::VoxelCoordTransform + primal_transform_at(int64_t bi) const { return impl()->primalTransform(bi); } - /// @brief Get the dual transform of the bi^th grid in the batch (i.e. world to dual grid coordinates) + /// @brief Get the dual transform of the bi^th grid in the batch (i.e. world to dual grid + /// coordinates) /// @param bi The index of the grid in the batch for which to get the dual transform /// @return The dual transform of the bi^th grid in the batch - inline const fvdb::detail::VoxelCoordTransform dual_transform_at(int64_t bi) const { + inline const fvdb::detail::VoxelCoordTransform + dual_transform_at(int64_t bi) const { return impl()->dualTransform(bi); } /// @brief Get the bounding box (in voxel coordinates) for each grid in the batch /// @return A tensor bboxes of shape [B, 2, 3] where - /// bboxes[bi] = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bi^th bounding box - /// such that bmin <= ijk < bmax for all voxels ijk in the bi^th grid + /// bboxes[bi] = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bi^th + /// bounding box such that bmin <= ijk < bmax for all voxels ijk in the bi^th grid const torch::Tensor bbox() const; /// @brief Get the bounding box (in voxel coordinates) of the bi^th grid in the batch /// @return A tensor, bbox, of shape [2, 3] where - /// bbox = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bi^th bounding box - /// such that bmin <= ijk < bmax for all voxels ijk in the bi^th grid + /// bbox = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bi^th bounding + /// box such that bmin <= ijk < bmax for all voxels ijk in the bi^th grid const torch::Tensor bbox_at(int64_t bi) const; /// @brief Get the bounding box (in voxel coordinates) for the dual of each grid in the batch /// @return A tensor bboxes of shape [B, 2, 3] where - /// bboxes[bi] = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bi^th bounding box - /// such that bmin <= ijk < bmax for all voxels ijk in the dual of the bi^th grid + /// bboxes[bi] = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bi^th + /// bounding box such that bmin <= ijk < bmax for all voxels ijk in the dual of the + /// bi^th grid const torch::Tensor dual_bbox() const; - /// @brief Get the bounding box (in voxel coordinates) of the dual of the bi^th grid in the batch + /// @brief Get the bounding box (in voxel coordinates) of the dual of the bi^th grid in the + /// batch /// @return A tensor, bbox, of shape [2, 3] where - /// bbox = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bi^th bounding box - /// such that bmin <= ijk < bmax for all voxels ijk in the dual of the bi^th grid + /// bbox = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bi^th bounding + /// box such that bmin <= ijk < bmax for all voxels ijk in the dual of the bi^th grid const torch::Tensor dual_bbox_at(int64_t bi) const; - /// @brief Get the bounding box (in voxel coordinates) which contains all the grids in this batch + /// @brief Get the bounding box (in voxel coordinates) which contains all the grids in this + /// batch /// @return A tensor, total_bbox, of shape [2, 3] where - /// total_bbox = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bounding box - /// such that bmin <= ijk < bmax for all voxels ijk in the batch + /// total_bbox = [[bmin_i, bmin_j, bmin_z=k], [bmax_i, bmax_j, bmax_k]] is the bounding + /// box such that bmin <= ijk < bmax for all voxels ijk in the batch const torch::Tensor total_bbox() const; /// @brief Downsample this batch of grids using maxpooling - /// @param pool_factor How much to pool by (i,e, (2,2,2) means take max over 2x2x2 from start of window) - /// @param data Data at each voxel in this grid to be downsampled (JaggedTensor of shape [B, -1, *]) + /// @param pool_factor How much to pool by (i,e, (2,2,2) means take max over 2x2x2 from start of + /// window) + /// @param data Data at each voxel in this grid to be downsampled (JaggedTensor of shape [B, -1, + /// *]) /// @param stride The stride to use when pooling - /// @param coarse_grid An optional coarse grid used to specify the output. This is mainly used for memory - /// efficiency so you can chache grids. If you don't pass it in, we'll just create it for you. - /// @return A pair (coarseData, coarseGrid) where coarseData is a JaggedTensor of shape [B, -1, *] of downsampled data + /// @param coarse_grid An optional coarse grid used to specify the output. This is mainly used + /// for memory + /// efficiency so you can chache grids. If you don't pass it in, we'll just + /// create it for you. + /// @return A pair (coarseData, coarseGrid) where coarseData is a JaggedTensor of shape [B, -1, + /// *] of downsampled data /// and coarseGrid is a GridBatch representing the downsampled grid batch - std::pair max_pool(Vec3iOrScalar pool_factor, - const JaggedTensor& data, - Vec3iOrScalar stride = 0, - torch::optional coarse_grid = torch::nullopt) const; + std::pair + max_pool(Vec3iOrScalar pool_factor, const JaggedTensor &data, Vec3iOrScalar stride = 0, + torch::optional coarse_grid = torch::nullopt) const; /// @brief Downsample this batch of grids using average pooling - /// @param pool_factor How much to pool by (i,e, (2, 2, 2) means take max over 2x2x2 from start of window) - /// @param data Data at each voxel in this grid to be downsampled (JaggedTensor of shape [B, -1, *]) + /// @param pool_factor How much to pool by (i,e, (2, 2, 2) means take max over 2x2x2 from start + /// of window) + /// @param data Data at each voxel in this grid to be downsampled (JaggedTensor of shape [B, -1, + /// *]) /// @param stride The stride to use when pooling - /// @param coarse_grid An optional coarse grid used to specify the output. This is mainly used for memory - /// efficiency so you can chache grids. If you don't pass it in, we'll just create it for you. - /// @return A pair (coarseData, coarseGrid) where coarseData is a JaggedTensor of shape [B, -1, *] of downsampled data + /// @param coarse_grid An optional coarse grid used to specify the output. This is mainly used + /// for memory + /// efficiency so you can chache grids. If you don't pass it in, we'll just + /// create it for you. + /// @return A pair (coarseData, coarseGrid) where coarseData is a JaggedTensor of shape [B, -1, + /// *] of downsampled data /// and coarseGrid is a GridBatch representing the downsampled grid batch - std::pair avg_pool(Vec3iOrScalar pool_factor, - const JaggedTensor& data, - Vec3iOrScalar stride = 0, - torch::optional coarse_grid = torch::nullopt) const; + std::pair + avg_pool(Vec3iOrScalar pool_factor, const JaggedTensor &data, Vec3iOrScalar stride = 0, + torch::optional coarse_grid = torch::nullopt) const; /// @brief Subdivide this batch of grids using nearest neighbor interpolation /// @param subdiv_factor How much to upsample by (i,e, (2,2,2) means upsample by 2x2x2) - /// @param data Data at each voxel in this grid to be upsampled (JaggedTensor of shape [B, -1, *]) + /// @param data Data at each voxel in this grid to be upsampled (JaggedTensor of shape [B, -1, + /// *]) /// @param mask An optional mask of shape [B, -1] specifying which coarse voxels to upsample - /// @param fine_grid An optional coarse grid used to specify the output. This is mainly used for memory - /// efficiency so you can chache grids. If you don't pass it in, we'll just create it for you. - /// @return A pair (fineData, fineGrid) where fineData is a JaggedTensor of shape [B, -1, *] of upsampled data and + /// @param fine_grid An optional coarse grid used to specify the output. This is mainly used for + /// memory + /// efficiency so you can chache grids. If you don't pass it in, we'll just + /// create it for you. + /// @return A pair (fineData, fineGrid) where fineData is a JaggedTensor of shape [B, -1, *] of + /// upsampled data and /// fineGrid is a GridBatch representing the upsampled grid batch - std::pair subdivide(Vec3iOrScalar subdiv_factor, - const JaggedTensor& data, - const torch::optional mask = torch::nullopt, - torch::optional fine_grid = torch::nullopt) const; + std::pair + subdivide(Vec3iOrScalar subdiv_factor, const JaggedTensor &data, + const torch::optional mask = torch::nullopt, + torch::optional fine_grid = torch::nullopt) const; /// @brief Read the values from a dense tensor of the voxels at the specified coordinates /// @param dense_data A dense tensor of shape [B, W, H, D, *] - /// @param dense_origins A tensor of shape [B, 3] or [3,] specifying the voxel coordinate(s) of the origin of the dense tensor i.e. [:, 0, 0, 0] - /// @return A JaggedTensor with shape [B, -1, *] containing the values at the specified coordinates - JaggedTensor read_from_dense(const torch::Tensor& dense_data, - const Vec3iBatch& dense_origins = torch::zeros(3, torch::kInt32)) const; + /// @param dense_origins A tensor of shape [B, 3] or [3,] specifying the voxel coordinate(s) of + /// the origin of the dense tensor i.e. [:, 0, 0, 0] + /// @return A JaggedTensor with shape [B, -1, *] containing the values at the specified + /// coordinates + JaggedTensor + read_from_dense(const torch::Tensor &dense_data, + const Vec3iBatch &dense_origins = torch::zeros(3, torch::kInt32)) const; /// @brief Read the values from a JaggedTensor indexed by this batch into a dense tensor - /// @param sparse_data A JaggedTensor of shape [B, -1, *] containing one value per voxel in the batch - /// @param min_coord An optional minimum coordinate to read from the batch (in voxel coordinates). + /// @param sparse_data A JaggedTensor of shape [B, -1, *] containing one value per voxel in the + /// batch + /// @param min_coord An optional minimum coordinate to read from the batch (in voxel + /// coordinates). /// Defaults to the minimum coordinate of the batch. /// @param grid_size An optional grid size to read from the batch (in voxel coordinates). /// Defaults to the total size of a grid containing the whole batch. - /// @return A dense tensor of shape [B, W, H, D, *] containing the values at the specified coordinates (and zero elsewhere) - torch::Tensor read_into_dense(const JaggedTensor& sparse_data, - const torch::optional& min_coord = torch::nullopt, - const torch::optional& grid_size = torch::nullopt) const; + /// @return A dense tensor of shape [B, W, H, D, *] containing the values at the specified + /// coordinates (and zero elsewhere) + torch::Tensor read_into_dense(const JaggedTensor &sparse_data, + const torch::optional &min_coord = torch::nullopt, + const torch::optional &grid_size = torch::nullopt) const; /// @brief Given a GridBatch and features associated with it, /// return a JaggedTensor representing features for this batch of grid. /// Fill any voxels not in the GridBatch with the default value. - /// @param features A JaggedTensor of shape [B, -1, *] containing features associated with other_grid. + /// @param features A JaggedTensor of shape [B, -1, *] containing features associated with + /// other_grid. /// @param other_grid A GridBatch representing the grid to fill from. /// @param default_value The value to fill in for voxels not in other_grid. - JaggedTensor fill_to_grid(const JaggedTensor& features, - const GridBatch& other_grid, + JaggedTensor fill_to_grid(const JaggedTensor &features, const GridBatch &other_grid, float default_value = 0.0f) const; /// @brief Convert grid coordinates to world coordinates - /// @param ijk A JaggedTensor of grid coordinates with shape [B, -1, 3] (one point set per grid in the batch) - /// @return A JaggedTensor of world coordinates with shape [B, -1, 3] (one point set per grid in the batch) - JaggedTensor grid_to_world(const JaggedTensor& ijk) const; + /// @param ijk A JaggedTensor of grid coordinates with shape [B, -1, 3] (one point set per grid + /// in the batch) + /// @return A JaggedTensor of world coordinates with shape [B, -1, 3] (one point set per grid in + /// the batch) + JaggedTensor grid_to_world(const JaggedTensor &ijk) const; /// @brief Convert world coordinates to grid coordinates - /// @param xyz A JaggedTensor of world coordinates with shape [B, -1, 3] (one point set per grid in the batch) - /// @return A JaggedTensor of grid coordinates with shape [B, -1, 3] (one point set per grid in the batch) - JaggedTensor world_to_grid(const JaggedTensor& xyz) const; + /// @param xyz A JaggedTensor of world coordinates with shape [B, -1, 3] (one point set per grid + /// in the batch) + /// @return A JaggedTensor of grid coordinates with shape [B, -1, 3] (one point set per grid in + /// the batch) + JaggedTensor world_to_grid(const JaggedTensor &xyz) const; /// @brief Get grid-to-world matrices /// @return A JaggedTensor of grid-to-world matrices with shape [B, 4, 4] - torch::Tensor grid_to_world_matrices(const torch::Dtype& dtype = torch::kFloat32) const; + torch::Tensor grid_to_world_matrices(const torch::Dtype &dtype = torch::kFloat32) const; /// @brief Get world-to-grid matrices /// @return A JaggedTensor of world-to-grid matrices with shape [B, 4, 4] - torch::Tensor world_to_grid_matrices(const torch::Dtype& dtype = torch::kFloat32) const; + torch::Tensor world_to_grid_matrices(const torch::Dtype &dtype = torch::kFloat32) const; /// @brief Sample features on the grid batch using trilinear interpolation - /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) - /// @param voxel_data a JaggedTensor of C-dimensional features at each voxel with shape [B, -1, C] or a Tensor of + /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the + /// batch) + /// @param voxel_data a JaggedTensor of C-dimensional features at each voxel with shape [B, -1, + /// C] or a Tensor of /// shape [N, C] where N is the total number of voxels in the batch /// (one item for each voxel in each grid in the batch) /// @return a JaggedTensor of sampled data with shape [B, -1, C] (one sample set per point) - JaggedTensor sample_trilinear(const JaggedTensor& points, - const JaggedTensor& voxel_data) const; + JaggedTensor sample_trilinear(const JaggedTensor &points, const JaggedTensor &voxel_data) const; /// @brief Sample features and spatial gradients on the grid batch using trilinear interpolation - /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) - /// @param voxel_data a JaggedTensor of C-dimensional features at each voxel with shape [B, -1, C] or a Tensor of + /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the + /// batch) + /// @param voxel_data a JaggedTensor of C-dimensional features at each voxel with shape [B, -1, + /// C] or a Tensor of /// shape [N, C] where N is the total number of voxels in the batch /// (one item for each voxel in each grid in the batch) - /// @return a pair (feat, grad_feat) which are JaggedTensors of sampled data with shape [B, -1, C], and [B, -1, C, 3] - /// respectively where feat are the sampled features and grad_feat are the spatial gradients of the sampled - /// features (one sample set per point) - std::vector sample_trilinear_with_grad(const JaggedTensor& points, - const JaggedTensor& voxel_data) const; + /// @return a pair (feat, grad_feat) which are JaggedTensors of sampled data with shape [B, -1, + /// C], and [B, -1, C, 3] + /// respectively where feat are the sampled features and grad_feat are the spatial + /// gradients of the sampled features (one sample set per point) + std::vector sample_trilinear_with_grad(const JaggedTensor &points, + const JaggedTensor &voxel_data) const; /// @brief Sample features on the grid batch using bezier interpolation - /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) - /// @param voxel_data a JaggedTensor of C-dimensional features at each voxel with shape [B, -1, C] or a Tensor of + /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the + /// batch) + /// @param voxel_data a JaggedTensor of C-dimensional features at each voxel with shape [B, -1, + /// C] or a Tensor of /// shape [N, C] where N is the total number of voxels in the batch /// (one item for each voxel in each grid in the batch) /// @return a JaggedTensor of sampled data with shape [B, -1, C] (one sample set per point) - JaggedTensor sample_bezier(const JaggedTensor& points, - const JaggedTensor& voxel_data) const; + JaggedTensor sample_bezier(const JaggedTensor &points, const JaggedTensor &voxel_data) const; /// @brief Sample features and spatial gradients on the grid batch using bezier interpolation - /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) - /// @param voxel_data a JaggedTensor of C-dimensional features at each voxel with shape [B, -1, C] or a Tensor of + /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the + /// batch) + /// @param voxel_data a JaggedTensor of C-dimensional features at each voxel with shape [B, -1, + /// C] or a Tensor of /// shape [N, C] where N is the total number of voxels in the batch /// (one item for each voxel in each grid in the batch) - /// @return a pair (feat, grad_feat) which are JaggedTensors of sampled data with shape [B, -1, C], and [B, -1, C, 3] - /// respectively where feat are the sampled features and grad_feat are the spatial gradients of the sampled - /// features (one sample set per point) - std::vector sample_bezier_with_grad(const JaggedTensor& points, - const JaggedTensor& voxel_data) const; + /// @return a pair (feat, grad_feat) which are JaggedTensors of sampled data with shape [B, -1, + /// C], and [B, -1, C, 3] + /// respectively where feat are the sampled features and grad_feat are the spatial + /// gradients of the sampled features (one sample set per point) + std::vector sample_bezier_with_grad(const JaggedTensor &points, + const JaggedTensor &voxel_data) const; /// @brief Splat features at points into a grid batch using trilinear interpolation - /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) - /// @param points_data a JaggedTensor of C-dimensional features at each point with shape [B, -1, C] - /// @return a JaggedTensor of C-dimensional features at each voxel in the batch with shape [B, -1, C] - JaggedTensor splat_trilinear(const JaggedTensor& points, - const JaggedTensor& points_data) const; + /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the + /// batch) + /// @param points_data a JaggedTensor of C-dimensional features at each point with shape [B, -1, + /// C] + /// @return a JaggedTensor of C-dimensional features at each voxel in the batch with shape [B, + /// -1, C] + JaggedTensor splat_trilinear(const JaggedTensor &points, const JaggedTensor &points_data) const; /// @brief Splat features at points into a grid using bezier interpolation - /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) - /// @param points_data a JaggedTensor of C-dimensional features at each point with shape [B, -1, C] - /// @return a JaggedTensor of C-dimensional features at each voxel in the batch with shape [B, -1, C] - JaggedTensor splat_bezier(const JaggedTensor& points, - const JaggedTensor& points_data) const; + /// @param points a JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the + /// batch) + /// @param points_data a JaggedTensor of C-dimensional features at each point with shape [B, -1, + /// C] + /// @return a JaggedTensor of C-dimensional features at each voxel in the batch with shape [B, + /// -1, C] + JaggedTensor splat_bezier(const JaggedTensor &points, const JaggedTensor &points_data) const; /// @brief Get the indices of neighbors in the N-ring of each voxel in the grid batch - /// (possibly bitshifting the coordinates which is useful when you use multiple grids to represent different - /// levels of a hierarchy and you want to query this grid with coordinates at a finer level) - /// @param ijk A JaggedTensor of voxel coordinates with shape [B, -1, 3] (one set of coordinates per grid in the batch) + /// (possibly bitshifting the coordinates which is useful when you use multiple grids to + /// represent different levels of a hierarchy and you want to query this grid with + /// coordinates at a finer level) + /// @param ijk A JaggedTensor of voxel coordinates with shape [B, -1, 3] (one set of coordinates + /// per grid in the batch) /// @param extent The size of a neighborhood to find indexes /// @param bitshift The number of bits to shift the coordinates by - /// @return A JaggedTensor of neighbor indexes with shape [B, -1, 2*extent+1, 2*extent+1, 2*extent+1] (-1 value indicates no neighbor at that index) - JaggedTensor neighbor_indexes(const JaggedTensor& ijk, - int32_t extent, + /// @return A JaggedTensor of neighbor indexes with shape [B, -1, 2*extent+1, 2*extent+1, + /// 2*extent+1] (-1 value indicates no neighbor at that index) + JaggedTensor neighbor_indexes(const JaggedTensor &ijk, int32_t extent, int32_t bitshift = 0) const; /// @brief Return whether each point lies inside the grid batch - /// @param xyz A JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) - /// @param ignore_disabled Whether to ignore voxels that have been disabled (only applicable to mutable grids) + /// @param xyz A JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the + /// batch) + /// @param ignore_disabled Whether to ignore voxels that have been disabled (only applicable to + /// mutable grids) /// @return A JaggedTensor of booleans with shape [B, -1] (one boolean per point) - /// where the [bi, i]^th entry is true if points[bi, i] lies inside the bi^th grid in the batch - JaggedTensor points_in_active_voxel(const JaggedTensor& xyz, - bool ignore_disabled = false) const; + /// where the [bi, i]^th entry is true if points[bi, i] lies inside the bi^th grid in + /// the batch + JaggedTensor points_in_active_voxel(const JaggedTensor &xyz, + bool ignore_disabled = false) const; - /// @brief Return whether the cube with corners at cube_min and cube_max centered at each point in world space + /// @brief Return whether the cube with corners at cube_min and cube_max centered at each point + /// in world space /// intersect the grid batch - /// @param cube_centers A JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) + /// @param cube_centers A JaggedTensor of points with shape [B, -1, 3] (one point set per grid + /// in the batch) /// @param cube_min A 3D tensor specifying the min corner relative to each point to check /// @param cube_max A 3D tensor specifying the max corner relative to each point to check - /// @param ignore_disabled Whether to ignore voxels that have been disabled (only applicable to mutable grids) + /// @param ignore_disabled Whether to ignore voxels that have been disabled (only applicable to + /// mutable grids) /// @return A JaggedTensor of booleans with shape [B, -1] (one boolean per point) - /// where the [bi, i]^th entry is true if the cube with extent (min, max) + points[bi, i] intersects - /// the bi^th grid in the batch - JaggedTensor cubes_intersect_grid(const JaggedTensor& cube_centers, - const Vec3dOrScalar& cube_min = 0.0, - const Vec3dOrScalar& cube_max = 0.0, - bool ignore_disabled = false) const; - - /// @brief Return whether the cube with corners at cube_min and cube_max centered at each point in world space + /// where the [bi, i]^th entry is true if the cube with extent (min, max) + points[bi, + /// i] intersects the bi^th grid in the batch + JaggedTensor cubes_intersect_grid(const JaggedTensor &cube_centers, + const Vec3dOrScalar &cube_min = 0.0, + const Vec3dOrScalar &cube_max = 0.0, + bool ignore_disabled = false) const; + + /// @brief Return whether the cube with corners at cube_min and cube_max centered at each point + /// in world space /// is fully contained in the grid batch's stencil - /// @param cube_centers A JaggedTensor of points with shape [B, -1, 3] (one point set per grid in the batch) + /// @param cube_centers A JaggedTensor of points with shape [B, -1, 3] (one point set per grid + /// in the batch) /// @param cube_min A 3D tensor specifying the min corner relative to each point to check /// @param cube_max A 3D tensor specifying the max corner relative to each point to check - /// @param ignore_disabled Whether to ignore voxels that have been disabled (only applicable to mutable grids) + /// @param ignore_disabled Whether to ignore voxels that have been disabled (only applicable to + /// mutable grids) /// @return A JaggedTensor of booleans with shape [B, -1] (one boolean per point) - /// where the [bi, i]^th entry is true if the cube with extent (min, max) + points[bi, i] lies - /// inside the bi^th grid in the batch - JaggedTensor cubes_in_grid(const JaggedTensor& cube_centers, - const Vec3dOrScalar& cube_min = 0.0, - const Vec3dOrScalar& cube_max = 0.0, - bool ignore_disabled = false) const; + /// where the [bi, i]^th entry is true if the cube with extent (min, max) + points[bi, + /// i] lies inside the bi^th grid in the batch + JaggedTensor cubes_in_grid(const JaggedTensor &cube_centers, + const Vec3dOrScalar &cube_min = 0.0, + const Vec3dOrScalar &cube_max = 0.0, + bool ignore_disabled = false) const; /// @brief Return a boolean mask indicating whether each voxel in the grid is enabled or not - /// @return A boolean JaggedTensor of shape [B, -1] indicating whether each voxel in the grid is enabled or not + /// @return A boolean JaggedTensor of shape [B, -1] indicating whether each voxel in the grid is + /// enabled or not JaggedTensor enabled_mask() const; /// @brief Return a boolean mask indicating whether each voxel in the grid is disabled or not - /// @return A boolean JaggedTensor of shape [B, -1] indicating whether each voxel in the grid is disabled or not + /// @return A boolean JaggedTensor of shape [B, -1] indicating whether each voxel in the grid is + /// disabled or not JaggedTensor disabled_mask() const; /// @brief Return whether each coordinate is in the grid batch or not /// @param ijk A JaggedTensor of ijk coordinates with lshape [N_0, ..., N_B] and eshape (3,) /// (one coordinate set per grid in the batch) - /// @param ignore_disabled Whether to ignore voxels that have been disabled (only applicable to mutable grids) + /// @param ignore_disabled Whether to ignore voxels that have been disabled (only applicable to + /// mutable grids) /// @return A JaggedTensor of booleans with shape [B, -1] (one boolean per coordinate) - /// where the [bi, i]^th entry is true if coords[bi, i] lies inside the bi^th grid in the batch - JaggedTensor coords_in_active_voxel(const JaggedTensor& ijk, bool ignore_disabled = false) const; + /// where the [bi, i]^th entry is true if coords[bi, i] lies inside the bi^th grid in + /// the batch + JaggedTensor coords_in_active_voxel(const JaggedTensor &ijk, + bool ignore_disabled = false) const; /// @brief Return the integer offset of each ijk value in the grid batch - /// @param ijk A JaggedTensor of ijk coordinates with shape [B, -1, 3] (one coordinate set per grid in the batch) - /// @param cumulative Whether to return cumulative offsets in the batch or offsets relative to each grid - /// @return A JaggedTensor of integer offsets with shape [B, -1] into the grid batch (one offset per coordinate) - JaggedTensor ijk_to_index(const JaggedTensor& ijk, bool cumulative = false) const; - - /// @brief Return a JaggedTensor of integers such that if it is used as a permutation of the input IJK coordinates, - /// it will re-order them to the indexing order of the grid batch. This effectively performs the inverse of - /// ijk_to_index if you pass in the ijk coordinates in the grid. + /// @param ijk A JaggedTensor of ijk coordinates with shape [B, -1, 3] (one coordinate set per + /// grid in the batch) + /// @param cumulative Whether to return cumulative offsets in the batch or offsets relative to + /// each grid + /// @return A JaggedTensor of integer offsets with shape [B, -1] into the grid batch (one offset + /// per coordinate) + JaggedTensor ijk_to_index(const JaggedTensor &ijk, bool cumulative = false) const; + + /// @brief Return a JaggedTensor of integers such that if it is used as a permutation of the + /// input IJK coordinates, + /// it will re-order them to the indexing order of the grid batch. This effectively + /// performs the inverse of ijk_to_index if you pass in the ijk coordinates in the grid. /// i.e. output[ijk_to_index(ijk[i])] = i /// @param ijk A JaggedTensor of ijk coordinates with lshape [N_0, ..., N_B] and eshape (3,) /// (one coordinate set per grid in the batch) - /// @param cumulative Whether to return cumulative offsets in the batch or offsets relative to each grid - /// @return A JaggedTensor of integers with shape [B, -1] (one integer per grids' ijk) which inverts ijkToIndex - JaggedTensor ijk_to_inv_index(const JaggedTensor& ijk, bool cumulative = false) const; + /// @param cumulative Whether to return cumulative offsets in the batch or offsets relative to + /// each grid + /// @return A JaggedTensor of integers with shape [B, -1] (one integer per grids' ijk) which + /// inverts ijkToIndex + JaggedTensor ijk_to_inv_index(const JaggedTensor &ijk, bool cumulative = false) const; /// @brief Return the set of active ijk coordinates indexed by this grid batch /// @return A JaggedTensor of voxel coordinates indexed by this grid batch (shape [B, -1, 3]) @@ -496,35 +592,48 @@ struct GridBatch : torch::CustomClassHolder { /// @return A JaggedTensor of voxel coordinates indexed by this grid batch (shape [B, -1, 3]) JaggedTensor ijk_enabled() const; - /// @brief Find the intersection between a collection of rays and the zero level set of a scalar field + /// @brief Find the intersection between a collection of rays and the zero level set of a scalar + /// field /// at each voxel in the grid batch - /// @param ray_origins A JaggedTensor of ray origins with shape [B, -1, 3] (one ray set per grid in the batch) - /// @param ray_directions A JaggedTensor of ray directions with shape [B, -1, 3] (one ray set per grid in the batch) - /// @param grid_scalars A JaggedTensor of scalar values with shape [B, -1] (one scalar per voxel in the batch) + /// @param ray_origins A JaggedTensor of ray origins with shape [B, -1, 3] (one ray set per grid + /// in the batch) + /// @param ray_directions A JaggedTensor of ray directions with shape [B, -1, 3] (one ray set + /// per grid in the batch) + /// @param grid_scalars A JaggedTensor of scalar values with shape [B, -1] (one scalar per voxel + /// in the batch) /// @param eps Skip voxels where the ray intersects by less than this distance /// @return A JaggedTensor of intersection times with shape [B, -1] (one time per ray) - JaggedTensor ray_implicit_intersection(const JaggedTensor& ray_origins, - const JaggedTensor& ray_directions, - const JaggedTensor& grid_scalars, - double eps = 0.0) const; + JaggedTensor ray_implicit_intersection(const JaggedTensor &ray_origins, + const JaggedTensor &ray_directions, + const JaggedTensor &grid_scalars, + double eps = 0.0) const; - /// @brief Enumerate the voxels in this grid batch (in-sorted order) intersected by a collection of rays + /// @brief Enumerate the voxels in this grid batch (in-sorted order) intersected by a collection + /// of rays /// @param ray_origins A JaggedTensor of ray origins with lshape [N_0, ..., N_B] and eshape [3,] /// where N_i is the number of rays to intersect with the i^th grid - /// @param ray_directions A JaggedTensor of ray directions with lshape [N_0, ..., N_B] and eshape [3,] + /// @param ray_directions A JaggedTensor of ray directions with lshape [N_0, ..., N_B] and + /// eshape [3,] /// where N_i is the number of rays to intersect with the i^th grid /// @param max_voxels The maximum number of voxels to return per ray /// @param eps Skip voxels where the ray intersects by less than this distance - /// @param return_ijk Whether to return the voxel coordinates in the grid or world coordinates or the voxel index - /// @param cumulative Whether to return cumulative indices in the batch or indices relative to each grid + /// @param return_ijk Whether to return the voxel coordinates in the grid or world coordinates + /// or the voxel index + /// @param cumulative Whether to return cumulative indices in the batch or indices relative to + /// each grid /// (only applicable to return_ijk = false, otherwise ignored) - /// @return A pair of JaggedTensors containing the voxels (or voxel indices) intersected by the rays. i.e.: - /// - voxels: A JaggedTensor with lshape [[V_{0,0}, ..., V_{0,N_0}], ..., [V_{B,0}, ..., V_{B,N_B}]] - /// and eshape (3,) or (,) containing the ijk coordinates or indices of the voxels - /// - times: A JaggedTensor with lshape [[T_{0,0}, ..., T_{0,N_0}], ..., [T_{B,0}, ..., T_{B,N_B}]] - /// and eshape (2,) containg the entry and exit distance along the ray of each voxel - std::vector voxels_along_rays(const JaggedTensor& ray_origins, - const JaggedTensor& ray_directions, + /// @return A pair of JaggedTensors containing the voxels (or voxel indices) intersected by the + /// rays. i.e.: + /// - voxels: A JaggedTensor with lshape [[V_{0,0}, ..., V_{0,N_0}], ..., [V_{B,0}, + /// ..., V_{B,N_B}]] + /// and eshape (3,) or (,) containing the ijk coordinates or indices of + /// the voxels + /// - times: A JaggedTensor with lshape [[T_{0,0}, ..., T_{0,N_0}], ..., [T_{B,0}, + /// ..., T_{B,N_B}]] + /// and eshape (2,) containg the entry and exit distance along the ray of + /// each voxel + std::vector voxels_along_rays(const JaggedTensor &ray_origins, + const JaggedTensor &ray_directions, int64_t max_voxels, double eps = 0.0, bool return_ijk = true, bool cumulative = false) const; @@ -533,90 +642,99 @@ struct GridBatch : torch::CustomClassHolder { /// grid batch (in-sorted order) intersected by a collection of rays /// @param ray_origins A JaggedTensor of ray origins with lshape [N_0, ..., N_B] and eshape [3,] /// where N_i is the number of rays to intersect with the i^th grid - /// @param ray_directions A JaggedTensor of ray directions with lshape [N_0, ..., N_B] and eshape [3,] + /// @param ray_directions A JaggedTensor of ray directions with lshape [N_0, ..., N_B] and + /// eshape [3,] /// where N_i is the number of rays to intersect with the i^th grid /// @param max_segments The maximum number of segments to return per ray /// @param eps Skip segments whose length is less than this distance /// @param ignore_masked If set to true, will treat masked voxels as active /// @return A JaggedTensor containing the segments intersected by the rays. i.e. a JaggedTensor /// with lshape [[S_{0,0}, ..., S_{0,N_0}], ..., [S_{B,0}, ..., S_{B,N_B}]] - JaggedTensor segments_along_rays(const JaggedTensor& ray_origins, - const JaggedTensor& ray_directions, - int64_t max_segments, double eps = 0.0, bool ignore_masked = false) const; + JaggedTensor segments_along_rays(const JaggedTensor &ray_origins, + const JaggedTensor &ray_directions, int64_t max_segments, + double eps = 0.0, bool ignore_masked = false) const; /// @brief Generate a set of uniform samples in active regions along a specified set of rays /// @param ray_origins A JaggedTensor of ray origins with lshape [N_0, ..., N_B] and eshape [3,] /// where N_i is the number of rays to intersect with the i^th grid - /// @param ray_directions A JaggedTensor of ray directions with lshape [N_0, ..., N_B] and eshape [3,] + /// @param ray_directions A JaggedTensor of ray directions with lshape [N_0, ..., N_B] and + /// eshape [3,] /// where N_i is the number of rays to intersect with the i^th grid /// @param t_min The start distance along each ray to begin generating samples /// @param t_max The end distance along each ray to stop generating samples /// @param step_size The distance between samples along each ray /// @param cone_angle A cone angle for each ray used to space samples along the ray /// @param include_end_segments Whether to include the end segments of the rays in the samples - /// @param return_midpoints Whether to return the midpoint of each sample instead of the start and end + /// @param return_midpoints Whether to return the midpoint of each sample instead of the start + /// and end /// @param eps Skip segments whose length is less than this distance /// @return A JaggedTensor containing the samples along the rays. i.e. a JaggedTensor - /// with lshape [[S_{0,0}, ..., S_{0,N_0}], ..., [S_{B,0}, ..., S_{B,N_B}]] and eshape (2,) or (1,) - /// representing the start and end distance of each sample or the midpoint of each sample - /// if return_midpoints is true - JaggedTensor uniform_ray_samples(const JaggedTensor& ray_origins, - const JaggedTensor& ray_directions, - const JaggedTensor& t_min, - const JaggedTensor& t_max, - double step_size, - double cone_angle = 0.0, - bool include_end_segments = true, - bool return_midpoints = false, - double eps = 0.0) const; + /// with lshape [[S_{0,0}, ..., S_{0,N_0}], ..., [S_{B,0}, ..., S_{B,N_B}]] and eshape + /// (2,) or (1,) representing the start and end distance of each sample or the midpoint + /// of each sample if return_midpoints is true + JaggedTensor uniform_ray_samples(const JaggedTensor &ray_origins, + const JaggedTensor &ray_directions, const JaggedTensor &t_min, + const JaggedTensor &t_max, double step_size, + double cone_angle = 0.0, bool include_end_segments = true, + bool return_midpoints = false, double eps = 0.0) const; /// @brief Return an edge network used which can be used to plot the grids in this batch - /// @param return_voxel_coordinates Whether to return the vertices in voxel coordinates or world coordinates - /// @return A pair (verts, edges) where verts is a JaggedTensor of vertex positions with shape [B, -1, 3] - /// (one vertex set per grid in the batch) and edges is a JaggedTensor of edge indices of - /// shape [B, -1, 2] (one edge set per grid in the batch) + /// @param return_voxel_coordinates Whether to return the vertices in voxel coordinates or world + /// coordinates + /// @return A pair (verts, edges) where verts is a JaggedTensor of vertex positions with shape + /// [B, -1, 3] + /// (one vertex set per grid in the batch) and edges is a JaggedTensor of edge indices + /// of shape [B, -1, 2] (one edge set per grid in the batch) std::vector viz_edge_network(bool return_voxel_coordinates = false) const; - /// @brief Disable the specified voxels in the grid batch. If the input ijk values refer to non-indexed voxels, + /// @brief Disable the specified voxels in the grid batch. If the input ijk values refer to + /// non-indexed voxels, /// then these are simply ignored. - /// @param ijk A Jagged tensor of shape [B, -1, 3] of coordinates to disable(one set of coordinates per grid in the batch) + /// @param ijk A Jagged tensor of shape [B, -1, 3] of coordinates to disable(one set of + /// coordinates per grid in the batch) /// @note This is only applicable to mutable grids - void disable_ijk(const JaggedTensor& ijk); + void disable_ijk(const JaggedTensor &ijk); - /// @brief Enable the specified voxels in the grid batch. If the input ijk values refer to non-indexed voxels, + /// @brief Enable the specified voxels in the grid batch. If the input ijk values refer to + /// non-indexed voxels, /// then these are simply ignored. - /// @param ijk A Jagged tensor of shape [B, -1, 3] of coordinates to enable (one set of coordinates per grid in the batch) + /// @param ijk A Jagged tensor of shape [B, -1, 3] of coordinates to enable (one set of + /// coordinates per grid in the batch) /// @note This is only applicable to mutable grids - void enable_ijk(const JaggedTensor& ijk); - - /// @brief Return a batch of grids representing the dual of this batch. i.e. The centers of the dual grid correspond - /// to the corners of this grid batch. The [i, j, k] coordinate of the dual grid corresponds to the bottom/left/back - /// corner of the [i, j, k] voxel in this grid batch. - /// @param exclude_border Whether to exclude the border of the grid batch when computing the dual grid + void enable_ijk(const JaggedTensor &ijk); + + /// @brief Return a batch of grids representing the dual of this batch. i.e. The centers of the + /// dual grid correspond + /// to the corners of this grid batch. The [i, j, k] coordinate of the dual grid + /// corresponds to the bottom/left/back corner of the [i, j, k] voxel in this grid batch. + /// @param exclude_border Whether to exclude the border of the grid batch when computing the + /// dual grid /// @return A GridBatch representing the dual of this grid batch GridBatch dual_grid(bool exclude_border = false) const; /// @brief Return a batch of grids representing the coarsened version of this batch. - /// Each voxel [i, j, k] in this grid batch maps to voxel [i / branchFactor, j / branchFactor, k / branchFactor] - /// in the coarse batch. - /// @param coarsening_factor The factor by which to coarsen the grid batch (i.e (2, 2, 2) coarses by a factor of 2x2x2) + /// Each voxel [i, j, k] in this grid batch maps to voxel [i / branchFactor, j / + /// branchFactor, k / branchFactor] in the coarse batch. + /// @param coarsening_factor The factor by which to coarsen the grid batch (i.e (2, 2, 2) + /// coarses by a factor of 2x2x2) /// @return A GridBatch representing the coarsened version of this batch. GridBatch coarsened_grid(Vec3iOrScalar coarsening_factor) const; /// @brief Subdivide the grid batch into a finer grid batch. - /// Each voxel [i, j, k] in this grid batch maps to voxels [i * subdivFactor, j * subdivFactor, k * subdivFactor] - /// in the fine batch. + /// Each voxel [i, j, k] in this grid batch maps to voxels [i * subdivFactor, j * + /// subdivFactor, k * subdivFactor] in the fine batch. /// @param subdiv_factor The factor by which to subdivide the grid batch - /// @param mask An optional JaggedTensor of shape [B, -1] of boolean values indicating which voxels to subdivide + /// @param mask An optional JaggedTensor of shape [B, -1] of boolean values indicating which + /// voxels to subdivide /// @return A GridBatch representing the subdivided version of this batch. - GridBatch subdivided_grid(Vec3iOrScalar subdiv_factor, + GridBatch subdivided_grid(Vec3iOrScalar subdiv_factor, const torch::optional mask = torch::nullopt) const; /// @brief Return a batch of grids representing the clipped version of this batch of grids. /// @param ijk_min Index space minimum bound of the clip region. /// @param ijk_max Index space maximum bound of the clip region. /// @return A GridBatch representing the clipped version of this batch of grids. - GridBatch clipped_grid(const Vec3iBatch& ijk_min, const Vec3iBatch& ijk_max) const; + GridBatch clipped_grid(const Vec3iBatch &ijk_min, const Vec3iBatch &ijk_max) const; /// @brief Generate the grid that is affected by the convolution operator. /// @param kernel_size The kernel size of convolution @@ -624,31 +742,41 @@ struct GridBatch : torch::CustomClassHolder { /// @return A GridBatch representing the convolved grid. GridBatch conv_grid(Vec3iOrScalar kernel_size, Vec3iOrScalar stride) const; - /// @brief Return a batch of grids representing the clipped version of this batch of grids and corresponding features. - /// @param features A JaggedTensor of shape [B, -1, *] containing features associated with this batch of grids. + /// @brief Return a batch of grids representing the clipped version of this batch of grids and + /// corresponding features. + /// @param features A JaggedTensor of shape [B, -1, *] containing features associated with this + /// batch of grids. /// @param ijk_min Index space minimum bound of the clip region. /// @param ijk_max Index space maximum bound of the clip region. - /// @return A pair (clipped_features, clipped_grid) where clipped_features is a JaggedTensor of shape [B, -1, *] and + /// @return A pair (clipped_features, clipped_grid) where clipped_features is a JaggedTensor of + /// shape [B, -1, *] and /// clipped_grid is a GridBatch representing the clipped version of this batch of grids. - std::pair clip(const JaggedTensor& features, const Vec3iBatch& ijk_min, const Vec3iBatch& ijk_max) const; + std::pair clip(const JaggedTensor &features, const Vec3iBatch &ijk_min, + const Vec3iBatch &ijk_max) const; /// @brief Extract 0-isosurface from an implicit field. /// @param field implicit value stored on each voxel center (or voxel corner on a dual grid) /// @param level level set of the surface to extract /// @return vertices and faces arrays of the extracted isosurface - std::vector marching_cubes(const JaggedTensor& field, double level = 0.0) const; + std::vector marching_cubes(const JaggedTensor &field, double level = 0.0) const; - /// @brief Perform in-grid convolution using fast halo buffer method. Currently only supports kernel_size = 3. - /// @param features A JaggedTensor of shape [B, -1, *] containing features associated with this batch of grids. + /// @brief Perform in-grid convolution using fast halo buffer method. Currently only supports + /// kernel_size = 3. + /// @param features A JaggedTensor of shape [B, -1, *] containing features associated with this + /// batch of grids. /// @param kernel A tensor of shape [Out, In, 3, 3, 3] containing the kernel to convolve with. /// @return A JaggedTensor of shape [B, -1, *] containing the convolved features. - JaggedTensor sparse_conv_halo(const JaggedTensor& features, const torch::Tensor& kernel, int variant) const; + JaggedTensor sparse_conv_halo(const JaggedTensor &features, const torch::Tensor &kernel, + int variant) const; - /// @brief Return a grid batch on the specified device. If the passed in device is the same as this grid batch's - /// device, then this grid batch is returned. Otherwise, a copy of this grid batch is returned on the specified device. + /// @brief Return a grid batch on the specified device. If the passed in device is the same as + /// this grid batch's + /// device, then this grid batch is returned. Otherwise, a copy of this grid batch is + /// returned on the specified device. /// @param to_device The device to return the grid batch on /// @return A GridBatch representing this grid batch on the specified device - GridBatch to(TorchDeviceOrString to_device) const { + GridBatch + to(TorchDeviceOrString to_device) const { torch::Device toDevice = to_device.value(); if (toDevice == device()) { return GridBatch(impl()); @@ -657,34 +785,45 @@ struct GridBatch : torch::CustomClassHolder { } } - /// @brief Return a grid batch on the same device as the specified grid batch. If the passed in grid has the same device as this grid batch's - /// device, then this grid batch is returned. Otherwise, a copy of this grid batch is returned on the specified device. + /// @brief Return a grid batch on the same device as the specified grid batch. If the passed in + /// grid has the same device as this grid batch's + /// device, then this grid batch is returned. Otherwise, a copy of this grid batch is + /// returned on the specified device. /// @param to_grid The grid batch used to specify which device to return the grid batch on /// @return A GridBatch representing this grid batch on the specified device - GridBatch to(const GridBatch& to_grid) const { + GridBatch + to(const GridBatch &to_grid) const { return this->to(to_grid.device()); } - /// @brief Return a grid batch on the same device as the specified tensor. If the passed in tensor has the same device as this grid batch's - /// device, then this grid batch is returned. Otherwise, a copy of this grid batch is returned on the specified device. + /// @brief Return a grid batch on the same device as the specified tensor. If the passed in + /// tensor has the same device as this grid batch's + /// device, then this grid batch is returned. Otherwise, a copy of this grid batch is + /// returned on the specified device. /// @param to_tensor The tensor used to specify which device to return the grid batch on /// @return A GridBatch representing this grid batch on the specified device - GridBatch to(const torch::Tensor& to_tensor) const { + GridBatch + to(const torch::Tensor &to_tensor) const { return this->to(to_tensor.device()); } - /// @brief Return a grid batch on the same device as the specified JaggedTensor. If the passed in JaggedTensor has the same device as this grid batch's - /// device, then this grid batch is returned. Otherwise, a copy of this grid batch is returned on the specified device. + /// @brief Return a grid batch on the same device as the specified JaggedTensor. If the passed + /// in JaggedTensor has the same device as this grid batch's + /// device, then this grid batch is returned. Otherwise, a copy of this grid batch is + /// returned on the specified device. /// @param to_jtensor The JaggedTensor used to specify which device to return the grid batch on /// @return A GridBatch representing this grid batch on the specified device - GridBatch to(const JaggedTensor& to_jtensor) const { + GridBatch + to(const JaggedTensor &to_jtensor) const { return this->to(to_jtensor.device()); } - /// @brief Return a view of this grid batch containing the grid at the specified index i.e. grid_batch[bi] + /// @brief Return a view of this grid batch containing the grid at the specified index i.e. + /// grid_batch[bi] /// @param bi The index to get a view on /// @return A GridBatch representing the grid at the specified index - GridBatch index(int64_t bi) const { + GridBatch + index(int64_t bi) const { return GridBatch(impl()->index(bi)); } @@ -693,146 +832,182 @@ struct GridBatch : torch::CustomClassHolder { /// @param stop The stop index of the slice /// @param step The step of the slice /// @return A GridBatch representing the slice of this grid batch - GridBatch index(size_t start, size_t stop, size_t step) const { + GridBatch + index(size_t start, size_t stop, size_t step) const { return GridBatch(impl()->index(start, stop, step)); } - /// @brief Return a view of this grid batch at the specified indices i.e. grid_batch[[i1, i2, ...]] + /// @brief Return a view of this grid batch at the specified indices i.e. grid_batch[[i1, i2, + /// ...]] /// @param bi A list of integers representing the indices to get a view on /// @return The grid batch vieweed at the specified indices - GridBatch index(const std::vector& bi) const { + GridBatch + index(const std::vector &bi) const { return GridBatch(impl()->index(bi)); } - /// @brief Return a view of this grid batch at indices specified by the given mask i.e. grid_batch[mask] + /// @brief Return a view of this grid batch at indices specified by the given mask i.e. + /// grid_batch[mask] /// @param bi A list of integers representing the indices to get a view on /// @return The grid batch vieweed at the specified indices - GridBatch index(const std::vector& bi) const { + GridBatch + index(const std::vector &bi) const { return GridBatch(impl()->index(bi)); } - /// @brief Return a view of this grid batch at the specified indices (or mask if bi is a bool tensor) i.e. grid_batch[[i1, i2, ...]] + /// @brief Return a view of this grid batch at the specified indices (or mask if bi is a bool + /// tensor) i.e. grid_batch[[i1, i2, ...]] /// @param bi A list of integers representing the indices to get a view on /// @return The grid batch vieweed at the specified indices - GridBatch index(const torch::Tensor& bi) const { + GridBatch + index(const torch::Tensor &bi) const { return GridBatch(impl()->index(bi)); } /// @brief Return a JaggedTensor whose joffsets and jidx match this grid batch's - /// @param data The data to use for the JaggedTensor (first dimension must match the total number of voxels in the grid batch) - /// @param ignore_disabled If true, then voxels which are disabled will be included in the returned JaggedTensor + /// @param data The data to use for the JaggedTensor (first dimension must match the total + /// number of voxels in the grid batch) + /// @param ignore_disabled If true, then voxels which are disabled will be included in the + /// returned JaggedTensor /// @return A JaggedTensor corresponding to the voxel grid of this grid batch - JaggedTensor jagged_like(const torch::Tensor& data, bool ignore_disabled = true) const { + JaggedTensor + jagged_like(const torch::Tensor &data, bool ignore_disabled = true) const { return impl()->jaggedTensor(data, ignore_disabled); } /// @brief Populate the grid batch with voxels that intersect a triangle mesh - /// @param vertices A JaggedTensor of shape [B, -1, 3] containing one vertex set per grid to create + /// @param vertices A JaggedTensor of shape [B, -1, 3] containing one vertex set per grid to + /// create /// @param faces A JaggedTensor of shape [B, -1, 3] containing one face set per grid to create - /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids - /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel + /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid + /// in the batch or one voxel size for all grids + /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the + /// [0, 0, 0] voxel /// for each grid in the batch, or one origin for all grids - void set_from_mesh(const JaggedTensor& vertices, - const JaggedTensor& faces, - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros(3, torch::kInt32)); + void set_from_mesh(const JaggedTensor &vertices, const JaggedTensor &faces, + const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros(3, torch::kInt32)); - /// @brief Populate the grid batch with voxels which contain a point in an input set of point clouds + /// @brief Populate the grid batch with voxels which contain a point in an input set of point + /// clouds /// (possibly padding each voxel containing a point) - /// @param points A JaggedTensor with shape [B, -1, 3] containing one point set per grid to create - /// @param pad_min A tensor of shape [3,] containing the number of voxels to pad each inserted voxel with to the left/back/bottom - /// @param pad_max A tensor of shape [3,] containing the number of voxels to pad each inserted voxel with to the right/front/top - /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids - /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel + /// @param points A JaggedTensor with shape [B, -1, 3] containing one point set per grid to + /// create + /// @param pad_min A tensor of shape [3,] containing the number of voxels to pad each inserted + /// voxel with to the left/back/bottom + /// @param pad_max A tensor of shape [3,] containing the number of voxels to pad each inserted + /// voxel with to the right/front/top + /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid + /// in the batch or one voxel size for all grids + /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the + /// [0, 0, 0] voxel /// for each grid in the batch, or one origin for all grids /// @param isMutable Whether the grid should be mutable or not - void set_from_points(const JaggedTensor& points, - const Vec3i& pad_min = torch::zeros(3, torch::kInt32), - const Vec3i& pad_max = torch::zeros(3, torch::kInt32), - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros(3, torch::kInt32)); - - /// @brief Populate the grid batch with the eight nearest voxels to each point in an input set of point clouds - /// @param points A JaggedTensor with shape [B, -1, 3] containing one point set per grid to create - /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids - /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel + void set_from_points(const JaggedTensor &points, + const Vec3i &pad_min = torch::zeros(3, torch::kInt32), + const Vec3i &pad_max = torch::zeros(3, torch::kInt32), + const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros(3, torch::kInt32)); + + /// @brief Populate the grid batch with the eight nearest voxels to each point in an input set + /// of point clouds + /// @param points A JaggedTensor with shape [B, -1, 3] containing one point set per grid to + /// create + /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid + /// in the batch or one voxel size for all grids + /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the + /// [0, 0, 0] voxel /// for each grid in the batch, or one origin for all grids /// @param isMutable Whether the grid should be mutable or not - void set_from_nearest_voxels_to_points(const JaggedTensor& points, - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros(3, torch::kInt32)); - + void set_from_nearest_voxels_to_points(const JaggedTensor &points, + const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros(3, + torch::kInt32)); /// @brief Populate the grid batch with the specified voxel coordinates (possibly with padding) - /// @param ijk A JaggedTensor of shape [B, -1, 3] specifying the coordinates of each voxel to insert - /// @param pad_min A tensor of shape [3,] containing the number of voxels to pad each inserted voxel with to the left/back/bottom - /// @param pad_max A tensor of shape [3,] containing the number of voxels to pad each inserted voxel with to the right/front/top - /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids - /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel + /// @param ijk A JaggedTensor of shape [B, -1, 3] specifying the coordinates of each voxel to + /// insert + /// @param pad_min A tensor of shape [3,] containing the number of voxels to pad each inserted + /// voxel with to the left/back/bottom + /// @param pad_max A tensor of shape [3,] containing the number of voxels to pad each inserted + /// voxel with to the right/front/top + /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid + /// in the batch or one voxel size for all grids + /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the + /// [0, 0, 0] voxel /// for each grid in the batch, or one origin for all grids - void set_from_ijk(const JaggedTensor& ijk, - const Vec3i& pad_min = torch::zeros(3, torch::kInt32), - const Vec3i& pad_max = torch::zeros(3, torch::kInt32), - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros(3, torch::kInt32)); + void set_from_ijk(const JaggedTensor &ijk, + const Vec3i &pad_min = torch::zeros(3, torch::kInt32), + const Vec3i &pad_max = torch::zeros(3, torch::kInt32), + const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros(3, torch::kInt32)); /// @brief Populate the grid batch densely from ijk_min to ijk_min + size /// @param num_grids The number of grids to create in the batch /// @param dense_dims The size of each dense grid (shape [3,] = [W, H, D]) /// @param ijk_min The minimum ijk coordinate of each dense grid in the batch (shape [3,]) - /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid in the batch or one voxel size for all grids - /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the [0, 0, 0] voxel + /// @param voxel_sizes A tensor of shape [B, 3] or [3,] containing the voxel size of each grid + /// in the batch or one voxel size for all grids + /// @param origins A tensor of shape [B, 3] or [3,] containing the world space coordinate of the + /// [0, 0, 0] voxel /// for each grid in the batch, or one origin for all grids - /// @param mask Optional mask of shape [W, H, D] to specify voxels which are included in the dense grid. + /// @param mask Optional mask of shape [W, H, D] to specify voxels which are included in the + /// dense grid. /// Note that the same mask will be re-used for all the grids in the batch. - void set_from_dense_grid(const int64_t num_grids, - const Vec3i& dense_dims, - const Vec3i& ijk_min = torch::zeros(3, torch::kInt32), - const Vec3dBatchOrScalar& voxel_sizes = 1.0, - const Vec3dBatch& origins = torch::zeros(3), - torch::optional mask = torch::nullopt); + void set_from_dense_grid(const int64_t num_grids, const Vec3i &dense_dims, + const Vec3i &ijk_min = torch::zeros(3, torch::kInt32), + const Vec3dBatchOrScalar &voxel_sizes = 1.0, + const Vec3dBatch &origins = torch::zeros(3), + torch::optional mask = torch::nullopt); /// @brief Serialize this grid batch to a torch tensor of bytes (dtype = int8) /// @return A serialized grid batch encoded as a torch::Tensor of type int8 - torch::Tensor serialize() const { + torch::Tensor + serialize() const { return impl()->serialize(); } /// @brief Deserialize an int8 tensor (returned by serialize()) into a grid batch /// @param data A tensor enccoding a serialized grid batch as an int8 tensor /// @return The deserializes grid batch - static GridBatch deserialize(const torch::Tensor& data) { + static GridBatch + deserialize(const torch::Tensor &data) { return GridBatch(detail::GridBatchImpl::deserialize(data)); } /// @brief Return an integer representing the actual data /// @return the value - int64_t address() const { + int64_t + address() const { return reinterpret_cast(impl().get()); } /// @brief Get the underlying nanovdb::GridHandle for the grid batch /// @return The underlying nanovdb::GridHandle for the grid batch - const nanovdb::GridHandle& nanovdb_grid_handle() const { + const nanovdb::GridHandle & + nanovdb_grid_handle() const { return impl()->nanoGridHandle(); } - inline const c10::intrusive_ptr impl() const { + inline const c10::intrusive_ptr + impl() const { return mImpl; } -private: - - void buildCoarseFromFineGrid(const GridBatch& fineGrid, nanovdb::Coord branchFactor); + private: + void buildCoarseFromFineGrid(const GridBatch &fineGrid, nanovdb::Coord branchFactor); - void buildFineFromCoarseGrid(const GridBatch& coarseGrid, const torch::optional& subdivMask, nanovdb::Coord subdivFactor); + void buildFineFromCoarseGrid(const GridBatch &coarseGrid, + const torch::optional &subdivMask, + nanovdb::Coord subdivFactor); - void buildDualFromPrimalGrid(const GridBatch& primalGrid, bool excludeBorder = false); + void buildDualFromPrimalGrid(const GridBatch &primalGrid, bool excludeBorder = false); c10::intrusive_ptr mImpl; }; - // using GridBatchPtr = c10::intrusive_ptr; } // namespace fvdb + +#endif // FVDB_GRIDBATCH_H \ No newline at end of file diff --git a/fvdb/src/JaggedTensor.cpp b/fvdb/src/JaggedTensor.cpp index 5486aa10e4..9f1531c5a1 100644 --- a/fvdb/src/JaggedTensor.cpp +++ b/fvdb/src/JaggedTensor.cpp @@ -2,49 +2,61 @@ // SPDX-License-Identifier: MPL-2.0 // #include "JaggedTensor.h" + #include "Config.h" -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" #include "detail/autograd/JaggedReduce.h" +#include "detail/ops/Ops.h" #include "detail/ops/jagged/JaggedOps.h" +#include "detail/utils/Utils.h" namespace fvdb { -void JaggedTensor::binary_op_check(const JaggedTensor& other) const { - TORCH_CHECK(this->device() == other.device(), "device should match between this tensor and other tensor"); - TORCH_CHECK(mData.sizes().equals(other.jdata().sizes()), "data shape should match between this tensor and other tensor"); - TORCH_CHECK(mBatchIdx.sizes().equals(other.jidx().sizes()), "batch indices' shape should match between this tensor and other tensor"); - TORCH_CHECK(mOffsets.sizes().equals(other.joffsets().sizes()), "offsets shape should match between this tensor and other tensor"); +void +JaggedTensor::binary_op_check(const JaggedTensor &other) const { + TORCH_CHECK(this->device() == other.device(), + "device should match between this tensor and other tensor"); + TORCH_CHECK(mData.sizes().equals(other.jdata().sizes()), + "data shape should match between this tensor and other tensor"); + TORCH_CHECK(mBatchIdx.sizes().equals(other.jidx().sizes()), + "batch indices' shape should match between this tensor and other tensor"); + TORCH_CHECK(mOffsets.sizes().equals(other.joffsets().sizes()), + "offsets shape should match between this tensor and other tensor"); if (Config::global().pendanticErrorCheckingEnabled()) { // This is a slow check that we cap optionally do for correctness. - TORCH_CHECK_VALUE(torch::equal(mOffsets, other.joffsets()), "offsets shape should match between this tensor and other tensor"); - TORCH_CHECK_VALUE(torch::equal(other.mListIdx, mListIdx), - "JaggedTensors must have the same lshape. ", - "This error was raised because config.pendatic_error_checking was enabled"); + TORCH_CHECK_VALUE(torch::equal(mOffsets, other.joffsets()), + "offsets shape should match between this tensor and other tensor"); + TORCH_CHECK_VALUE( + torch::equal(other.mListIdx, mListIdx), "JaggedTensors must have the same lshape. ", + "This error was raised because config.pendatic_error_checking was enabled"); } } -torch::Tensor JaggedTensor::joffsets_from_jidx_and_jdata(torch::Tensor jidx, torch::Tensor jdata, int64_t num_tensors) { +torch::Tensor +JaggedTensor::joffsets_from_jidx_and_jdata(torch::Tensor jidx, torch::Tensor jdata, + int64_t num_tensors) { return FVDB_DISPATCH_KERNEL_DEVICE(jdata.device(), [&]() { return detail::ops::dispatchJOffsetsForJIdx(jidx, jdata, num_tensors); }); } -torch::Tensor JaggedTensor::jidx_from_joffsets(torch::Tensor joffsets, int64_t num_elements) { +torch::Tensor +JaggedTensor::jidx_from_joffsets(torch::Tensor joffsets, int64_t num_elements) { return FVDB_DISPATCH_KERNEL_DEVICE(joffsets.device(), [&]() { return detail::ops::dispatchJIdxForJOffsets(joffsets, num_elements); }); } JaggedTensor::JaggedTensor(torch::Tensor data) - : mData(data), mBatchIdx(torch::empty({0}, torch::TensorOptions().dtype(JIdxScalarType).device(data.device()))) { - mListIdx = torch::empty({0, 1}, torch::TensorOptions().dtype(JLIdxScalarType).device(data.device())); - mOffsets = joffsets_from_jidx_and_jdata(mBatchIdx, mData, 1); + : mData(data), mBatchIdx(torch::empty( + { 0 }, torch::TensorOptions().dtype(JIdxScalarType).device(data.device()))) { + mListIdx = + torch::empty({ 0, 1 }, torch::TensorOptions().dtype(JLIdxScalarType).device(data.device())); + mOffsets = joffsets_from_jidx_and_jdata(mBatchIdx, mData, 1); mNumOuterLists = 1; } -JaggedTensor::JaggedTensor(const std::vector& tensors) { +JaggedTensor::JaggedTensor(const std::vector &tensors) { // TODO: (Francis): rewrite as a cuda kernel TORCH_CHECK(tensors.size() > 0, "empty tensor list"); @@ -56,10 +68,15 @@ JaggedTensor::JaggedTensor(const std::vector& tensors) { if (tensors[0].dim() == 0) { mData = mData.unsqueeze(0); } - TORCH_CHECK(mData.dim() > 0, "assigned data must have shape [N, ...], but got data.dim() = 0"); - mBatchIdx = torch::empty({0}, torch::TensorOptions().dtype(JIdxScalarType).device(mData.device())); - mOffsets = torch::tensor({JOffsetsType(0), mData.size(0)}, torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); - mListIdx = torch::empty({0, 1}, torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())); + TORCH_CHECK(mData.dim() > 0, + "assigned data must have shape [N, ...], but got data.dim() = 0"); + mBatchIdx = torch::empty( + { 0 }, torch::TensorOptions().dtype(JIdxScalarType).device(mData.device())); + mOffsets = + torch::tensor({ JOffsetsType(0), mData.size(0) }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); + mListIdx = torch::empty( + { 0, 1 }, torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())); mNumOuterLists = 1; return; } @@ -67,45 +84,49 @@ JaggedTensor::JaggedTensor(const std::vector& tensors) { torch::Device device = tensors[0].device(); std::vector jIdxs; - mOffsets = torch::empty({(JOffsetsType) tensors.size() + 1}, torch::TensorOptions().dtype(JOffsetsScalarType).device(torch::kCPU)); + mOffsets = torch::empty({ (JOffsetsType)tensors.size() + 1 }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(torch::kCPU)); auto elementCountsAcc = mOffsets.accessor(); - elementCountsAcc[0] = 0; + elementCountsAcc[0] = 0; jIdxs.reserve(tensors.size()); - std::vector tensorsReshaped; // Reshape 0D tensors to 1D + std::vector tensorsReshaped; // Reshape 0D tensors to 1D tensorsReshaped.reserve(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { TORCH_CHECK_VALUE(tensors[i].device() == device, "All tensors must be on the same device"); if (tensors[i].dim() == 0 && tensors[i].numel() == 1) { - tensorsReshaped.push_back(tensors[i].view({1})); + tensorsReshaped.push_back(tensors[i].view({ 1 })); } else { tensorsReshaped.push_back(tensors[i]); } - jIdxs.push_back(torch::full({tensorsReshaped[i].size(0)}, (int) i, torch::TensorOptions().dtype(JIdxScalarType).device(tensorsReshaped[i].device()))); - elementCountsAcc[i+1] = tensorsReshaped[i].size(0); + jIdxs.push_back(torch::full( + { tensorsReshaped[i].size(0) }, (int)i, + torch::TensorOptions().dtype(JIdxScalarType).device(tensorsReshaped[i].device()))); + elementCountsAcc[i + 1] = tensorsReshaped[i].size(0); } mOffsets = mOffsets.to(tensors[0].device()); torch::cumsum_out(mOffsets, mOffsets, 0); mBatchIdx = torch::cat(jIdxs, 0); - mData = torch::cat(tensorsReshaped, 0); - mListIdx = torch::empty({0, 1}, torch::TensorOptions().dtype(JLIdxScalarType).device(device)); + mData = torch::cat(tensorsReshaped, 0); + mListIdx = torch::empty({ 0, 1 }, torch::TensorOptions().dtype(JLIdxScalarType).device(device)); mNumOuterLists = tensors.size(); } -JaggedTensor::JaggedTensor(const std::vector>& tensors) { +JaggedTensor::JaggedTensor(const std::vector> &tensors) { // TODO: (Francis): rewrite as a cuda kernel - torch::Device device = torch::kCPU; - bool deviceIsNotSet = true; - JOffsetsType totalTensors = 0; + torch::Device device = torch::kCPU; + bool deviceIsNotSet = true; + JOffsetsType totalTensors = 0; TORCH_CHECK(tensors.size() > 0, "empty tensor list"); for (size_t i = 0; i < tensors.size(); ++i) { for (size_t j = 0; j < tensors[i].size(); j += 1) { if (deviceIsNotSet) { - device = tensors[i][j].device(); + device = tensors[i][j].device(); deviceIsNotSet = false; } - TORCH_CHECK_VALUE(tensors[i][j].device() == device, "All tensors must be on the same device"); + TORCH_CHECK_VALUE(tensors[i][j].device() == device, + "All tensors must be on the same device"); totalTensors += 1; } } @@ -113,16 +134,23 @@ JaggedTensor::JaggedTensor(const std::vector>& tensor // This is an implementation detail where we don't store jidx for // a single list since everything is just zero by default. if (totalTensors == 1) { - TORCH_CHECK(tensors.size() == 1, "Single tensor must be a 1D tensor. This should never happen."); - TORCH_CHECK(tensors[0].size() == 1, "Single tensor must be a 1D tensor. This should never happen."); + TORCH_CHECK(tensors.size() == 1, + "Single tensor must be a 1D tensor. This should never happen."); + TORCH_CHECK(tensors[0].size() == 1, + "Single tensor must be a 1D tensor. This should never happen."); mData = tensors[0][0]; if (mData.dim() == 0) { mData = mData.unsqueeze(0); } - TORCH_CHECK(mData.dim() > 0, "assigned data must have shape [N, ...], but got data.dim() = 0"); - mBatchIdx = torch::empty({0}, torch::TensorOptions().dtype(JIdxScalarType).device(mData.device())); - mOffsets = torch::tensor({JOffsetsType(0), mData.size(0)}, torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); - mListIdx = torch::zeros({1, 2}, torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())); + TORCH_CHECK(mData.dim() > 0, + "assigned data must have shape [N, ...], but got data.dim() = 0"); + mBatchIdx = torch::empty( + { 0 }, torch::TensorOptions().dtype(JIdxScalarType).device(mData.device())); + mOffsets = + torch::tensor({ JOffsetsType(0), mData.size(0) }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); + mListIdx = torch::zeros( + { 1, 2 }, torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())); mNumOuterLists = 1; return; } @@ -131,14 +159,17 @@ JaggedTensor::JaggedTensor(const std::vector>& tensor std::vector batchIdxs; batchIdxs.reserve(totalTensors); - mOffsets = torch::empty({totalTensors + 1}, torch::TensorOptions().dtype(JOffsetsScalarType).device(torch::kCPU)); + mOffsets = torch::empty({ totalTensors + 1 }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(torch::kCPU)); auto elementCountsAcc = mOffsets.accessor(); - elementCountsAcc[0] = 0; + elementCountsAcc[0] = 0; - torch::Tensor listIndexes = torch::empty({totalTensors, (JLIdxType) 2}, torch::TensorOptions().dtype(JLIdxScalarType).device(torch::kCPU)); + torch::Tensor listIndexes = + torch::empty({ totalTensors, (JLIdxType)2 }, + torch::TensorOptions().dtype(JLIdxScalarType).device(torch::kCPU)); auto listIndexesAcc = listIndexes.accessor(); - std::vector tensorsReshaped; // Reshape 0D tensors to 1D + std::vector tensorsReshaped; // Reshape 0D tensors to 1D tensorsReshaped.reserve(totalTensors); int64_t tensorCount = 0; @@ -149,51 +180,59 @@ JaggedTensor::JaggedTensor(const std::vector>& tensor torch::Tensor tij = tensors[i][j]; if (tij.dim() == 0 && tij.numel() == 1) { - tensorsReshaped.push_back(tij.view({1})); + tensorsReshaped.push_back(tij.view({ 1 })); } else { tensorsReshaped.push_back(tij); } - batchIdxs.push_back(torch::full({tensorsReshaped[tensorCount].size(0)}, - tensorCount, - torch::TensorOptions().dtype(JIdxScalarType).device(device))); - elementCountsAcc[tensorCount+1] = tensorsReshaped[tensorCount].size(0); + batchIdxs.push_back( + torch::full({ tensorsReshaped[tensorCount].size(0) }, tensorCount, + torch::TensorOptions().dtype(JIdxScalarType).device(device))); + elementCountsAcc[tensorCount + 1] = tensorsReshaped[tensorCount].size(0); tensorCount += 1; - } } mOffsets = mOffsets.to(device); torch::cumsum_out(mOffsets, mOffsets, 0); - mBatchIdx = torch::cat(batchIdxs, 0); - mData = torch::cat(tensorsReshaped, 0); - mListIdx = listIndexes.to(device); + mBatchIdx = torch::cat(batchIdxs, 0); + mData = torch::cat(tensorsReshaped, 0); + mListIdx = listIndexes.to(device); mNumOuterLists = tensors.size(); } -JaggedTensor::JaggedTensor(const std::vector& lsizes, const torch::Tensor data) { +JaggedTensor::JaggedTensor(const std::vector &lsizes, const torch::Tensor data) { // TODO: (Francis): rewrite as a cuda kernel TORCH_CHECK_VALUE(lsizes.size() > 0, "empty list sizes"); // This is an implementation detail where we don't store jidx for // a single list since everything is just zero by default. if (lsizes.size() == 1) { - TORCH_CHECK_VALUE(lsizes[0] == data.size(0), "Sum of list sizes must equal the number of elements in data"); - mOffsets = torch::tensor({JOffsetsType(0), data.size(0)}, torch::TensorOptions().dtype(JOffsetsScalarType).device(data.device())); - mListIdx = torch::empty({0, 1}, torch::TensorOptions().dtype(JLIdxScalarType).device(data.device())); + TORCH_CHECK_VALUE(lsizes[0] == data.size(0), + "Sum of list sizes must equal the number of elements in data"); + mOffsets = + torch::tensor({ JOffsetsType(0), data.size(0) }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(data.device())); + mListIdx = torch::empty( + { 0, 1 }, torch::TensorOptions().dtype(JLIdxScalarType).device(data.device())); mNumOuterLists = 1; - mBatchIdx = torch::empty({0}, torch::TensorOptions().dtype(JIdxScalarType).device(data.device())); + mBatchIdx = + torch::empty({ 0 }, torch::TensorOptions().dtype(JIdxScalarType).device(data.device())); mData = data; if (mData.dim() == 0) { mData = mData.unsqueeze(0); } - TORCH_CHECK(mData.dim() > 0, "assigned data must have shape [N, ...], but got data.dim() = 0"); + TORCH_CHECK(mData.dim() > 0, + "assigned data must have shape [N, ...], but got data.dim() = 0"); return; } - torch::Tensor offsetsCPU = torch::empty({(JOffsetsType) lsizes.size() + 1}, torch::TensorOptions().dtype(JOffsetsScalarType).device(torch::kCPU)); + torch::Tensor offsetsCPU = + torch::empty({ (JOffsetsType)lsizes.size() + 1 }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(torch::kCPU)); auto offsetsCPUAcc = offsetsCPU.accessor(); - mListIdx = torch::empty({0, 1}, torch::TensorOptions().dtype(JLIdxScalarType).device(data.device())); + mListIdx = + torch::empty({ 0, 1 }, torch::TensorOptions().dtype(JLIdxScalarType).device(data.device())); mNumOuterLists = lsizes.size(); JOffsetsType cumulativeElements = 0; @@ -202,46 +241,59 @@ JaggedTensor::JaggedTensor(const std::vector& lsizes, const torch::Tens cumulativeElements += lsizes[i]; } offsetsCPUAcc[lsizes.size()] = cumulativeElements; - TORCH_CHECK_VALUE(cumulativeElements == data.size(0), "Sum of list sizes must equal the number of elements in data"); + TORCH_CHECK_VALUE(cumulativeElements == data.size(0), + "Sum of list sizes must equal the number of elements in data"); - mOffsets = offsetsCPU.to(data.device()); - mData = data; + mOffsets = offsetsCPU.to(data.device()); + mData = data; mBatchIdx = jidx_from_joffsets(mOffsets, data.size(0)); } -JaggedTensor::JaggedTensor(const std::vector>& lsizes, const int64_t totalTensors, const torch::Tensor data) { +JaggedTensor::JaggedTensor(const std::vector> &lsizes, + const int64_t totalTensors, const torch::Tensor data) { // TODO (Francis) : Rewrite as a cuda kernel TORCH_CHECK_VALUE(lsizes.size() > 0, "empty lshape"); // This is an implementation detail where we don't store jidx for // a single list since everything is just zero by default. if (totalTensors == 1) { - TORCH_CHECK(lsizes.size() == 1, "Single tensor must be a 1D tensor. This should never happen."); - TORCH_CHECK(lsizes[0].size() == 1, "Single tensor must be a 1D tensor. This should never happen."); + TORCH_CHECK(lsizes.size() == 1, + "Single tensor must be a 1D tensor. This should never happen."); + TORCH_CHECK(lsizes[0].size() == 1, + "Single tensor must be a 1D tensor. This should never happen."); TORCH_CHECK_VALUE(lsizes[0][0] == data.size(0), "Invalid size for data tensor."); mData = data; if (mData.dim() == 0) { mData = mData.unsqueeze(0); } - TORCH_CHECK(mData.dim() > 0, "assigned data must have shape [N, ...], but got data.dim() = 0"); - mBatchIdx = torch::empty({0}, torch::TensorOptions().dtype(JIdxScalarType).device(mData.device())); - mOffsets = torch::tensor({JOffsetsType(0), mData.size(0)}, torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); - mListIdx = torch::zeros({1, 2}, torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())); + TORCH_CHECK(mData.dim() > 0, + "assigned data must have shape [N, ...], but got data.dim() = 0"); + mBatchIdx = torch::empty( + { 0 }, torch::TensorOptions().dtype(JIdxScalarType).device(mData.device())); + mOffsets = + torch::tensor({ JOffsetsType(0), mData.size(0) }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); + mListIdx = torch::zeros( + { 1, 2 }, torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())); mNumOuterLists = 1; return; } - torch::Tensor offsetsCPU = torch::empty({(JOffsetsType) totalTensors + 1}, torch::TensorOptions().dtype(JOffsetsScalarType).device(torch::kCPU)); - torch::Tensor listIdsCPU = torch::empty({(JLIdxType) totalTensors, 2}, torch::TensorOptions().dtype(JLIdxScalarType).device(torch::kCPU)); + torch::Tensor offsetsCPU = + torch::empty({ (JOffsetsType)totalTensors + 1 }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(torch::kCPU)); + torch::Tensor listIdsCPU = + torch::empty({ (JLIdxType)totalTensors, 2 }, + torch::TensorOptions().dtype(JLIdxScalarType).device(torch::kCPU)); auto offsetsCPUAcc = offsetsCPU.accessor(); auto listIdsCPUAcc = listIdsCPU.accessor(); JOffsetsType cumulativeElements = 0; - int64_t tensorCount = 0; + int64_t tensorCount = 0; for (size_t i = 0; i < lsizes.size(); ++i) { TORCH_CHECK_VALUE(lsizes[i].size() > 0, "empty lshape"); for (size_t j = 0; j < lsizes[i].size(); j += 1) { - offsetsCPUAcc[tensorCount] = cumulativeElements; + offsetsCPUAcc[tensorCount] = cumulativeElements; listIdsCPUAcc[tensorCount][0] = i; listIdsCPUAcc[tensorCount][1] = j; cumulativeElements += lsizes[i][j]; @@ -249,26 +301,28 @@ JaggedTensor::JaggedTensor(const std::vector>& lsizes, cons } } offsetsCPUAcc[totalTensors] = cumulativeElements; - TORCH_CHECK_VALUE(cumulativeElements == data.size(0), "Sum of list sizes must equal the number of elements in data"); + TORCH_CHECK_VALUE(cumulativeElements == data.size(0), + "Sum of list sizes must equal the number of elements in data"); - mOffsets = offsetsCPU.to(data.device()); - mListIdx = listIdsCPU.to(data.device()); - mBatchIdx = jidx_from_joffsets(mOffsets, data.size(0)); - mData = data; + mOffsets = offsetsCPU.to(data.device()); + mListIdx = listIdsCPU.to(data.device()); + mBatchIdx = jidx_from_joffsets(mOffsets, data.size(0)); + mData = data; mNumOuterLists = lsizes.size(); } -void JaggedTensor::recompute_lsizes_if_dirty() { +void +JaggedTensor::recompute_lsizes_if_dirty() { if (!mLShapeCache.mDirty) { return; } mLShapeCache.clear(); if (ldim() == 1) { const torch::Tensor offsetsCpu = mOffsets.cpu(); - const auto acc = offsetsCpu.accessor(); + const auto acc = offsetsCpu.accessor(); for (int i = 0; i < num_tensors(); ++i) { const JOffsetsType startIdx = acc[i]; - const JOffsetsType endIdx = acc[i+1]; + const JOffsetsType endIdx = acc[i + 1]; mLShapeCache.mLShape1.push_back(endIdx - startIdx); } mLShapeCache.mDirty = false; @@ -276,8 +330,8 @@ void JaggedTensor::recompute_lsizes_if_dirty() { } else if (ldim() == 2) { const torch::Tensor offsetsCpu = mOffsets.cpu(); const torch::Tensor listIdxCpu = mListIdx.cpu(); - const auto offAcc = offsetsCpu.accessor(); - const auto lixAcc = listIdxCpu.accessor(); + const auto offAcc = offsetsCpu.accessor(); + const auto lixAcc = listIdxCpu.accessor(); ssize_t currentList = -1; for (int i = 0; i < num_tensors(); ++i) { @@ -288,37 +342,41 @@ void JaggedTensor::recompute_lsizes_if_dirty() { mLShapeCache.mLShape2.push_back(std::vector()); } const JOffsetsType startIdx = offAcc[i]; - const JOffsetsType endIdx = offAcc[i+1]; + const JOffsetsType endIdx = offAcc[i + 1]; mLShapeCache.mLShape2.back().push_back(endIdx - startIdx); } mLShapeCache.mDirty = false; return; } else { - TORCH_CHECK(false, "Unsupported list dimension. Currently JaggedTensor only supports up to 2."); + TORCH_CHECK(false, + "Unsupported list dimension. Currently JaggedTensor only supports up to 2."); } } -std::vector JaggedTensor::unbind1() const { +std::vector +JaggedTensor::unbind1() const { std::vector ret(num_tensors()); int64_t ldim = mListIdx.size(1); if (ldim != 1) { - TORCH_WARN("Calling unbind on a multidimensional list of jagged tensors will return a flattened list"); + TORCH_WARN( + "Calling unbind on a multidimensional list of jagged tensors will return a flattened list"); } torch::Tensor offsetsCpu = mOffsets.cpu(); - auto acc = offsetsCpu.accessor(); + auto acc = offsetsCpu.accessor(); for (int i = 0; i < num_tensors(); ++i) { const JOffsetsType startIdx = acc[i]; - const JOffsetsType endIdx = acc[i+1]; + const JOffsetsType endIdx = acc[i + 1]; - ret[i] = mData.index({torch::indexing::Slice(startIdx, endIdx)}); + ret[i] = mData.index({ torch::indexing::Slice(startIdx, endIdx) }); } return ret; } -std::vector> JaggedTensor::unbind2() const { +std::vector> +JaggedTensor::unbind2() const { std::vector> ret; int64_t ldim = mListIdx.size(1); @@ -327,9 +385,9 @@ std::vector> JaggedTensor::unbind2() const { TORCH_CHECK_VALUE(false, "Called unbind2() on a list with list dimension != 2"); } - torch::Tensor listIdxCpu = mListIdx.cpu(); - torch::Tensor offsetsCpu = mOffsets.cpu(); - ssize_t currentList = -1; + torch::Tensor listIdxCpu = mListIdx.cpu(); + torch::Tensor offsetsCpu = mOffsets.cpu(); + ssize_t currentList = -1; auto offAcc = offsetsCpu.accessor(); auto lixAcc = listIdxCpu.accessor(); @@ -342,34 +400,38 @@ std::vector> JaggedTensor::unbind2() const { ret.push_back(std::vector()); } const JOffsetsType startIdx = offAcc[i]; - const JOffsetsType endIdx = offAcc[i+1]; - + const JOffsetsType endIdx = offAcc[i + 1]; - ret.back().push_back(mData.index({torch::indexing::Slice(startIdx, endIdx)})); + ret.back().push_back(mData.index({ torch::indexing::Slice(startIdx, endIdx) })); } return ret; } -std::vector JaggedTensor::lsizes1() const { +std::vector +JaggedTensor::lsizes1() const { TORCH_CHECK(ldim() == 1, "Nesting dimension must be 1"); - const_cast(this)->recompute_lsizes_if_dirty(); + const_cast(this)->recompute_lsizes_if_dirty(); return mLShapeCache.mLShape1; } -std::vector> JaggedTensor::lsizes2() const { +std::vector> +JaggedTensor::lsizes2() const { TORCH_CHECK(ldim() == 2, "Nesting dimension must be 2"); - const_cast(this)->recompute_lsizes_if_dirty(); + const_cast(this)->recompute_lsizes_if_dirty(); return mLShapeCache.mLShape2; } -int64_t JaggedTensor::ldim() const { +int64_t +JaggedTensor::ldim() const { TORCH_CHECK_VALUE(mListIdx.dim() == 2, "Corrupt list indices. This should never happen"); - TORCH_CHECK_VALUE(mListIdx.numel() == 0 || mListIdx.size(0) == (mOffsets.size(0) - 1), "Corrupt list indices. This should never happen"); + TORCH_CHECK_VALUE(mListIdx.numel() == 0 || mListIdx.size(0) == (mOffsets.size(0) - 1), + "Corrupt list indices. This should never happen"); return mListIdx.size(1); } -std::vector JaggedTensor::esizes() const { +std::vector +JaggedTensor::esizes() const { std::vector sizes; for (size_t i = 1; i < mData.sizes().size(); i++) { sizes.push_back(mData.size(i)); @@ -377,111 +439,148 @@ std::vector JaggedTensor::esizes() const { return sizes; } -int64_t JaggedTensor::edim() const { +int64_t +JaggedTensor::edim() const { return mData.dim() > 0 ? mData.dim() - 1 : 0; } - -JaggedTensor JaggedTensor::jagged_like(torch::Tensor data) const { - TORCH_CHECK_VALUE(data.dim() > 0, "assigned data must have shape [N, ...], but got data.dim() = 0"); +JaggedTensor +JaggedTensor::jagged_like(torch::Tensor data) const { + TORCH_CHECK_VALUE(data.dim() > 0, + "assigned data must have shape [N, ...], but got data.dim() = 0"); TORCH_CHECK_VALUE(mListIdx.dim() == 2, "Corrupt list indices. This should never happen"); - TORCH_CHECK_VALUE(mListIdx.numel() == 0 || mListIdx.size(0) == (mOffsets.size(0) - 1), "Corrupt list indices. This should never happen"); - TORCH_CHECK_VALUE(data.size(0) == mData.size(0), "Assigned data must have the same number of elements as the JaggedTensor"); + TORCH_CHECK_VALUE(mListIdx.numel() == 0 || mListIdx.size(0) == (mOffsets.size(0) - 1), + "Corrupt list indices. This should never happen"); + TORCH_CHECK_VALUE(data.size(0) == mData.size(0), + "Assigned data must have the same number of elements as the JaggedTensor"); JaggedTensor ret; - ret.mBatchIdx = jidx(); - ret.mOffsets = joffsets(); - ret.mListIdx = jlidx(); + ret.mBatchIdx = jidx(); + ret.mOffsets = joffsets(); + ret.mListIdx = jlidx(); ret.mNumOuterLists = mNumOuterLists; - ret.mData = data.to(device()); - ret.mLShapeCache = mLShapeCache; + ret.mData = data.to(device()); + ret.mLShapeCache = mLShapeCache; return ret; } -JaggedTensor JaggedTensor::from_data_indices_and_list_ids(torch::Tensor data, torch::Tensor indices, torch::Tensor list_ids, int64_t num_tensors) { +JaggedTensor +JaggedTensor::from_data_indices_and_list_ids(torch::Tensor data, torch::Tensor indices, + torch::Tensor list_ids, int64_t num_tensors) { JaggedTensor ret; - ret.mData = data; - ret.mBatchIdx = indices; - ret.mListIdx = list_ids; - ret.mOffsets = joffsets_from_jidx_and_jdata(indices, data, num_tensors); + ret.mData = data; + ret.mBatchIdx = indices; + ret.mListIdx = list_ids; + ret.mOffsets = joffsets_from_jidx_and_jdata(indices, data, num_tensors); ret.mNumOuterLists = ret.joffsets().size(0) - 1; ret.mLShapeCache.markDirty(); return ret; } -JaggedTensor JaggedTensor::from_data_offsets_and_list_ids(torch::Tensor data, torch::Tensor offsets, torch::Tensor list_ids) { - TORCH_CHECK_VALUE(list_ids.dim() == 2, "Invalid list indices when constructing JaggedTensor from data, offsets, and list indices"); - TORCH_CHECK_VALUE(list_ids.numel() == 0 || list_ids.size(0) == (offsets.size(0) - 1), "Invalid list indices when constructing JaggedTensor from data, offsets, and list indices"); - TORCH_CHECK_VALUE(offsets.dim() == 1, "Invalid offsets when constructing JaggedTensor from data, offsets, and list indices"); +JaggedTensor +JaggedTensor::from_data_offsets_and_list_ids(torch::Tensor data, torch::Tensor offsets, + torch::Tensor list_ids) { + TORCH_CHECK_VALUE( + list_ids.dim() == 2, + "Invalid list indices when constructing JaggedTensor from data, offsets, and list indices"); + TORCH_CHECK_VALUE( + list_ids.numel() == 0 || list_ids.size(0) == (offsets.size(0) - 1), + "Invalid list indices when constructing JaggedTensor from data, offsets, and list indices"); + TORCH_CHECK_VALUE( + offsets.dim() == 1, + "Invalid offsets when constructing JaggedTensor from data, offsets, and list indices"); JaggedTensor ret; - ret.mData = data; - ret.mOffsets = offsets; - ret.mListIdx = list_ids; + ret.mData = data; + ret.mOffsets = offsets; + ret.mListIdx = list_ids; ret.mNumOuterLists = offsets.size(0) - 1; - ret.mBatchIdx = jidx_from_joffsets(offsets, data.size(0)); + ret.mBatchIdx = jidx_from_joffsets(offsets, data.size(0)); ret.mLShapeCache.markDirty(); return ret; } -JaggedTensor JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(torch::Tensor jdata, torch::Tensor joffsets, - torch::Tensor jidx, torch::Tensor lidx, - int64_t numOuterLists) { - TORCH_CHECK_VALUE(lidx.dim() == 2, "Invalid list indices when constructing JaggedTensor from data, offsets, indices, list indices"); - TORCH_CHECK_VALUE(lidx.numel() == 0 || lidx.size(0) == (joffsets.size(0) - 1), "Invalid list indices when constructing JaggedTensor from data, offsets, indices, list indices"); - TORCH_CHECK_VALUE(joffsets.dim() == 1, "Invalid offsets when constructing JaggedTensor from data, offsets, indices, list indices"); +JaggedTensor +JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(torch::Tensor jdata, torch::Tensor joffsets, + torch::Tensor jidx, torch::Tensor lidx, + int64_t numOuterLists) { + TORCH_CHECK_VALUE( + lidx.dim() == 2, + "Invalid list indices when constructing JaggedTensor from data, offsets, indices, list indices"); + TORCH_CHECK_VALUE( + lidx.numel() == 0 || lidx.size(0) == (joffsets.size(0) - 1), + "Invalid list indices when constructing JaggedTensor from data, offsets, indices, list indices"); + TORCH_CHECK_VALUE( + joffsets.dim() == 1, + "Invalid offsets when constructing JaggedTensor from data, offsets, indices, list indices"); JaggedTensor ret; - ret.mData = jdata; - ret.mOffsets = joffsets; - ret.mListIdx = lidx; + ret.mData = jdata; + ret.mOffsets = joffsets; + ret.mListIdx = lidx; ret.mNumOuterLists = numOuterLists; - ret.mBatchIdx = jidx; + ret.mBatchIdx = jidx; ret.mLShapeCache.markDirty(); ret.recompute_lsizes_if_dirty(); return ret; } -void JaggedTensor::set_data(const torch::Tensor& data) { - TORCH_CHECK_VALUE(data.dim() > 0, "assigned data must have shape [N, ...], but got data.dim() = 0"); - TORCH_CHECK_VALUE((data.device() == mBatchIdx.device()) || (mBatchIdx.numel() == 0 && num_tensors() == 1), "Incorrect device for data"); +void +JaggedTensor::set_data(const torch::Tensor &data) { + TORCH_CHECK_VALUE(data.dim() > 0, + "assigned data must have shape [N, ...], but got data.dim() = 0"); + TORCH_CHECK_VALUE((data.device() == mBatchIdx.device()) || + (mBatchIdx.numel() == 0 && num_tensors() == 1), + "Incorrect device for data"); TORCH_CHECK_VALUE(data.device() == mOffsets.device(), "Incorrect device for data"); TORCH_CHECK_VALUE(mListIdx.dim() == 2, "Corrupt list indices. This should never happen"); - TORCH_CHECK_VALUE(mListIdx.numel() == 0 || mListIdx.size(0) == (mOffsets.size(0) - 1), "Corrupt list indices. This should never happen"); + TORCH_CHECK_VALUE(mListIdx.numel() == 0 || mListIdx.size(0) == (mOffsets.size(0) - 1), + "Corrupt list indices. This should never happen"); if (mBatchIdx.size(0) == 0) { TORCH_CHECK(mOffsets.dim() == 1, "bad offsets. this should never happen"); - TORCH_CHECK(mOffsets.size(0) == (num_outer_lists() + 1), "bad offsets. this should never happen"); + TORCH_CHECK(mOffsets.size(0) == (num_outer_lists() + 1), + "bad offsets. this should never happen"); TORCH_CHECK_VALUE(data.size(0) == mData.size(0), "assigned data must have shape [N, ...]"); } else { - TORCH_CHECK_VALUE(data.size(0) == mBatchIdx.size(0), "assigned data must have shape [N, ...]"); + TORCH_CHECK_VALUE(data.size(0) == mBatchIdx.size(0), + "assigned data must have shape [N, ...]"); } mData = data; } -JaggedTensor JaggedTensor::rmask(const torch::Tensor& mask) const { - TORCH_CHECK(mask.device() == mBatchIdx.device(), "mask must be on the same device as the JaggedTensor"); +JaggedTensor +JaggedTensor::rmask(const torch::Tensor &mask) const { + TORCH_CHECK(mask.device() == mBatchIdx.device(), + "mask must be on the same device as the JaggedTensor"); TORCH_CHECK(mask.dim() == 1, "mask must be 1-dimensional"); - TORCH_CHECK(mask.size(0) == mData.size(0), "mask must have the same size as the first dimension of the JaggedTensor"); + TORCH_CHECK(mask.size(0) == mData.size(0), + "mask must have the same size as the first dimension of the JaggedTensor"); TORCH_CHECK(mask.scalar_type() == torch::kBool, "mask must be of type bool"); - TORCH_CHECK((mask.size(0) == mBatchIdx.size(0)) || (mBatchIdx.size(0) == 0 && mOffsets.size(0) == 2), - "Bad jidx. This should never happen. mask.size(0) = ", mask.size(0), " mBatchIdx.size(0) = ", mBatchIdx.size(0)); - const torch::Tensor retData = mData.index({mask, "..."}); - const torch::Tensor retBatchIds = mBatchIdx.size(0) > 0 ? mBatchIdx.index({mask}) : mBatchIdx; - const torch::Tensor retOffsets = joffsets_from_jidx_and_jdata(retBatchIds, retData, num_tensors()); - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retOffsets, retBatchIds, mListIdx, mNumOuterLists); + TORCH_CHECK((mask.size(0) == mBatchIdx.size(0)) || + (mBatchIdx.size(0) == 0 && mOffsets.size(0) == 2), + "Bad jidx. This should never happen. mask.size(0) = ", mask.size(0), + " mBatchIdx.size(0) = ", mBatchIdx.size(0)); + const torch::Tensor retData = mData.index({ mask, "..." }); + const torch::Tensor retBatchIds = mBatchIdx.size(0) > 0 ? mBatchIdx.index({ mask }) : mBatchIdx; + const torch::Tensor retOffsets = + joffsets_from_jidx_and_jdata(retBatchIds, retData, num_tensors()); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retOffsets, retBatchIds, + mListIdx, mNumOuterLists); } -JaggedTensor JaggedTensor::index(JaggedTensorIndex idx) const { +JaggedTensor +JaggedTensor::index(JaggedTensorIndex idx) const { if (idx.is_integer()) { return FVDB_DISPATCH_KERNEL_DEVICE(mData.device(), [&]() { return detail::ops::dispatchJaggedTensorIndex(*this, idx.integer()); }); } else if (idx.is_slice()) { int64_t start = idx.slice().start().as_int_unchecked(); - int64_t end = idx.slice().stop().as_int_unchecked(); - int64_t step = idx.slice().step().as_int_unchecked(); - TORCH_CHECK_INDEX(step == 1, "step must be 1 for JaggedTensor. Only contiguous slicing is supported."); + int64_t end = idx.slice().stop().as_int_unchecked(); + int64_t step = idx.slice().step().as_int_unchecked(); + TORCH_CHECK_INDEX(step == 1, + "step must be 1 for JaggedTensor. Only contiguous slicing is supported."); // Deal with symbolic int if (start >= at::indexing::INDEX_MAX) { @@ -502,93 +601,109 @@ JaggedTensor JaggedTensor::index(JaggedTensorIndex idx) const { start = end; } - start = std::max(start, (int64_t) 0); - end = std::min(end, mNumOuterLists); + start = std::max(start, (int64_t)0); + end = std::min(end, mNumOuterLists); if (mListIdx.size(0) == 0) { TORCH_CHECK(ldim() == 1, "bad list indexes. this should never happen"); - const JOffsetsType startIdx = mOffsets[start].item(); - const JOffsetsType endIdx = mOffsets[end].item(); - const torch::Tensor retLidx = mListIdx.numel() == 0 ? mListIdx : mListIdx.index({torch::indexing::Slice(start, end)}); + const JOffsetsType startIdx = mOffsets[start].item(); + const JOffsetsType endIdx = mOffsets[end].item(); + const torch::Tensor retLidx = + mListIdx.numel() == 0 ? mListIdx + : mListIdx.index({ torch::indexing::Slice(start, end) }); return JaggedTensor::from_data_offsets_and_list_ids( - mData.index({torch::indexing::Slice(startIdx, endIdx)}), - mOffsets.index({torch::indexing::Slice(start, end+1)}) - startIdx, - retLidx); + mData.index({ torch::indexing::Slice(startIdx, endIdx) }), + mOffsets.index({ torch::indexing::Slice(start, end + 1) }) - startIdx, retLidx); } else { // Find all tensors that belong to the slice - const torch::Tensor outerLidx = mListIdx.index({torch::indexing::Slice(), 0}); - const torch::Tensor mask = outerLidx.ge(start).logical_and(outerLidx.lt(end)); - const torch::Tensor joffsetCat = torch::stack({ - mOffsets.index({torch::indexing::Slice(0, num_tensors())}), - mOffsets.index({torch::indexing::Slice(1, num_tensors()+1)}) - }, 1); - const torch::Tensor selectedOffsets = joffsetCat.index({mask}); + const torch::Tensor outerLidx = mListIdx.index({ torch::indexing::Slice(), 0 }); + const torch::Tensor mask = outerLidx.ge(start).logical_and(outerLidx.lt(end)); + const torch::Tensor joffsetCat = + torch::stack({ mOffsets.index({ torch::indexing::Slice(0, num_tensors()) }), + mOffsets.index({ torch::indexing::Slice(1, num_tensors() + 1) }) }, + 1); + const torch::Tensor selectedOffsets = joffsetCat.index({ mask }); // Get the start and end offsets into the data tensor for the slice - JOffsetsType startIdx = selectedOffsets.size(0) > 0 ? selectedOffsets[0][0].item() : 0; - JOffsetsType endIdx = selectedOffsets.size(0) > 0 ? selectedOffsets[-1][1].item() : 0; + JOffsetsType startIdx = + selectedOffsets.size(0) > 0 ? selectedOffsets[0][0].item() : 0; + JOffsetsType endIdx = + selectedOffsets.size(0) > 0 ? selectedOffsets[-1][1].item() : 0; // Slice the data tensor - const torch::Tensor retData = mData.index({torch::indexing::Slice(startIdx, endIdx)}); + const torch::Tensor retData = mData.index({ torch::indexing::Slice(startIdx, endIdx) }); // Subtract the start offset from the selected offsets to get the new offsets // NOTE: This assumes offsets are always contiguous - const torch::Tensor retOffsets = selectedOffsets.numel() > 0 ? torch::cat({ - selectedOffsets.index({torch::indexing::Slice(), 0}), - selectedOffsets.index({-1, 1}).unsqueeze(0) - }) - startIdx : torch::zeros({1}, torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); + const torch::Tensor retOffsets = + selectedOffsets.numel() > 0 + ? torch::cat({ selectedOffsets.index({ torch::indexing::Slice(), 0 }), + selectedOffsets.index({ -1, 1 }).unsqueeze(0) }) - + startIdx + : torch::zeros( + { 1 }, + torch::TensorOptions().dtype(JOffsetsScalarType).device(mData.device())); // Slice the list indices and subtract the start index TORCH_CHECK(mListIdx.size(1) > 1, "bad list indexes. this should never happen"); - torch::Tensor retListIdx = mListIdx.index({mask}); - retListIdx.index({torch::indexing::Slice(), 0}) -= start; + torch::Tensor retListIdx = mListIdx.index({ mask }); + retListIdx.index({ torch::indexing::Slice(), 0 }) -= start; if (retListIdx.dim() == 0) { retListIdx = retListIdx.unsqueeze(1); } - const int64_t retNumOuterLists = end - start; - const torch::Tensor retJidx = jidx_from_joffsets(retOffsets, retData.size(0)); - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retOffsets, retJidx, retListIdx, retNumOuterLists); + const int64_t retNumOuterLists = end - start; + const torch::Tensor retJidx = jidx_from_joffsets(retOffsets, retData.size(0)); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( + retData, retOffsets, retJidx, retListIdx, retNumOuterLists); } } else if (idx.is_ellipsis()) { return *this; } else if (idx.is_jagged_tensor()) { - const JaggedTensor& jtIndices = idx.jagged_tensor(); - TORCH_CHECK_VALUE(jtIndices.device() == device(), "indices must be on the same device as the JaggedTensor"); + const JaggedTensor &jtIndices = idx.jagged_tensor(); + TORCH_CHECK_VALUE(jtIndices.device() == device(), + "indices must be on the same device as the JaggedTensor"); - TORCH_CHECK_INDEX(jtIndices.mListIdx.dim() == mListIdx.dim(), - "Indices must have the same list structure as JaggedTensor being indexed"); + TORCH_CHECK_INDEX( + jtIndices.mListIdx.dim() == mListIdx.dim(), + "Indices must have the same list structure as JaggedTensor being indexed"); for (int i = 0; i < mListIdx.dim(); ++i) { - TORCH_CHECK_INDEX(jtIndices.mListIdx.size(i) == mListIdx.size(i), - "Indices must have the same list structure as JaggedTensor being indexed"); + TORCH_CHECK_INDEX( + jtIndices.mListIdx.size(i) == mListIdx.size(i), + "Indices must have the same list structure as JaggedTensor being indexed"); } if (Config::global().pendanticErrorCheckingEnabled()) { // This is a slow check that we cap optionally do for correctness. - TORCH_CHECK_INDEX(torch::all(jtIndices.mListIdx == mListIdx).item(), - "Indices must have the same list structure as JaggedTensor being indexed. ", - "This error was raised because config.pendatic_error_checking was enabled"); + TORCH_CHECK_INDEX( + torch::all(jtIndices.mListIdx == mListIdx).item(), + "Indices must have the same list structure as JaggedTensor being indexed. ", + "This error was raised because config.pendatic_error_checking was enabled"); } - c10::ScalarType idxdt = jtIndices.scalar_type(); + c10::ScalarType idxdt = jtIndices.scalar_type(); const bool isIndexType = (idxdt == c10::ScalarType::Long || idxdt == c10::ScalarType::Int || idxdt == c10::ScalarType::Byte || idxdt == c10::ScalarType::Bool); - TORCH_CHECK_INDEX(isIndexType, "JaggedTensors used as indices must be long, int, byte or bool JaggedTensors but got ", idxdt); + TORCH_CHECK_INDEX( + isIndexType, + "JaggedTensors used as indices must be long, int, byte or bool JaggedTensors but got ", + idxdt); torch::Tensor selidx; if (jtIndices.scalar_type() == torch::kBool) { selidx = jtIndices.jdata(); } else { - // FIXME (Francis): We're not checking out of range here and it's sketchy! Fix in a unified CUDA kernel + // FIXME (Francis): We're not checking out of range here and it's sketchy! Fix in a + // unified CUDA kernel selidx = jtIndices.jdata().clone(); for (int i = 0; i < jtIndices.joffsets().size(0) - 1; ++i) { const JOffsetsType start = jtIndices.joffsets()[i].item(); - const JOffsetsType end = jtIndices.joffsets()[i+1].item(); - const JOffsetsType add = mOffsets[i].item(); - selidx.index({torch::indexing::Slice(start, end)}).add_(add); + const JOffsetsType end = jtIndices.joffsets()[i + 1].item(); + const JOffsetsType add = mOffsets[i].item(); + selidx.index({ torch::indexing::Slice(start, end) }).add_(add); } } - const torch::Tensor retJdata = mData.index({selidx}); - torch::Tensor retJidx = mBatchIdx.index({selidx}); + const torch::Tensor retJdata = mData.index({ selidx }); + torch::Tensor retJidx = mBatchIdx.index({ selidx }); if (retJidx.dim() > 1) { std::vector idx; idx.reserve(retJidx.dim()); @@ -599,51 +714,64 @@ JaggedTensor JaggedTensor::index(JaggedTensorIndex idx) const { retJidx = retJidx.index(idx); } retJidx = retJidx.contiguous(); - const torch::Tensor retJOffsets = joffsets_from_jidx_and_jdata(retJidx, retJdata, num_tensors()); + const torch::Tensor retJOffsets = + joffsets_from_jidx_and_jdata(retJidx, retJdata, num_tensors()); const torch::Tensor retListIdx = mListIdx; - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retJdata, retJOffsets, - retJidx, retListIdx, - mNumOuterLists); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( + retJdata, retJOffsets, retJidx, retListIdx, mNumOuterLists); } else { TORCH_CHECK_INDEX(false, "Unsupported indexing operation"); } } -JaggedTensor JaggedTensor::jreshape(const std::vector& lsizes) const { +JaggedTensor +JaggedTensor::jreshape(const std::vector &lsizes) const { return JaggedTensor(lsizes, mData); } -JaggedTensor JaggedTensor::jreshape(const std::vector>& lsizes) const { +JaggedTensor +JaggedTensor::jreshape(const std::vector> &lsizes) const { return JaggedTensor(lsizes, num_tensors(), mData); } -JaggedTensor JaggedTensor::jreshape_as(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::jreshape_as(const JaggedTensor &other) const { return other.jagged_like(mData); } -JaggedTensor JaggedTensor::jflatten(const int64_t dim) const { +JaggedTensor +JaggedTensor::jflatten(const int64_t dim) const { int64_t jdim = dim; if (dim < 0) { jdim += ldim(); } TORCH_CHECK_INDEX(jdim >= 0 && jdim < ldim(), "Invalid dimension to flatten"); - if (ldim() == 2) { if (jdim == 1) { - torch::Tensor newJIdx = mListIdx.index({torch::indexing::Slice(), 0}).index({mBatchIdx.to(torch::kInt)}).to(JIdxScalarType); - torch::Tensor newOffsets = joffsets_from_jidx_and_jdata(newJIdx, mData, num_outer_lists()); + torch::Tensor newJIdx = mListIdx.index({ torch::indexing::Slice(), 0 }) + .index({ mBatchIdx.to(torch::kInt) }) + .to(JIdxScalarType); + torch::Tensor newOffsets = + joffsets_from_jidx_and_jdata(newJIdx, mData, num_outer_lists()); return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( - mData, newOffsets, newJIdx, torch::empty({0, 1}, torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())), newOffsets.size(0) - 1); + mData, newOffsets, newJIdx, + torch::empty({ 0, 1 }, + torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())), + newOffsets.size(0) - 1); } else { return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( - mData, mOffsets, mBatchIdx, torch::empty({0, 1}, torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())), mOffsets.size(0) - 1); + mData, mOffsets, mBatchIdx, + torch::empty({ 0, 1 }, + torch::TensorOptions().dtype(JLIdxScalarType).device(mData.device())), + mOffsets.size(0) - 1); } } else if (ldim() == 1) { return JaggedTensor(mData); } else { - TORCH_CHECK(false, "Unsupported list dimension. Currently JaggedTensor only supports up to 2."); + TORCH_CHECK(false, + "Unsupported list dimension. Currently JaggedTensor only supports up to 2."); } } // JaggedTensor JaggedTensor::jagged_argsort() { @@ -655,9 +783,11 @@ JaggedTensor JaggedTensor::jflatten(const int64_t dim) const { // return jagged_like(argsortIdx); // } -JaggedTensor JaggedTensor::jsum(int64_t dim, bool keepdim) const { +JaggedTensor +JaggedTensor::jsum(int64_t dim, bool keepdim) const { const int64_t jdim = mData.dim(); - TORCH_CHECK_INDEX(dim >= -(jdim-1) && dim < jdim, "dim must be between ", -(jdim-1), " and ", jdim-1, " inclusive"); + TORCH_CHECK_INDEX(dim >= -(jdim - 1) && dim < jdim, "dim must be between ", -(jdim - 1), + " and ", jdim - 1, " inclusive"); if (dim < 0) { dim += jdim; } @@ -667,20 +797,26 @@ JaggedTensor JaggedTensor::jsum(int64_t dim, bool keepdim) const { if (mBatchIdx.size(0) == 0) { retData = mData.sum(0).unsqueeze(0); } else { - retData = detail::autograd::JaggedSum::apply(jdata(), jidx(), joffsets(), num_tensors())[0]; + retData = + detail::autograd::JaggedSum::apply(jdata(), jidx(), joffsets(), num_tensors())[0]; } - const torch::Tensor retOffsets = torch::arange(0, retData.size(0) + 1, torch::TensorOptions().dtype(JOffsetsScalarType).device(retData.device())); + const torch::Tensor retOffsets = torch::arange( + 0, retData.size(0) + 1, + torch::TensorOptions().dtype(JOffsetsScalarType).device(retData.device())); const torch::Tensor retJidx = jidx_from_joffsets(retOffsets, retData.size(0)); - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retOffsets, retJidx, mListIdx, mNumOuterLists); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retData, retOffsets, retJidx, + mListIdx, mNumOuterLists); } else { return jagged_like(mData.sum(dim, keepdim)); } } -std::vector JaggedTensor::jmin(int64_t dim, bool keepdim) const { +std::vector +JaggedTensor::jmin(int64_t dim, bool keepdim) const { const int64_t jdim = mData.dim(); - TORCH_CHECK_INDEX(dim >= -(jdim-1) && dim <= jdim, "dim must be between ", -(jdim-1), " and ", jdim-1, " inclusive"); + TORCH_CHECK_INDEX(dim >= -(jdim - 1) && dim <= jdim, "dim must be between ", -(jdim - 1), + " and ", jdim - 1, " inclusive"); if (dim < 0) { dim += jdim; } @@ -689,31 +825,37 @@ std::vector JaggedTensor::jmin(int64_t dim, bool keepdim) const { torch::Tensor minVals, minIndices; if (mBatchIdx.size(0) == 0) { auto minTuple = mData.min(0); - minVals = std::get<0>(minTuple).unsqueeze(0); - minIndices = std::get<1>(minTuple).unsqueeze(0); - } else { - auto minTuple = detail::autograd::JaggedMin::apply(jdata(), jidx(), joffsets(), num_tensors()); - minVals = minTuple[0]; + minVals = std::get<0>(minTuple).unsqueeze(0); + minIndices = std::get<1>(minTuple).unsqueeze(0); + } else { + auto minTuple = + detail::autograd::JaggedMin::apply(jdata(), jidx(), joffsets(), num_tensors()); + minVals = minTuple[0]; minIndices = minTuple[1]; } - const torch::Tensor retOffsets = torch::arange(0, minVals.size(0) + 1, torch::TensorOptions().dtype(JOffsetsScalarType).device(minVals.device())); + const torch::Tensor retOffsets = torch::arange( + 0, minVals.size(0) + 1, + torch::TensorOptions().dtype(JOffsetsScalarType).device(minVals.device())); const torch::Tensor retJidx = jidx_from_joffsets(retOffsets, minVals.size(0)); - JaggedTensor retVals = JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(minVals, retOffsets, retJidx, mListIdx, mNumOuterLists); + JaggedTensor retVals = JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( + minVals, retOffsets, retJidx, mListIdx, mNumOuterLists); JaggedTensor retIdxs = retVals.jagged_like(minIndices); return { retVals, retIdxs }; } else { - auto minTuple = mData.min(dim, keepdim); - torch::Tensor minVals = std::get<0>(minTuple); + auto minTuple = mData.min(dim, keepdim); + torch::Tensor minVals = std::get<0>(minTuple); torch::Tensor minIndices = std::get<1>(minTuple); return { jagged_like(minVals), jagged_like(minIndices) }; } } -std::vector JaggedTensor::jmax(int64_t dim, bool keepdim) const { +std::vector +JaggedTensor::jmax(int64_t dim, bool keepdim) const { const int64_t jdim = mData.dim(); - TORCH_CHECK_INDEX(dim >= -(jdim-1) && dim <= jdim, "dim must be between ", -(jdim-1), " and ", jdim-1, " inclusive"); + TORCH_CHECK_INDEX(dim >= -(jdim - 1) && dim <= jdim, "dim must be between ", -(jdim - 1), + " and ", jdim - 1, " inclusive"); if (dim < 0) { dim += jdim; } @@ -722,28 +864,33 @@ std::vector JaggedTensor::jmax(int64_t dim, bool keepdim) const { torch::Tensor maxVals, maxIndices; if (mBatchIdx.size(0) == 0) { auto maxTuple = mData.max(0); - maxVals = std::get<0>(maxTuple).unsqueeze(0); - maxIndices = std::get<1>(maxTuple).unsqueeze(0); - } else { - auto maxTuple = detail::autograd::JaggedMax::apply(jdata(), jidx(), joffsets(), num_tensors()); - maxVals = maxTuple[0]; + maxVals = std::get<0>(maxTuple).unsqueeze(0); + maxIndices = std::get<1>(maxTuple).unsqueeze(0); + } else { + auto maxTuple = + detail::autograd::JaggedMax::apply(jdata(), jidx(), joffsets(), num_tensors()); + maxVals = maxTuple[0]; maxIndices = maxTuple[1]; } - const torch::Tensor retOffsets = torch::arange(0, maxVals.size(0) + 1, torch::TensorOptions().dtype(JOffsetsScalarType).device(maxVals.device())); + const torch::Tensor retOffsets = torch::arange( + 0, maxVals.size(0) + 1, + torch::TensorOptions().dtype(JOffsetsScalarType).device(maxVals.device())); const torch::Tensor retJidx = jidx_from_joffsets(retOffsets, maxVals.size(0)); - JaggedTensor retVals = JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(maxVals, retOffsets, retJidx, mListIdx, mNumOuterLists); + JaggedTensor retVals = JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( + maxVals, retOffsets, retJidx, mListIdx, mNumOuterLists); JaggedTensor retIdxs = retVals.jagged_like(maxIndices); return { retVals, retIdxs }; } else { - auto maxTuple = mData.max(dim, keepdim); - torch::Tensor maxVals = std::get<0>(maxTuple); + auto maxTuple = mData.max(dim, keepdim); + torch::Tensor maxVals = std::get<0>(maxTuple); torch::Tensor maxIndices = std::get<1>(maxTuple); - return {jagged_like(maxVals), jagged_like(maxIndices) }; + return { jagged_like(maxVals), jagged_like(maxIndices) }; } } -JaggedTensor JaggedTensor::jcat(const std::vector& vec, c10::optional dimension) { +JaggedTensor +JaggedTensor::jcat(const std::vector &vec, c10::optional dimension) { // Null dimension is just list concatenation if (!dimension.has_value()) { TORCH_CHECK_VALUE(vec.size() > 0, "Empty jagged tensor list"); @@ -752,18 +899,24 @@ JaggedTensor JaggedTensor::jcat(const std::vector& vec, c10::optio std::vector data; std::vector offsets; std::vector lidx; - JOffsetsType curOffset = 0; - int64_t totalLists = 0; - torch::Tensor curListOffset = torch::zeros({1, vec[0].mListIdx.size(1)}, torch::TensorOptions().dtype(JLIdxScalarType).device(vec[0].mData.device())); + JOffsetsType curOffset = 0; + int64_t totalLists = 0; + torch::Tensor curListOffset = torch::zeros( + { 1, vec[0].mListIdx.size(1) }, + torch::TensorOptions().dtype(JLIdxScalarType).device(vec[0].mData.device())); for (size_t i = 0; i < vec.size(); ++i) { - const auto& jvec = vec[i]; - TORCH_CHECK_VALUE(jvec.mData.device() == vec[0].mData.device(), "All JaggedTensors must be on the same device"); - TORCH_CHECK_VALUE(jvec.mListIdx.size(1) == vec[0].mListIdx.size(1), "All JaggedTensors must have the same list dimension"); - TORCH_CHECK_VALUE(jvec.scalar_type() == vec[0].scalar_type(), "All JaggedTensors must have the same scalar type"); + const auto &jvec = vec[i]; + TORCH_CHECK_VALUE(jvec.mData.device() == vec[0].mData.device(), + "All JaggedTensors must be on the same device"); + TORCH_CHECK_VALUE(jvec.mListIdx.size(1) == vec[0].mListIdx.size(1), + "All JaggedTensors must have the same list dimension"); + TORCH_CHECK_VALUE(jvec.scalar_type() == vec[0].scalar_type(), + "All JaggedTensors must have the same scalar type"); data.push_back(jvec.mData); if (i < vec.size() - 1) { - offsets.push_back(jvec.mOffsets.index({torch::indexing::Slice(0, -1)}) + curOffset); + offsets.push_back(jvec.mOffsets.index({ torch::indexing::Slice(0, -1) }) + + curOffset); } else { offsets.push_back(jvec.mOffsets + curOffset); } @@ -772,408 +925,501 @@ JaggedTensor JaggedTensor::jcat(const std::vector& vec, c10::optio curListOffset[0][0] += jvec.mNumOuterLists; totalLists += jvec.mNumOuterLists; } - const torch::Tensor retJData = torch::cat(data, 0); + const torch::Tensor retJData = torch::cat(data, 0); const torch::Tensor retJOffsets = torch::cat(offsets, 0); - const torch::Tensor retJidx = jidx_from_joffsets(retJOffsets, retJData.size(0)); - const torch::Tensor retLidx = torch::cat(lidx, 0); - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retJData, retJOffsets, retJidx, retLidx, totalLists); + const torch::Tensor retJidx = jidx_from_joffsets(retJOffsets, retJData.size(0)); + const torch::Tensor retLidx = torch::cat(lidx, 0); + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(retJData, retJOffsets, + retJidx, retLidx, totalLists); } else { int64_t dim = dimension.value(); TORCH_CHECK_VALUE(vec.size() > 0, "empty tensor list"); const int64_t jdim = vec[0].mData.dim(); - TORCH_CHECK_INDEX(dim >= -(jdim-1) && dim <= jdim, "dim must be between ", -(jdim-1), " and ", jdim-1, " inclusive"); + TORCH_CHECK_INDEX(dim >= -(jdim - 1) && dim <= jdim, "dim must be between ", -(jdim - 1), + " and ", jdim - 1, " inclusive"); if (dim < 0) { dim += jdim; } if (dim == 0) { - return FVDB_DISPATCH_KERNEL_DEVICE(vec[0].device(), [&]() { - return detail::ops::dispatchJCat0(vec); - }); + return FVDB_DISPATCH_KERNEL_DEVICE( + vec[0].device(), [&]() { return detail::ops::dispatchJCat0(vec); }); } else { std::vector data; - for (const auto& jvec : vec) { data.push_back(jvec.mData); } - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(torch::cat(data, dim), vec[0].mOffsets, vec[0].mBatchIdx, vec[0].mListIdx, vec[0].mNumOuterLists); + for (const auto &jvec: vec) { + data.push_back(jvec.mData); + } + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( + torch::cat(data, dim), vec[0].mOffsets, vec[0].mBatchIdx, vec[0].mListIdx, + vec[0].mNumOuterLists); } } - } - -JaggedTensor JaggedTensor::sqrt() const { +JaggedTensor +JaggedTensor::sqrt() const { return jagged_like(torch::sqrt(mData)); } -JaggedTensor JaggedTensor::abs() const { +JaggedTensor +JaggedTensor::abs() const { return jagged_like(torch::abs(mData)); } -JaggedTensor JaggedTensor::floor() const { +JaggedTensor +JaggedTensor::floor() const { return jagged_like(torch::floor(mData)); } -JaggedTensor JaggedTensor::ceil() const { +JaggedTensor +JaggedTensor::ceil() const { return jagged_like(torch::ceil(mData)); } -JaggedTensor JaggedTensor::round(int decimals) const { +JaggedTensor +JaggedTensor::round(int decimals) const { return jagged_like(torch::round(mData, decimals)); } - -JaggedTensor& JaggedTensor::sqrt_() { +JaggedTensor & +JaggedTensor::sqrt_() { mData.sqrt_(); return *this; } -JaggedTensor& JaggedTensor::abs_() { +JaggedTensor & +JaggedTensor::abs_() { mData.abs_(); return *this; } -JaggedTensor& JaggedTensor::floor_() { +JaggedTensor & +JaggedTensor::floor_() { mData.floor_(); return *this; } -JaggedTensor& JaggedTensor::ceil_() { +JaggedTensor & +JaggedTensor::ceil_() { mData.ceil_(); return *this; } -JaggedTensor& JaggedTensor::round_(int decimals) { +JaggedTensor & +JaggedTensor::round_(int decimals) { mData.round_(decimals); return *this; } - - -const JaggedTensor& JaggedTensor::set_requires_grad(bool requires_grad) const { +const JaggedTensor & +JaggedTensor::set_requires_grad(bool requires_grad) const { mData.set_requires_grad(requires_grad); return *this; } -bool JaggedTensor::requires_grad() const { +bool +JaggedTensor::requires_grad() const { return mData.requires_grad(); } -JaggedTensor JaggedTensor::detach() const { +JaggedTensor +JaggedTensor::detach() const { return jagged_like(mData.detach()); } -JaggedTensor JaggedTensor::clone() const { +JaggedTensor +JaggedTensor::clone() const { return jagged_like(mData.clone()); } - -JaggedTensor JaggedTensor::operator+(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator+(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData + other.mData); } -JaggedTensor JaggedTensor::operator+(const int other) const { +JaggedTensor +JaggedTensor::operator+(const int other) const { return jagged_like(mData + other); } -JaggedTensor JaggedTensor::operator+(const float other) const { +JaggedTensor +JaggedTensor::operator+(const float other) const { return jagged_like(mData + other); } -JaggedTensor JaggedTensor::operator+(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator+(const torch::Tensor &other) const { return jagged_like(mData + other); } -JaggedTensor& JaggedTensor::operator+=(const JaggedTensor& other) { +JaggedTensor & +JaggedTensor::operator+=(const JaggedTensor &other) { binary_op_check(other); mData += other.mData; return *this; } -JaggedTensor& JaggedTensor::operator+=(const int other) { +JaggedTensor & +JaggedTensor::operator+=(const int other) { mData += other; return *this; } -JaggedTensor& JaggedTensor::operator+=(const float other) { +JaggedTensor & +JaggedTensor::operator+=(const float other) { mData += other; return *this; } -JaggedTensor& JaggedTensor::operator+=(const torch::Tensor& other) { +JaggedTensor & +JaggedTensor::operator+=(const torch::Tensor &other) { mData += other; return *this; } -JaggedTensor JaggedTensor::operator-(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator-(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData - other.mData); } -JaggedTensor JaggedTensor::operator-(const int other) const { +JaggedTensor +JaggedTensor::operator-(const int other) const { return jagged_like(mData - other); } -JaggedTensor JaggedTensor::operator-(const float other) const { +JaggedTensor +JaggedTensor::operator-(const float other) const { return jagged_like(mData - other); } -JaggedTensor JaggedTensor::operator-(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator-(const torch::Tensor &other) const { return jagged_like(mData - other); } -JaggedTensor JaggedTensor::operator-() const { +JaggedTensor +JaggedTensor::operator-() const { return jagged_like(-mData); } -JaggedTensor& JaggedTensor::operator-=(const JaggedTensor& other) { +JaggedTensor & +JaggedTensor::operator-=(const JaggedTensor &other) { binary_op_check(other); mData -= other.mData; return *this; } -JaggedTensor& JaggedTensor::operator-=(const int other) { +JaggedTensor & +JaggedTensor::operator-=(const int other) { mData -= other; return *this; } -JaggedTensor& JaggedTensor::operator-=(const float other) { +JaggedTensor & +JaggedTensor::operator-=(const float other) { mData -= other; return *this; } -JaggedTensor& JaggedTensor::operator-=(const torch::Tensor& other) { +JaggedTensor & +JaggedTensor::operator-=(const torch::Tensor &other) { mData -= other; return *this; } -JaggedTensor JaggedTensor::operator*(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator*(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData * other.mData); } -JaggedTensor JaggedTensor::operator*(const int other) const { +JaggedTensor +JaggedTensor::operator*(const int other) const { return jagged_like(mData * other); } -JaggedTensor JaggedTensor::operator*(const float other) const { +JaggedTensor +JaggedTensor::operator*(const float other) const { return jagged_like(mData * other); } -JaggedTensor JaggedTensor::operator*(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator*(const torch::Tensor &other) const { return jagged_like(mData * other); } -JaggedTensor& JaggedTensor::operator*=(const JaggedTensor& other) { +JaggedTensor & +JaggedTensor::operator*=(const JaggedTensor &other) { binary_op_check(other); mData *= other.mData; return *this; } -JaggedTensor& JaggedTensor::operator*=(const int other) { +JaggedTensor & +JaggedTensor::operator*=(const int other) { mData *= other; return *this; } -JaggedTensor& JaggedTensor::operator*=(const float other) { +JaggedTensor & +JaggedTensor::operator*=(const float other) { mData *= other; return *this; } -JaggedTensor& JaggedTensor::operator*=(const torch::Tensor& other) { +JaggedTensor & +JaggedTensor::operator*=(const torch::Tensor &other) { mData *= other; return *this; } -JaggedTensor JaggedTensor::operator/(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator/(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData / other.mData); } -JaggedTensor JaggedTensor::operator/(const int other) const { +JaggedTensor +JaggedTensor::operator/(const int other) const { return jagged_like(mData / other); } -JaggedTensor JaggedTensor::operator/(const float other) const { +JaggedTensor +JaggedTensor::operator/(const float other) const { return jagged_like(mData / other); } -JaggedTensor JaggedTensor::operator/(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator/(const torch::Tensor &other) const { return jagged_like(mData / other); } -JaggedTensor& JaggedTensor::operator/=(const JaggedTensor& other) { +JaggedTensor & +JaggedTensor::operator/=(const JaggedTensor &other) { binary_op_check(other); mData /= other.mData; return *this; } -JaggedTensor& JaggedTensor::operator/=(const int other) { +JaggedTensor & +JaggedTensor::operator/=(const int other) { mData /= other; return *this; } -JaggedTensor& JaggedTensor::operator/=(const float other) { +JaggedTensor & +JaggedTensor::operator/=(const float other) { mData /= other; return *this; } -JaggedTensor& JaggedTensor::operator/=(const torch::Tensor& other) { +JaggedTensor & +JaggedTensor::operator/=(const torch::Tensor &other) { mData /= other; return *this; } -JaggedTensor JaggedTensor::floordiv(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::floordiv(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(torch::floor_divide(mData, other.mData)); } -JaggedTensor JaggedTensor::floordiv(const int other) const { +JaggedTensor +JaggedTensor::floordiv(const int other) const { return jagged_like(torch::floor_divide(mData, other)); } -JaggedTensor JaggedTensor::floordiv(const float other) const { +JaggedTensor +JaggedTensor::floordiv(const float other) const { return jagged_like(torch::floor_divide(mData, other)); } -JaggedTensor JaggedTensor::floordiv(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::floordiv(const torch::Tensor &other) const { return jagged_like(torch::floor_divide(mData, other)); } -JaggedTensor& JaggedTensor::floordiveq(const JaggedTensor& other) { +JaggedTensor & +JaggedTensor::floordiveq(const JaggedTensor &other) { binary_op_check(other); mData.floor_divide_(other.mData); return *this; } -JaggedTensor& JaggedTensor::floordiveq(const int other) { +JaggedTensor & +JaggedTensor::floordiveq(const int other) { mData = torch::floor_divide(mData, other); return *this; } -JaggedTensor& JaggedTensor::floordiveq(const float other) { +JaggedTensor & +JaggedTensor::floordiveq(const float other) { mData = torch::floor_divide(mData, other); return *this; } -JaggedTensor& JaggedTensor::floordiveq(const torch::Tensor& other) { +JaggedTensor & +JaggedTensor::floordiveq(const torch::Tensor &other) { mData.floor_divide_(other); return *this; } -JaggedTensor JaggedTensor::operator%(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator%(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData % other.mData); } -JaggedTensor JaggedTensor::operator%(const int other) const { +JaggedTensor +JaggedTensor::operator%(const int other) const { return jagged_like(mData % other); } -JaggedTensor JaggedTensor::operator%(const float other) const { +JaggedTensor +JaggedTensor::operator%(const float other) const { return jagged_like(mData % other); } -JaggedTensor JaggedTensor::operator%(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator%(const torch::Tensor &other) const { return jagged_like(mData % other); } -JaggedTensor& JaggedTensor::operator%=(const JaggedTensor& other) { +JaggedTensor & +JaggedTensor::operator%=(const JaggedTensor &other) { binary_op_check(other); mData = mData % other.mData; return *this; } -JaggedTensor& JaggedTensor::operator%=(const int other) { +JaggedTensor & +JaggedTensor::operator%=(const int other) { mData = mData % other; return *this; } -JaggedTensor& JaggedTensor::operator%=(const float other) { +JaggedTensor & +JaggedTensor::operator%=(const float other) { mData = mData % other; return *this; } -JaggedTensor& JaggedTensor::operator%=(const torch::Tensor& other) { +JaggedTensor & +JaggedTensor::operator%=(const torch::Tensor &other) { mData = mData % other; return *this; } -JaggedTensor JaggedTensor::pow(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::pow(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData.pow(other.mData)); } -JaggedTensor JaggedTensor::pow(const int other) const { +JaggedTensor +JaggedTensor::pow(const int other) const { return jagged_like(mData.pow(other)); } -JaggedTensor JaggedTensor::pow(const float other) const { +JaggedTensor +JaggedTensor::pow(const float other) const { return jagged_like(mData.pow(other)); } -JaggedTensor JaggedTensor::pow(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::pow(const torch::Tensor &other) const { return jagged_like(mData.pow(other)); } -JaggedTensor& JaggedTensor::poweq(const JaggedTensor& other) { +JaggedTensor & +JaggedTensor::poweq(const JaggedTensor &other) { binary_op_check(other); mData.pow_(other.mData); return *this; } -JaggedTensor& JaggedTensor::poweq(const int other) { +JaggedTensor & +JaggedTensor::poweq(const int other) { mData = mData.pow(other); return *this; } -JaggedTensor& JaggedTensor::poweq(const float other) { +JaggedTensor & +JaggedTensor::poweq(const float other) { mData = mData.pow(other); return *this; } -JaggedTensor& JaggedTensor::poweq(const torch::Tensor& other) { +JaggedTensor & +JaggedTensor::poweq(const torch::Tensor &other) { mData.pow_(other); return *this; } - -JaggedTensor JaggedTensor::operator>(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator>(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData > other.mData); } -JaggedTensor JaggedTensor::operator>(const int other) const { +JaggedTensor +JaggedTensor::operator>(const int other) const { return jagged_like(mData > other); } -JaggedTensor JaggedTensor::operator>(const float other) const { +JaggedTensor +JaggedTensor::operator>(const float other) const { return jagged_like(mData > other); } -JaggedTensor JaggedTensor::operator>(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator>(const torch::Tensor &other) const { return jagged_like(mData > other); } -JaggedTensor JaggedTensor::operator>=(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator>=(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData >= other.mData); } -JaggedTensor JaggedTensor::operator>=(const int other) const { +JaggedTensor +JaggedTensor::operator>=(const int other) const { return jagged_like(mData >= other); } -JaggedTensor JaggedTensor::operator>=(const float other) const { +JaggedTensor +JaggedTensor::operator>=(const float other) const { return jagged_like(mData >= other); } -JaggedTensor JaggedTensor::operator>=(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator>=(const torch::Tensor &other) const { return jagged_like(mData >= other); } -JaggedTensor JaggedTensor::operator<(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator<(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData < other.mData); } -JaggedTensor JaggedTensor::operator<(const int other) const { +JaggedTensor +JaggedTensor::operator<(const int other) const { return jagged_like(mData < other); } -JaggedTensor JaggedTensor::operator<(const float other) const { +JaggedTensor +JaggedTensor::operator<(const float other) const { return jagged_like(mData < other); } -JaggedTensor JaggedTensor::operator<(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator<(const torch::Tensor &other) const { return jagged_like(mData < other); } -JaggedTensor JaggedTensor::operator<=(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator<=(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData <= other.mData); } -JaggedTensor JaggedTensor::operator<=(const int other) const { +JaggedTensor +JaggedTensor::operator<=(const int other) const { return jagged_like(mData <= other); } -JaggedTensor JaggedTensor::operator<=(const float other) const { +JaggedTensor +JaggedTensor::operator<=(const float other) const { return jagged_like(mData <= other); } -JaggedTensor JaggedTensor::operator<=(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator<=(const torch::Tensor &other) const { return jagged_like(mData <= other); } -JaggedTensor JaggedTensor::operator==(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator==(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData == other.mData); } -JaggedTensor JaggedTensor::operator==(const int other) const { +JaggedTensor +JaggedTensor::operator==(const int other) const { return jagged_like(mData == other); } -JaggedTensor JaggedTensor::operator==(const float other) const { +JaggedTensor +JaggedTensor::operator==(const float other) const { return jagged_like(mData == other); } -JaggedTensor JaggedTensor::operator==(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator==(const torch::Tensor &other) const { return jagged_like(mData == other); } -JaggedTensor JaggedTensor::operator!=(const JaggedTensor& other) const { +JaggedTensor +JaggedTensor::operator!=(const JaggedTensor &other) const { binary_op_check(other); return jagged_like(mData != other.mData); } -JaggedTensor JaggedTensor::operator!=(const int other) const { +JaggedTensor +JaggedTensor::operator!=(const int other) const { return jagged_like(mData != other); } -JaggedTensor JaggedTensor::operator!=(const float other) const { +JaggedTensor +JaggedTensor::operator!=(const float other) const { return jagged_like(mData != other); } -JaggedTensor JaggedTensor::operator!=(const torch::Tensor& other) const { +JaggedTensor +JaggedTensor::operator!=(const torch::Tensor &other) const { return jagged_like(mData != other); } diff --git a/fvdb/src/JaggedTensor.h b/fvdb/src/JaggedTensor.h index eb94ed875d..6875671d8c 100644 --- a/fvdb/src/JaggedTensor.h +++ b/fvdb/src/JaggedTensor.h @@ -1,125 +1,136 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_JAGGEDTENSOR_H +#define FVDB_JAGGEDTENSOR_H +#include "detail/utils/Utils.h" + +#include #include #include -#include - -#include "detail/utils/Utils.h" namespace fvdb { struct JaggedTensorIndex; -using JIdxType = int32_t; +using JIdxType = int32_t; using JOffsetsType = int64_t; -using JLIdxType = int32_t; +using JLIdxType = int32_t; -constexpr c10::ScalarType JIdxScalarType = c10::CppTypeToScalarType::value; +constexpr c10::ScalarType JIdxScalarType = c10::CppTypeToScalarType::value; constexpr c10::ScalarType JOffsetsScalarType = c10::CppTypeToScalarType::value; -constexpr c10::ScalarType JLIdxScalarType = c10::CppTypeToScalarType::value; +constexpr c10::ScalarType JLIdxScalarType = c10::CppTypeToScalarType::value; -template -class JaggedAccessor { - torch::TensorAccessor mBatchIdx; +template class JaggedAccessor { + torch::TensorAccessor mBatchIdx; torch::TensorAccessor mOffsets; - torch::TensorAccessor mListIndexes; - torch::TensorAccessor mData; + torch::TensorAccessor mListIndexes; + torch::TensorAccessor mData; friend class JaggedTensor; - JaggedAccessor(torch::TensorAccessor batchIdx, + JaggedAccessor(torch::TensorAccessor batchIdx, torch::TensorAccessor offsets, - torch::TensorAccessor listIndexes, - torch::TensorAccessor data) + torch::TensorAccessor listIndexes, + torch::TensorAccessor data) : mBatchIdx(batchIdx), mOffsets(offsets), mListIndexes(listIndexes), mData(data) {} -public: - template - using TensorAccessorType = torch::TensorAccessor; + public: + template using TensorAccessorType = torch::TensorAccessor; - inline __hostdev__ int64_t elementCount() const { + inline __hostdev__ int64_t + elementCount() const { return mData.size(0); } - inline __hostdev__ JIdxType batchIdx(int64_t idx) const { + inline __hostdev__ JIdxType + batchIdx(int64_t idx) const { return mBatchIdx.size(0) > 0 ? mBatchIdx[idx] : 0; } - inline __hostdev__ JOffsetsType offsetStart(int64_t idx) const { + inline __hostdev__ JOffsetsType + offsetStart(int64_t idx) const { return mOffsets[idx]; } - inline __hostdev__ JOffsetsType offsetEnd(int64_t idx) const { - return mOffsets[idx+1]; + inline __hostdev__ JOffsetsType + offsetEnd(int64_t idx) const { + return mOffsets[idx + 1]; } - inline __hostdev__ const torch::TensorAccessor& data() const { + inline __hostdev__ const torch::TensorAccessor & + data() const { return mData; } }; - -template typename PtrTraits = torch::DefaultPtrTraits> +template typename PtrTraits = torch::DefaultPtrTraits> class PackedJaggedAccessor32 { - torch::PackedTensorAccessor32 mBatchIdx; + torch::PackedTensorAccessor32 mBatchIdx; torch::PackedTensorAccessor32 mOffsets; - torch::PackedTensorAccessor32 mListIndexes; - torch::PackedTensorAccessor32 mData; + torch::PackedTensorAccessor32 mListIndexes; + torch::PackedTensorAccessor32 mData; friend class JaggedTensor; - PackedJaggedAccessor32(torch::PackedTensorAccessor32 batchIdx, + PackedJaggedAccessor32(torch::PackedTensorAccessor32 batchIdx, torch::PackedTensorAccessor32 offsets, - torch::PackedTensorAccessor32 listIndexes, - torch::PackedTensorAccessor32 data) + torch::PackedTensorAccessor32 listIndexes, + torch::PackedTensorAccessor32 data) : mBatchIdx(batchIdx), mOffsets(offsets), mListIndexes(listIndexes), mData(data) {} -public: - + public: template using TensorAccessorType = torch::TensorAccessor; - inline __hostdev__ int64_t elementCount() const { + inline __hostdev__ int64_t + elementCount() const { return mData.size(0); } - inline __hostdev__ JIdxType batchIdx(int64_t idx) const { + inline __hostdev__ JIdxType + batchIdx(int64_t idx) const { return mBatchIdx.size(0) > 0 ? mBatchIdx[idx] : 0; } - inline __hostdev__ JOffsetsType offsetStart(int64_t idx) const { + inline __hostdev__ JOffsetsType + offsetStart(int64_t idx) const { return mOffsets[idx]; } - inline __hostdev__ JOffsetsType offsetEnd(int64_t idx) const { - return mOffsets[idx+1]; + inline __hostdev__ JOffsetsType + offsetEnd(int64_t idx) const { + return mOffsets[idx + 1]; } - inline __hostdev__ const torch::PackedTensorAccessor32& data() const { + inline __hostdev__ const torch::PackedTensorAccessor32 & + data() const { return mData; } }; - class JaggedTensor : public torch::CustomClassHolder { - torch::Tensor mData; // Actual data indexed by a jagged tensor - torch::Tensor mBatchIdx; // Which (linear) batch is each datum in - torch::Tensor mOffsets; // Offset of each tensor in the list of lists - torch::Tensor mListIdx; // LoL indexing of tensor with shape [num_tensors, ldim] - int64_t mNumOuterLists; // Number of outer lists in this JaggedTensor + torch::Tensor mData; // Actual data indexed by a jagged tensor + torch::Tensor mBatchIdx; // Which (linear) batch is each datum in + torch::Tensor mOffsets; // Offset of each tensor in the list of lists + torch::Tensor mListIdx; // LoL indexing of tensor with shape [num_tensors, ldim] + int64_t mNumOuterLists; // Number of outer lists in this JaggedTensor // Store the number of elements in each tensor in the jagged tensor // Computing this requires a GPU -> CPU copy so we cache it struct { - std::vector mLShape1; - std::vector> mLShape2; + std::vector mLShape1; + std::vector> mLShape2; std::vector>> mLShape3; - bool mDirty = true; - void markDirty() { mDirty = true; } - void clear() { + bool mDirty = true; + void + markDirty() { + mDirty = true; + } + void + clear() { mLShape1.clear(); mLShape2.clear(); mLShape3.clear(); @@ -129,25 +140,30 @@ class JaggedTensor : public torch::CustomClassHolder { void recompute_lsizes_if_dirty(); + void binary_op_check(const JaggedTensor &other) const; - void binary_op_check(const JaggedTensor& other) const; - -public: - static torch::Tensor joffsets_from_jidx_and_jdata(torch::Tensor jidx, torch::Tensor jdata, int64_t num_tensors); + public: + static torch::Tensor joffsets_from_jidx_and_jdata(torch::Tensor jidx, torch::Tensor jdata, + int64_t num_tensors); static torch::Tensor jidx_from_joffsets(torch::Tensor joffsets, int64_t num_elements); - static JaggedTensor from_jdata_joffsets_jidx_and_lidx_unsafe(torch::Tensor jdata, torch::Tensor joffsets, - torch::Tensor jidx, torch::Tensor jlidx, - int64_t numOuterLists); + static JaggedTensor from_jdata_joffsets_jidx_and_lidx_unsafe(torch::Tensor jdata, + torch::Tensor joffsets, + torch::Tensor jidx, + torch::Tensor jlidx, + int64_t numOuterLists); - static JaggedTensor from_data_indices_and_list_ids(torch::Tensor data, torch::Tensor indices, torch::Tensor list_ids, int64_t num_tensors); - static JaggedTensor from_data_offsets_and_list_ids(torch::Tensor data, torch::Tensor offsets, torch::Tensor list_ids); + static JaggedTensor from_data_indices_and_list_ids(torch::Tensor data, torch::Tensor indices, + torch::Tensor list_ids, int64_t num_tensors); + static JaggedTensor from_data_offsets_and_list_ids(torch::Tensor data, torch::Tensor offsets, + torch::Tensor list_ids); /// @brief Concatenate the list of JaggedTensors along a given dimension. /// There are two modes for this function. /// 1. If dim is an integer: /// e.g. if [jt_a, jt_b] are two JaggedTensors of the form - /// jt_a = [[a_11, a_12], [a_21], [a_31, a_32]] and jt_b = [[b_11, b_12], [b_21], [b_31, b_32]], - /// then JaggedTensor::jcat({jt_a, jt_b}) will return a JaggedTensor of the form + /// jt_a = [[a_11, a_12], [a_21], [a_31, a_32]] and jt_b = [[b_11, b_12], [b_21], + /// [b_31, b_32]], then JaggedTensor::jcat({jt_a, jt_b}) will return a JaggedTensor + /// of the form /// [[torch.cat([a_11, b_11], dim=dim), torch.cat([a_12, b_12], dim=dim)], /// [torch.cat([a_21, b_21], dim=dim)], /// [torch.cat([a_31, b_31], dim=dim), torch.cat([a_32, b_32], dim=dim)]] @@ -157,176 +173,229 @@ class JaggedTensor : public torch::CustomClassHolder { /// then JaggedTensor::jcat({jt_a, jt_b}) will return a JaggedTensor of the form /// [[a_11, a_12], [a_21], [a_31, a_32], [b_11], [b_21, b_22]] /// @param vec A vector of JaggedTensors to concatenate - /// @param dim The dimension along which to concatenate each JaggedTensor or c10::nullopt to concatenate + /// @param dim The dimension along which to concatenate each JaggedTensor or c10::nullopt to + /// concatenate /// the JaggedTensors as lists /// @return A JaggedTensor containing the concatenated data - static JaggedTensor jcat(const std::vector& vec, c10::optional dim); + static JaggedTensor jcat(const std::vector &vec, c10::optional dim); /// @brief Create an empty JaggedTensor JaggedTensor() { - mData = torch::Tensor(); - mBatchIdx = torch::empty({0}, torch::TensorOptions().dtype(JIdxScalarType)); - mOffsets = torch::zeros({1}, torch::TensorOptions().dtype(JOffsetsScalarType)); - mListIdx = torch::empty({0, 1}, torch::TensorOptions().dtype(JLIdxScalarType)); - mNumOuterLists = 0; + mData = torch::Tensor(); + mBatchIdx = torch::empty({ 0 }, torch::TensorOptions().dtype(JIdxScalarType)); + mOffsets = torch::zeros({ 1 }, torch::TensorOptions().dtype(JOffsetsScalarType)); + mListIdx = torch::empty({ 0, 1 }, torch::TensorOptions().dtype(JLIdxScalarType)); + mNumOuterLists = 0; } - /// @brief Create a JaggedTensor representing a list with a single tensor. Note this function does not copy the + /// @brief Create a JaggedTensor representing a list with a single tensor. Note this function + /// does not copy the /// data tensor, it only creates a view of it. /// @param data The data tensor JaggedTensor(torch::Tensor data); /// @brief Create a JaggedTensor representing a list of tensors. /// @param tensors A list of tensors - JaggedTensor(const std::vector& tensors); + JaggedTensor(const std::vector &tensors); /// @brief Create a JaggedTensor representing a list of lists of tensors. /// @param tensors A list of lists of tensors - JaggedTensor(const std::vector>& tensors); - - /// @brief Create a JaggedTensor representing a list of tensors where the number of elements in each tensor is given - /// by the lsizes vector. i.e. if lsizes = [2, 1, 2], then the first tensor will have 2 elements, the second - /// tensor will have 1 element, and the third tensor will have 2 elements. The raw data tensor must then have - /// a number of elements equal to the sum of the elements in lsizes (i.e. shape [sum(lsizes), ...]) + JaggedTensor(const std::vector> &tensors); + + /// @brief Create a JaggedTensor representing a list of tensors where the number of elements in + /// each tensor is given + /// by the lsizes vector. i.e. if lsizes = [2, 1, 2], then the first tensor will have 2 + /// elements, the second tensor will have 1 element, and the third tensor will have 2 + /// elements. The raw data tensor must then have a number of elements equal to the sum of + /// the elements in lsizes (i.e. shape [sum(lsizes), ...]) /// @param lsizes A vector of integers indicating the number of elements in each tensor /// @param data The raw data tensor - JaggedTensor(const std::vector& lsizes, const torch::Tensor data); - - /// @brief Create a JaggedTensor representing a list of lists of tensors where the number of elements in each tensor - /// is given by the lsizes vector. i.e. if lsizes = [[2, 1], [5, 6, 7]], then the first list will have 2 tensors with 1 and 2 elements - /// respectively and the second list will have 3 tensors with 5, 6, and 7 elements respectively. - /// The raw data tensor must then have a number of elements equal to the sum of the elements in lsizes (i.e. shape [sum(lsizes), ...]) - /// @param lsizes A vector of vectors of integers indicating the number of elements in each tensor + JaggedTensor(const std::vector &lsizes, const torch::Tensor data); + + /// @brief Create a JaggedTensor representing a list of lists of tensors where the number of + /// elements in each tensor + /// is given by the lsizes vector. i.e. if lsizes = [[2, 1], [5, 6, 7]], then the first + /// list will have 2 tensors with 1 and 2 elements respectively and the second list will + /// have 3 tensors with 5, 6, and 7 elements respectively. The raw data tensor must then + /// have a number of elements equal to the sum of the elements in lsizes (i.e. shape + /// [sum(lsizes), ...]) + /// @param lsizes A vector of vectors of integers indicating the number of elements in each + /// tensor /// @param total_tensors The total number of tensors in the list of lists /// @param data The raw data tensor - JaggedTensor(const std::vector>& lsizes, const int64_t total_tensors, const torch::Tensor data); + JaggedTensor(const std::vector> &lsizes, const int64_t total_tensors, + const torch::Tensor data); - /// @brief Create a JaggedTensor with the same list structure as this one but with the given raw data. - /// The returned JaggedTensor will share the same memory for indices/list ids/offsets as this one - /// those are modified. + /// @brief Create a JaggedTensor with the same list structure as this one but with the given raw + /// data. + /// The returned JaggedTensor will share the same memory for indices/list ids/offsets as + /// this one those are modified. /// @param data A tensor with the same number of elements as the original data /// @return A JaggedTensor with the same list structure as this one but with the given data JaggedTensor jagged_like(torch::Tensor data) const; /// @brief Set the raw data of this JaggedTensor to the given tensor /// @param data A data tensor with the same number of elements as the original data - void set_data(const torch::Tensor& data); + void set_data(const torch::Tensor &data); /// @brief Get the raw data indexed by this JaggedTensor /// @return The raw data tensor - const torch::Tensor& jdata() const { return mData; } + const torch::Tensor & + jdata() const { + return mData; + } - /// @brief Get the indices of this jagged tensor. i.e. a tensor of size (num_elements,) indicating which + /// @brief Get the indices of this jagged tensor. i.e. a tensor of size (num_elements,) + /// indicating which /// tensor each element belongs to /// @return The indices of this JaggedTensor - const torch::Tensor& jidx() const { return mBatchIdx; } + const torch::Tensor & + jidx() const { + return mBatchIdx; + } - /// @brief Get the offsets of each tensor indexed by this JaggedTensor. i.e. a tensor of size (num_tensors + 1) - /// where joffsets[i] is the start offset in jdata and joffsets[i+1] is the end offset in jdata + /// @brief Get the offsets of each tensor indexed by this JaggedTensor. i.e. a tensor of size + /// (num_tensors + 1) + /// where joffsets[i] is the start offset in jdata and joffsets[i+1] is the end offset in + /// jdata /// @return The offsets of each tensor indexed by this JaggedTensor - const torch::Tensor& joffsets() const { return mOffsets; } + const torch::Tensor & + joffsets() const { + return mOffsets; + } - /// @brief Get the list indices of each tensor indexed by this JaggedTensor. i.e. a tensor of size (num_tensors, ldim) - /// where e.g. jlidx[i][j] is the index of the j-th list in the i-th tensor (for a list of lists JaggedTensor) + /// @brief Get the list indices of each tensor indexed by this JaggedTensor. i.e. a tensor of + /// size (num_tensors, ldim) + /// where e.g. jlidx[i][j] is the index of the j-th list in the i-th tensor (for a list + /// of lists JaggedTensor) /// @return The list indices of each tensor indexed by this JaggedTensor - const torch::Tensor& jlidx() const { return mListIdx; } + const torch::Tensor & + jlidx() const { + return mListIdx; + } /// @brief Get the number of outer lists in this JaggedTensor - int64_t num_outer_lists() const { return mNumOuterLists; } + int64_t + num_outer_lists() const { + return mNumOuterLists; + } /// @brief Get the number of tensors in this JaggedTensor - int64_t num_tensors() const { return mOffsets.size(0) - 1; } + int64_t + num_tensors() const { + return mOffsets.size(0) - 1; + } - /// @brief Get the number of elements in each tensor indexed by this JaggedTensor. Assumes the JaggedTensor has ldim() == 1 + /// @brief Get the number of elements in each tensor indexed by this JaggedTensor. Assumes the + /// JaggedTensor has ldim() == 1 /// i.e. it represents a list of tensors /// @return The number of elements in each tensor indexed by this JaggedTensor std::vector lsizes1() const; - /// @brief Get the number of elements in each tensor indexed by this JaggedTensor. Assumes JaggedTensor has ldim() == 2 + /// @brief Get the number of elements in each tensor indexed by this JaggedTensor. Assumes + /// JaggedTensor has ldim() == 2 /// i.e. it represents a list of lists of tensors - /// @return The number of elements in each tensor indexed by this JaggedTensor such that lsizes2()[i][j] is the number of elements + /// @return The number of elements in each tensor indexed by this JaggedTensor such that + /// lsizes2()[i][j] is the number of elements /// in the j-th tensor in i-th list std::vector> lsizes2() const; - /// @brief Get the number of nested lists encoded by this JaggedTensor. An ldim of one means this JaggedTensor encodes a list - // of tensors, an ldim of 2 means this JaggedTensor encodes a list of lists of tensors, etc. + /// @brief Get the number of nested lists encoded by this JaggedTensor. An ldim of one means + /// this JaggedTensor encodes a list + // of tensors, an ldim of 2 means this JaggedTensor encodes a list of lists of tensors, + // etc. /// @return The number of nested lists encoded by this JaggedTensor int64_t ldim() const; - /// @brief Get the size of each element indexed by this JaggedTensor. i.e. if the JaggedTensor represents a list of tensors + /// @brief Get the size of each element indexed by this JaggedTensor. i.e. if the JaggedTensor + /// represents a list of tensors /// where each tensor has shape [N, A, B, C], then esizes() will return [A, B, C] /// @return The size of each element indexed by this JaggedTensor std::vector esizes() const; - /// @brief Get the number of dimensions of each element indexed by this JaggedTensor. i.e. if the JaggedTensor represents a list of tensors + /// @brief Get the number of dimensions of each element indexed by this JaggedTensor. i.e. if + /// the JaggedTensor represents a list of tensors /// where each tensor has shape [N, A, B, C], then edim() will return 3 /// @return The number of dimensions of each element indexed by this JaggedTensor int64_t edim() const; - /// @brief Convert the JaggedTensor to a list of tensors assuming this JaggedTensor represents a list of tensors. + /// @brief Convert the JaggedTensor to a list of tensors assuming this JaggedTensor represents a + /// list of tensors. /// Note this function doesn't work for nested lists of tensors (instead use unbind2()) /// @return A list of tensors where each tensor is indexed by this JaggedTensor. std::vector unbind1() const; - /// @brief Convert the JaggedTensor to a list of lists of tensors assuming this JaggedTensor represents a list of lists of tensors. + /// @brief Convert the JaggedTensor to a list of lists of tensors assuming this JaggedTensor + /// represents a list of lists of tensors. /// Note this function doesn't work for a flat list of tensors (instead use unbind1()) /// @return A list of lists of tensors where each tensor is indexed by this JaggedTensor. std::vector> unbind2() const; - /// @brief Index this JaggedTensor along the outer list dimension. There are several ways to index a JaggedTensor jt: - /// 1. Indexing with an integer jt[i] will return the i^th list in this tensor (or a list containing the i^th + /// @brief Index this JaggedTensor along the outer list dimension. There are several ways to + /// index a JaggedTensor jt: + /// 1. Indexing with an integer jt[i] will return the i^th list in this tensor (or a list + /// containing the i^th /// tensor if jt represents a list of tensors) - /// 2. Indexing with a slice jt[2:5] will return a JaggedTensor containing the 2nd, 3rd, and 4th lists in this tensor + /// 2. Indexing with a slice jt[2:5] will return a JaggedTensor containing the 2nd, 3rd, + /// and 4th lists in this tensor /// Note: We currently only support cotiguous slices (i.e. stride = 1) /// 3. Indexing with another JaggedTensor of boolean mask values jt[mask] /// will return a JaggedTensor containing tensors masked by the boolean mask /// i.e. jt[mask][i][j].jdata = jt[i][j].jdata[mask[i][j].jdata] - /// 4. Indexing with a tensor of integer indices jt[indices] will return a JaggedTensor containing tensors - /// indexed by the integer indices. i.e. jt[indices][i][j].jdata = jt[i][j].jdata[indices[i][j]] + /// 4. Indexing with a tensor of integer indices jt[indices] will return a JaggedTensor + /// containing tensors + /// indexed by the integer indices. i.e. jt[indices][i][j].jdata = + /// jt[i][j].jdata[indices[i][j]] /// 5. Indexing with ellipses jt[...] is a no-op /// @param idx The index to use to index this JaggedTensor /// @return A JaggedTensor containing the indexed data JaggedTensor index(JaggedTensorIndex idx) const; - /// @brief Reshape this JaggedTensor to have a new list structure. The provided lshape should be compatible with + /// @brief Reshape this JaggedTensor to have a new list structure. The provided lshape should be + /// compatible with /// this tensor. i.e. the sum of the elements in lshape should be equal to the number of /// elements in this JaggedTensor. - /// Note this function creates a view over the original JaggedTensor so modifying the returned JaggedTensor - /// will modify the original tensor. + /// Note this function creates a view over the original JaggedTensor so modifying the + /// returned JaggedTensor will modify the original tensor. /// @param lsizes The new list structure /// @return A JaggedTensor with the new list structure - JaggedTensor jreshape(const std::vector& lsizes) const; - JaggedTensor jreshape(const std::vector>& lsizes) const; + JaggedTensor jreshape(const std::vector &lsizes) const; + JaggedTensor jreshape(const std::vector> &lsizes) const; /// @brief Reshape this JaggedTensor to have the same list structure as another JaggedTensor. - /// Note this function creates a view over the original JaggedTensor so modifying the returned JaggedTensor - /// will modify the original tensor. - /// @param other The JaggedTensor to reshape this JaggedTensor to have the same list structure as + /// Note this function creates a view over the original JaggedTensor so modifying the + /// returned JaggedTensor will modify the original tensor. + /// @param other The JaggedTensor to reshape this JaggedTensor to have the same list structure + /// as /// @return A JaggedTensor with the same list structure as the other JaggedTensor - JaggedTensor jreshape_as(const JaggedTensor& other) const; + JaggedTensor jreshape_as(const JaggedTensor &other) const; - /// Flatten one of the list dimensions of this JaggedTensor. i.e. if this JaggedTensor represents a list of lists of tensors - /// then jflatten(0) will flatten the outer list dimension and jflatten(1) will flatten the inner list dimension. - /// e.g. if this JaggedTensor represents a list of lists of tensors [[A, B], [C], [D, E]] then + /// Flatten one of the list dimensions of this JaggedTensor. i.e. if this JaggedTensor + /// represents a list of lists of tensors then jflatten(0) will flatten the outer list dimension + /// and jflatten(1) will flatten the inner list dimension. e.g. if this JaggedTensor represents + /// a list of lists of tensors [[A, B], [C], [D, E]] then /// - jflatten(0) will return a JaggedTensor [A, B, C, D, E] - /// - jflatten(1) will return a JaggedTensor [[torch.cat(A, B, dim=0)], [C], [torch.cat(D, E, dim=0)]] + /// - jflatten(1) will return a JaggedTensor [[torch.cat(A, B, dim=0)], [C], [torch.cat(D, + /// E, dim=0)]] /// e.g. if this JaggedTensor represents a list of tensors with shapes [A, B, C] then /// - jflatten(0) will return a JaggedTensor with shape [torch.cat(A, B, C, dim=0)] /// - jflatten(1) will raise an exception as there is no inner list dimension - /// Note this function creates a view over the original JaggedTensor so modifying the returned JaggedTensor - /// will modify the original tensor. + /// Note this function creates a view over the original JaggedTensor so modifying the returned + /// JaggedTensor will modify the original tensor. /// @param dim The dimension to flatten /// @return A JaggedTensor with the flattened list dimension JaggedTensor jflatten(const int64_t dim = 0) const; /// @brief Sorts each batch element in ascending order, note that jdata has to be 1-dimensional - /// @return An indexing tensor with the same size as jdata, that permutes the elements of data to be in sorted order + /// @return An indexing tensor with the same size as jdata, that permutes the elements of data + /// to be in sorted order // JaggedTensor jagged_argsort(); /// @brief Compute the summation of each batch element /// @param dim The dimension to sum over /// @param keepdim Whether to keep the summed dimension - /// @return A tensor of size (batch_size, *) containing the sum of each batch element, feature dimensions are preserved + /// @return A tensor of size (batch_size, *) containing the sum of each batch element, feature + /// dimensions are preserved JaggedTensor jsum(int64_t dim = 0, bool keepdim = false) const; /// @brief Compute the minimum of each batch element @@ -342,31 +411,43 @@ class JaggedTensor : public torch::CustomClassHolder { std::vector jmax(int64_t dim = 0, bool keepdim = false) const; // Operators on raw data - inline int64_t rsize(int64_t dim) const { return mData.size(dim); } - inline int64_t rdim() const { return mData.dim(); } - inline std::vector rsizes() const { return mData.sizes().vec(); } - JaggedTensor rmask(const torch::Tensor& mask) const; + inline int64_t + rsize(int64_t dim) const { + return mData.size(dim); + } + inline int64_t + rdim() const { + return mData.dim(); + } + inline std::vector + rsizes() const { + return mData.sizes().vec(); + } + JaggedTensor rmask(const torch::Tensor &mask) const; - /// @brief Get an accessor for the JaggedTensor. Useful for reading/writing values in the JaggedTensor + /// @brief Get an accessor for the JaggedTensor. Useful for reading/writing values in the + /// JaggedTensor /// @tparam Scalar The type of the data in the JaggedTensor /// @tparam NDims The number of dimensions of the data in the JaggedTensor (i.e. edim() + 1) /// @return An accessor for the JaggedTensor template - JaggedAccessor accessor() const { + JaggedAccessor + accessor() const { return JaggedAccessor( - mBatchIdx.accessor(), - mOffsets.accessor(), - mListIdx.accessor(), - mData.accessor()); + mBatchIdx.accessor(), mOffsets.accessor(), + mListIdx.accessor(), mData.accessor()); } - /// @brief Get a packed accessor for the JaggedTensor. Useful for reading/writing values in the JaggedTensor in Cuda + /// @brief Get a packed accessor for the JaggedTensor. Useful for reading/writing values in the + /// JaggedTensor in Cuda /// @tparam Scalar The type of the data in the JaggedTensor /// @tparam NDims The number of dimensions of the data in the JaggedTensor (i.e. edim() + 1) /// @tparam PtrTraits The type of the pointer traits for the packed accessor /// @return A packed accessor for the JaggedTensor - template typename PtrTraits = torch::DefaultPtrTraits> - PackedJaggedAccessor32 packed_accessor32() const { + template typename PtrTraits = torch::DefaultPtrTraits> + PackedJaggedAccessor32 + packed_accessor32() const { return PackedJaggedAccessor32( mBatchIdx.packed_accessor32(), mOffsets.packed_accessor32(), @@ -375,217 +456,255 @@ class JaggedTensor : public torch::CustomClassHolder { } /// @brief Raise an exception if the JaggedTensor is in an invalid state - inline void check_valid() const { - TORCH_CHECK((jidx().size(0) == 0 && joffsets().size(0) == 2) || (jidx().size(0) == jdata().size(0)), "tensor must be a valid JaggedTensor"); - TORCH_CHECK(jidx().device() == jdata().device(), "batch index and data must be on the same device"); + inline void + check_valid() const { + TORCH_CHECK((jidx().size(0) == 0 && joffsets().size(0) == 2) || + (jidx().size(0) == jdata().size(0)), + "tensor must be a valid JaggedTensor"); + TORCH_CHECK(jidx().device() == jdata().device(), + "batch index and data must be on the same device"); TORCH_CHECK(jidx().dtype() == JIdxScalarType, "batch index must be int"); - TORCH_CHECK(joffsets().device() == jdata().device(), "offsets and data must be on the same device"); - TORCH_CHECK_VALUE(jlidx().numel() == 0 || jlidx().size(0) == (joffsets().size(0) - 1), "Corrupt list indices. This should never happen"); + TORCH_CHECK(joffsets().device() == jdata().device(), + "offsets and data must be on the same device"); + TORCH_CHECK_VALUE(jlidx().numel() == 0 || jlidx().size(0) == (joffsets().size(0) - 1), + "Corrupt list indices. This should never happen"); } - inline int64_t element_count() const { + inline int64_t + element_count() const { return jdata().size(0); } - inline torch::Device device() const { + inline torch::Device + device() const { return mData.device(); } - caffe2::TypeMeta dtype() const { + caffe2::TypeMeta + dtype() const { return mData.dtype(); } - torch::Layout layout() const { + torch::Layout + layout() const { return mData.layout(); } - inline torch::ScalarType scalar_type() const { + inline torch::ScalarType + scalar_type() const { return mData.scalar_type(); } - inline bool is_cuda() const { + inline bool + is_cuda() const { return mData.is_cuda(); } - inline bool is_cpu() const { + inline bool + is_cpu() const { return mData.is_cpu(); } - int64_t get_device() const { + int64_t + get_device() const { return mData.get_device(); } - bool is_complex() const { + bool + is_complex() const { return at::isComplexType(this->scalar_type()); } - bool is_floating_point() const { + bool + is_floating_point() const { return at::isFloatingType(this->scalar_type()); } - bool is_signed() const { + bool + is_signed() const { return at::isSignedType(this->scalar_type()); } - int64_t numel() const { + int64_t + numel() const { return mData.numel(); } - inline bool is_contiguous() const { + inline bool + is_contiguous() const { return mData.is_contiguous(); } - inline JaggedTensor contiguous() const { - return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe(mData.contiguous(), mOffsets.contiguous(), mBatchIdx.contiguous(), mListIdx.contiguous(), mNumOuterLists); + inline JaggedTensor + contiguous() const { + return JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( + mData.contiguous(), mOffsets.contiguous(), mBatchIdx.contiguous(), + mListIdx.contiguous(), mNumOuterLists); } - inline JaggedTensor to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) const { + inline JaggedTensor + to(at::TensorOptions options = {}, bool non_blocking = false, bool copy = false, + c10::optional memory_format = c10::nullopt) const { JaggedTensor ret = *this; - ret.mData = ret.mData.to(options, non_blocking, copy, memory_format); - ret.mBatchIdx = ret.mBatchIdx.to(ret.mData.device(), non_blocking, copy, memory_format); - ret.mOffsets = ret.mOffsets.to(ret.mData.device(), non_blocking, copy, memory_format); - ret.mListIdx = ret.mListIdx.to(ret.mData.device(), non_blocking, copy, memory_format); + ret.mData = ret.mData.to(options, non_blocking, copy, memory_format); + ret.mBatchIdx = ret.mBatchIdx.to(ret.mData.device(), non_blocking, copy, memory_format); + ret.mOffsets = ret.mOffsets.to(ret.mData.device(), non_blocking, copy, memory_format); + ret.mListIdx = ret.mListIdx.to(ret.mData.device(), non_blocking, copy, memory_format); return ret; } - inline JaggedTensor to(c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, bool copy, c10::optional memory_format) { + inline JaggedTensor + to(c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory, bool non_blocking, + bool copy, c10::optional memory_format) { JaggedTensor ret = *this; - ret.mData = ret.mData.to(dtype, layout, device, pin_memory, non_blocking, copy, memory_format); - ret.mBatchIdx = ret.mBatchIdx.to(JIdxScalarType, layout, device, pin_memory, non_blocking, copy, memory_format); - ret.mOffsets = ret.mOffsets.to(JOffsetsScalarType, layout, device, pin_memory, non_blocking, copy, memory_format); - ret.mListIdx = ret.mListIdx.to(JLIdxScalarType, layout, device, pin_memory, non_blocking, copy, memory_format); + ret.mData = + ret.mData.to(dtype, layout, device, pin_memory, non_blocking, copy, memory_format); + ret.mBatchIdx = ret.mBatchIdx.to(JIdxScalarType, layout, device, pin_memory, non_blocking, + copy, memory_format); + ret.mOffsets = ret.mOffsets.to(JOffsetsScalarType, layout, device, pin_memory, non_blocking, + copy, memory_format); + ret.mListIdx = ret.mListIdx.to(JLIdxScalarType, layout, device, pin_memory, non_blocking, + copy, memory_format); return ret; } - inline JaggedTensor to(torch::Device device, torch::ScalarType dtype, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) { + inline JaggedTensor + to(torch::Device device, torch::ScalarType dtype, bool non_blocking = false, bool copy = false, + c10::optional memory_format = c10::nullopt) { JaggedTensor ret = *this; - ret.mData = ret.mData.to(device, dtype, non_blocking, copy, memory_format); - ret.mBatchIdx = ret.mBatchIdx.to(device, non_blocking, copy, memory_format); - ret.mOffsets = ret.mOffsets.to(device, non_blocking, copy, memory_format); - ret.mListIdx = ret.mListIdx.to(device, non_blocking, copy, memory_format); + ret.mData = ret.mData.to(device, dtype, non_blocking, copy, memory_format); + ret.mBatchIdx = ret.mBatchIdx.to(device, non_blocking, copy, memory_format); + ret.mOffsets = ret.mOffsets.to(device, non_blocking, copy, memory_format); + ret.mListIdx = ret.mListIdx.to(device, non_blocking, copy, memory_format); return ret; } - inline JaggedTensor to(torch::ScalarType dtype, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) { + inline JaggedTensor + to(torch::ScalarType dtype, bool non_blocking = false, bool copy = false, + c10::optional memory_format = c10::nullopt) { JaggedTensor ret = *this; - ret.mData = ret.mData.to(dtype, non_blocking, copy, memory_format); - ret.mBatchIdx = ret.mBatchIdx.to(JIdxScalarType, non_blocking, copy, memory_format); - ret.mOffsets = ret.mOffsets.to(JOffsetsScalarType, non_blocking, copy, memory_format); - ret.mListIdx = ret.mListIdx.to(JLIdxScalarType, non_blocking, copy, memory_format); + ret.mData = ret.mData.to(dtype, non_blocking, copy, memory_format); + ret.mBatchIdx = ret.mBatchIdx.to(JIdxScalarType, non_blocking, copy, memory_format); + ret.mOffsets = ret.mOffsets.to(JOffsetsScalarType, non_blocking, copy, memory_format); + ret.mListIdx = ret.mListIdx.to(JLIdxScalarType, non_blocking, copy, memory_format); return ret; } - torch::TensorOptions options() const { + torch::TensorOptions + options() const { return torch::TensorOptions().dtype(dtype()).device(device()).layout(layout()); } - JaggedTensor cuda() const { + JaggedTensor + cuda() const { return to(this->options().device(torch::kCUDA), /*non_blocking*/ false, /*copy*/ false); } - JaggedTensor cpu() const { + JaggedTensor + cpu() const { return to(this->options().device(torch::kCPU), /*non_blocking*/ false, /*copy*/ false); } - JaggedTensor operator+(const JaggedTensor& other) const; + JaggedTensor operator+(const JaggedTensor &other) const; JaggedTensor operator+(const int other) const; JaggedTensor operator+(const float other) const; - JaggedTensor operator+(const torch::Tensor& other) const; + JaggedTensor operator+(const torch::Tensor &other) const; - JaggedTensor& operator+=(const JaggedTensor& other); - JaggedTensor& operator+=(const int other); - JaggedTensor& operator+=(const float other); - JaggedTensor& operator+=(const torch::Tensor& other); + JaggedTensor &operator+=(const JaggedTensor &other); + JaggedTensor &operator+=(const int other); + JaggedTensor &operator+=(const float other); + JaggedTensor &operator+=(const torch::Tensor &other); - JaggedTensor operator-(const JaggedTensor& other) const; + JaggedTensor operator-(const JaggedTensor &other) const; JaggedTensor operator-(const int other) const; JaggedTensor operator-(const float other) const; - JaggedTensor operator-(const torch::Tensor& other) const; + JaggedTensor operator-(const torch::Tensor &other) const; JaggedTensor operator-() const; - JaggedTensor& operator-=(const JaggedTensor& other); - JaggedTensor& operator-=(const int other); - JaggedTensor& operator-=(const float other); - JaggedTensor& operator-=(const torch::Tensor& other); + JaggedTensor &operator-=(const JaggedTensor &other); + JaggedTensor &operator-=(const int other); + JaggedTensor &operator-=(const float other); + JaggedTensor &operator-=(const torch::Tensor &other); - JaggedTensor operator*(const JaggedTensor& other) const; + JaggedTensor operator*(const JaggedTensor &other) const; JaggedTensor operator*(const int other) const; JaggedTensor operator*(const float other) const; - JaggedTensor operator*(const torch::Tensor& other) const; + JaggedTensor operator*(const torch::Tensor &other) const; - JaggedTensor& operator*=(const JaggedTensor& other); - JaggedTensor& operator*=(const int other); - JaggedTensor& operator*=(const float other); - JaggedTensor& operator*=(const torch::Tensor& other); + JaggedTensor &operator*=(const JaggedTensor &other); + JaggedTensor &operator*=(const int other); + JaggedTensor &operator*=(const float other); + JaggedTensor &operator*=(const torch::Tensor &other); - JaggedTensor operator/(const JaggedTensor& other) const; + JaggedTensor operator/(const JaggedTensor &other) const; JaggedTensor operator/(const int other) const; JaggedTensor operator/(const float other) const; - JaggedTensor operator/(const torch::Tensor& other) const; + JaggedTensor operator/(const torch::Tensor &other) const; - JaggedTensor& operator/=(const JaggedTensor& other); - JaggedTensor& operator/=(const int other); - JaggedTensor& operator/=(const float other); - JaggedTensor& operator/=(const torch::Tensor& other); + JaggedTensor &operator/=(const JaggedTensor &other); + JaggedTensor &operator/=(const int other); + JaggedTensor &operator/=(const float other); + JaggedTensor &operator/=(const torch::Tensor &other); - JaggedTensor floordiv(const JaggedTensor& other) const; + JaggedTensor floordiv(const JaggedTensor &other) const; JaggedTensor floordiv(const int other) const; JaggedTensor floordiv(const float other) const; - JaggedTensor floordiv(const torch::Tensor& other) const; + JaggedTensor floordiv(const torch::Tensor &other) const; - JaggedTensor& floordiveq(const JaggedTensor& other); - JaggedTensor& floordiveq(const int other); - JaggedTensor& floordiveq(const float other); - JaggedTensor& floordiveq(const torch::Tensor& other); + JaggedTensor &floordiveq(const JaggedTensor &other); + JaggedTensor &floordiveq(const int other); + JaggedTensor &floordiveq(const float other); + JaggedTensor &floordiveq(const torch::Tensor &other); - JaggedTensor operator%(const JaggedTensor& other) const; + JaggedTensor operator%(const JaggedTensor &other) const; JaggedTensor operator%(const int other) const; JaggedTensor operator%(const float other) const; - JaggedTensor operator%(const torch::Tensor& other) const; + JaggedTensor operator%(const torch::Tensor &other) const; - JaggedTensor& operator%=(const JaggedTensor& other); - JaggedTensor& operator%=(const int other); - JaggedTensor& operator%=(const float other); - JaggedTensor& operator%=(const torch::Tensor& other); + JaggedTensor &operator%=(const JaggedTensor &other); + JaggedTensor &operator%=(const int other); + JaggedTensor &operator%=(const float other); + JaggedTensor &operator%=(const torch::Tensor &other); - JaggedTensor pow(const JaggedTensor& other) const; + JaggedTensor pow(const JaggedTensor &other) const; JaggedTensor pow(const int other) const; JaggedTensor pow(const float other) const; - JaggedTensor pow(const torch::Tensor& other) const; + JaggedTensor pow(const torch::Tensor &other) const; - JaggedTensor& poweq(const JaggedTensor& other); - JaggedTensor& poweq(const int other); - JaggedTensor& poweq(const float other); - JaggedTensor& poweq(const torch::Tensor& other); + JaggedTensor &poweq(const JaggedTensor &other); + JaggedTensor &poweq(const int other); + JaggedTensor &poweq(const float other); + JaggedTensor &poweq(const torch::Tensor &other); - JaggedTensor operator>(const JaggedTensor& other) const; + JaggedTensor operator>(const JaggedTensor &other) const; JaggedTensor operator>(const int other) const; JaggedTensor operator>(const float other) const; - JaggedTensor operator>(const torch::Tensor& other) const; + JaggedTensor operator>(const torch::Tensor &other) const; - JaggedTensor operator>=(const JaggedTensor& other) const; + JaggedTensor operator>=(const JaggedTensor &other) const; JaggedTensor operator>=(const int other) const; JaggedTensor operator>=(const float other) const; - JaggedTensor operator>=(const torch::Tensor& other) const; + JaggedTensor operator>=(const torch::Tensor &other) const; - JaggedTensor operator<(const JaggedTensor& other) const; + JaggedTensor operator<(const JaggedTensor &other) const; JaggedTensor operator<(const int other) const; JaggedTensor operator<(const float other) const; - JaggedTensor operator<(const torch::Tensor& other) const; + JaggedTensor operator<(const torch::Tensor &other) const; - JaggedTensor operator<=(const JaggedTensor& other) const; + JaggedTensor operator<=(const JaggedTensor &other) const; JaggedTensor operator<=(const int other) const; JaggedTensor operator<=(const float other) const; - JaggedTensor operator<=(const torch::Tensor& other) const; + JaggedTensor operator<=(const torch::Tensor &other) const; - JaggedTensor operator==(const JaggedTensor& other) const; + JaggedTensor operator==(const JaggedTensor &other) const; JaggedTensor operator==(const int other) const; JaggedTensor operator==(const float other) const; - JaggedTensor operator==(const torch::Tensor& other) const; + JaggedTensor operator==(const torch::Tensor &other) const; - JaggedTensor operator!=(const JaggedTensor& other) const; + JaggedTensor operator!=(const JaggedTensor &other) const; JaggedTensor operator!=(const int other) const; JaggedTensor operator!=(const float other) const; - JaggedTensor operator!=(const torch::Tensor& other) const; + JaggedTensor operator!=(const torch::Tensor &other) const; JaggedTensor sqrt() const; JaggedTensor abs() const; @@ -593,88 +712,111 @@ class JaggedTensor : public torch::CustomClassHolder { JaggedTensor floor() const; JaggedTensor ceil() const; - JaggedTensor& sqrt_(); - JaggedTensor& abs_(); - JaggedTensor& round_(int decimals = 0); - JaggedTensor& floor_(); - JaggedTensor& ceil_(); + JaggedTensor &sqrt_(); + JaggedTensor &abs_(); + JaggedTensor &round_(int decimals = 0); + JaggedTensor &floor_(); + JaggedTensor &ceil_(); - const JaggedTensor& set_requires_grad(bool requires_grad) const; - bool requires_grad() const; - JaggedTensor detach() const; - JaggedTensor clone() const; + const JaggedTensor &set_requires_grad(bool requires_grad) const; + bool requires_grad() const; + JaggedTensor detach() const; + JaggedTensor clone() const; }; - struct JaggedTensorIndex { JaggedTensorIndex(c10::nullopt_t) : mType(JaggedTensorIndexType::None) {} JaggedTensorIndex(int64_t integer) : mType(JaggedTensorIndexType::Integer), mInteger(integer) {} - JaggedTensorIndex(torch::indexing::EllipsisIndexType) : mType(JaggedTensorIndexType::Ellipsis) {} + JaggedTensorIndex(torch::indexing::EllipsisIndexType) + : mType(JaggedTensorIndexType::Ellipsis) {} JaggedTensorIndex(at::Tensor tensor) : mType(JaggedTensorIndexType::Tensor), mTensor(tensor) {} - JaggedTensorIndex(torch::indexing::Slice slice) : mType(JaggedTensorIndexType::Slice), mSlice(slice) {} - JaggedTensorIndex(fvdb::JaggedTensor jaggedTensor) : mType(JaggedTensorIndexType::JaggedTensor), mJaggedTensor(jaggedTensor) {} + JaggedTensorIndex(torch::indexing::Slice slice) + : mType(JaggedTensorIndexType::Slice), mSlice(slice) {} + JaggedTensorIndex(fvdb::JaggedTensor jaggedTensor) + : mType(JaggedTensorIndexType::JaggedTensor), mJaggedTensor(jaggedTensor) {} template ::value>::type> JaggedTensorIndex(T boolean) : mType(JaggedTensorIndexType::Boolean), mBoolean(boolean) {} - inline bool is_none() const { + inline bool + is_none() const { return mType == JaggedTensorIndexType::None; } - inline bool is_ellipsis() const { + inline bool + is_ellipsis() const { return mType == JaggedTensorIndexType::Ellipsis; } - inline bool is_integer() const { + inline bool + is_integer() const { return mType == JaggedTensorIndexType::Integer; } - inline bool is_boolean() const { + inline bool + is_boolean() const { return mType == JaggedTensorIndexType::Boolean; } - inline bool is_slice() const { + inline bool + is_slice() const { return mType == JaggedTensorIndexType::Slice; } - inline bool is_tensor() const { + inline bool + is_tensor() const { return mType == JaggedTensorIndexType::Tensor; } - inline bool is_jagged_tensor() const { + inline bool + is_jagged_tensor() const { return mType == JaggedTensorIndexType::JaggedTensor; } - inline int64_t integer() const { + inline int64_t + integer() const { return mInteger; } - inline bool boolean() const { + inline bool + boolean() const { return mBoolean; } - inline const torch::indexing::Slice& slice() const { + inline const torch::indexing::Slice & + slice() const { return mSlice; } - inline const torch::Tensor& tensor() const { + inline const torch::Tensor & + tensor() const { return mTensor; } - inline const fvdb::JaggedTensor& jagged_tensor() const { + inline const fvdb::JaggedTensor & + jagged_tensor() const { return mJaggedTensor; } -private: - enum class JaggedTensorIndexType { None, Ellipsis, Integer, Slice, Tensor, Boolean, JaggedTensor }; + private: + enum class JaggedTensorIndexType { + None, + Ellipsis, + Integer, + Slice, + Tensor, + Boolean, + JaggedTensor + }; JaggedTensorIndexType mType; - torch::Tensor mTensor; - int64_t mInteger; + torch::Tensor mTensor; + int64_t mInteger; torch::indexing::Slice mSlice; - bool mBoolean; - fvdb::JaggedTensor mJaggedTensor; + bool mBoolean; + fvdb::JaggedTensor mJaggedTensor; }; +} // namespace fvdb -} // namespace fvdb \ No newline at end of file +#endif // FVDB_JAGGEDTENSOR_H \ No newline at end of file diff --git a/fvdb/src/SparseConvPackInfo.cpp b/fvdb/src/SparseConvPackInfo.cpp index 4f8a46ed63..cb2f9fca94 100644 --- a/fvdb/src/SparseConvPackInfo.cpp +++ b/fvdb/src/SparseConvPackInfo.cpp @@ -3,17 +3,19 @@ // #include "SparseConvPackInfo.h" +#include "detail/autograd/Autograd.h" #include "detail/ops/Ops.h" #include "detail/ops/convolution/pack_info/PackInfoOps.h" -#include "detail/autograd/Autograd.h" - namespace fvdb { -SparseConvPackInfo::SparseConvPackInfo(Vec3iOrScalar kernelsize, Vec3iOrScalar stride, GridBatch srcGrid, - torch::optional maybeTarget) { - TORCH_CHECK(Vec3iOrScalar(0).value() < kernelsize.value(), "Expect kernel size to be larger than {0,0,0}, but got " + kernelsize.toString() + "."); - TORCH_CHECK(Vec3iOrScalar(0).value() < stride.value(), "Expect stride to be larger than 0, but got " + stride.toString() + "."); +SparseConvPackInfo::SparseConvPackInfo(Vec3iOrScalar kernelsize, Vec3iOrScalar stride, + GridBatch srcGrid, torch::optional maybeTarget) { + TORCH_CHECK(Vec3iOrScalar(0).value() < kernelsize.value(), + "Expect kernel size to be larger than {0,0,0}, but got " + kernelsize.toString() + + "."); + TORCH_CHECK(Vec3iOrScalar(0).value() < stride.value(), + "Expect stride to be larger than 0, but got " + stride.toString() + "."); GridBatch targetGrid; if (!maybeTarget.has_value()) { @@ -26,47 +28,56 @@ SparseConvPackInfo::SparseConvPackInfo(Vec3iOrScalar kernelsize, Vec3iOrScalar s targetGrid = maybeTarget.value(); } - TORCH_CHECK(srcGrid.is_mutable() == targetGrid.is_mutable(), "Source and target grids must both be mutable or immutable"); - TORCH_CHECK(srcGrid.device() == targetGrid.device(), "Source and target grids must both be on the same device"); - TORCH_CHECK(srcGrid.device() == targetGrid.device(), "Device should match between this grid and target grid."); - TORCH_CHECK(!(kernelsize.value() == Vec3iOrScalar(1).value() && stride.value() == Vec3iOrScalar(1).value()), "1x1 conv does not need kernel map to be built!"); - - mStride = stride; + TORCH_CHECK(srcGrid.is_mutable() == targetGrid.is_mutable(), + "Source and target grids must both be mutable or immutable"); + TORCH_CHECK(srcGrid.device() == targetGrid.device(), + "Source and target grids must both be on the same device"); + TORCH_CHECK(srcGrid.device() == targetGrid.device(), + "Device should match between this grid and target grid."); + TORCH_CHECK(!(kernelsize.value() == Vec3iOrScalar(1).value() && + stride.value() == Vec3iOrScalar(1).value()), + "1x1 conv does not need kernel map to be built!"); + + mStride = stride; mKernelSize = kernelsize; mTargetGrid = targetGrid; mSourceGrid = srcGrid; } -void SparseConvPackInfo::buildGatherScatter(bool use_me) { +void +SparseConvPackInfo::buildGatherScatter(bool use_me) { if (mGSNeighborMap.has_value() && mGSNeighborSizes.has_value()) { - TORCH_CHECK(mGSUseME == use_me, "Gather scatter is already built with different use_me value"); + TORCH_CHECK(mGSUseME == use_me, + "Gather scatter is already built with different use_me value"); return; } int kernelVolume = mKernelSize.value().x() * mKernelSize.value().y() * mKernelSize.value().z(); - torch::Tensor kmap = torch::full( - {mTargetGrid.total_voxels(), kernelVolume}, -1, - torch::TensorOptions().dtype(torch::kInt32).device(mTargetGrid.device())); + torch::Tensor kmap = + torch::full({ mTargetGrid.total_voxels(), kernelVolume }, -1, + torch::TensorOptions().dtype(torch::kInt32).device(mTargetGrid.device())); FVDB_DISPATCH_KERNEL_DEVICE(mSourceGrid.device(), [&]() { detail::ops::dispatchConvolutionKernelMap( *mSourceGrid.impl(), *mTargetGrid.impl(), kmap, mKernelSize, mStride); }); - kmap = kmap.t(); - torch::Tensor kmask = kmap != -1; + kmap = kmap.t(); + torch::Tensor kmask = kmap != -1; torch::Tensor nbsizes = torch::sum(kmask, -1); - torch::Tensor nbmap = torch::nonzero(kmask).contiguous(); + torch::Tensor nbmap = torch::nonzero(kmask).contiguous(); - torch::Tensor indices = nbmap.index({torch::indexing::Slice(), 0}) * kmap.size(1) + \ - nbmap.index({torch::indexing::Slice(), 1}); - nbmap.index_put_({torch::indexing::Slice(), 0}, kmap.reshape({-1}).index({indices})); - mGSNeighborMap = nbmap.to(torch::kInt32); + torch::Tensor indices = nbmap.index({ torch::indexing::Slice(), 0 }) * kmap.size(1) + + nbmap.index({ torch::indexing::Slice(), 1 }); + nbmap.index_put_({ torch::indexing::Slice(), 0 }, kmap.reshape({ -1 }).index({ indices })); + mGSNeighborMap = nbmap.to(torch::kInt32); mGSNeighborSizes = nbsizes.to(torch::kInt32); - mGSUseME = use_me; + mGSUseME = use_me; } -void SparseConvPackInfo::buildImplicitGEMM(bool sorted, int splitMaskNum, bool training, int splitMaskNumBwd, bool use_tf32) { +void +SparseConvPackInfo::buildImplicitGEMM(bool sorted, int splitMaskNum, bool training, + int splitMaskNumBwd, bool use_tf32) { if (mIGEMMOutInMap.has_value()) { if (mIGEMMReorderLoc.has_value()) { TORCH_CHECK(mIGEMMReorderLoc->size(0) == splitMaskNum, @@ -80,9 +91,9 @@ void SparseConvPackInfo::buildImplicitGEMM(bool sorted, int splitMaskNum, bool t int kernelVolume = mKernelSize.value().x() * mKernelSize.value().y() * mKernelSize.value().z(); int outInMapSize = (mTargetGrid.total_voxels() + 128 - 1) / 128 * 128; - mIGEMMOutInMap = torch::full( - {outInMapSize, kernelVolume}, -1, - torch::TensorOptions().dtype(torch::kInt32).device(mTargetGrid.device())); + mIGEMMOutInMap = + torch::full({ outInMapSize, kernelVolume }, -1, + torch::TensorOptions().dtype(torch::kInt32).device(mTargetGrid.device())); mIGEMMUseTF32 = use_tf32; // Note: This could also be converted from GSNeighbourMap if exists @@ -92,141 +103,146 @@ void SparseConvPackInfo::buildImplicitGEMM(bool sorted, int splitMaskNum, bool t }); if (sorted) { - TORCH_CHECK(mSourceGrid.device().is_cuda(), "Implicit GEMM with sorted kernel map is only supported on CUDA"); + TORCH_CHECK(mSourceGrid.device().is_cuda(), + "Implicit GEMM with sorted kernel map is only supported on CUDA"); torch::Tensor bitmask = detail::ops::dispatchBitmaskFromOutInMap( mIGEMMOutInMap.value(), splitMaskNum, mTargetGrid.total_voxels()); - auto ret = torch::sort(bitmask, -1L, true); - mIGEMMSortedMask = std::get<0>(ret); // Mainly used for transpose. - mIGEMMReorderLoc = std::get<1>(ret).to(torch::kInt32); + auto ret = torch::sort(bitmask, -1L, true); + mIGEMMSortedMask = std::get<0>(ret); // Mainly used for transpose. + mIGEMMReorderLoc = std::get<1>(ret).to(torch::kInt32); mIGEMMReoderOutInMap = detail::ops::dispatchReorderOutInMap( mIGEMMOutInMap.value(), mIGEMMReorderLoc.value()); - mIGEMMReducedSortedMask = detail::ops::dispatchReduceMask( - mIGEMMSortedMask.value(), 128); + mIGEMMReducedSortedMask = + detail::ops::dispatchReduceMask(mIGEMMSortedMask.value(), 128); } if (training) { int outInMapTSize = (mSourceGrid.total_voxels() + 128 - 1) / 128 * 128; - mIGEMMOutInMapBwd = torch::full( - {outInMapTSize, kernelVolume}, -1, - torch::TensorOptions().dtype(torch::kInt32).device(mSourceGrid.device())); - detail::ops::dispatchTransposeOutInMap( - mIGEMMOutInMap.value(), mIGEMMOutInMapBwd.value()); + mIGEMMOutInMapBwd = + torch::full({ outInMapTSize, kernelVolume }, -1, + torch::TensorOptions().dtype(torch::kInt32).device(mSourceGrid.device())); + detail::ops::dispatchTransposeOutInMap(mIGEMMOutInMap.value(), + mIGEMMOutInMapBwd.value()); torch::Tensor bitmask = detail::ops::dispatchBitmaskFromOutInMap( mIGEMMOutInMapBwd.value(), splitMaskNumBwd, mSourceGrid.total_voxels()); - auto ret = torch::sort(bitmask, -1L, true); + auto ret = torch::sort(bitmask, -1L, true); torch::Tensor sortedMaskBwd = std::get<0>(ret); - mIGEMMReorderLocBwd = std::get<1>(ret).to(torch::kInt32); - mIGEMMReorderOutInMapBwd = detail::ops::dispatchReorderOutInMap( + mIGEMMReorderLocBwd = std::get<1>(ret).to(torch::kInt32); + mIGEMMReorderOutInMapBwd = detail::ops::dispatchReorderOutInMap( mIGEMMOutInMapBwd.value(), mIGEMMReorderLocBwd.value()); - mIGEMMSortedMaskBwdW = detail::ops::dispatchReduceMask( - sortedMaskBwd, 64); - mIGEMMSortedMaskBwdD = detail::ops::dispatchReduceMask( - sortedMaskBwd, 128); + mIGEMMSortedMaskBwdW = detail::ops::dispatchReduceMask(sortedMaskBwd, 64); + mIGEMMSortedMaskBwdD = detail::ops::dispatchReduceMask(sortedMaskBwd, 128); } - } -SparseConvPackInfo SparseConvPackInfo::transposed() const { +SparseConvPackInfo +SparseConvPackInfo::transposed() const { SparseConvPackInfo ret(mKernelSize, mStride, mSourceGrid, mTargetGrid); - bool sorted = mIGEMMReorderLoc.has_value(); - bool training = mIGEMMOutInMapBwd.has_value(); + bool sorted = mIGEMMReorderLoc.has_value(); + bool training = mIGEMMOutInMapBwd.has_value(); int splitMaskNum = mIGEMMReorderLoc.has_value() ? mIGEMMReorderLoc.value().size(0) : 1; int outInMapSize = (mSourceGrid.total_voxels() + 128 - 1) / 128 * 128; int kernelVolume = mKernelSize.value().x() * mKernelSize.value().y() * mKernelSize.value().z(); - ret.mIGEMMOutInMap = torch::full( - {outInMapSize, kernelVolume}, -1, - torch::TensorOptions().dtype(torch::kInt32).device(mSourceGrid.device())); - detail::ops::dispatchTransposeOutInMap( - mIGEMMOutInMap.value(), ret.mIGEMMOutInMap.value()); + ret.mIGEMMOutInMap = + torch::full({ outInMapSize, kernelVolume }, -1, + torch::TensorOptions().dtype(torch::kInt32).device(mSourceGrid.device())); + detail::ops::dispatchTransposeOutInMap(mIGEMMOutInMap.value(), + ret.mIGEMMOutInMap.value()); if (sorted) { if (training) { - ret.mIGEMMOutInMapBwd = mIGEMMOutInMap; + ret.mIGEMMOutInMapBwd = mIGEMMOutInMap; ret.mIGEMMReorderOutInMapBwd = mIGEMMReoderOutInMap; - ret.mIGEMMReorderLocBwd = mIGEMMReorderLoc; - torch::Tensor sortedMaskBwd = mIGEMMSortedMask.value(); - ret.mIGEMMSortedMaskBwdW = detail::ops::dispatchReduceMask( - sortedMaskBwd, 64); - ret.mIGEMMSortedMaskBwdD = detail::ops::dispatchReduceMask( - sortedMaskBwd, 128); + ret.mIGEMMReorderLocBwd = mIGEMMReorderLoc; + torch::Tensor sortedMaskBwd = mIGEMMSortedMask.value(); + ret.mIGEMMSortedMaskBwdW = + detail::ops::dispatchReduceMask(sortedMaskBwd, 64); + ret.mIGEMMSortedMaskBwdD = + detail::ops::dispatchReduceMask(sortedMaskBwd, 128); } torch::Tensor bitmask = detail::ops::dispatchBitmaskFromOutInMap( ret.mIGEMMOutInMap.value(), splitMaskNum, mSourceGrid.total_voxels()); - auto rets = torch::sort(bitmask, -1L, true); - ret.mIGEMMSortedMask = std::get<0>(rets); - ret.mIGEMMReorderLoc = std::get<1>(rets).to(torch::kInt32); + auto rets = torch::sort(bitmask, -1L, true); + ret.mIGEMMSortedMask = std::get<0>(rets); + ret.mIGEMMReorderLoc = std::get<1>(rets).to(torch::kInt32); ret.mIGEMMReoderOutInMap = detail::ops::dispatchReorderOutInMap( ret.mIGEMMOutInMap.value(), ret.mIGEMMReorderLoc.value()); - ret.mIGEMMReducedSortedMask = detail::ops::dispatchReduceMask( - ret.mIGEMMSortedMask.value(), 128); + ret.mIGEMMReducedSortedMask = + detail::ops::dispatchReduceMask(ret.mIGEMMSortedMask.value(), 128); } else if (training) { - int splitMaskNumBwd = mIGEMMReorderLocBwd.value().size(0); - ret.mIGEMMOutInMapBwd = mIGEMMOutInMap; + int splitMaskNumBwd = mIGEMMReorderLocBwd.value().size(0); + ret.mIGEMMOutInMapBwd = mIGEMMOutInMap; torch::Tensor bitmaskBwd = detail::ops::dispatchBitmaskFromOutInMap( ret.mIGEMMOutInMapBwd.value(), splitMaskNumBwd, mTargetGrid.total_voxels()); - auto rets = torch::sort(bitmaskBwd, -1L, true); - torch::Tensor sortedMaskBwd = std::get<0>(rets); - ret.mIGEMMReorderLocBwd = std::get<1>(rets).to(torch::kInt32); + auto rets = torch::sort(bitmaskBwd, -1L, true); + torch::Tensor sortedMaskBwd = std::get<0>(rets); + ret.mIGEMMReorderLocBwd = std::get<1>(rets).to(torch::kInt32); ret.mIGEMMReorderOutInMapBwd = detail::ops::dispatchReorderOutInMap( ret.mIGEMMOutInMapBwd.value(), ret.mIGEMMReorderLocBwd.value()); - ret.mIGEMMSortedMaskBwdW = detail::ops::dispatchReduceMask( - sortedMaskBwd, 64); - ret.mIGEMMSortedMaskBwdD = detail::ops::dispatchReduceMask( - sortedMaskBwd, 128); + ret.mIGEMMSortedMaskBwdW = detail::ops::dispatchReduceMask(sortedMaskBwd, 64); + ret.mIGEMMSortedMaskBwdD = + detail::ops::dispatchReduceMask(sortedMaskBwd, 128); } ret.mIGEMMUseTF32 = mIGEMMUseTF32; return ret; } -void SparseConvPackInfo::buildCutlass(bool benchmark) { +void +SparseConvPackInfo::buildCutlass(bool benchmark) { if (mCUTLASSHaloIndexBuffer.has_value()) { - TORCH_CHECK(mCUTLASSBenchmark == benchmark, "Cutlass is already built with different benchmark flag"); + TORCH_CHECK(mCUTLASSBenchmark == benchmark, + "Cutlass is already built with different benchmark flag"); return; } std::vector res = FVDB_DISPATCH_KERNEL_DEVICE(mSourceGrid.device(), [&]() { - return detail::ops::dispatchBrickHaloBuffer( - *mSourceGrid.impl(), benchmark); + return detail::ops::dispatchBrickHaloBuffer(*mSourceGrid.impl(), benchmark); }); - mCUTLASSHaloIndexBuffer = res[1]; - mCUTLASSOutputIndexBuffer = res[2]; - mCUTLASSBenchmark = benchmark; + mCUTLASSHaloIndexBuffer = res[1]; + mCUTLASSOutputIndexBuffer = res[2]; + mCUTLASSBenchmark = benchmark; } -void SparseConvPackInfo::buildLGGS() { - TORCH_CHECK(mKernelSize.value().x() == 3 && mKernelSize.value().y() == 3 && mKernelSize.value().z() == 3, +void +SparseConvPackInfo::buildLGGS() { + TORCH_CHECK(mKernelSize.value().x() == 3 && mKernelSize.value().y() == 3 && + mKernelSize.value().z() == 3, "LGGS only supports 3x3x3 kernel size"); - int outInMapSize = (mTargetGrid.total_voxels() + 64 - 1) / 64 * 64; - torch::Tensor outInMap = torch::full( - {outInMapSize, 27}, -1, - torch::TensorOptions().dtype(torch::kInt32).device(mTargetGrid.device())); + int outInMapSize = (mTargetGrid.total_voxels() + 64 - 1) / 64 * 64; + torch::Tensor outInMap = + torch::full({ outInMapSize, 27 }, -1, + torch::TensorOptions().dtype(torch::kInt32).device(mTargetGrid.device())); FVDB_DISPATCH_KERNEL_DEVICE(mSourceGrid.device(), [&]() { detail::ops::dispatchConvolutionKernelMap( *mSourceGrid.impl(), *mTargetGrid.impl(), outInMap, mKernelSize, mStride); }); - outInMap = outInMap.view({-1, 64, 27}).transpose(1, 2); // [#blocks, 27, 64] + outInMap = outInMap.view({ -1, 64, 27 }).transpose(1, 2); // [#blocks, 27, 64] torch::Tensor mapMask = outInMap != -1; - torch::Tensor mapNNZ = torch::nonzero(mapMask); + torch::Tensor mapNNZ = torch::nonzero(mapMask); torch::Tensor kernelRanges = mapMask.sum(-1).view(-1).cumsum(0); - kernelRanges = torch::cat({torch::zeros(1, kernelRanges.options()), kernelRanges}, 0); + kernelRanges = torch::cat({ torch::zeros(1, kernelRanges.options()), kernelRanges }, 0); - torch::Tensor relOutIndices = mapNNZ.index({torch::indexing::Slice(), -1}); - torch::Tensor inIndices = outInMap.index({mapNNZ.index({torch::indexing::Slice(), 0}), - mapNNZ.index({torch::indexing::Slice(), 1}), - mapNNZ.index({torch::indexing::Slice(), 2})}); + torch::Tensor relOutIndices = mapNNZ.index({ torch::indexing::Slice(), -1 }); + torch::Tensor inIndices = outInMap.index({ mapNNZ.index({ torch::indexing::Slice(), 0 }), + mapNNZ.index({ torch::indexing::Slice(), 1 }), + mapNNZ.index({ torch::indexing::Slice(), 2 }) }); mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData = relOutIndices.to(torch::kInt32); - mLGGSSpokeInputGlobalIndicesFlattenedData = inIndices; - mLGGSSpokeIndicesFlattenedOffset = kernelRanges.to(torch::kInt32); + mLGGSSpokeInputGlobalIndicesFlattenedData = inIndices; + mLGGSSpokeIndicesFlattenedOffset = kernelRanges.to(torch::kInt32); } -JaggedTensor SparseConvPackInfo::sparseConv3d(const JaggedTensor& input, const torch::Tensor& weights, ConvPackBackend backend) const { - TORCH_CHECK_VALUE(input.num_outer_lists() == mSourceGrid.grid_count(), "Input batch size must match target grid batch size"); - TORCH_CHECK_VALUE(input.element_count() == mSourceGrid.total_voxels(), "Input element count must match target grid total voxels"); +JaggedTensor +SparseConvPackInfo::sparseConv3d(const JaggedTensor &input, const torch::Tensor &weights, + ConvPackBackend backend) const { + TORCH_CHECK_VALUE(input.num_outer_lists() == mSourceGrid.grid_count(), + "Input batch size must match target grid batch size"); + TORCH_CHECK_VALUE(input.element_count() == mSourceGrid.total_voxels(), + "Input element count must match target grid total voxels"); if (backend == ConvPackBackend::GATHER_SCATTER) { auto ret = detail::autograd::SparseConvolutionKernelMap::apply( @@ -242,41 +258,43 @@ JaggedTensor SparseConvPackInfo::sparseConv3d(const JaggedTensor& input, const t // Re-shape kernel from [Do, Di, D, H, W] to [Do, D, H, W, Di]. TORCH_CHECK(mCUTLASSHaloIndexBuffer.has_value() && mCUTLASSOutputIndexBuffer.has_value(), "Cutlass buffer is not built"); - auto kernel = weights.permute({0, 4, 3, 2, 1}).contiguous(); + auto kernel = weights.permute({ 0, 4, 3, 2, 1 }).contiguous(); torch::Tensor out = FVDB_DISPATCH_KERNEL_DEVICE(mCUTLASSHaloIndexBuffer->device(), [&]() { return detail::ops::dispatchSparseConvolutionCutlass( - input.jdata(), kernel, - mCUTLASSHaloIndexBuffer.value(), mCUTLASSOutputIndexBuffer.value(), - mCUTLASSBenchmark); + input.jdata(), kernel, mCUTLASSHaloIndexBuffer.value(), + mCUTLASSOutputIndexBuffer.value(), mCUTLASSBenchmark); }); return mTargetGrid.impl()->jaggedTensor(out, false); } else if (backend == ConvPackBackend::LGGS) { TORCH_CHECK(mLGGSSpokeIndicesFlattenedOffset.has_value() && - mLGGSSpokeInputGlobalIndicesFlattenedData.has_value() && - mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData.has_value(), + mLGGSSpokeInputGlobalIndicesFlattenedData.has_value() && + mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData.has_value(), "LGGS buffer is not built"); // Reshape kernel from [Do, Di, D, H, W] to [WHD, Di, Do]. - auto kernel = weights.permute({4, 3, 2, 1, 0}).contiguous(); - kernel = kernel.reshape({-1, kernel.size(3), kernel.size(4)}); - torch::Tensor out = FVDB_DISPATCH_KERNEL_DEVICE(mLGGSSpokeIndicesFlattenedOffset->device(), [&]() { - return detail::ops::dispatchSparseConvolutionLggs( - input.jdata(), kernel, - mLGGSSpokeIndicesFlattenedOffset.value(), - mLGGSSpokeInputGlobalIndicesFlattenedData.value(), - mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData.value()); - }); + auto kernel = weights.permute({ 4, 3, 2, 1, 0 }).contiguous(); + kernel = kernel.reshape({ -1, kernel.size(3), kernel.size(4) }); + torch::Tensor out = + FVDB_DISPATCH_KERNEL_DEVICE(mLGGSSpokeIndicesFlattenedOffset->device(), [&]() { + return detail::ops::dispatchSparseConvolutionLggs( + input.jdata(), kernel, mLGGSSpokeIndicesFlattenedOffset.value(), + mLGGSSpokeInputGlobalIndicesFlattenedData.value(), + mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData.value()); + }); return mTargetGrid.impl()->jaggedTensor(out, false); } else { TORCH_CHECK(false, "Unknown backend"); } - } -JaggedTensor SparseConvPackInfo::sparseTransposeConv3d(const JaggedTensor& input, const torch::Tensor& weights, ConvPackBackend backend) const { - TORCH_CHECK_VALUE(input.num_outer_lists() == mTargetGrid.grid_count(), "Input batch size must match target grid batch size"); - TORCH_CHECK_VALUE(input.element_count() == mTargetGrid.total_voxels(), "Input element count must match target grid total voxels"); +JaggedTensor +SparseConvPackInfo::sparseTransposeConv3d(const JaggedTensor &input, const torch::Tensor &weights, + ConvPackBackend backend) const { + TORCH_CHECK_VALUE(input.num_outer_lists() == mTargetGrid.grid_count(), + "Input batch size must match target grid batch size"); + TORCH_CHECK_VALUE(input.element_count() == mTargetGrid.total_voxels(), + "Input element count must match target grid total voxels"); if (backend == ConvPackBackend::GATHER_SCATTER) { auto ret = detail::autograd::SparseConvolutionKernelMap::apply( @@ -293,8 +311,6 @@ JaggedTensor SparseConvPackInfo::sparseTransposeConv3d(const JaggedTensor& input } else { TORCH_CHECK(false, "Unknown backend"); } - } - } // namespace fvdb diff --git a/fvdb/src/SparseConvPackInfo.h b/fvdb/src/SparseConvPackInfo.h index 2dfe0e01e7..3adcdc11e5 100644 --- a/fvdb/src/SparseConvPackInfo.h +++ b/fvdb/src/SparseConvPackInfo.h @@ -1,11 +1,11 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_SPARSECONVPACKINFO_H +#define FVDB_SPARSECONVPACKINFO_H #include "GridBatch.h" - namespace fvdb { enum ConvPackBackend { @@ -16,7 +16,6 @@ enum ConvPackBackend { }; class SparseConvPackInfo : torch::CustomClassHolder { - // #IO: Number of input-output pairs // #O-P: Number of output voxels, padded to multiple of 128 // #I-P: Number of input voxels, padded to multiple of 128 @@ -24,29 +23,40 @@ class SparseConvPackInfo : torch::CustomClassHolder { // S: Split count bool mGSUseME = false; - torch::optional mGSNeighborMap; // [#IO, 2] (int32), GATHER_SCATTER, GATHER_SCATTER(me) - torch::optional mGSNeighborSizes; // [#IO, 2] (int32), GATHER_SCATTER, GATHER_SCATTER(me) - - bool mIGEMMUseTF32 = false; - torch::optional mIGEMMOutInMap; // [#O-P, K] (int32), IGEMM, IGEMM(sorted) - torch::optional mIGEMMReorderLoc; // [S, #O-P] (int32), IGEMM(sorted) - torch::optional mIGEMMSortedMask; // [S, #O-P] (int32), IGEMM(sorted) + torch::optional + mGSNeighborMap; // [#IO, 2] (int32), GATHER_SCATTER, GATHER_SCATTER(me) + torch::optional + mGSNeighborSizes; // [#IO, 2] (int32), GATHER_SCATTER, GATHER_SCATTER(me) + + bool mIGEMMUseTF32 = false; + torch::optional mIGEMMOutInMap; // [#O-P, K] (int32), IGEMM, IGEMM(sorted) + torch::optional mIGEMMReorderLoc; // [S, #O-P] (int32), IGEMM(sorted) + torch::optional mIGEMMSortedMask; // [S, #O-P] (int32), IGEMM(sorted) torch::optional mIGEMMReducedSortedMask; // [S, #O-P//128] (int32), IGEMM(sorted) torch::optional mIGEMMReoderOutInMap; // [#O-P, K] (int32), IGEMM(sorted) - torch::optional mIGEMMOutInMapBwd; // [#I-P, K] (int32), IGEMM, IGEMM(sorted, training) - torch::optional mIGEMMReorderLocBwd; // [S, #I-P] (int32), IGEMM, IGEMM(sorted, training) - torch::optional mIGEMMSortedMaskBwdW; // [S, #I-P//x] (int32), IGEMM, IGEMM(sorted, training) - torch::optional mIGEMMSortedMaskBwdD; // [S, #I-P//y] (int32), IGEMM, IGEMM(sorted, training) - torch::optional mIGEMMReorderOutInMapBwd; // [#I-P, K] (int32), IGEMM, IGEMM(sorted, training) + torch::optional + mIGEMMOutInMapBwd; // [#I-P, K] (int32), IGEMM, IGEMM(sorted, training) + torch::optional + mIGEMMReorderLocBwd; // [S, #I-P] (int32), IGEMM, IGEMM(sorted, training) + torch::optional + mIGEMMSortedMaskBwdW; // [S, #I-P//x] (int32), IGEMM, IGEMM(sorted, training) + torch::optional + mIGEMMSortedMaskBwdD; // [S, #I-P//y] (int32), IGEMM, IGEMM(sorted, training) + torch::optional + mIGEMMReorderOutInMapBwd; // [#I-P, K] (int32), IGEMM, IGEMM(sorted, training) bool mCUTLASSBenchmark = false; - torch::optional mCUTLASSHaloIndexBuffer; // [#active_brick, 6, 4, 4] (int32), CUTLASS - torch::optional mCUTLASSOutputIndexBuffer; // [#active_brick, 4, 2, 2] (int32), CUTLASS + torch::optional + mCUTLASSHaloIndexBuffer; // [#active_brick, 6, 4, 4] (int32), CUTLASS + torch::optional + mCUTLASSOutputIndexBuffer; // [#active_brick, 4, 2, 2] (int32), CUTLASS - torch::optional mLGGSSpokeIndicesFlattenedOffset; // 1D array. (int32), LGGS - torch::optional mLGGSSpokeInputGlobalIndicesFlattenedData; // 1D array. (int32), LGGS - torch::optional mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData; // 1D array. (int32), LGGS + torch::optional mLGGSSpokeIndicesFlattenedOffset; // 1D array. (int32), LGGS + torch::optional + mLGGSSpokeInputGlobalIndicesFlattenedData; // 1D array. (int32), LGGS + torch::optional + mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData; // 1D array. (int32), LGGS Vec3iOrScalar mStride; Vec3iOrScalar mKernelSize; @@ -54,37 +64,109 @@ class SparseConvPackInfo : torch::CustomClassHolder { GridBatch mSourceGrid; GridBatch mTargetGrid; -public: - const torch::optional neighborMap() const { return mGSNeighborMap; } - const torch::optional neighborSizes() const { return mGSNeighborSizes; } - const bool useME() const { return mGSUseME; } - - const torch::optional outInMap() const { return mIGEMMOutInMap; } - const torch::optional reorderLoc() const { return mIGEMMReorderLoc; } - const torch::optional sortedMask() const { return mIGEMMSortedMask; } - const torch::optional reducedSortedMask() const { return mIGEMMReducedSortedMask; } - const torch::optional reoderOutInMap() const { return mIGEMMReoderOutInMap; } - const bool useTF32() const { return mIGEMMUseTF32; } - - const torch::optional outInMapBwd() const { return mIGEMMOutInMapBwd; } - const torch::optional reorderLocBwd() const { return mIGEMMReorderLocBwd; } - const torch::optional sortedMaskBwdW() const { return mIGEMMSortedMaskBwdW; } - const torch::optional sortedMaskBwdD() const { return mIGEMMSortedMaskBwdD; } - const torch::optional reorderOutInMapBwd() const { return mIGEMMReorderOutInMapBwd; } - - const torch::optional haloIndexBuffer() const { return mCUTLASSHaloIndexBuffer; } - const torch::optional outputIndexBuffer() const { return mCUTLASSOutputIndexBuffer; } - const bool benchmark() const { return mCUTLASSBenchmark; } - - const torch::optional blockKernelRanges() const { return mLGGSSpokeIndicesFlattenedOffset; } - const torch::optional blockKernelInIdx() const { return mLGGSSpokeInputGlobalIndicesFlattenedData; } - const torch::optional blockKernelRelOutIdx() const { return mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData; } - - const Vec3iOrScalar stride() const { return mStride; } - const Vec3iOrScalar kernelSize() const { return mKernelSize; } - - GridBatch targetGrid() const { return mTargetGrid; } - GridBatch sourceGrid() const { return mSourceGrid; } + public: + const torch::optional + neighborMap() const { + return mGSNeighborMap; + } + const torch::optional + neighborSizes() const { + return mGSNeighborSizes; + } + const bool + useME() const { + return mGSUseME; + } + + const torch::optional + outInMap() const { + return mIGEMMOutInMap; + } + const torch::optional + reorderLoc() const { + return mIGEMMReorderLoc; + } + const torch::optional + sortedMask() const { + return mIGEMMSortedMask; + } + const torch::optional + reducedSortedMask() const { + return mIGEMMReducedSortedMask; + } + const torch::optional + reoderOutInMap() const { + return mIGEMMReoderOutInMap; + } + const bool + useTF32() const { + return mIGEMMUseTF32; + } + + const torch::optional + outInMapBwd() const { + return mIGEMMOutInMapBwd; + } + const torch::optional + reorderLocBwd() const { + return mIGEMMReorderLocBwd; + } + const torch::optional + sortedMaskBwdW() const { + return mIGEMMSortedMaskBwdW; + } + const torch::optional + sortedMaskBwdD() const { + return mIGEMMSortedMaskBwdD; + } + const torch::optional + reorderOutInMapBwd() const { + return mIGEMMReorderOutInMapBwd; + } + + const torch::optional + haloIndexBuffer() const { + return mCUTLASSHaloIndexBuffer; + } + const torch::optional + outputIndexBuffer() const { + return mCUTLASSOutputIndexBuffer; + } + const bool + benchmark() const { + return mCUTLASSBenchmark; + } + + const torch::optional + blockKernelRanges() const { + return mLGGSSpokeIndicesFlattenedOffset; + } + const torch::optional + blockKernelInIdx() const { + return mLGGSSpokeInputGlobalIndicesFlattenedData; + } + const torch::optional + blockKernelRelOutIdx() const { + return mLGGSSpokeOutputLocalOffsetsRelativeToBlockFlattenedData; + } + + const Vec3iOrScalar + stride() const { + return mStride; + } + const Vec3iOrScalar + kernelSize() const { + return mKernelSize; + } + + GridBatch + targetGrid() const { + return mTargetGrid; + } + GridBatch + sourceGrid() const { + return mSourceGrid; + } SparseConvPackInfo(Vec3iOrScalar kernelsize, Vec3iOrScalar stride, GridBatch src, torch::optional maybeTarget); @@ -93,13 +175,18 @@ class SparseConvPackInfo : torch::CustomClassHolder { // Will not rebuild if already built void buildGatherScatter(bool use_me = false); - void buildImplicitGEMM(bool sorted, int splitMaskNum, bool training, int splitMaskNumBwd, bool use_tf32 = false); + void buildImplicitGEMM(bool sorted, int splitMaskNum, bool training, int splitMaskNumBwd, + bool use_tf32 = false); void buildCutlass(bool benchmark = false); void buildLGGS(); - JaggedTensor sparseConv3d(const JaggedTensor& input, const torch::Tensor& weights, ConvPackBackend backend = ConvPackBackend::GATHER_SCATTER) const; - JaggedTensor sparseTransposeConv3d(const JaggedTensor& input, const torch::Tensor& weights, ConvPackBackend backend = ConvPackBackend::GATHER_SCATTER) const; + JaggedTensor sparseConv3d(const JaggedTensor &input, const torch::Tensor &weights, + ConvPackBackend backend = ConvPackBackend::GATHER_SCATTER) const; + JaggedTensor + sparseTransposeConv3d(const JaggedTensor &input, const torch::Tensor &weights, + ConvPackBackend backend = ConvPackBackend::GATHER_SCATTER) const; }; +} // namespace fvdb -} +#endif // FVDB_SPARSECONVPACKINFO_H \ No newline at end of file diff --git a/fvdb/src/Types.h b/fvdb/src/Types.h index c407eff022..6c469f3db8 100644 --- a/fvdb/src/Types.h +++ b/fvdb/src/Types.h @@ -1,137 +1,157 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once - -#include -#include -#include +#ifndef FVDB_TYPES_H +#define FVDB_TYPES_H #include "detail/TypesImpl.h" +#include -namespace fvdb { +#include +#include +namespace fvdb { -// These are union types that can be constructed from nanovdb types, torch tensors, std::vectors, single scalars, etc... -// They are used to allow the user to pass in a variety of types to the API, and then convert them to the correct type -using Vec3i = detail::Coord3Impl; +// These are union types that can be constructed from nanovdb types, torch tensors, std::vectors, +// single scalars, etc... They are used to allow the user to pass in a variety of types to the API, +// and then convert them to the correct type +using Vec3i = detail::Coord3Impl; using Vec3iOrScalar = detail::Coord3Impl; -using Vec4i = detail::Coord4Impl; -using Vec3d = detail::Vec3dImpl; +using Vec4i = detail::Coord4Impl; +using Vec3d = detail::Vec3dImpl; using Vec3dOrScalar = detail::Vec3dImpl; -// These are union types that can be constructed from nanovdb types, torch tensors, std::vectors, single scalars, etc... -// and resolve to a batch of values. They are used to allow the user to pass in a single vector (or scalar) and have -// it be broadcast to a whole batch of values. -// E.g. if you are constructing a batch of grids, you can pass in a single scalar 1.0 to have a voxel size of [1, 1, 1] -// for every grid in the batch. Or a user can pass in a vector [1, 2, 3] to have each grid have a voxel -// size of [1, 2, 3]. Alternatively, a user can specify a voxel size for each grid in the batch -// [[v1x, v1y, v1z], ..., [vnx, vny, vnz]]. The Vec3dBatchOrScalar will accept all these inputs -// and resolve them to a batch of values. -using Vec3dBatchOrScalar = detail::Vec3BatchImpl; -using Vec3dBatch = detail::Vec3BatchImpl; -using Vec3iBatch = detail::Vec3BatchImpl; - +// These are union types that can be constructed from nanovdb types, torch tensors, std::vectors, +// single scalars, etc... and resolve to a batch of values. They are used to allow the user to pass +// in a single vector (or scalar) and have it be broadcast to a whole batch of values. E.g. if you +// are constructing a batch of grids, you can pass in a single scalar 1.0 to have a voxel size of +// [1, 1, 1] +// for every grid in the batch. Or a user can pass in a vector [1, 2, 3] to have each grid have +// a voxel +// size of [1, 2, 3]. Alternatively, a user can specify a voxel size for each grid in the +// batch +// [[v1x, v1y, v1z], ..., [vnx, vny, vnz]]. The Vec3dBatchOrScalar will accept all these +// inputs and resolve them to a batch of values. +using Vec3dBatchOrScalar = + detail::Vec3BatchImpl; +using Vec3dBatch = + detail::Vec3BatchImpl; +using Vec3iBatch = + detail::Vec3BatchImpl; /// @brief A class that can be constructed from a torch::Device or a string. /// Calling value() returns a torch::device class TorchDeviceOrString { torch::Device mValue; - void setIndex() { - if (mValue.is_cuda() && ! mValue.has_index()) { + void + setIndex() { + if (mValue.is_cuda() && !mValue.has_index()) { mValue.set_index(c10::cuda::current_device()); } } -public: + + public: TorchDeviceOrString() : mValue(torch::kCPU) { setIndex(); } TorchDeviceOrString(torch::Device device) : mValue(device) { setIndex(); } TorchDeviceOrString(c10::DeviceType deviceType) : mValue(deviceType) { setIndex(); } - TorchDeviceOrString(std::string& str) : mValue(str) { setIndex(); } + TorchDeviceOrString(std::string &str) : mValue(str) { setIndex(); } - const torch::Device& value() const { + const torch::Device & + value() const { return mValue; } }; - -/// @brief A class that con be constructed from a string or a list of strings but always returns a list of strings -/// Used to enable broadcasting for arguments that specify a single value or a list of values for a whole batch +/// @brief A class that con be constructed from a string or a list of strings but always returns a +/// list of strings +/// Used to enable broadcasting for arguments that specify a single value or a list of values +/// for a whole batch class StringOrListOfStrings { std::vector mValue; -public: + + public: StringOrListOfStrings() : mValue() {} - StringOrListOfStrings(std::string str) : mValue({str}) {} + StringOrListOfStrings(std::string str) : mValue({ str }) {} StringOrListOfStrings(std::vector str) : mValue(str) {} - const std::vector& value() const { + const std::vector & + value() const { return mValue; } }; - -/// @brief A class representing a set of unique IDs for a nanovdb grid (used to specify which grids to load -/// from an .nvdb file). You can specify the set of grids to load as a integer index, a single string name, -/// a vector of integer indices, or a vector of string names +/// @brief A class representing a set of unique IDs for a nanovdb grid (used to specify which grids +/// to load +/// from an .nvdb file). You can specify the set of grids to load as a integer index, a +/// single string name, a vector of integer indices, or a vector of string names class NanoVDBFileGridIdentifier { - std::vector mIndices; + std::vector mIndices; std::vector mGridNames; -public: + public: NanoVDBFileGridIdentifier() : mIndices(), mGridNames() {}; - NanoVDBFileGridIdentifier(uint64_t index) : mIndices({index}) {}; + NanoVDBFileGridIdentifier(uint64_t index) : mIndices({ index }) {}; NanoVDBFileGridIdentifier(std::vector indices) : mIndices(indices) {}; - NanoVDBFileGridIdentifier(std::string gridName) : mGridNames({gridName}) {}; + NanoVDBFileGridIdentifier(std::string gridName) : mGridNames({ gridName }) {}; NanoVDBFileGridIdentifier(std::vector gridNames) : mGridNames(gridNames) {}; - std::string toString() const { + std::string + toString() const { std::stringstream ss; if (specifiesIndices()) { - for(auto idx : mIndices) { + for (auto idx: mIndices) { ss << idx << ", "; } return "NanoVDBFileGridIdentifier indices: " + ss.str(); } else { - for(auto idx : mGridNames) { + for (auto idx: mGridNames) { ss << idx << ", "; } return "NanoVDBFileGridIdentifier gridNames: " + ss.str(); } } - bool isValid() const { + bool + isValid() const { return (mIndices.empty() != mGridNames.empty()); } - bool specifiesIndices() const { + bool + specifiesIndices() const { return !mIndices.empty(); } - bool specifiesNames() const { + bool + specifiesNames() const { return !mGridNames.empty(); } - const std::vector& indicesValue() const { + const std::vector & + indicesValue() const { return mIndices; } - const std::vector& namesValue() const { + const std::vector & + namesValue() const { return mGridNames; } - bool empty() const { + bool + empty() const { return (mIndices.empty() && mGridNames.empty()); } - size_t size() const { + size_t + size() const { if (specifiesIndices()) { return mIndices.size(); } else { return mGridNames.size(); } } - }; +} // namespace fvdb -} // namespace fvdb \ No newline at end of file +#endif // FVDB_TYPES_H \ No newline at end of file diff --git a/fvdb/src/detail/GridBatchImpl.cu b/fvdb/src/detail/GridBatchImpl.cu index dd200f6aa2..0159c991a8 100644 --- a/fvdb/src/detail/GridBatchImpl.cu +++ b/fvdb/src/detail/GridBatchImpl.cu @@ -3,34 +3,34 @@ // #include "GridBatchImpl.h" -#include +#include +#include #include +#include #include #include -#include -#include "detail/ops/Ops.h" -#include "detail/build/Build.h" +#include namespace { -__global__ void computeBatchOffsetsFromMetadata( - uint32_t numGrids, - fvdb::detail::GridBatchImpl::GridMetadata* perGridMetadata, - torch::PackedTensorAccessor32 outBatchOffsets) { - +__global__ void +computeBatchOffsetsFromMetadata( + uint32_t numGrids, fvdb::detail::GridBatchImpl::GridMetadata *perGridMetadata, + torch::PackedTensorAccessor32 + outBatchOffsets) { if (numGrids == 0) { return; } outBatchOffsets[0] = 0; for (uint32_t i = 1; i < (numGrids + 1); i += 1) { - outBatchOffsets[i] = outBatchOffsets[i-1] + perGridMetadata[i-1].mNumVoxels; + outBatchOffsets[i] = outBatchOffsets[i - 1] + perGridMetadata[i - 1].mNumVoxels; } } -} +} // namespace namespace fvdb { namespace detail { @@ -39,40 +39,49 @@ GridBatchImpl::GridBatchImpl(torch::Device device, bool isMutable) { std::vector dummy; dummy.push_back(nanovdb::Vec3d(1.0, 1.0, 1.0)); // TODO (Francis): No list-of-lists support for now, so we just pass an empty list of indices - const torch::Tensor lidx = torch::empty({0, 1}, torch::TensorOptions().dtype(fvdb::JLIdxScalarType).device(device)); + const torch::Tensor lidx = + torch::empty({ 0, 1 }, torch::TensorOptions().dtype(fvdb::JLIdxScalarType).device(device)); setGrid(build::buildEmptyGrid(device, isMutable), lidx, dummy, dummy, false); mHostGridMetadata.clear(); syncMetadataToDeviceIfCUDA(false); mBatchMetadata.mIsContiguous = true; } -GridBatchImpl::GridBatchImpl(nanovdb::GridHandle&& gridHdl, - const std::vector& voxelSizes, - const std::vector& voxelOrigins) { - TORCH_CHECK(!gridHdl.buffer().isEmpty(), "Cannot create a batched grid handle from an empty grid handle"); +GridBatchImpl::GridBatchImpl(nanovdb::GridHandle &&gridHdl, + const std::vector &voxelSizes, + const std::vector &voxelOrigins) { + TORCH_CHECK(!gridHdl.buffer().isEmpty(), + "Cannot create a batched grid handle from an empty grid handle"); for (std::size_t i = 0; i < voxelSizes.size(); i += 1) { - TORCH_CHECK_VALUE(voxelSizes[i][0] > 0 && voxelSizes[i][1] > 0 && voxelSizes[i][2] > 0, "Voxel size must be greater than 0"); + TORCH_CHECK_VALUE(voxelSizes[i][0] > 0 && voxelSizes[i][1] > 0 && voxelSizes[i][2] > 0, + "Voxel size must be greater than 0"); } mDeviceGridMetadata = nullptr; // TODO (Francis): No list-of-lists support for now, so we just pass an empty list of indices - const torch::Tensor lidx = torch::empty({0, 1}, torch::TensorOptions().dtype(fvdb::JLIdxScalarType).device(gridHdl.buffer().device())); + const torch::Tensor lidx = torch::empty( + { 0, 1 }, + torch::TensorOptions().dtype(fvdb::JLIdxScalarType).device(gridHdl.buffer().device())); setGrid(std::move(gridHdl), lidx, voxelSizes, voxelOrigins, false /* blocking */); mBatchMetadata.mIsContiguous = true; }; -GridBatchImpl::GridBatchImpl(nanovdb::GridHandle&& gridHdl, - const nanovdb::Vec3d& globalVoxelSize, - const nanovdb::Vec3d& globalVoxelOrigin) { - TORCH_CHECK(!gridHdl.buffer().isEmpty(), "Cannot create a batched grid handle from an empty grid handle"); - TORCH_CHECK_VALUE(globalVoxelSize[0] > 0 && globalVoxelSize[1] > 0 && globalVoxelSize[2] > 0, "Voxel size must be greater than 0"); +GridBatchImpl::GridBatchImpl(nanovdb::GridHandle &&gridHdl, + const nanovdb::Vec3d &globalVoxelSize, + const nanovdb::Vec3d &globalVoxelOrigin) { + TORCH_CHECK(!gridHdl.buffer().isEmpty(), + "Cannot create a batched grid handle from an empty grid handle"); + TORCH_CHECK_VALUE(globalVoxelSize[0] > 0 && globalVoxelSize[1] > 0 && globalVoxelSize[2] > 0, + "Voxel size must be greater than 0"); mDeviceGridMetadata = nullptr; std::vector voxelSizes, voxelOrigins; - for(size_t i = 0; i < gridHdl.gridCount(); ++i) { + for (size_t i = 0; i < gridHdl.gridCount(); ++i) { voxelSizes.push_back(globalVoxelSize); voxelOrigins.push_back(globalVoxelOrigin); } // TODO (Francis): No list-of-lists support for now, so we just pass an empty list of indices - const torch::Tensor lidx = torch::empty({0, 1}, torch::TensorOptions().dtype(fvdb::JLIdxScalarType).device(gridHdl.buffer().device())); + const torch::Tensor lidx = torch::empty( + { 0, 1 }, + torch::TensorOptions().dtype(fvdb::JLIdxScalarType).device(gridHdl.buffer().device())); setGrid(std::move(gridHdl), lidx, voxelSizes, voxelOrigins, false /* blocking */); mBatchMetadata.mIsContiguous = true; }; @@ -84,13 +93,15 @@ GridBatchImpl::~GridBatchImpl() { } }; -torch::Tensor GridBatchImpl::worldToGridMatrix(int64_t bi) const { +torch::Tensor +GridBatchImpl::worldToGridMatrix(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); - torch::Tensor xformMat = torch::eye(4, torch::TensorOptions().device(device()).dtype(torch::kDouble)); - const VoxelCoordTransform& transform = primalTransform(bi); - const nanovdb::Vec3d& scale = transform.scale(); - const nanovdb::Vec3d& translate = transform.translate(); + torch::Tensor xformMat = + torch::eye(4, torch::TensorOptions().device(device()).dtype(torch::kDouble)); + const VoxelCoordTransform &transform = primalTransform(bi); + const nanovdb::Vec3d &scale = transform.scale(); + const nanovdb::Vec3d &translate = transform.translate(); xformMat[0][0] = scale[0]; xformMat[1][1] = scale[1]; @@ -103,41 +114,49 @@ torch::Tensor GridBatchImpl::worldToGridMatrix(int64_t bi) const { return xformMat; } -void GridBatchImpl::recomputeBatchOffsets() { +void +GridBatchImpl::recomputeBatchOffsets() { TORCH_CHECK(batchSize() == mHostGridMetadata.size(), "Batch size does not match metadata size"); - mBatchOffsets = torch::empty({batchSize() + 1}, torch::TensorOptions().dtype(fvdb::JOffsetsScalarType).device(device())); + mBatchOffsets = + torch::empty({ batchSize() + 1 }, + torch::TensorOptions().dtype(fvdb::JOffsetsScalarType).device(device())); if (device().is_cuda()) { - computeBatchOffsetsFromMetadata<<<1, 1>>>(batchSize(), mDeviceGridMetadata, mBatchOffsets.packed_accessor32()); + computeBatchOffsetsFromMetadata<<<1, 1>>>( + batchSize(), mDeviceGridMetadata, + mBatchOffsets.packed_accessor32()); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto outBatchOffsets = mBatchOffsets.accessor(); - outBatchOffsets[0] = 0; + outBatchOffsets[0] = 0; for (int i = 1; i < (mHostGridMetadata.size() + 1); i += 1) { - outBatchOffsets[i] = outBatchOffsets[i-1] + mHostGridMetadata[i-1].mNumVoxels; + outBatchOffsets[i] = outBatchOffsets[i - 1] + mHostGridMetadata[i - 1].mNumVoxels; } } } - -torch::Tensor GridBatchImpl::gridToWorldMatrix(int64_t bi) const { +torch::Tensor +GridBatchImpl::gridToWorldMatrix(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return torch::linalg::inv(worldToGridMatrix(bi)); } -c10::intrusive_ptr GridBatchImpl::clone(torch::Device device, bool blocking) const { - // If you're cloning an empty grid, just create a new empty grid on the right device and return it +c10::intrusive_ptr +GridBatchImpl::clone(torch::Device device, bool blocking) const { + // If you're cloning an empty grid, just create a new empty grid on the right device and return + // it if (batchSize() == 0) { return c10::make_intrusive(device, isMutable()); } - // The guide buffer is a hack to perform the correct copy (i.e. host -> device / device -> host etc...) - // The guide carries the desired target device to the copy. - // The reason we do this is to conform with the nanovdb which can only accept a buffer as an extra argument. + // The guide buffer is a hack to perform the correct copy (i.e. host -> device / device -> host + // etc...) The guide carries the desired target device to the copy. The reason we do this is to + // conform with the nanovdb which can only accept a buffer as an extra argument. TorchDeviceBuffer guideBuffer(0, nullptr); guideBuffer.setDevice(device, true); // Make a copy of this gridHandle on the same device as the guide buffer - nanovdb::GridHandle clonedHdl = mGridHdl->copy(guideBuffer); + nanovdb::GridHandle clonedHdl = + mGridHdl->copy(guideBuffer); // Copy the voxel sizes and origins for this grid std::vector voxelSizes, voxelOrigins; @@ -145,28 +164,31 @@ c10::intrusive_ptr GridBatchImpl::clone(torch::Device device, boo // Build a GridBatchImpl from the cloned grid handle and voxel sizes/origins // FIXME: (@fwilliams) This makes an extra copy or non contiguous grids - return GridBatchImpl::contiguous(c10::make_intrusive(std::move(clonedHdl), voxelSizes, voxelOrigins)); + return GridBatchImpl::contiguous( + c10::make_intrusive(std::move(clonedHdl), voxelSizes, voxelOrigins)); } -void GridBatchImpl::syncMetadataToDeviceIfCUDA(bool blocking) { +void +GridBatchImpl::syncMetadataToDeviceIfCUDA(bool blocking) { if (device().is_cuda()) { // There is something to sync and we're on a cuda device // We haven't allocated the cuda memory yet, so we need to do that now if (mDeviceGridMetadata == nullptr) { // We need to allocate the memory on the device c10::cuda::CUDAGuard deviceGuard(device()); - size_t metaDataByteSize = sizeof(GridMetadata) * mHostGridMetadata.size(); + size_t metaDataByteSize = sizeof(GridMetadata) * mHostGridMetadata.size(); at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(device().index()); - mDeviceGridMetadata = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(metaDataByteSize, defaultStream.stream())); + mDeviceGridMetadata = + static_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream( + metaDataByteSize, defaultStream.stream())); } // Copy host grid metadata to device buffer - size_t metaDataByteSize = sizeof(GridMetadata) * mHostGridMetadata.size(); - at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(mGridHdl->buffer().device().index()); - C10_CUDA_CHECK(cudaMemcpyAsync(mDeviceGridMetadata, - mHostGridMetadata.data(), - metaDataByteSize, - cudaMemcpyHostToDevice, + size_t metaDataByteSize = sizeof(GridMetadata) * mHostGridMetadata.size(); + at::cuda::CUDAStream defaultStream = + at::cuda::getCurrentCUDAStream(mGridHdl->buffer().device().index()); + C10_CUDA_CHECK(cudaMemcpyAsync(mDeviceGridMetadata, mHostGridMetadata.data(), + metaDataByteSize, cudaMemcpyHostToDevice, defaultStream.stream())); // Block if you asked for it if (blocking) { @@ -175,7 +197,8 @@ void GridBatchImpl::syncMetadataToDeviceIfCUDA(bool blocking) { } } -void GridBatchImpl::setGlobalPrimalTransform(const VoxelCoordTransform& transform, bool syncToDevice) { +void +GridBatchImpl::setGlobalPrimalTransform(const VoxelCoordTransform &transform, bool syncToDevice) { for (size_t i = 0; i < mHostGridMetadata.size(); i++) { mHostGridMetadata[i].mPrimalTransform = transform; } @@ -185,7 +208,8 @@ void GridBatchImpl::setGlobalPrimalTransform(const VoxelCoordTransform& transfor } } -void GridBatchImpl::setGlobalDualTransform(const VoxelCoordTransform& transform, bool syncToDevice) { +void +GridBatchImpl::setGlobalDualTransform(const VoxelCoordTransform &transform, bool syncToDevice) { for (size_t i = 0; i < mHostGridMetadata.size(); i++) { mHostGridMetadata[i].mDualTransform = transform; } @@ -195,7 +219,8 @@ void GridBatchImpl::setGlobalDualTransform(const VoxelCoordTransform& transform, } } -void GridBatchImpl::setGlobalVoxelSize(const nanovdb::Vec3d& voxelSize, bool syncToDevice) { +void +GridBatchImpl::setGlobalVoxelSize(const nanovdb::Vec3d &voxelSize, bool syncToDevice) { TORCH_CHECK(batchSize() > 0, "Cannot set global voxel size on an empty batch of grids"); for (size_t i = 0; i < mHostGridMetadata.size(); i++) { @@ -207,7 +232,8 @@ void GridBatchImpl::setGlobalVoxelSize(const nanovdb::Vec3d& voxelSize, bool syn } } -void GridBatchImpl::setGlobalVoxelOrigin(const nanovdb::Vec3d& voxelOrigin, bool syncToDevice) { +void +GridBatchImpl::setGlobalVoxelOrigin(const nanovdb::Vec3d &voxelOrigin, bool syncToDevice) { TORCH_CHECK(batchSize() > 0, "Cannot set global voxel origin on an empty batch of grids"); for (size_t i = 0; i < mHostGridMetadata.size(); i++) { @@ -219,8 +245,11 @@ void GridBatchImpl::setGlobalVoxelOrigin(const nanovdb::Vec3d& voxelOrigin, bool } } -void GridBatchImpl::setGlobalVoxelSizeAndOrigin(const nanovdb::Vec3d& voxelSize, const nanovdb::Vec3d& voxelOrigin, bool syncToDevice) { - TORCH_CHECK(batchSize() > 0, "Cannot set global voxel size and origin on an empty batch of grids"); +void +GridBatchImpl::setGlobalVoxelSizeAndOrigin(const nanovdb::Vec3d &voxelSize, + const nanovdb::Vec3d &voxelOrigin, bool syncToDevice) { + TORCH_CHECK(batchSize() > 0, + "Cannot set global voxel size and origin on an empty batch of grids"); for (size_t i = 0; i < mHostGridMetadata.size(); i++) { mHostGridMetadata[i].setTransform(voxelSize, voxelOrigin); @@ -231,9 +260,11 @@ void GridBatchImpl::setGlobalVoxelSizeAndOrigin(const nanovdb::Vec3d& voxelSize, } } - -void GridBatchImpl::setFineTransformFromCoarseGrid(const GridBatchImpl& coarseBatch, nanovdb::Coord subdivisionFactor) { - TORCH_CHECK(coarseBatch.batchSize() == batchSize(), "Coarse grid batch size must match fine grid batch size"); +void +GridBatchImpl::setFineTransformFromCoarseGrid(const GridBatchImpl &coarseBatch, + nanovdb::Coord subdivisionFactor) { + TORCH_CHECK(coarseBatch.batchSize() == batchSize(), + "Coarse grid batch size must match fine grid batch size"); for (size_t i = 0; i < mHostGridMetadata.size(); i++) { auto sizeAndOrigin = coarseBatch.fineVoxSizeAndOrigin(i, subdivisionFactor); @@ -243,9 +274,11 @@ void GridBatchImpl::setFineTransformFromCoarseGrid(const GridBatchImpl& coarseBa syncMetadataToDeviceIfCUDA(false); } - -void GridBatchImpl::setCoarseTransformFromFineGrid(const GridBatchImpl& fineBatch, nanovdb::Coord coarseningFactor) { - TORCH_CHECK(fineBatch.batchSize() == batchSize(), "Fine grid batch size must match coarse grid batch size"); +void +GridBatchImpl::setCoarseTransformFromFineGrid(const GridBatchImpl &fineBatch, + nanovdb::Coord coarseningFactor) { + TORCH_CHECK(fineBatch.batchSize() == batchSize(), + "Fine grid batch size must match coarse grid batch size"); for (size_t i = 0; i < mHostGridMetadata.size(); i++) { auto sizeAndOrigin = fineBatch.coarseVoxSizeAndOrigin(i, coarseningFactor); @@ -255,29 +288,34 @@ void GridBatchImpl::setCoarseTransformFromFineGrid(const GridBatchImpl& fineBatc syncMetadataToDeviceIfCUDA(false); } - -void GridBatchImpl::setPrimalTransformFromDualGrid(const GridBatchImpl& dualBatch) { - TORCH_CHECK(dualBatch.batchSize() == batchSize(), "Dual grid batch size must match primal grid batch size"); +void +GridBatchImpl::setPrimalTransformFromDualGrid(const GridBatchImpl &dualBatch) { + TORCH_CHECK(dualBatch.batchSize() == batchSize(), + "Dual grid batch size must match primal grid batch size"); for (size_t i = 0; i < mHostGridMetadata.size(); i++) { - mHostGridMetadata[i].mDualTransform = dualBatch.mHostGridMetadata[i].mPrimalTransform; + mHostGridMetadata[i].mDualTransform = dualBatch.mHostGridMetadata[i].mPrimalTransform; mHostGridMetadata[i].mPrimalTransform = dualBatch.mHostGridMetadata[i].mDualTransform; - mHostGridMetadata[i].mVoxelSize = dualBatch.mHostGridMetadata[i].mVoxelSize; + mHostGridMetadata[i].mVoxelSize = dualBatch.mHostGridMetadata[i].mVoxelSize; } syncMetadataToDeviceIfCUDA(false); } - -void GridBatchImpl::setGrid(nanovdb::GridHandle&& gridHdl, - const torch::Tensor listIndices, - const std::vector& voxelSizes, - const std::vector& voxelOrigins, - bool blocking) { +void +GridBatchImpl::setGrid(nanovdb::GridHandle &&gridHdl, + const torch::Tensor listIndices, + const std::vector &voxelSizes, + const std::vector &voxelOrigins, bool blocking) { TORCH_CHECK(!gridHdl.buffer().isEmpty(), "Empty grid handle"); - TORCH_CHECK(voxelSizes.size() == gridHdl.gridCount(), "voxelSizes array does not have the same size as the number of grids, got ", voxelSizes.size(), " expected ", gridHdl.gridCount()); - TORCH_CHECK(voxelOrigins.size() == gridHdl.gridCount(), "Voxel origins must be the same size as the number of grids"); - TORCH_CHECK((gridHdl.gridType(0) == nanovdb::GridType::OnIndex) || (gridHdl.gridType(0) == nanovdb::GridType::OnIndexMask), "GridBatchImpl only supports ValueOnIndex and ValueOnIndexMask grids"); + TORCH_CHECK(voxelSizes.size() == gridHdl.gridCount(), + "voxelSizes array does not have the same size as the number of grids, got ", + voxelSizes.size(), " expected ", gridHdl.gridCount()); + TORCH_CHECK(voxelOrigins.size() == gridHdl.gridCount(), + "Voxel origins must be the same size as the number of grids"); + TORCH_CHECK((gridHdl.gridType(0) == nanovdb::GridType::OnIndex) || + (gridHdl.gridType(0) == nanovdb::GridType::OnIndexMask), + "GridBatchImpl only supports ValueOnIndex and ValueOnIndexMask grids"); const torch::Device device = gridHdl.buffer().device(); // Clear out old grid metadata @@ -292,31 +330,35 @@ void GridBatchImpl::setGrid(nanovdb::GridHandle&& gridHdl, FVDB_DISPATCH_KERNEL_DEVICE(device, [&]() { // Allocate device memory for metadata - GridBatchMetadata* deviceBatchMetadataPtr = nullptr; + GridBatchMetadata *deviceBatchMetadataPtr = nullptr; if constexpr (DeviceTag == torch::kCUDA) { c10::cuda::CUDAGuard deviceGuard(device); - const size_t metaDataByteSize = sizeof(GridMetadata) * gridHdl.gridCount(); - at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(device.index()); - mDeviceGridMetadata = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(metaDataByteSize, defaultStream.stream())); - deviceBatchMetadataPtr = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(sizeof(GridBatchMetadata), defaultStream.stream())); + const size_t metaDataByteSize = sizeof(GridMetadata) * gridHdl.gridCount(); + at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(device.index()); + mDeviceGridMetadata = + static_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream( + metaDataByteSize, defaultStream.stream())); + deviceBatchMetadataPtr = static_cast( + c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(sizeof(GridBatchMetadata), + defaultStream.stream())); } // Populate host and/or device metadata const bool isGridMutable = gridHdl.gridType(0) == nanovdb::GridType::OnIndexMask; ops::dispatchPopulateGridMetadata( - gridHdl, voxelSizes, voxelOrigins, isGridMutable, - mBatchOffsets, + gridHdl, voxelSizes, voxelOrigins, isGridMutable, mBatchOffsets, mHostGridMetadata.data(), mDeviceGridMetadata, &mBatchMetadata, deviceBatchMetadataPtr); - TORCH_CHECK(listIndices.numel() == 0 || listIndices.size(0) == (mBatchOffsets.size(0) - 1), "Invalid list indices when building grid"); + TORCH_CHECK(listIndices.numel() == 0 || listIndices.size(0) == (mBatchOffsets.size(0) - 1), + "Invalid list indices when building grid"); mListIndices = listIndices; - // We don't need the device copy of the global batch metadata anymore (we only carry around the host version and pass it by value to device kernels), so delete it + // We don't need the device copy of the global batch metadata anymore (we only carry around + // the host version and pass it by value to device kernels), so delete it if constexpr (DeviceTag == torch::kCUDA) { c10::cuda::CUDACachingAllocator::raw_delete(deviceBatchMetadataPtr); } }); - // FIXME: This is slow // Populate batch offsets for each leaf node { @@ -324,8 +366,7 @@ void GridBatchImpl::setGrid(nanovdb::GridHandle&& gridHdl, leafBatchIdxs.reserve(gridHdl.gridCount()); for (uint32_t i = 0; i < gridHdl.gridCount(); i += 1) { leafBatchIdxs.push_back( - torch::full({mHostGridMetadata[i].mNumLeaves}, - static_cast(i), + torch::full({ mHostGridMetadata[i].mNumLeaves }, static_cast(i), torch::TensorOptions().dtype(fvdb::JIdxScalarType).device(device))); } mLeafBatchIndices = torch::cat(leafBatchIdxs, 0); @@ -335,39 +376,43 @@ void GridBatchImpl::setGrid(nanovdb::GridHandle&& gridHdl, mGridHdl = std::make_shared>(std::move(gridHdl)); } - -c10::intrusive_ptr GridBatchImpl::index(int64_t bi) const { +c10::intrusive_ptr +GridBatchImpl::index(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); - return index(bi, bi+1, 1); + return index(bi, bi + 1, 1); } - -c10::intrusive_ptr GridBatchImpl::index(const torch::Tensor& indices) const { +c10::intrusive_ptr +GridBatchImpl::index(const torch::Tensor &indices) const { TORCH_CHECK_INDEX(indices.dim() == 1, "indices must be a 1D tensor"); TORCH_CHECK_INDEX(!indices.is_floating_point(), "indices must be an integer tensor"); torch::Tensor numericIndices; - if(indices.scalar_type() == torch::kBool) { + if (indices.scalar_type() == torch::kBool) { TORCH_CHECK_INDEX(indices.dim() == 1, "bool indices must be a 1D tensor"); - TORCH_CHECK_INDEX(indices.numel() == batchSize(), "bool indices must have the same number of entries as grids in the batch"); - numericIndices = torch::arange(batchSize(), torch::TensorOptions().dtype(torch::kInt64).device(indices.device())); + TORCH_CHECK_INDEX( + indices.numel() == batchSize(), + "bool indices must have the same number of entries as grids in the batch"); + numericIndices = torch::arange( + batchSize(), torch::TensorOptions().dtype(torch::kInt64).device(indices.device())); numericIndices = numericIndices.masked_select(indices); } else { numericIndices = indices; } - torch::Tensor indicesCpu = numericIndices.to(torch::kCPU).to(torch::kInt64); - auto indicesAccessor = indicesCpu.accessor(); + torch::Tensor indicesCpu = numericIndices.to(torch::kCPU).to(torch::kInt64); + auto indicesAccessor = indicesCpu.accessor(); return indexInternal(indicesAccessor, indicesAccessor.size(0)); } - -c10::intrusive_ptr GridBatchImpl::index(const std::vector& indices) const { +c10::intrusive_ptr +GridBatchImpl::index(const std::vector &indices) const { return indexInternal(indices, indices.size()); } -c10::intrusive_ptr GridBatchImpl::index(const std::vector& indices) const { +c10::intrusive_ptr +GridBatchImpl::index(const std::vector &indices) const { std::vector indicesInt; indicesInt.reserve(indices.size()); for (size_t i = 0; i < indices.size(); i += 1) { @@ -379,15 +424,16 @@ c10::intrusive_ptr GridBatchImpl::index(const std::vector& return indexInternal(indicesInt, indicesInt.size()); } - -c10::intrusive_ptr GridBatchImpl::index(ssize_t start, ssize_t stop, ssize_t step) const { +c10::intrusive_ptr +GridBatchImpl::index(ssize_t start, ssize_t stop, ssize_t step) const { struct RangeAccessor { ssize_t mStart; ssize_t mStop; ssize_t mStep; ssize_t mLen; - RangeAccessor(ssize_t start, ssize_t stop, ssize_t step, ssize_t batchSize) : mStart(start), mStop(stop), mStep(step) { + RangeAccessor(ssize_t start, ssize_t stop, ssize_t step, ssize_t batchSize) + : mStart(start), mStop(stop), mStep(step) { TORCH_CHECK_INDEX(step != 0, "slice step cannot be zero"); TORCH_CHECK_INDEX(0 <= start && start <= batchSize, "slice index out of range"); TORCH_CHECK_INDEX(-1 <= stop && stop <= batchSize, "slice index out of range"); @@ -399,10 +445,12 @@ c10::intrusive_ptr GridBatchImpl::index(ssize_t start, ssize_t st } else if (stop <= start && step < 0) { mLen = (mStart - mStop - mStep - 1) / -mStep; } else { - TORCH_CHECK_INDEX(false, "Invalid slice start=", start, ", stop=", stop, ", step=", step, " for batch size ", batchSize); + TORCH_CHECK_INDEX(false, "Invalid slice start=", start, ", stop=", stop, + ", step=", step, " for batch size ", batchSize); } } - size_t operator[](size_t i) const { + size_t + operator[](size_t i) const { return mStart + i * mStep; } }; @@ -411,21 +459,19 @@ c10::intrusive_ptr GridBatchImpl::index(ssize_t start, ssize_t st return indexInternal(acc, acc.mLen); } - -c10::intrusive_ptr GridBatchImpl::concatenate( - const std::vector>& elements) { - +c10::intrusive_ptr +GridBatchImpl::concatenate(const std::vector> &elements) { TORCH_CHECK_VALUE(elements.size() > 0, "Must provide at least one grid for concatenate!") - torch::Device device = elements[0]->device(); - bool isMutable = elements[0]->isMutable(); + torch::Device device = elements[0]->device(); + bool isMutable = elements[0]->isMutable(); std::vector>> handles; - std::vector> byteSizes; - std::vector> readByteOffsets; - std::vector> writeByteOffsets; - int64_t totalByteSize = 0; - int64_t totalGrids = 0; + std::vector> byteSizes; + std::vector> readByteOffsets; + std::vector> writeByteOffsets; + int64_t totalByteSize = 0; + int64_t totalGrids = 0; handles.reserve(elements.size()); byteSizes.reserve(elements.size()); readByteOffsets.reserve(elements.size()); @@ -434,8 +480,10 @@ c10::intrusive_ptr GridBatchImpl::concatenate( std::vector voxelSizes, voxelOrigins; for (size_t i = 0; i < elements.size(); i += 1) { - TORCH_CHECK(elements[i]->device() == device, "All grid batches must be on the same device!"); - TORCH_CHECK(elements[i]->isMutable() == isMutable, "All grid batches must have the same mutability!"); + TORCH_CHECK(elements[i]->device() == device, + "All grid batches must be on the same device!"); + TORCH_CHECK(elements[i]->isMutable() == isMutable, + "All grid batches must have the same mutability!"); // Empty grids don't contribute to the concatenation if (elements[i]->batchSize() == 0) { @@ -457,21 +505,22 @@ c10::intrusive_ptr GridBatchImpl::concatenate( voxelSizes.push_back(elements[i]->voxelSize(j)); voxelOrigins.push_back(elements[i]->voxelOrigin(j)); - readByteOffsets.back().push_back(elements[i]->cumBytes(j)); // Where to start reading from in the current grid - byteSizes.back().push_back(elements[i]->numBytes(j)); // How many bytes to read - writeByteOffsets.back().push_back(totalByteSize); // Where to start writing to in the concatenated grid + readByteOffsets.back().push_back( + elements[i]->cumBytes(j)); // Where to start reading from in the current grid + byteSizes.back().push_back(elements[i]->numBytes(j)); // How many bytes to read + writeByteOffsets.back().push_back( + totalByteSize); // Where to start writing to in the concatenated grid totalByteSize += elements[i]->numBytes(j); } - } if (handles.size() == 0) { return c10::make_intrusive(device, isMutable); } - const bool isHost = device.is_cpu(); + const bool isHost = device.is_cpu(); TorchDeviceBuffer buffer(totalByteSize, nullptr, isHost, device.index()); - int count = 0; + int count = 0; int nonEmptyCount = 0; if (isHost) { for (size_t i = 0; i < elements.size(); i += 1) { @@ -480,50 +529,53 @@ c10::intrusive_ptr GridBatchImpl::concatenate( } for (size_t j = 0; j < elements[i]->batchSize(); j += 1) { - const int64_t readOffset = readByteOffsets[nonEmptyCount][j]; + const int64_t readOffset = readByteOffsets[nonEmptyCount][j]; const int64_t writeOffset = writeByteOffsets[nonEmptyCount][j]; - const int64_t numBytes = byteSizes[nonEmptyCount][j]; + const int64_t numBytes = byteSizes[nonEmptyCount][j]; - nanovdb::GridData* dst = reinterpret_cast(buffer.data() + writeOffset); - const uint8_t* src = elements[i]->mGridHdl->buffer().data() + readOffset; - memcpy((void*) dst, (void*) src, numBytes); + nanovdb::GridData *dst = + reinterpret_cast(buffer.data() + writeOffset); + const uint8_t *src = elements[i]->mGridHdl->buffer().data() + readOffset; + memcpy((void *)dst, (void *)src, numBytes); nanovdb::tools::updateGridCount(dst, count++, totalGrids); } nonEmptyCount += 1; } - } - else { + } else { for (size_t i = 0; i < elements.size(); i += 1) { if (elements[i]->batchSize() == 0) { continue; } for (size_t j = 0; j < elements[i]->batchSize(); j += 1) { - const int64_t readOffset = readByteOffsets[nonEmptyCount][j]; + const int64_t readOffset = readByteOffsets[nonEmptyCount][j]; const int64_t writeOffset = writeByteOffsets[nonEmptyCount][j]; - const int64_t numBytes = byteSizes[nonEmptyCount][j]; + const int64_t numBytes = byteSizes[nonEmptyCount][j]; c10::cuda::CUDAGuard deviceGuard(device.index()); - nanovdb::GridData* dst = reinterpret_cast(buffer.deviceData() + writeOffset); - const uint8_t* src = elements[i]->mGridHdl->buffer().deviceData() + readOffset; - cudaMemcpyAsync((uint8_t*) dst, src, numBytes, cudaMemcpyDeviceToDevice); + nanovdb::GridData *dst = + reinterpret_cast(buffer.deviceData() + writeOffset); + const uint8_t *src = elements[i]->mGridHdl->buffer().deviceData() + readOffset; + cudaMemcpyAsync((uint8_t *)dst, src, numBytes, cudaMemcpyDeviceToDevice); bool dirty, *d_dirty; - cudaMallocAsync((void**)&d_dirty, sizeof(bool), 0); + cudaMallocAsync((void **)&d_dirty, sizeof(bool), 0); nanovdb::cuda::updateGridCount<<<1, 1>>>(dst, count++, totalGrids, d_dirty); C10_CUDA_KERNEL_LAUNCH_CHECK(); cudaMemcpyAsync(&dirty, d_dirty, sizeof(bool), cudaMemcpyDeviceToHost); - if (dirty) nanovdb::tools::cuda::updateChecksum(dst, nanovdb::CheckMode::Partial); + if (dirty) + nanovdb::tools::cuda::updateChecksum(dst, nanovdb::CheckMode::Partial); } nonEmptyCount += 1; } } - nanovdb::GridHandle gridHdl = nanovdb::GridHandle(std::move(buffer)); + nanovdb::GridHandle gridHdl = + nanovdb::GridHandle(std::move(buffer)); return c10::make_intrusive(std::move(gridHdl), voxelSizes, voxelOrigins); } - -c10::intrusive_ptr GridBatchImpl::contiguous(c10::intrusive_ptr input) { +c10::intrusive_ptr +GridBatchImpl::contiguous(c10::intrusive_ptr input) { if (input->isContiguous()) { return input; } @@ -535,10 +587,10 @@ c10::intrusive_ptr GridBatchImpl::contiguous(c10::intrusive_ptrnumBytes(i); } - const bool isHost = input->device().is_cpu(); + const bool isHost = input->device().is_cpu(); TorchDeviceBuffer buffer(totalByteSize, nullptr, isHost, input->device().index()); - int64_t writeOffset = 0; + int64_t writeOffset = 0; std::vector voxelSizes, voxelOrigins; voxelSizes.reserve(input->batchSize()); voxelOrigins.reserve(input->batchSize()); @@ -548,39 +600,42 @@ c10::intrusive_ptr GridBatchImpl::contiguous(c10::intrusive_ptrvoxelSize(i)); voxelOrigins.push_back(input->voxelOrigin(i)); - nanovdb::GridData* dst = reinterpret_cast(buffer.data() + writeOffset); - const uint8_t* src = input->nanoGridHandle().buffer().data() + input->cumBytes(i); - memcpy((void*) dst, (void*) src, input->numBytes(i)); + nanovdb::GridData *dst = + reinterpret_cast(buffer.data() + writeOffset); + const uint8_t *src = input->nanoGridHandle().buffer().data() + input->cumBytes(i); + memcpy((void *)dst, (void *)src, input->numBytes(i)); nanovdb::tools::updateGridCount(dst, i, totalGrids); writeOffset += input->numBytes(i); } - } - else { + } else { for (size_t i = 0; i < input->batchSize(); i += 1) { voxelSizes.push_back(input->voxelSize(i)); voxelOrigins.push_back(input->voxelOrigin(i)); c10::cuda::CUDAGuard deviceGuard(input->device().index()); - nanovdb::GridData* dst = reinterpret_cast(buffer.deviceData() + writeOffset); - const uint8_t* src = input->nanoGridHandle().buffer().deviceData() + input->cumBytes(i); - cudaMemcpyAsync((uint8_t*) dst, src, input->numBytes(i), cudaMemcpyDeviceToDevice); + nanovdb::GridData *dst = + reinterpret_cast(buffer.deviceData() + writeOffset); + const uint8_t *src = input->nanoGridHandle().buffer().deviceData() + input->cumBytes(i); + cudaMemcpyAsync((uint8_t *)dst, src, input->numBytes(i), cudaMemcpyDeviceToDevice); bool dirty, *d_dirty; - cudaMallocAsync((void**)&d_dirty, sizeof(bool), 0); + cudaMallocAsync((void **)&d_dirty, sizeof(bool), 0); nanovdb::cuda::updateGridCount<<<1, 1>>>(dst, i, totalGrids, d_dirty); C10_CUDA_KERNEL_LAUNCH_CHECK(); cudaMemcpyAsync(&dirty, d_dirty, sizeof(bool), cudaMemcpyDeviceToHost); - if (dirty) nanovdb::tools::cuda::updateChecksum(dst, nanovdb::CheckMode::Partial); + if (dirty) + nanovdb::tools::cuda::updateChecksum(dst, nanovdb::CheckMode::Partial); writeOffset += input->numBytes(i); } } - return c10::make_intrusive(nanovdb::GridHandle(std::move(buffer)), voxelSizes, voxelOrigins); + return c10::make_intrusive( + nanovdb::GridHandle(std::move(buffer)), voxelSizes, voxelOrigins); } - -JaggedTensor GridBatchImpl::jaggedTensor(const torch::Tensor& data, bool ignoreDisabledVoxels) const { +JaggedTensor +GridBatchImpl::jaggedTensor(const torch::Tensor &data, bool ignoreDisabledVoxels) const { checkDevice(data); TORCH_CHECK(data.dim() >= 1, "Data have more than one dimensions"); if (ignoreDisabledVoxels || !isMutable()) { @@ -588,45 +643,52 @@ JaggedTensor GridBatchImpl::jaggedTensor(const torch::Tensor& data, bool ignoreD } else { // TODO: (@fwilliams) check data size need to call totalActiveVoxels() } - return JaggedTensor::from_data_offsets_and_list_ids(data, voxelOffsets(ignoreDisabledVoxels), jlidx(ignoreDisabledVoxels)); + return JaggedTensor::from_data_offsets_and_list_ids(data, voxelOffsets(ignoreDisabledVoxels), + jlidx(ignoreDisabledVoxels)); } - -int64_t GridBatchImpl::totalEnabledVoxels(bool ignoreDisabledVoxels) const { +int64_t +GridBatchImpl::totalEnabledVoxels(bool ignoreDisabledVoxels) const { if (!isMutable() || ignoreDisabledVoxels) { return totalVoxels(); } - return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return ops::dispatchCountEnabledVoxels(*this, -1); - }); + return FVDB_DISPATCH_KERNEL_DEVICE( + device(), [&]() { return ops::dispatchCountEnabledVoxels(*this, -1); }); } - -torch::Tensor GridBatchImpl::jidx(bool ignoreDisabledVoxels) const { +torch::Tensor +GridBatchImpl::jidx(bool ignoreDisabledVoxels) const { return FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { if (batchSize() == 1 || totalVoxels() == 0) { - return torch::empty({0}, torch::TensorOptions().dtype(fvdb::JIdxScalarType).device(device())); + return torch::empty( + { 0 }, torch::TensorOptions().dtype(fvdb::JIdxScalarType).device(device())); } return ops::dispatchJIdxForGrid(*this, ignoreDisabledVoxels); }); } -torch::Tensor GridBatchImpl::jlidx(bool ignoreDisabledVoxels) const { +torch::Tensor +GridBatchImpl::jlidx(bool ignoreDisabledVoxels) const { return mListIndices; } -torch::Tensor GridBatchImpl::voxelOffsets(bool ignoreDisabledVoxels) const { +torch::Tensor +GridBatchImpl::voxelOffsets(bool ignoreDisabledVoxels) const { if (!isMutable() || ignoreDisabledVoxels) { return mBatchOffsets; - } else { + } else { // FIXME: This is slow for mutable grids - TORCH_CHECK(isMutable(), "This grid is not mutable, cannot get voxel offsets. This should never happen."); - torch::Tensor numEnabledPerGrid = torch::empty({batchSize() + 1}, torch::TensorOptions().dtype(fvdb::JOffsetsScalarType).device(torch::kCPU)); + TORCH_CHECK( + isMutable(), + "This grid is not mutable, cannot get voxel offsets. This should never happen."); + torch::Tensor numEnabledPerGrid = torch::empty( + { batchSize() + 1 }, + torch::TensorOptions().dtype(fvdb::JOffsetsScalarType).device(torch::kCPU)); auto acc = numEnabledPerGrid.accessor(); - acc[0] = 0; + acc[0] = 0; for (int i = 1; i < (batchSize() + 1); i += 1) { acc[i] = FVDB_DISPATCH_KERNEL_DEVICE(device(), [&]() { - return ops::dispatchCountEnabledVoxels(*this, i-1); + return ops::dispatchCountEnabledVoxels(*this, i - 1); }); } numEnabledPerGrid = numEnabledPerGrid.to(device()); @@ -634,39 +696,43 @@ torch::Tensor GridBatchImpl::voxelOffsets(bool ignoreDisabledVoxels) const { } } -torch::Tensor GridBatchImpl::serialize() const { +torch::Tensor +GridBatchImpl::serialize() const { return serializeV0(); } -c10::intrusive_ptr GridBatchImpl::deserialize(const torch::Tensor& serialized) { +c10::intrusive_ptr +GridBatchImpl::deserialize(const torch::Tensor &serialized) { return deserializeV0(serialized); } - -torch::Tensor GridBatchImpl::serializeV0() const { - c10::intrusive_ptr self = c10::intrusive_ptr::reclaim_copy((GridBatchImpl*) this); +torch::Tensor +GridBatchImpl::serializeV0() const { + c10::intrusive_ptr self = + c10::intrusive_ptr::reclaim_copy((GridBatchImpl *)this); if (!device().is_cpu()) { self = clone(torch::kCPU, true); } - int64_t numGrids = self->nanoGridHandle().gridCount(); + int64_t numGrids = self->nanoGridHandle().gridCount(); int64_t hdlBufSize = self->nanoGridHandle().buffer().size(); struct V01Header { - uint64_t magic = 0x0F0F0F0F0F0F0F0F; + uint64_t magic = 0x0F0F0F0F0F0F0F0F; uint64_t version = 0; uint64_t numGrids; uint64_t totalBytes; } header; - const int64_t headerSize = sizeof(V01Header) + numGrids * sizeof(GridMetadata) + sizeof(GridBatchMetadata); + const int64_t headerSize = + sizeof(V01Header) + numGrids * sizeof(GridMetadata) + sizeof(GridBatchMetadata); const int64_t totalByteSize = headerSize + hdlBufSize; header.totalBytes = totalByteSize; - header.numGrids = numGrids; + header.numGrids = numGrids; - torch::Tensor ret = torch::empty({totalByteSize}, torch::kInt8); - int8_t* retPtr = ret.data_ptr(); + torch::Tensor ret = torch::empty({ totalByteSize }, torch::kInt8); + int8_t *retPtr = ret.data_ptr(); memcpy(retPtr, &header, sizeof(V01Header)); retPtr += sizeof(V01Header); @@ -680,41 +746,49 @@ torch::Tensor GridBatchImpl::serializeV0() const { memcpy(retPtr, self->nanoGridHandle().buffer().data(), hdlBufSize); retPtr += hdlBufSize; - TORCH_CHECK(retPtr == (ret.data_ptr() + totalByteSize), "Something went wrong with serialization"); + TORCH_CHECK(retPtr == (ret.data_ptr() + totalByteSize), + "Something went wrong with serialization"); return ret; } -c10::intrusive_ptr GridBatchImpl::deserializeV0(const torch::Tensor& serialized) { +c10::intrusive_ptr +GridBatchImpl::deserializeV0(const torch::Tensor &serialized) { struct V01Header { - uint64_t magic = 0x0F0F0F0F0F0F0F0F; + uint64_t magic = 0x0F0F0F0F0F0F0F0F; uint64_t version = 0; uint64_t numGrids; uint64_t totalBytes; }; TORCH_CHECK(serialized.scalar_type() == torch::kInt8, "Serialized data must be of type int8"); - TORCH_CHECK(serialized.numel() >= sizeof(V01Header), "Serialized data is too small to be a valid grid handle"); + TORCH_CHECK(serialized.numel() >= sizeof(V01Header), + "Serialized data is too small to be a valid grid handle"); - const int8_t* serializedPtr = serialized.data_ptr(); + const int8_t *serializedPtr = serialized.data_ptr(); - const V01Header* header = reinterpret_cast(serializedPtr); + const V01Header *header = reinterpret_cast(serializedPtr); TORCH_CHECK(header->magic == 0x0F0F0F0F0F0F0F0F, "Serialized data is not a valid grid handle"); TORCH_CHECK(header->version == 0, "Serialized data is not a valid grid handle"); - TORCH_CHECK(serialized.numel() == header->totalBytes, "Serialized data is not a valid grid handle"); + TORCH_CHECK(serialized.numel() == header->totalBytes, + "Serialized data is not a valid grid handle"); const uint64_t numGrids = header->numGrids; - const GridBatchMetadata* batchMetadata = reinterpret_cast(serializedPtr + sizeof(V01Header)); + const GridBatchMetadata *batchMetadata = + reinterpret_cast(serializedPtr + sizeof(V01Header)); TORCH_CHECK(batchMetadata->version == 1, "Serialized data is not a valid grid handle"); - const GridMetadata* gridMetadata = reinterpret_cast(serializedPtr + sizeof(V01Header) + sizeof(GridBatchMetadata)); + const GridMetadata *gridMetadata = reinterpret_cast( + serializedPtr + sizeof(V01Header) + sizeof(GridBatchMetadata)); for (uint64_t i = 0; i < numGrids; i += 1) { TORCH_CHECK(gridMetadata[i].version == 1, "Serialized data is not a valid grid handle"); } - const int8_t* gridBuffer = serializedPtr + sizeof(V01Header) + sizeof(GridBatchMetadata) + numGrids * sizeof(GridMetadata); + const int8_t *gridBuffer = serializedPtr + sizeof(V01Header) + sizeof(GridBatchMetadata) + + numGrids * sizeof(GridMetadata); - const uint64_t sizeofMetadata = sizeof(V01Header) + sizeof(GridBatchMetadata) + numGrids * sizeof(GridMetadata); + const uint64_t sizeofMetadata = + sizeof(V01Header) + sizeof(GridBatchMetadata) + numGrids * sizeof(GridMetadata); const uint64_t sizeofGrid = header->totalBytes - sizeofMetadata; auto buf = TorchDeviceBuffer(sizeofGrid, nullptr, true /* host */, -1 /* deviceIndex */); diff --git a/fvdb/src/detail/GridBatchImpl.h b/fvdb/src/detail/GridBatchImpl.h index 72a678e58e..dbff710882 100644 --- a/fvdb/src/detail/GridBatchImpl.h +++ b/fvdb/src/detail/GridBatchImpl.h @@ -1,51 +1,55 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once -#include +#ifndef FVDB_DETAIL_GRIDBATCHIMPL_H +#define FVDB_DETAIL_GRIDBATCHIMPL_H -#include +#include "TorchDeviceBuffer.h" +#include "VoxelCoordTransform.h" +#include "utils/Utils.h" + +#include -#include #include +#include -#include "VoxelCoordTransform.h" -#include "JaggedTensor.h" -#include "detail/utils/Utils.h" -#include "detail/TorchDeviceBuffer.h" +#include +#include #if !defined(__CUDACC__) && !defined(__restrict__) #define __restrict__ #endif - namespace fvdb { namespace detail { class GridBatchImpl : public torch::CustomClassHolder { - -public: + public: // Metadata about a single grid in the batch struct GridMetadata { - uint32_t version = 1; // Version of this struct - - int64_t mCumLeaves = 0; // Cumulative number of leaf nodes in the batch up to this grid - int64_t mCumVoxels = 0; // Cumulative number of voxels in the batch up to this grid - uint64_t mCumBytes = 0; // Cumulative number of bytes in the buffer of grids up to this grid - VoxelCoordTransform mPrimalTransform; // Primal Transform of this grid (i.e. transform which aligns origin with voxel center) - VoxelCoordTransform mDualTransform; // Dual Transform of this grid (i.e. transform which aligns origin with voxel corner) - nanovdb::Vec3d mVoxelSize; // Size of a single voxel in world space - uint32_t mNumLeaves; // Number of leaf nodes in this grid - int64_t mNumVoxels; // Number of voxels in this grid - uint64_t mNumBytes; // Number of bytes in the buffer of this grid - nanovdb::CoordBBox mBBox; // Bounding box of this grid - - nanovdb::Vec3d voxelOrigin() const { + uint32_t version = 1; // Version of this struct + + int64_t mCumLeaves = 0; // Cumulative number of leaf nodes in the batch up to this grid + int64_t mCumVoxels = 0; // Cumulative number of voxels in the batch up to this grid + uint64_t mCumBytes = 0; // Cumulative number of bytes in the buffer of grids up to this grid + VoxelCoordTransform mPrimalTransform; // Primal Transform of this grid (i.e. transform which + // aligns origin with voxel center) + VoxelCoordTransform mDualTransform; // Dual Transform of this grid (i.e. transform which + // aligns origin with voxel corner) + nanovdb::Vec3d mVoxelSize; // Size of a single voxel in world space + uint32_t mNumLeaves; // Number of leaf nodes in this grid + int64_t mNumVoxels; // Number of voxels in this grid + uint64_t mNumBytes; // Number of bytes in the buffer of this grid + nanovdb::CoordBBox mBBox; // Bounding box of this grid + + nanovdb::Vec3d + voxelOrigin() const { return mPrimalTransform.applyInv(0., 0., 0.); } - __hostdev__ void setTransform(const nanovdb::Vec3d& voxSize, const nanovdb::Vec3d& voxOrigin) { + __hostdev__ void + setTransform(const nanovdb::Vec3d &voxSize, const nanovdb::Vec3d &voxOrigin) { mVoxelSize = voxSize; voxelTransformForSizeAndOrigin(voxSize, voxOrigin, mPrimalTransform, mDualTransform); } @@ -77,63 +81,72 @@ class GridBatchImpl : public torch::CustomClassHolder { bool mIsContiguous = true; }; - -private: + private: // Metadata for each grid in the batch. There is a seperate host and device version of these. - // The caller of this class sets the host version and is responsible for syncing the device version - // with the host version by calling syncMetadataToDeviceIfCUDA - std::vector mHostGridMetadata; // CPU only - GridMetadata* mDeviceGridMetadata = nullptr; // CUDA only + // The caller of this class sets the host version and is responsible for syncing the device + // version with the host version by calling syncMetadataToDeviceIfCUDA + std::vector mHostGridMetadata; // CPU only + GridMetadata *mDeviceGridMetadata = nullptr; // CUDA only - GridBatchMetadata mBatchMetadata; // Metadata about the whole batch + GridBatchMetadata mBatchMetadata; // Metadata about the whole batch - std::shared_ptr> mGridHdl; // NanoVDB grid handle - torch::Tensor mLeafBatchIndices; // Indices of leaf nodes in the batch shape = [total_leafs] - torch::Tensor mBatchOffsets; // Batch indices for grid (ignores disabled) - torch::Tensor mListIndices; // List indices for grid (same as JaggedTensor, ignores disabled) + std::shared_ptr> mGridHdl; // NanoVDB grid handle + torch::Tensor mLeafBatchIndices; // Indices of leaf nodes in the batch shape = [total_leafs] + torch::Tensor mBatchOffsets; // Batch indices for grid (ignores disabled) + torch::Tensor mListIndices; // List indices for grid (same as JaggedTensor, ignores disabled) // Write back changes to host metadata to the device if we're a cuda handle void syncMetadataToDeviceIfCUDA(bool blocking); - inline std::pair fineVoxSizeAndOrigin(int64_t bi, nanovdb::Coord subdivFactor) const { - TORCH_CHECK(subdivFactor[0] > 0 && subdivFactor[1] > 0 && subdivFactor[2] > 0, "Subdivision factor must be greater than 0"); + inline std::pair + fineVoxSizeAndOrigin(int64_t bi, nanovdb::Coord subdivFactor) const { + TORCH_CHECK(subdivFactor[0] > 0 && subdivFactor[1] > 0 && subdivFactor[2] > 0, + "Subdivision factor must be greater than 0"); const nanovdb::Vec3d w = voxelSize(bi) / subdivFactor.asVec3d(); - const nanovdb::Vec3d tx = voxelOrigin(bi) - (subdivFactor.asVec3d() - nanovdb::Vec3d(1.0)) * w * 0.5; + const nanovdb::Vec3d tx = + voxelOrigin(bi) - (subdivFactor.asVec3d() - nanovdb::Vec3d(1.0)) * w * 0.5; return std::make_pair(w, tx); } - inline std::pair coarseVoxSizeAndOrigin(int64_t bi, nanovdb::Coord branchingFactor) const { - TORCH_CHECK(branchingFactor[0] > 0 && branchingFactor[1] > 0 && branchingFactor[2] > 0, "Coarsening factor must be greater than 0"); + inline std::pair + coarseVoxSizeAndOrigin(int64_t bi, nanovdb::Coord branchingFactor) const { + TORCH_CHECK(branchingFactor[0] > 0 && branchingFactor[1] > 0 && branchingFactor[2] > 0, + "Coarsening factor must be greater than 0"); const nanovdb::Vec3d w = branchingFactor.asVec3d() * voxelSize(bi); - const nanovdb::Vec3d tx = (branchingFactor.asVec3d() - nanovdb::Vec3d(1.0)) * voxelSize(bi) * 0.5 + voxelOrigin(bi); + const nanovdb::Vec3d tx = + (branchingFactor.asVec3d() - nanovdb::Vec3d(1.0)) * voxelSize(bi) * 0.5 + + voxelOrigin(bi); return std::make_pair(w, tx); } - inline int64_t negativeToPositiveIndexWithRangecheck(int64_t bi) const { + inline int64_t + negativeToPositiveIndexWithRangecheck(int64_t bi) const { if (bi < 0) { bi += batchSize(); } - TORCH_CHECK_INDEX(bi >= 0 && bi < batchSize(), "Batch index ", bi, " is out of range for grid batch of size " - + std::to_string(batchSize())); + TORCH_CHECK_INDEX(bi >= 0 && bi < batchSize(), "Batch index ", bi, + " is out of range for grid batch of size " + std::to_string(batchSize())); return static_cast(bi); } void recomputeBatchOffsets(); template - c10::intrusive_ptr indexInternal(const Indexable& idx, int64_t size) const { + c10::intrusive_ptr + indexInternal(const Indexable &idx, int64_t size) const { if (size == 0) { return c10::make_intrusive(device(), isMutable()); } - TORCH_CHECK(size >= 0, "Indexing with negative size is not supported (this should never happen)"); + TORCH_CHECK(size >= 0, + "Indexing with negative size is not supported (this should never happen)"); c10::intrusive_ptr ret = c10::make_intrusive(); - ret->mGridHdl = mGridHdl; + ret->mGridHdl = mGridHdl; - int64_t cumVoxels = 0; - int64_t cumLeaves = 0; - int64_t maxVoxels = 0; - uint32_t maxLeafCount = 0; - int64_t count = 0; + int64_t cumVoxels = 0; + int64_t cumLeaves = 0; + int64_t maxVoxels = 0; + uint32_t maxLeafCount = 0; + int64_t count = 0; nanovdb::CoordBBox totalBbox; std::vector leafBatchIdxs; @@ -142,18 +155,19 @@ class GridBatchImpl : public torch::CustomClassHolder { bool isContiguous = mBatchMetadata.mIsContiguous; for (size_t i = 0; i < size; i += 1) { int64_t bi = idx[i]; - bi = negativeToPositiveIndexWithRangecheck(bi); + bi = negativeToPositiveIndexWithRangecheck(bi); - // If indices are not contiguous or the grid we're viewing is not contiguous, then we're no longer contiguous + // If indices are not contiguous or the grid we're viewing is not contiguous, then we're + // no longer contiguous isContiguous = isContiguous && (bi == count); - const uint32_t numLeaves = mHostGridMetadata[bi].mNumLeaves; - const int64_t numVoxels = mHostGridMetadata[bi].mNumVoxels; - const nanovdb::CoordBBox& bbox = mHostGridMetadata[bi].mBBox; + const uint32_t numLeaves = mHostGridMetadata[bi].mNumLeaves; + const int64_t numVoxels = mHostGridMetadata[bi].mNumVoxels; + const nanovdb::CoordBBox &bbox = mHostGridMetadata[bi].mBBox; ret->mHostGridMetadata.push_back(mHostGridMetadata[bi]); ret->mHostGridMetadata[count].mCumLeaves = cumLeaves; - ret->mHostGridMetadata[count].mCumVoxels= cumVoxels; + ret->mHostGridMetadata[count].mCumVoxels = cumVoxels; if (count == 0) { totalBbox = bbox; @@ -162,20 +176,23 @@ class GridBatchImpl : public torch::CustomClassHolder { } cumLeaves += numLeaves; cumVoxels += numVoxels; - maxVoxels = std::max(maxVoxels, numVoxels); + maxVoxels = std::max(maxVoxels, numVoxels); maxLeafCount = std::max(maxLeafCount, numLeaves); - leafBatchIdxs.push_back(torch::full({numLeaves}, torch::Scalar(count), torch::TensorOptions().dtype(fvdb::JIdxScalarType).device(device()))); + leafBatchIdxs.push_back( + torch::full({ numLeaves }, torch::Scalar(count), + torch::TensorOptions().dtype(fvdb::JIdxScalarType).device(device()))); count += 1; } - // If all the indices were contiguous and the grid we're viewing is contiguous, then we're contiguous + // If all the indices were contiguous and the grid we're viewing is contiguous, then we're + // contiguous ret->mBatchMetadata.mIsContiguous = isContiguous && (count == batchSize()); - ret->mBatchMetadata.mTotalLeaves = cumLeaves; - ret->mBatchMetadata.mTotalVoxels = cumVoxels; - ret->mBatchMetadata.mMaxVoxels = maxVoxels; + ret->mBatchMetadata.mTotalLeaves = cumLeaves; + ret->mBatchMetadata.mTotalVoxels = cumVoxels; + ret->mBatchMetadata.mMaxVoxels = maxVoxels; ret->mBatchMetadata.mMaxLeafCount = maxLeafCount; - ret->mBatchMetadata.mTotalBBox = totalBbox; - ret->mBatchMetadata.mIsMutable = isMutable(); + ret->mBatchMetadata.mTotalBBox = totalBbox; + ret->mBatchMetadata.mIsMutable = isMutable(); if (leafBatchIdxs.size() > 0) { ret->mLeafBatchIndices = torch::cat(leafBatchIdxs, 0); @@ -192,123 +209,137 @@ class GridBatchImpl : public torch::CustomClassHolder { return ret; } -public: - template - class Accessor { + public: + template class Accessor { friend class GridBatchImpl; - const GridBatchImpl::GridMetadata* __restrict__ mMetadata = nullptr; // 8 bytes - const nanovdb::NanoGrid* __restrict__ mGridPtr = nullptr; // 8 bytes - fvdb::JIdxType* __restrict__ mLeafBatchIndices = nullptr; // 8 bytes - int64_t mTotalVoxels = 0; // 8 bytes - int64_t mTotalLeaves = 0; // 8 bytes - int64_t mMaxVoxels = 0; // 8 bytes - uint32_t mMaxLeafCount = 0; // 4 bytes - int64_t mGridCount = 0; // 8 bytes - - private: - __hostdev__ inline int64_t negativeToPositiveIndexWithRangecheck(int64_t bi) const { - if (bi < 0) { - bi += batchSize(); + const GridBatchImpl::GridMetadata *__restrict__ mMetadata = nullptr; // 8 bytes + const nanovdb::NanoGrid *__restrict__ mGridPtr = nullptr; // 8 bytes + fvdb::JIdxType *__restrict__ mLeafBatchIndices = nullptr; // 8 bytes + int64_t mTotalVoxels = 0; // 8 bytes + int64_t mTotalLeaves = 0; // 8 bytes + int64_t mMaxVoxels = 0; // 8 bytes + uint32_t mMaxLeafCount = 0; // 4 bytes + int64_t mGridCount = 0; // 8 bytes + + private: + __hostdev__ inline int64_t + negativeToPositiveIndexWithRangecheck(int64_t bi) const { + if (bi < 0) { + bi += batchSize(); + } + assert(bi >= 0 && bi < batchSize()); + return static_cast(bi); } - assert(bi >= 0 && bi < batchSize()); - return static_cast(bi); - } - public: - - __hostdev__ const nanovdb::NanoGrid* grid(int64_t bi) const { + public: + __hostdev__ const nanovdb::NanoGrid * + grid(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); - return reinterpret_cast*>( - reinterpret_cast(mGridPtr) + mMetadata[bi].mCumBytes); + return reinterpret_cast *>( + reinterpret_cast(mGridPtr) + mMetadata[bi].mCumBytes); } - __hostdev__ nanovdb::CoordBBox bbox(int64_t bi) const { + __hostdev__ nanovdb::CoordBBox + bbox(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return grid(bi)->tree().bbox(); } - __hostdev__ nanovdb::CoordBBox dualBbox(int64_t bi) const { + __hostdev__ nanovdb::CoordBBox + dualBbox(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); nanovdb::CoordBBox dualBbox(bbox(bi)); dualBbox.mCoord[1] += nanovdb::Coord(1, 1, 1); return dualBbox; } - __hostdev__ int64_t batchSize() const { + __hostdev__ int64_t + batchSize() const { return mGridCount; } - __hostdev__ int64_t voxelOffset(int64_t bi) const { + __hostdev__ int64_t + voxelOffset(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mMetadata[bi].mCumVoxels; } - __hostdev__ int64_t leafOffset(int64_t bi) const { + __hostdev__ int64_t + leafOffset(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mMetadata[bi].mCumLeaves; } - __hostdev__ int64_t maxVoxels() const { + __hostdev__ int64_t + maxVoxels() const { return mMaxVoxels; } - __hostdev__ uint32_t maxLeafCount() const { + __hostdev__ uint32_t + maxLeafCount() const { return mMaxLeafCount; } - __hostdev__ int64_t totalVoxels() const { + __hostdev__ int64_t + totalVoxels() const { return mTotalVoxels; } - __hostdev__ int64_t totalLeaves() const { + __hostdev__ int64_t + totalLeaves() const { return mTotalLeaves; } - __hostdev__ const VoxelCoordTransform& primalTransform(int64_t bi) const { + __hostdev__ const VoxelCoordTransform & + primalTransform(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mMetadata[bi].mPrimalTransform; } - __hostdev__ const VoxelCoordTransform& dualTransform(int64_t bi) const { + __hostdev__ const VoxelCoordTransform & + dualTransform(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mMetadata[bi].mDualTransform; } - __hostdev__ fvdb::JIdxType leafBatchIndex(int64_t leaf_idx) const { + __hostdev__ fvdb::JIdxType + leafBatchIndex(int64_t leaf_idx) const { return mLeafBatchIndices[leaf_idx]; } }; template - Accessor hostAccessor() const { + Accessor + hostAccessor() const { TORCH_CHECK(!isEmpty(), "Cannot access empty grid"); Accessor ret; ret.mMetadata = mHostGridMetadata.data(); - ret.mGridPtr = mGridHdl->template grid(); + ret.mGridPtr = mGridHdl->template grid(); TORCH_CHECK(ret.mGridPtr != nullptr, "Failed to get host grid pointer"); - ret.mTotalVoxels = mBatchMetadata.mTotalVoxels; - ret.mTotalLeaves = mBatchMetadata.mTotalLeaves; - ret.mMaxVoxels = mBatchMetadata.mMaxVoxels; - ret.mMaxLeafCount = mBatchMetadata.mMaxLeafCount; - ret.mGridCount = static_cast(mHostGridMetadata.size()); + ret.mTotalVoxels = mBatchMetadata.mTotalVoxels; + ret.mTotalLeaves = mBatchMetadata.mTotalLeaves; + ret.mMaxVoxels = mBatchMetadata.mMaxVoxels; + ret.mMaxLeafCount = mBatchMetadata.mMaxLeafCount; + ret.mGridCount = static_cast(mHostGridMetadata.size()); ret.mLeafBatchIndices = mLeafBatchIndices.data_ptr(); return ret; } template - Accessor deviceAccessor() const { + Accessor + deviceAccessor() const { TORCH_CHECK(!isEmpty(), "Cannot access empty grid"); TORCH_CHECK(device().is_cuda(), "Cannot access device accessor on non-CUDA device"); Accessor ret; ret.mMetadata = mDeviceGridMetadata; - ret.mGridPtr = mGridHdl->template deviceGrid(); + ret.mGridPtr = mGridHdl->template deviceGrid(); TORCH_CHECK(ret.mGridPtr != nullptr, "Failed to get device grid pointer"); - ret.mTotalVoxels = mBatchMetadata.mTotalVoxels; - ret.mTotalLeaves = mBatchMetadata.mTotalLeaves; - ret.mMaxVoxels = mBatchMetadata.mMaxVoxels; - ret.mMaxLeafCount = mBatchMetadata.mMaxLeafCount; - ret.mGridCount = static_cast(mHostGridMetadata.size()); + ret.mTotalVoxels = mBatchMetadata.mTotalVoxels; + ret.mTotalLeaves = mBatchMetadata.mTotalLeaves; + ret.mMaxVoxels = mBatchMetadata.mMaxVoxels; + ret.mMaxLeafCount = mBatchMetadata.mMaxLeafCount; + ret.mGridCount = static_cast(mHostGridMetadata.size()); ret.mLeafBatchIndices = mLeafBatchIndices.data_ptr(); return ret; @@ -318,22 +349,21 @@ class GridBatchImpl : public torch::CustomClassHolder { GridBatchImpl(torch::Device device, bool isMutable); - GridBatchImpl(nanovdb::GridHandle&& gridHdl, - const std::vector& voxelSizes, - const std::vector& voxelOrigins); + GridBatchImpl(nanovdb::GridHandle &&gridHdl, + const std::vector &voxelSizes, + const std::vector &voxelOrigins); - GridBatchImpl(nanovdb::GridHandle&& gridHdl, - const nanovdb::Vec3d& globalVoxelSize, - const nanovdb::Vec3d& globalVoxelOrigin); + GridBatchImpl(nanovdb::GridHandle &&gridHdl, + const nanovdb::Vec3d &globalVoxelSize, const nanovdb::Vec3d &globalVoxelOrigin); ~GridBatchImpl(); // Cannot move make copies of this handle. There is only one owner of the underlying buffer. // This class should only be created and copied through c10::intrusive_ptr - GridBatchImpl& operator=(GridBatchImpl&& other) = delete; - GridBatchImpl(GridBatchImpl&& other) = delete; - GridBatchImpl(GridBatchImpl& other) = delete; - GridBatchImpl& operator=(GridBatchImpl& other) = delete; + GridBatchImpl &operator=(GridBatchImpl &&other) = delete; + GridBatchImpl(GridBatchImpl &&other) = delete; + GridBatchImpl(GridBatchImpl &other) = delete; + GridBatchImpl &operator=(GridBatchImpl &other) = delete; torch::Tensor voxelOffsets(bool ignoreDisabledVoxels) const; @@ -341,89 +371,107 @@ class GridBatchImpl : public torch::CustomClassHolder { torch::Tensor jidx(bool ignoreDisabledVoxels) const; - int64_t totalLeaves() const { + int64_t + totalLeaves() const { return mBatchMetadata.mTotalLeaves; } - int64_t totalVoxels() const { + int64_t + totalVoxels() const { return mBatchMetadata.mTotalVoxels; } int64_t totalEnabledVoxels(bool ignoreDisabledVoxels) const; - int64_t maxVoxelsPerGrid() const { + int64_t + maxVoxelsPerGrid() const { return mBatchMetadata.mMaxVoxels; } - int64_t maxLeavesPerGrid() const { + int64_t + maxLeavesPerGrid() const { return static_cast(mBatchMetadata.mMaxLeafCount); } - int64_t batchSize() const { + int64_t + batchSize() const { return static_cast(mHostGridMetadata.size()); } - uint64_t totalBytes() const { + uint64_t + totalBytes() const { uint64_t sum = 0; - for (const auto& grid : mHostGridMetadata) { + for (const auto &grid: mHostGridMetadata) { sum += grid.mNumBytes; } return sum; } - const nanovdb::GridHandle& nanoGridHandle() const { + const nanovdb::GridHandle & + nanoGridHandle() const { return *mGridHdl; } - bool isMutable() const { + bool + isMutable() const { return mBatchMetadata.mIsMutable; } - const c10::Device device() const { + const c10::Device + device() const { return mGridHdl->buffer().device(); } - bool isEmpty() const { + bool + isEmpty() const { return mGridHdl->buffer().isEmpty(); } - uint32_t numLeaves(int64_t bi) const { + uint32_t + numLeaves(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mNumLeaves; } - int64_t numVoxels(int64_t bi) const { + int64_t + numVoxels(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mNumVoxels; } - int64_t cumVoxels(int64_t bi) const { + int64_t + cumVoxels(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mCumVoxels; } - uint64_t numBytes(int64_t bi) const { + uint64_t + numBytes(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mNumBytes; } - uint64_t cumBytes(int64_t bi) const { + uint64_t + cumBytes(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mCumBytes; } - const VoxelCoordTransform& primalTransform(int64_t bi) const { + const VoxelCoordTransform & + primalTransform(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mPrimalTransform; } - const VoxelCoordTransform& dualTransform(int64_t bi) const { + const VoxelCoordTransform & + dualTransform(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mDualTransform; } - void gridVoxelSizesAndOrigins(std::vector& outVoxelSizes, - std::vector& outVoxelOrigins) const { + void + gridVoxelSizesAndOrigins(std::vector &outVoxelSizes, + std::vector &outVoxelOrigins) const { outVoxelSizes.clear(); outVoxelOrigins.clear(); for (int64_t i = 0; i < batchSize(); ++i) { @@ -432,29 +480,34 @@ class GridBatchImpl : public torch::CustomClassHolder { } } - const nanovdb::CoordBBox& totalBBox() const { + const nanovdb::CoordBBox & + totalBBox() const { return mBatchMetadata.mTotalBBox; } - const nanovdb::CoordBBox& bbox(int64_t bi) const { + const nanovdb::CoordBBox & + bbox(int64_t bi) const { checkNonEmptyGrid(); bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mBBox; } - const nanovdb::CoordBBox dualBbox(int64_t bi) const { - bi = negativeToPositiveIndexWithRangecheck(bi); + const nanovdb::CoordBBox + dualBbox(int64_t bi) const { + bi = negativeToPositiveIndexWithRangecheck(bi); nanovdb::CoordBBox dualBbox = bbox(bi); dualBbox.mCoord[1] += nanovdb::Coord(1, 1, 1); return dualBbox; } - const nanovdb::Vec3d& voxelSize(int64_t bi) const { + const nanovdb::Vec3d & + voxelSize(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].mVoxelSize; } - const nanovdb::Vec3d voxelOrigin(int64_t bi) const { + const nanovdb::Vec3d + voxelOrigin(int64_t bi) const { bi = negativeToPositiveIndexWithRangecheck(bi); return mHostGridMetadata[bi].voxelOrigin(); } @@ -465,51 +518,59 @@ class GridBatchImpl : public torch::CustomClassHolder { c10::intrusive_ptr clone(torch::Device device, bool blocking = false) const; - void checkNonEmptyGrid() const { + void + checkNonEmptyGrid() const { TORCH_CHECK(!isEmpty(), "Empty grid"); } - void checkDevice(const torch::Tensor& t) const { + void + checkDevice(const torch::Tensor &t) const { torch::Device hdlDevice = mGridHdl->buffer().device(); - TORCH_CHECK(hdlDevice == t.device(), "All tensors must be on the same device (" + hdlDevice.str() + - ") as index grid but got " + t.device().str()); + TORCH_CHECK(hdlDevice == t.device(), "All tensors must be on the same device (" + + hdlDevice.str() + ") as index grid but got " + + t.device().str()); } - void checkDevice(const JaggedTensor& t) const { + void + checkDevice(const JaggedTensor &t) const { torch::Device hdlDevice = mGridHdl->buffer().device(); - TORCH_CHECK(hdlDevice == t.device(), "All tensors must be on the same device (" + hdlDevice.str() + - ") as index grid but got " + t.device().str()); + TORCH_CHECK(hdlDevice == t.device(), "All tensors must be on the same device (" + + hdlDevice.str() + ") as index grid but got " + + t.device().str()); } - JaggedTensor jaggedTensor(const torch::Tensor& data, bool ignoreDisabledVoxels) const; + JaggedTensor jaggedTensor(const torch::Tensor &data, bool ignoreDisabledVoxels) const; - void setGlobalPrimalTransform(const VoxelCoordTransform& transform, bool syncToDevice = true); - void setGlobalDualTransform(const VoxelCoordTransform& transform, bool syncToDevice = true); - void setGlobalVoxelSize(const nanovdb::Vec3d& voxelSize, bool syncToDevice = true); - void setGlobalVoxelOrigin(const nanovdb::Vec3d& voxelOrigin, bool syncToDevice = true); - void setGlobalVoxelSizeAndOrigin(const nanovdb::Vec3d& voxelSize, const nanovdb::Vec3d& voxelOrigin, bool syncToDevice = true); + void setGlobalPrimalTransform(const VoxelCoordTransform &transform, bool syncToDevice = true); + void setGlobalDualTransform(const VoxelCoordTransform &transform, bool syncToDevice = true); + void setGlobalVoxelSize(const nanovdb::Vec3d &voxelSize, bool syncToDevice = true); + void setGlobalVoxelOrigin(const nanovdb::Vec3d &voxelOrigin, bool syncToDevice = true); + void setGlobalVoxelSizeAndOrigin(const nanovdb::Vec3d &voxelSize, + const nanovdb::Vec3d &voxelOrigin, bool syncToDevice = true); - void setFineTransformFromCoarseGrid(const GridBatchImpl& coarseBatch, nanovdb::Coord subdivisionFactor); - void setCoarseTransformFromFineGrid(const GridBatchImpl& fineBatch, nanovdb::Coord coarseningFactor); - void setPrimalTransformFromDualGrid(const GridBatchImpl& dualBatch); + void setFineTransformFromCoarseGrid(const GridBatchImpl &coarseBatch, + nanovdb::Coord subdivisionFactor); + void setCoarseTransformFromFineGrid(const GridBatchImpl &fineBatch, + nanovdb::Coord coarseningFactor); + void setPrimalTransformFromDualGrid(const GridBatchImpl &dualBatch); - void setGrid(nanovdb::GridHandle&& gridHdl, - const torch::Tensor listIndices, - const std::vector& voxelSizes, - const std::vector& voxelOrigins, - bool blocking = false); + void setGrid(nanovdb::GridHandle &&gridHdl, const torch::Tensor listIndices, + const std::vector &voxelSizes, + const std::vector &voxelOrigins, bool blocking = false); c10::intrusive_ptr index(int64_t bi) const; c10::intrusive_ptr index(ssize_t start, ssize_t stop, ssize_t step) const; - c10::intrusive_ptr index(const torch::Tensor& indices) const; - c10::intrusive_ptr index(const std::vector& indices) const; - c10::intrusive_ptr index(const std::vector& indices) const; + c10::intrusive_ptr index(const torch::Tensor &indices) const; + c10::intrusive_ptr index(const std::vector &indices) const; + c10::intrusive_ptr index(const std::vector &indices) const; - static c10::intrusive_ptr concatenate(const std::vector>& elements); + static c10::intrusive_ptr + concatenate(const std::vector> &elements); static c10::intrusive_ptr contiguous(c10::intrusive_ptr input); - bool isContiguous() const { + bool + isContiguous() const { return mBatchMetadata.mIsContiguous; } @@ -517,18 +578,17 @@ class GridBatchImpl : public torch::CustomClassHolder { torch::Tensor serialize() const; // Load a CPU int8 tensor into a grid batch handle - static c10::intrusive_ptr deserialize(const torch::Tensor& serialized); + static c10::intrusive_ptr deserialize(const torch::Tensor &serialized); -private: + private: // We're going to version serialization. These are v0 - torch::Tensor serializeV0() const; - static c10::intrusive_ptr deserializeV0(const torch::Tensor& serialized); - + torch::Tensor serializeV0() const; + static c10::intrusive_ptr deserializeV0(const torch::Tensor &serialized); }; -template -using BatchGridAccessor = typename GridBatchImpl::Accessor; +template using BatchGridAccessor = typename GridBatchImpl::Accessor; +} // namespace detail +} // namespace fvdb -} // namespace detail -} // namespace fvdb +#endif // FVDB_DETAIL_GRIDBATCHIMPL_H \ No newline at end of file diff --git a/fvdb/src/detail/TorchDeviceBuffer.cpp b/fvdb/src/detail/TorchDeviceBuffer.cpp index 0293050cbb..f498a66802 100644 --- a/fvdb/src/detail/TorchDeviceBuffer.cpp +++ b/fvdb/src/detail/TorchDeviceBuffer.cpp @@ -6,9 +6,9 @@ #include #include -#include // for cudaMalloc/cudaMallocManaged/cudaFree #include #include +#include // for cudaMalloc/cudaMallocManaged/cudaFree namespace nanovdb { @@ -17,89 +17,98 @@ namespace nanovdb { // TODO: Pass in synchronous option template <> template <> -GridHandle GridHandle::copy(const fvdb::detail::TorchDeviceBuffer& guide) const { +GridHandle +GridHandle::copy( + const fvdb::detail::TorchDeviceBuffer &guide) const { if (mBuffer.isEmpty()) { fvdb::detail::TorchDeviceBuffer retbuf(0, nullptr); retbuf.setDevice(guide.device(), false); - return GridHandle(std::move(retbuf));// return an empty handle + return GridHandle( + std::move(retbuf)); // return an empty handle } - const bool guideIsHost = guide.device().is_cpu(); - const bool iAmHost = mBuffer.device().is_cpu(); + const bool guideIsHost = guide.device().is_cpu(); + const bool iAmHost = mBuffer.device().is_cpu(); const bool guideIsDevice = !guideIsHost; - const bool iAmDevice = !iAmHost; + const bool iAmDevice = !iAmHost; auto buffer = fvdb::detail::TorchDeviceBuffer::create(mBuffer.size(), &guide, guideIsHost); if (iAmHost && guideIsHost) { - std::memcpy(buffer.data(), mBuffer.data(), mBuffer.size()); // deep copy of buffer in CPU RAM + std::memcpy(buffer.data(), mBuffer.data(), + mBuffer.size()); // deep copy of buffer in CPU RAM } else if (iAmHost && guideIsDevice) { at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(guide.device().index()); - cudaCheck(cudaMemcpyAsync(buffer.deviceData(), mBuffer.data(), mBuffer.size(), cudaMemcpyHostToDevice, defaultStream.stream())); + cudaCheck(cudaMemcpyAsync(buffer.deviceData(), mBuffer.data(), mBuffer.size(), + cudaMemcpyHostToDevice, defaultStream.stream())); cudaCheck(cudaStreamSynchronize(defaultStream.stream())); } else if (iAmDevice && guideIsHost) { - at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(mBuffer.device().index()); - cudaCheck(cudaMemcpyAsync(buffer.data(), mBuffer.deviceData(), mBuffer.size(), cudaMemcpyDeviceToHost, defaultStream.stream())); + at::cuda::CUDAStream defaultStream = + at::cuda::getCurrentCUDAStream(mBuffer.device().index()); + cudaCheck(cudaMemcpyAsync(buffer.data(), mBuffer.deviceData(), mBuffer.size(), + cudaMemcpyDeviceToHost, defaultStream.stream())); cudaCheck(cudaStreamSynchronize(defaultStream.stream())); } else if (iAmDevice && guideIsDevice) { if (mBuffer.device() == guide.device()) { - at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(mBuffer.device().index()); - cudaCheck(cudaMemcpyAsync(buffer.deviceData(), mBuffer.deviceData(), mBuffer.size(), cudaMemcpyDeviceToDevice, defaultStream.stream())); + at::cuda::CUDAStream defaultStream = + at::cuda::getCurrentCUDAStream(mBuffer.device().index()); + cudaCheck(cudaMemcpyAsync(buffer.deviceData(), mBuffer.deviceData(), mBuffer.size(), + cudaMemcpyDeviceToDevice, defaultStream.stream())); cudaCheck(cudaStreamSynchronize(defaultStream.stream())); } else { std::unique_ptr buf(new uint8_t[mBuffer.size()]); - at::cuda::CUDAStream mBufferStream = at::cuda::getCurrentCUDAStream(mBuffer.device().index()); - at::cuda::CUDAStream outBufferStream = at::cuda::getCurrentCUDAStream(buffer.device().index()); - cudaCheck(cudaMemcpyAsync(buf.get(), mBuffer.deviceData(), mBuffer.size(), cudaMemcpyDeviceToHost, mBufferStream.stream())); + at::cuda::CUDAStream mBufferStream = + at::cuda::getCurrentCUDAStream(mBuffer.device().index()); + at::cuda::CUDAStream outBufferStream = + at::cuda::getCurrentCUDAStream(buffer.device().index()); + cudaCheck(cudaMemcpyAsync(buf.get(), mBuffer.deviceData(), mBuffer.size(), + cudaMemcpyDeviceToHost, mBufferStream.stream())); cudaCheck(cudaStreamSynchronize(mBufferStream.stream())); - cudaCheck(cudaMemcpyAsync(buffer.deviceData(), buf.get(), mBuffer.size(), cudaMemcpyHostToDevice, outBufferStream.stream())); + cudaCheck(cudaMemcpyAsync(buffer.deviceData(), buf.get(), mBuffer.size(), + cudaMemcpyHostToDevice, outBufferStream.stream())); cudaCheck(cudaStreamSynchronize(outBufferStream.stream())); } } return GridHandle(std::move(buffer)); } -} +} // namespace nanovdb namespace fvdb { namespace detail { -TorchDeviceBuffer::TorchDeviceBuffer(uint64_t size /* = 0*/, void* data /* = nullptr*/, bool host /* = true*/, int deviceIndex /* = -1*/) - : mSize(0) - , mCpuData(nullptr) - , mGpuData(nullptr) - , mDevice(host ? torch::kCPU : torch::kCUDA, deviceIndex) { - - TORCH_CHECK(host || (!host && deviceIndex >= 0), "You must set deviceIndex when setting host to false"); +TorchDeviceBuffer::TorchDeviceBuffer(uint64_t size /* = 0*/, void *data /* = nullptr*/, + bool host /* = true*/, int deviceIndex /* = -1*/) + : mSize(0), mCpuData(nullptr), mGpuData(nullptr), + mDevice(host ? torch::kCPU : torch::kCUDA, deviceIndex) { + TORCH_CHECK(host || (!host && deviceIndex >= 0), + "You must set deviceIndex when setting host to false"); this->init(size, data, host); } - -TorchDeviceBuffer::TorchDeviceBuffer(TorchDeviceBuffer&& other) noexcept - : mSize(other.mSize) - , mCpuData(other.mCpuData) - , mGpuData(other.mGpuData) - , mDevice(other.mDevice) { - other.mSize = 0; +TorchDeviceBuffer::TorchDeviceBuffer(TorchDeviceBuffer &&other) noexcept + : mSize(other.mSize), mCpuData(other.mCpuData), mGpuData(other.mGpuData), + mDevice(other.mDevice) { + other.mSize = 0; other.mCpuData = nullptr; other.mGpuData = nullptr; } - -TorchDeviceBuffer& TorchDeviceBuffer::operator=(TorchDeviceBuffer&& other) noexcept { +TorchDeviceBuffer & +TorchDeviceBuffer::operator=(TorchDeviceBuffer &&other) noexcept { clear(); - mSize = other.mSize; - mCpuData = other.mCpuData; - mGpuData = other.mGpuData; - mDevice = other.mDevice; - other.mSize = 0; + mSize = other.mSize; + mCpuData = other.mCpuData; + mGpuData = other.mGpuData; + mDevice = other.mDevice; + other.mSize = 0; other.mCpuData = nullptr; other.mGpuData = nullptr; return *this; } - -void TorchDeviceBuffer::setDevice(const torch::Device& toDevice, bool blocking) { +void +TorchDeviceBuffer::setDevice(const torch::Device &toDevice, bool blocking) { // Same device, no-op if (toDevice == mDevice) { return; @@ -122,11 +131,10 @@ void TorchDeviceBuffer::setDevice(const torch::Device& toDevice, bool blocking) } else { TORCH_CHECK(false, "Only CPU and CUDA devices are supported") } - } -void TorchDeviceBuffer::toCpu(bool blocking) { - +void +TorchDeviceBuffer::toCpu(bool blocking) { // Empty buffer, set the device and return if (mGpuData == nullptr && mCpuData == nullptr) { mDevice = torch::kCPU; @@ -148,7 +156,8 @@ void TorchDeviceBuffer::toCpu(bool blocking) { mDevice = torch::kCPU; } -void TorchDeviceBuffer::toCuda(torch::Device toDevice, bool blocking) { +void +TorchDeviceBuffer::toCuda(torch::Device toDevice, bool blocking) { TORCH_CHECK(toDevice.is_cuda(), "Invalid device must be a CUDA device"); TORCH_CHECK(toDevice.has_index(), "Invalid device must specify device index"); @@ -168,38 +177,42 @@ void TorchDeviceBuffer::toCuda(torch::Device toDevice, bool blocking) { { c10::cuda::CUDAGuard deviceGuard(mDevice); at::cuda::CUDAStream currentStream = at::cuda::getCurrentCUDAStream(mDevice.index()); - cudaCheck(cudaMemcpyAsync(buf.get(), mGpuData, mSize, cudaMemcpyDeviceToHost, currentStream.stream())); + cudaCheck(cudaMemcpyAsync(buf.get(), mGpuData, mSize, cudaMemcpyDeviceToHost, + currentStream.stream())); cudaCheck(cudaStreamSynchronize(currentStream.stream())); c10::cuda::CUDACachingAllocator::raw_delete(mGpuData); } { c10::cuda::CUDAGuard deviceGuard(toDevice); at::cuda::CUDAStream toStream = at::cuda::getCurrentCUDAStream(toDevice.index()); - mGpuData = reinterpret_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(mSize, toStream.stream())); - cudaCheck(cudaMemcpyAsync(mGpuData, buf.get(), mSize, cudaMemcpyHostToDevice, toStream.stream())); + mGpuData = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(mSize, toStream.stream())); + cudaCheck(cudaMemcpyAsync(mGpuData, buf.get(), mSize, cudaMemcpyHostToDevice, + toStream.stream())); } mDevice = toDevice; - } else if (mDevice.is_cpu()) { // CPU -> CUDA + } else if (mDevice.is_cpu()) { // CPU -> CUDA TORCH_CHECK(toDevice.has_index(), "Invalid device must specify device index"); c10::cuda::CUDAGuard deviceGuard(toDevice); at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(toDevice.index()); - copyHostToDeviceAndFreeHost((void*) stream.stream(), blocking); + copyHostToDeviceAndFreeHost((void *)stream.stream(), blocking); mDevice = toDevice; } else { TORCH_CHECK(false, "This should never happen. File a bug.") } } - -void TorchDeviceBuffer::init(uint64_t size, void* data /* = nullptr */, bool host /* = true */) { - TORCH_CHECK((host && mDevice.is_cpu()) || (!host && mDevice.is_cuda()), "Invalid device for host argument to TorchDeviceBuffer::init"); +void +TorchDeviceBuffer::init(uint64_t size, void *data /* = nullptr */, bool host /* = true */) { + TORCH_CHECK((host && mDevice.is_cpu()) || (!host && mDevice.is_cuda()), + "Invalid device for host argument to TorchDeviceBuffer::init"); if (size == mSize) { // If we already initialized the buffer with the same size, just return return; } - if (mSize >= 0) { // If we're initializing to a different size, need to free the old buffer + if (mSize >= 0) { // If we're initializing to a different size, need to free the old buffer this->clear(); } - if (size == 0) { // If we're initializing to a size of 0, just return + if (size == 0) { // If we're initializing to a size of 0, just return return; } @@ -209,26 +222,30 @@ void TorchDeviceBuffer::init(uint64_t size, void* data /* = nullptr */, bool hos // Initalize on the host if (host) { if (data) { - mCpuData = (uint8_t*) data; + mCpuData = (uint8_t *)data; } else { - // cudaCheck(cudaMallocHost((void**)&mCpuData, size)); // un-managed pinned memory on the host (can be slow to access!). Always 32B aligned - mCpuData = (uint8_t*) malloc(size); + // cudaCheck(cudaMallocHost((void**)&mCpuData, size)); // un-managed pinned memory on + // the host (can be slow to access!). Always 32B aligned + mCpuData = (uint8_t *)malloc(size); } // checkPtr(mCpuData, "failed to allocate host data"); - // Initalize on the device + // Initalize on the device } else { if (data) { - mGpuData = (uint8_t*) data; + mGpuData = (uint8_t *)data; } else { c10::cuda::CUDAGuard deviceGuard(mDevice); at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(mDevice.index()); - mGpuData = reinterpret_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(size, defaultStream.stream())); + mGpuData = + reinterpret_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream( + size, defaultStream.stream())); checkPtr(mGpuData, "failed to allocate device data"); } } } -void TorchDeviceBuffer::clear() { +void +TorchDeviceBuffer::clear() { if (mGpuData) { c10::cuda::CUDACachingAllocator::raw_delete(mGpuData); } @@ -237,30 +254,37 @@ void TorchDeviceBuffer::clear() { free(mCpuData); } mCpuData = mGpuData = nullptr; - mSize = 0; + mSize = 0; } -TorchDeviceBuffer TorchDeviceBuffer::create(uint64_t size, const TorchDeviceBuffer* proto, bool host, void* stream) { - // This is a hack to pass in the device index when creating grids from nanovdb. Since we can't pass arguments - // through nanovdb creation functions, we use a prototype grid to pass in the device index. +TorchDeviceBuffer +TorchDeviceBuffer::create(uint64_t size, const TorchDeviceBuffer *proto, bool host, void *stream) { + // This is a hack to pass in the device index when creating grids from nanovdb. Since we can't + // pass arguments through nanovdb creation functions, we use a prototype grid to pass in the + // device index. int deviceId = -1; if (proto != nullptr) { - TORCH_CHECK((host && proto->device().is_cpu()) || (!host && proto->device().is_cuda()), "Invalid guide buffer device for host argument to TorchDeviceBuffer::create"); + TORCH_CHECK((host && proto->device().is_cpu()) || (!host && proto->device().is_cuda()), + "Invalid guide buffer device for host argument to TorchDeviceBuffer::create"); deviceId = proto->mDevice.index(); } return TorchDeviceBuffer(size, nullptr, host, host ? -1 : deviceId); } -void TorchDeviceBuffer::copyDeviceToHostAndFreeDevice(void* streamPtr /* = 0*/, bool blocking /* = true*/) { +void +TorchDeviceBuffer::copyDeviceToHostAndFreeDevice(void *streamPtr /* = 0*/, + bool blocking /* = true*/) { cudaStream_t stream = reinterpret_cast(streamPtr); TORCH_CHECK(mGpuData, "uninitialized cpu data, this should never happen"); if (mCpuData == nullptr) { // Allocate CPU data if we upload to the device - // cudaCheck(cudaMallocHost((void**)&mCpuData, mSize)); // un-managed pinned memory on the host (can be slow to access!). Always 32B aligned - mCpuData = (uint8_t*) malloc(mSize); + // cudaCheck(cudaMallocHost((void**)&mCpuData, mSize)); // un-managed pinned memory on the + // host (can be slow to access!). Always 32B aligned + mCpuData = (uint8_t *)malloc(mSize); } // Copy to the host buffer - cudaCheck(cudaMemcpyAsync(mCpuData, mGpuData, mSize, cudaMemcpyDeviceToHost, reinterpret_cast(stream))); + cudaCheck(cudaMemcpyAsync(mCpuData, mGpuData, mSize, cudaMemcpyDeviceToHost, + reinterpret_cast(stream))); if (blocking) { cudaCheck(cudaStreamSynchronize(reinterpret_cast(stream))); } @@ -268,12 +292,15 @@ void TorchDeviceBuffer::copyDeviceToHostAndFreeDevice(void* streamPtr /* = 0*/, c10::cuda::CUDACachingAllocator::raw_delete(mGpuData); } -void TorchDeviceBuffer::copyHostToDeviceAndFreeHost(void* streamPtr /* = 0*/, bool blocking /* = true*/) { +void +TorchDeviceBuffer::copyHostToDeviceAndFreeHost(void *streamPtr /* = 0*/, + bool blocking /* = true*/) { cudaStream_t stream = reinterpret_cast(streamPtr); TORCH_CHECK(mCpuData, "uninitialized cpu data, this should never happen"); - if (mGpuData == nullptr) { // Allocate a new CUDA buffer - mGpuData = reinterpret_cast(c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(mSize, stream)); + if (mGpuData == nullptr) { // Allocate a new CUDA buffer + mGpuData = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(mSize, stream)); } // Copy the data to the CUDA buffer cudaCheck(cudaMemcpyAsync(mGpuData, mCpuData, mSize, cudaMemcpyHostToDevice, stream)); diff --git a/fvdb/src/detail/TorchDeviceBuffer.h b/fvdb/src/detail/TorchDeviceBuffer.h index 4a89d8eb59..bda65b912e 100644 --- a/fvdb/src/detail/TorchDeviceBuffer.h +++ b/fvdb/src/detail/TorchDeviceBuffer.h @@ -1,12 +1,12 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once - -#include +#ifndef FVDB_DETAIL_TORCHDEVICEBUFFER_H +#define FVDB_DETAIL_TORCHDEVICEBUFFER_H #include // for BufferTraits +#include namespace fvdb { namespace detail { @@ -16,56 +16,65 @@ namespace detail { /// @brief Simple memory buffer using un-managed pinned host memory when compiled with NVCC. /// Obviously this class is making explicit used of CUDA so replace it with your own memory /// allocator if you are not using CUDA. -/// @note While CUDA's pinned host memory allows for asynchronous memory copy between host and device +/// @note While CUDA's pinned host memory allows for asynchronous memory copy between host and +/// device /// it is significantly slower then cached (un-pinned) memory on the host. -class TorchDeviceBuffer -{ - uint64_t mSize; // total number of bytes for the NanoVDB grid. - uint8_t *mCpuData, *mGpuData; // raw buffer for the NanoVDB grid. +class TorchDeviceBuffer { + uint64_t mSize; // total number of bytes for the NanoVDB grid. + uint8_t *mCpuData, *mGpuData; // raw buffer for the NanoVDB grid. torch::Device mDevice = torch::Device(torch::kCPU); - /// @brief Helper function to move this buffer to the CPU. If the buffer is on the GPU, the GPU memory will be freed. - /// @param blocking If set to false, then memory allocations and copies are performed asynchronously + /// @brief Helper function to move this buffer to the CPU. If the buffer is on the GPU, the GPU + /// memory will be freed. + /// @param blocking If set to false, then memory allocations and copies are performed + /// asynchronously void toCpu(bool blocking); /// @brief Helper function to move this buffer to the specified CUDA device. /// @param device The device on which the buffer should be moved to - /// @param blocking If set to false, then memory allocations and copies are performed asynchronously + /// @param blocking If set to false, then memory allocations and copies are performed + /// asynchronously void toCuda(torch::Device device, bool blocking); - /// @brief Helper function to copy from the host to the device and then free the host buffer. If @c blocking is false the memory copy is asynchronous! + /// @brief Helper function to copy from the host to the device and then free the host buffer. If + /// @c blocking is false the memory copy is asynchronous! /// /// @note This will allocate memory on the GPU/device if it is not already allocated. /// @note The device of this buffer must be CPU - void copyHostToDeviceAndFreeHost(void* stream = 0, bool blocking = true); // Delete + void copyHostToDeviceAndFreeHost(void *stream = 0, bool blocking = true); // Delete - /// @brief Helper function to copy from a device to the host and then free the device buffer. If @c blocking is false the memory copy is asynchronous! + /// @brief Helper function to copy from a device to the host and then free the device buffer. If + /// @c blocking is false the memory copy is asynchronous! /// /// @note This will allocate memory on the host if it is not already allocated. /// @note The device of this buffer must be CPU - void copyDeviceToHostAndFreeDevice(void* stream = 0, bool blocking = true); // Delete + void copyDeviceToHostAndFreeDevice(void *stream = 0, bool blocking = true); // Delete -public: - /// @brief Default constructor initializes a buffer with the given size and device specified by host and deviceIndex. - /// @note This has a weird API because it has to match other buffer classes in nanovdb like nanovdb::HostBuffer + public: + /// @brief Default constructor initializes a buffer with the given size and device specified by + /// host and deviceIndex. + /// @note This has a weird API because it has to match other buffer classes in nanovdb like + /// nanovdb::HostBuffer /// @param size The size (in bytes to allocate for this buffer) /// @param data If non-null, the data pointer to use for this buffer /// @param host If true buffer is initialized only on the host/CPU, else on the device/GPU - /// @param deviceIndex If host is false, then this specifies the device index to use for the buffer + /// @param deviceIndex If host is false, then this specifies the device index to use for the + /// buffer /// (must be set to a nonzero value when host is false) - TorchDeviceBuffer(uint64_t size = 0, void* data = nullptr, bool host = true, int deviceIndex = -1); + TorchDeviceBuffer(uint64_t size = 0, void *data = nullptr, bool host = true, + int deviceIndex = -1); /// @brief Disallow copy-construction - TorchDeviceBuffer(const TorchDeviceBuffer&) = delete; + TorchDeviceBuffer(const TorchDeviceBuffer &) = delete; /// @brief Move copy-constructor - TorchDeviceBuffer(TorchDeviceBuffer&& other) noexcept; + TorchDeviceBuffer(TorchDeviceBuffer &&other) noexcept; /// @brief Disallow copy assignment operation - TorchDeviceBuffer& operator=(const TorchDeviceBuffer&) = delete; + TorchDeviceBuffer &operator=(const TorchDeviceBuffer &) = delete; /// @brief Move copy assignment operation - TorchDeviceBuffer& operator=(TorchDeviceBuffer&& other) noexcept; + TorchDeviceBuffer &operator=(TorchDeviceBuffer &&other) noexcept; /// @brief Destructor frees memory on both the host and device ~TorchDeviceBuffer() { this->clear(); }; @@ -76,56 +85,74 @@ class TorchDeviceBuffer /// The selected device will be this->device which must be a cuda device /// @note All existing buffers are first cleared /// @warning size is expected to be non-zero. Use clear() clear buffer! - void init(uint64_t size, void* data = nullptr, bool host = true); + void init(uint64_t size, void *data = nullptr, bool host = true); /// @brief Set the device of this buffer and shuffle data around accordingly /// @param device The device to be used by this buffer (if CUDA, must specify a device index) /// @param blocking If true the memory copy is synchronous, else asynchronous - void setDevice(const torch::Device& device, bool blocking); + void setDevice(const torch::Device &device, bool blocking); /// @brief Returns the device used by this buffer /// @return The device used by this buffer - const torch::Device& device() const { + const torch::Device & + device() const { return mDevice; } /// @brief Retuns a pointer to the raw memory buffer managed by this allocator. /// @warning Note that the pointer can be NULL is the allocator was not initialized! - uint8_t* data() const { return mCpuData; } - uint8_t* deviceData() const { return mGpuData; } + uint8_t * + data() const { + return mCpuData; + } + uint8_t * + deviceData() const { + return mGpuData; + } /// @brief Returns the size in bytes of the raw memory buffer managed by this allocator. - uint64_t size() const { return mSize; } + uint64_t + size() const { + return mSize; + } /// @brief Returns true if this allocator is empty, i.e. has no allocated memory - bool empty() const { return mSize == 0 && mCpuData == nullptr && mGpuData == nullptr; } - bool isEmpty() const { return empty(); } + bool + empty() const { + return mSize == 0 && mCpuData == nullptr && mGpuData == nullptr; + } + bool + isEmpty() const { + return empty(); + } /// @brief De-allocate all memory managed by this allocator and set all pointer to NULL void clear(); /// @brief Static factory method that return an instance of this buffer /// @param size byte size of buffer to be initialized - /// @param guide this argument is there to match the signature of the other create() methods (e.g. nanovdb::HostBuffer) + /// @param guide this argument is there to match the signature of the other create() methods + /// (e.g. nanovdb::HostBuffer) /// and to provide a way to specify the device to be used for the buffer. - /// i.e. if guide is non-null, the created buffer will be on the same device as guide! - /// note you must also set the host argument to match the guide buffer device - /// @param host If true buffer is initialized only on the host/CPU, else on the device/GPU. If you passed in a guide + /// i.e. if guide is non-null, the created buffer will be on the same device as + /// guide! note you must also set the host argument to match the guide buffer + /// device + /// @param host If true buffer is initialized only on the host/CPU, else on the device/GPU. If + /// you passed in a guide /// buffer, then this must match the device of the guide buffer! /// @return An instance of this class using move semantics - static TorchDeviceBuffer create(uint64_t size, const TorchDeviceBuffer* guide = nullptr, bool host = true, void* stream = nullptr); + static TorchDeviceBuffer create(uint64_t size, const TorchDeviceBuffer *guide = nullptr, + bool host = true, void *stream = nullptr); }; // TorchDeviceBuffer class - } // namespace detail } // namespace fvdb - namespace nanovdb { - template<> - struct BufferTraits - { - static const bool hasDeviceDual = true; - }; +template <> struct BufferTraits { + static const bool hasDeviceDual = true; +}; } // namespace nanovdb + +#endif // FVDB_DETAIL_TORCHDEVICEBUFFER_H \ No newline at end of file diff --git a/fvdb/src/detail/TypesImpl.h b/fvdb/src/detail/TypesImpl.h index eb957f713d..028fa49bfb 100644 --- a/fvdb/src/detail/TypesImpl.h +++ b/fvdb/src/detail/TypesImpl.h @@ -1,41 +1,47 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include +#ifndef FVDB_DETAIL_TYPESIMPL_H +#define FVDB_DETAIL_TYPESIMPL_H + +#include +#include namespace fvdb { namespace detail { -template -class Vec3dImpl { +template class Vec3dImpl { nanovdb::Vec3d mValue; - bool mWasScalar = false; + bool mWasScalar = false; -public: + public: static constexpr bool SupportsScalarCast = AllowScalar; - using ValueType = nanovdb::Vec3d::ValueType; + using ValueType = nanovdb::Vec3d::ValueType; Vec3dImpl() : mValue(0.0, 0.0, 0.0) {} - Vec3dImpl(const nanovdb::Vec3d& coord) : mValue(coord) {} - Vec3dImpl(const nanovdb::Vec3f& coord) : mValue(coord[0], coord[1], coord[2]) {} - Vec3dImpl(const torch::Tensor& coordTensor) : mValue(fvdb::tensorToVec3d(coordTensor)) {} - template - Vec3dImpl(const std::vector& coordVec) { - static_assert(std::is_arithmetic::value, "Coord3D can only be constructed from integral types"); - TORCH_CHECK_VALUE(coordVec.size() == 3, "Coord3D can only be constructed from a vector of size 3"); + Vec3dImpl(const nanovdb::Vec3d &coord) : mValue(coord) {} + Vec3dImpl(const nanovdb::Vec3f &coord) : mValue(coord[0], coord[1], coord[2]) {} + Vec3dImpl(const torch::Tensor &coordTensor) : mValue(fvdb::tensorToVec3d(coordTensor)) {} + template Vec3dImpl(const std::vector &coordVec) { + static_assert(std::is_arithmetic::value, + "Coord3D can only be constructed from integral types"); + TORCH_CHECK_VALUE(coordVec.size() == 3, + "Coord3D can only be constructed from a vector of size 3"); mValue = nanovdb::Vec3d(coordVec[0], coordVec[1], coordVec[2]); } - template - Vec3dImpl(T scalar) { - static_assert(AllowScalar, "Vec3d can only be constructed from a scalar if AllowScalar is true"); - static_assert(std::is_arithmetic::value, "Vec3d can only be constructed from numeric types"); - mValue = nanovdb::Vec3d(scalar, scalar, scalar); + template Vec3dImpl(T scalar) { + static_assert(AllowScalar, + "Vec3d can only be constructed from a scalar if AllowScalar is true"); + static_assert(std::is_arithmetic::value, + "Vec3d can only be constructed from numeric types"); + mValue = nanovdb::Vec3d(scalar, scalar, scalar); mWasScalar = true; } - const nanovdb::Vec3d& value() const { + const nanovdb::Vec3d & + value() const { if constexpr (!AllowScalar) { TORCH_CHECK_VALUE(!mWasScalar, "Expected a vector, but got a scalar"); } @@ -43,83 +49,88 @@ class Vec3dImpl { } }; - -template -class Coord3Impl { +template class Coord3Impl { nanovdb::Coord mValue; - bool mWasScalar = false; + bool mWasScalar = false; -public: + public: static constexpr bool SupportsScalarCast = AllowScalar; - using ValueType = nanovdb::Coord::ValueType; + using ValueType = nanovdb::Coord::ValueType; Coord3Impl() : mValue(0, 0, 0) {} - Coord3Impl(const nanovdb::Coord& coord) : mValue(coord) {} - Coord3Impl(const nanovdb::Vec3i& coord) : mValue(coord[0], coord[1], coord[2]) {} - Coord3Impl(const nanovdb::Vec3u& coord) : mValue(coord[0], coord[1], coord[2]) {} - Coord3Impl(const torch::Tensor& coordTensor) : mValue(fvdb::tensorToCoord(coordTensor)) {} - template - Coord3Impl(const std::vector& coordVec) { - static_assert(std::is_integral::value, "Coord can only be constructed from integral types"); - TORCH_CHECK_VALUE(coordVec.size() == 3, "Coord can only be constructed from a vector of size 3"); + Coord3Impl(const nanovdb::Coord &coord) : mValue(coord) {} + Coord3Impl(const nanovdb::Vec3i &coord) : mValue(coord[0], coord[1], coord[2]) {} + Coord3Impl(const nanovdb::Vec3u &coord) : mValue(coord[0], coord[1], coord[2]) {} + Coord3Impl(const torch::Tensor &coordTensor) : mValue(fvdb::tensorToCoord(coordTensor)) {} + template Coord3Impl(const std::vector &coordVec) { + static_assert(std::is_integral::value, + "Coord can only be constructed from integral types"); + TORCH_CHECK_VALUE(coordVec.size() == 3, + "Coord can only be constructed from a vector of size 3"); mValue = nanovdb::Coord(coordVec[0], coordVec[1], coordVec[2]); } - template - Coord3Impl(T scalar) { - static_assert(AllowScalar, "Coord3 can only be constructed from a scalar if AllowScalar is true"); - static_assert(std::is_integral::value, "Coord3D can only be constructed from integral types"); - mValue = nanovdb::Coord(scalar, scalar, scalar); + template Coord3Impl(T scalar) { + static_assert(AllowScalar, + "Coord3 can only be constructed from a scalar if AllowScalar is true"); + static_assert(std::is_integral::value, + "Coord3D can only be constructed from integral types"); + mValue = nanovdb::Coord(scalar, scalar, scalar); mWasScalar = true; } - const nanovdb::Coord& value() const { + const nanovdb::Coord & + value() const { if constexpr (!AllowScalar) { TORCH_CHECK_VALUE(!mWasScalar, "Expected a vector, but got a scalar"); } return mValue; } - std::string toString() const { - return "{" + std::to_string(mValue[0]) + ", " + std::to_string(mValue[1]) + ", " + std::to_string(mValue[2]) + "}"; + std::string + toString() const { + return "{" + std::to_string(mValue[0]) + ", " + std::to_string(mValue[1]) + ", " + + std::to_string(mValue[2]) + "}"; } }; - -template -class Coord4Impl { +template class Coord4Impl { nanovdb::Vec4i mValue; - static_assert(!AllowScalar, "Coord does not allow scalar conversion. We may wish to change this in the future."); + static_assert( + !AllowScalar, + "Coord does not allow scalar conversion. We may wish to change this in the future."); -public: + public: static constexpr bool SupportsScalarCast = AllowScalar; - using ValueType = nanovdb::Coord::ValueType; + using ValueType = nanovdb::Coord::ValueType; Coord4Impl() : mValue(0, 0, 0, 0) {} - Coord4Impl(const nanovdb::Vec4i& coord) : mValue(coord) {} - Coord4Impl(const torch::Tensor& coordTensor) : mValue(fvdb::tensorToCoord4(coordTensor)) {} - template - Coord4Impl(const std::vector& coordVec) { - static_assert(std::is_integral::value, "Vec4i can only be constructed from integral types"); - TORCH_CHECK_VALUE(coordVec.size() == 4, "Vec4i can only be constructed from a vector of size 4"); + Coord4Impl(const nanovdb::Vec4i &coord) : mValue(coord) {} + Coord4Impl(const torch::Tensor &coordTensor) : mValue(fvdb::tensorToCoord4(coordTensor)) {} + template Coord4Impl(const std::vector &coordVec) { + static_assert(std::is_integral::value, + "Vec4i can only be constructed from integral types"); + TORCH_CHECK_VALUE(coordVec.size() == 4, + "Vec4i can only be constructed from a vector of size 4"); mValue = nanovdb::Vec4i(coordVec[0], coordVec[1], coordVec[2], coordVec[3]); } - const nanovdb::Vec4i& value() const { + const nanovdb::Vec4i & + value() const { return mValue; } }; - -template -class Vec3BatchImpl { -private: +template class Vec3BatchImpl { + private: std::vector mValue; - bool isScalar = false; - bool isSingle = false; + bool isScalar = false; + bool isSingle = false; - std::vector repeatIt(int64_t batchSize, bool onlyPositive) const { + std::vector + repeatIt(int64_t batchSize, bool onlyPositive) const { if (onlyPositive) { - TORCH_CHECK_VALUE(mValue[0][0] > 0 && mValue[0][1] > 0 && mValue[0][2] > 0, "Expected all coordinates to be positive"); + TORCH_CHECK_VALUE(mValue[0][0] > 0 && mValue[0][1] > 0 && mValue[0][2] > 0, + "Expected all coordinates to be positive"); } std::vector result; result.reserve(batchSize); @@ -129,16 +140,16 @@ class Vec3BatchImpl { return result; } -public: - static constexpr bool SupportsBroadcast = AllowBroadcast; + public: + static constexpr bool SupportsBroadcast = AllowBroadcast; static constexpr bool SupportsScalarCast = AllowScalar; using ValueType = typename VecT::ValueType; - using VecType = VecT; + using VecType = VecT; Vec3BatchImpl() : mValue() {} - Vec3BatchImpl(const torch::Tensor& tensor) { + Vec3BatchImpl(const torch::Tensor &tensor) { torch::Tensor squeezed = tensor.squeeze().cpu(); if constexpr (AllowScalar) { @@ -151,60 +162,74 @@ class Vec3BatchImpl { if constexpr (AllowBroadcast) { if (squeezed.numel() == 3) { - mValue.push_back(VecT(squeezed[0].item(), squeezed[1].item(), squeezed[2].item())); + mValue.push_back(VecT(squeezed[0].item(), squeezed[1].item(), + squeezed[2].item())); isSingle = true; return; } } - TORCH_CHECK_VALUE(squeezed.dim() == 2, "Expected a batch of 3D coordinates with size [B, 3]"); - TORCH_CHECK_VALUE(squeezed.size(1) == 3, "Expected a batch of 3D coordinates with size [B, 3]"); + TORCH_CHECK_VALUE(squeezed.dim() == 2, + "Expected a batch of 3D coordinates with size [B, 3]"); + TORCH_CHECK_VALUE(squeezed.size(1) == 3, + "Expected a batch of 3D coordinates with size [B, 3]"); mValue.reserve(squeezed.size(0)); for (int i = 0; i < squeezed.size(0); ++i) { - mValue.push_back(VecT(squeezed[i][0].item(), squeezed[i][1].item(), squeezed[i][2].item())); + mValue.push_back(VecT(squeezed[i][0].item(), squeezed[i][1].item(), + squeezed[i][2].item())); } } - template - Vec3BatchImpl(const std::vector>& vectorData) { + template Vec3BatchImpl(const std::vector> &vectorData) { if constexpr (nanovdb::util::is_same::value) { - static_assert(std::is_integral::value, "Vec3Batch can only be constructed from integral types"); + static_assert(std::is_integral::value, + "Vec3Batch can only be constructed from integral types"); } - static_assert(std::is_arithmetic::value, "Vec3Batch can only be constructed from numeric types"); + static_assert(std::is_arithmetic::value, + "Vec3Batch can only be constructed from numeric types"); size_t batchSize = vectorData.size(); TORCH_CHECK_VALUE(batchSize > 0, "Expected a batch of coordinates with size [B, 3]"); for (size_t i = 0; i < batchSize; i += 1) { - TORCH_CHECK_VALUE(vectorData[i].size() == 3, "Expected a batch of 3D coordinates with size [B, 3]"); + TORCH_CHECK_VALUE(vectorData[i].size() == 3, + "Expected a batch of 3D coordinates with size [B, 3]"); mValue.push_back(VecT(vectorData[i][0], vectorData[i][1], vectorData[i][2])); } } - template - Vec3BatchImpl(const T& scalar) { - static_assert(AllowScalar, "Cannot construct Vec3Batch from scalar when AllowScalar is set to false"); + template Vec3BatchImpl(const T &scalar) { + static_assert(AllowScalar, + "Cannot construct Vec3Batch from scalar when AllowScalar is set to false"); if constexpr (nanovdb::util::is_same::value) { - static_assert(std::is_integral::value, "Vec3Batch can only be constructed from integral types"); + static_assert(std::is_integral::value, + "Vec3Batch can only be constructed from integral types"); } - static_assert(std::is_arithmetic::value, "Vec3Batch can only be constructed from numeric types"); - mValue.push_back(VecT((double) scalar)); + static_assert(std::is_arithmetic::value, + "Vec3Batch can only be constructed from numeric types"); + mValue.push_back(VecT((double)scalar)); isScalar = true; } - template - Vec3BatchImpl(const std::vector& vec) { - static_assert(AllowBroadcast, "Cannot construct Vec3Batch from single vector when AllowBroadcast is set to false"); + template Vec3BatchImpl(const std::vector &vec) { + static_assert( + AllowBroadcast, + "Cannot construct Vec3Batch from single vector when AllowBroadcast is set to false"); if constexpr (nanovdb::util::is_same::value) { - static_assert(std::is_integral::value, "Vec3Batch can only be constructed from integral types"); + static_assert(std::is_integral::value, + "Vec3Batch can only be constructed from integral types"); } - static_assert(std::is_arithmetic::value, "Vec3Batch can only be constructed from numeric types"); - TORCH_CHECK_VALUE(vec.size() == 3, "Expected a batch of 3D coordinates with size [B, 3] or a single coordinate of size [3,]"); + static_assert(std::is_arithmetic::value, + "Vec3Batch can only be constructed from numeric types"); + TORCH_CHECK_VALUE( + vec.size() == 3, + "Expected a batch of 3D coordinates with size [B, 3] or a single coordinate of size [3,]"); mValue.push_back(VecT(vec[0], vec[1], vec[2])); isSingle = true; } - std::vector value(uint64_t batchSize, bool onlyPositive, std::string name) const { + std::vector + value(uint64_t batchSize, bool onlyPositive, std::string name) const { TORCH_CHECK(batchSize > 0, "Can't request empty batch of coordinates"); TORCH_CHECK(mValue.size() > 0, "Can't request empty batch of coordinates"); @@ -212,7 +237,6 @@ class Vec3BatchImpl { if (isScalar) { return repeatIt(batchSize, onlyPositive); } - } if constexpr (AllowBroadcast) { if (isSingle && batchSize != 1) { @@ -222,20 +246,24 @@ class Vec3BatchImpl { if (onlyPositive) { for (size_t i = 0; i < mValue.size(); ++i) { - TORCH_CHECK_VALUE(mValue[i][0] > 0 && mValue[i][1] > 0 && mValue[i][2] > 0, "Expected all coordinates of " + name + " to be positive"); + TORCH_CHECK_VALUE(mValue[i][0] > 0 && mValue[i][1] > 0 && mValue[i][2] > 0, + "Expected all coordinates of " + name + " to be positive"); } } - TORCH_CHECK_VALUE(batchSize == mValue.size(), "Expected " + name + " batch of 3D coordinates to have size [" + std::to_string(batchSize) + ", 3]"); + TORCH_CHECK_VALUE(batchSize == mValue.size(), + "Expected " + name + " batch of 3D coordinates to have size [" + + std::to_string(batchSize) + ", 3]"); return mValue; } - torch::Tensor tensorValue(uint64_t batchSize, bool onlyPositive, std::string name) const { + torch::Tensor + tensorValue(uint64_t batchSize, bool onlyPositive, std::string name) const { std::vector vec = value(batchSize, onlyPositive, name); if constexpr (nanovdb::util::is_same::value) { - return torch::from_blob(vec.data(), { (int64_t) vec.size(), 3 }, torch::kInt32).clone(); + return torch::from_blob(vec.data(), { (int64_t)vec.size(), 3 }, torch::kInt32).clone(); } else if constexpr (nanovdb::util::is_same::value) { - return torch::from_blob(vec.data(), { (int64_t) vec.size(), 3 }, torch::kDouble).clone(); + return torch::from_blob(vec.data(), { (int64_t)vec.size(), 3 }, torch::kDouble).clone(); } else { static_assert("Only Coord and Vec3d are supported for now"); } @@ -244,3 +272,5 @@ class Vec3BatchImpl { } // namespace detail } // namespace fvdb + +#endif // FVDB_DETAIL_TYPESIMPL_H \ No newline at end of file diff --git a/fvdb/src/detail/VoxelCoordTransform.h b/fvdb/src/detail/VoxelCoordTransform.h index bf69328c5f..f98859732a 100644 --- a/fvdb/src/detail/VoxelCoordTransform.h +++ b/fvdb/src/detail/VoxelCoordTransform.h @@ -1,86 +1,93 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_VOXELCOORDTRANSFORM_H +#define FVDB_DETAIL_VOXELCOORDTRANSFORM_H + +#include "utils/Utils.h" #include #include -#include "detail/utils/Utils.h" - - namespace fvdb { namespace detail { /// @brief A class representing the the transformation from world space (xyz) to voxel space (ijk) -/// its inverse, and gradient. It can be applied to points, vectors and rays. It stores the transformation in -/// float16, float32 and float64 precision, using the appropriate representation depending on the -/// input types. +/// its inverse, and gradient. It can be applied to points, vectors and rays. It stores the +/// transformation in float16, float32 and float64 precision, using the appropriate +/// representation depending on the input types. /// @note This class currently only supports translation and non-uniform scaling transformations. struct VoxelCoordTransform { - /// @brief Construct a voxel coordinate transform with identity transformation /// @return The voxel coordinate transform __hostdev__ VoxelCoordTransform() {}; - /// @brief Construct a voxel coordinate transform that scales and translates each input point when mappint to voxel coordinates + /// @brief Construct a voxel coordinate transform that scales and translates each input point + /// when mappint to voxel coordinates /// @param scale The 3D scale to apply to each input point /// @param translate The 3D translation to apply to each input point - __hostdev__ VoxelCoordTransform(const nanovdb::Vec3d& scale, const nanovdb::Vec3d& translate) : mTransform(scale, translate) {} + __hostdev__ + VoxelCoordTransform(const nanovdb::Vec3d &scale, const nanovdb::Vec3d &translate) + : mTransform(scale, translate) {} /// @brief Apply the gradient of the transformation (from xyz to ijk) to an input point xyz /// @tparam ScalarT The scalar type of the input point xyz /// @param xyz The input point to apply the gradient to /// @return The gradient dT/dxyz of the transformation applied to xyz template - __hostdev__ nanovdb::math::Vec3 applyGrad(const nanovdb::math::Vec3& xyz) const { + __hostdev__ nanovdb::math::Vec3 + applyGrad(const nanovdb::math::Vec3 &xyz) const { static_assert(is_floating_point_or_half::value); return mTransform.scale(); } - /// @brief Apply the gradient of the transformation (from xyz to ijk) to an input point (x, y, z) + /// @brief Apply the gradient of the transformation (from xyz to ijk) to an input point (x, y, + /// z) /// @tparam ScalarT The scalar type of the input point (x, y, z) /// @param x The x component of the input point to apply the gradient to /// @param y The y component of the input point to apply the gradient to /// @param z The z component of the input point to apply the gradient to /// @return The gradient dT/d(x, y, z) of the transformation applied to (x, y, z) template - __hostdev__ nanovdb::math::Vec3 applyGrad(ScalarT x, ScalarT y, ScalarT z) const { + __hostdev__ nanovdb::math::Vec3 + applyGrad(ScalarT x, ScalarT y, ScalarT z) const { static_assert(is_floating_point_or_half::value); return mTransform.scale(); } - - /// @brief Apply the gradient of the inverse transformation (from ijk to xyz) to an input coordinate ijk + /// @brief Apply the gradient of the inverse transformation (from ijk to xyz) to an input + /// coordinate ijk /// @tparam ScalarT The scalar type of the input coordinate ijk /// @param ijk The input point to apply the gradient to /// @return The gradient dT^-1/dijk of the inverse transformation applied to ijk template - __hostdev__ nanovdb::math::Vec3 applyInvGrad(const nanovdb::math::Vec3& ijk) const { + __hostdev__ nanovdb::math::Vec3 + applyInvGrad(const nanovdb::math::Vec3 &ijk) const { static_assert(is_floating_point_or_half::value); return nanovdb::math::Vec3(1.0, 1.0, 1.0) / mTransform.scale(); } - /// @brief Apply the gradient of the inverse transformation (from ijk to xyz) to an input coordinate (i, j, k) + /// @brief Apply the gradient of the inverse transformation (from ijk to xyz) to an input + /// coordinate (i, j, k) /// @tparam ScalarT ScalarT The scalar type of the input coordinate (i, j, k) /// @param i The i component of the input coordinate to apply the gradient to /// @param j The j component of the input coordinate to apply the gradient to /// @param k The k component of the input coordinate to apply the gradient to /// @return The gradient dT^-1/d(i, j, k) of the inverse transformation applied to (i, j, k) template - __hostdev__ nanovdb::math::Vec3 applyInvGrad(ScalarT i, ScalarT j, ScalarT k) const { + __hostdev__ nanovdb::math::Vec3 + applyInvGrad(ScalarT i, ScalarT j, ScalarT k) const { static_assert(is_floating_point_or_half::value); return nanovdb::math::Vec3(1.0, 1.0, 1.0) / mTransform.scale(); } - - /// @brief Apply the transformation (from xyz to ijk) to an input point xyz /// @tparam ScalarT The scalar type of the input point xyz /// @param xyz The input point to apply the transformation to /// @return The transformed point T(xyz) template - __hostdev__ nanovdb::math::Vec3 apply(const nanovdb::math::Vec3& xyz) const { + __hostdev__ nanovdb::math::Vec3 + apply(const nanovdb::math::Vec3 &xyz) const { static_assert(is_floating_point_or_half::value); return xyz * mTransform.scale() + mTransform.translate(); } @@ -92,31 +99,34 @@ struct VoxelCoordTransform { /// @param z The z component of the input point to apply the transformation to /// @return The transformed point T(x, y, z) template - __hostdev__ nanovdb::math::Vec3 apply(ScalarT x, ScalarT y, ScalarT z) const { + __hostdev__ nanovdb::math::Vec3 + apply(ScalarT x, ScalarT y, ScalarT z) const { static_assert(is_floating_point_or_half::value); const nanovdb::math::Vec3 xyz(x, y, z); return xyz * mTransform.scale() + mTransform.translate(); } - /// @brief Apply the transformation (from xyz to ijk) to an input point xyz which is of an indexable type + /// @brief Apply the transformation (from xyz to ijk) to an input point xyz which is of an + /// indexable type /// @tparam ScalarT The scalar type of the input point xyz - /// @param xyz The input point to apply the transformation to (must support indexing with [0], [1], [2]) + /// @param xyz The input point to apply the transformation to (must support indexing with [0], + /// [1], [2]) /// @return The transformed point T(xyz) template - __hostdev__ nanovdb::math::Vec3 apply(const InVec3T& xyz) const { + __hostdev__ nanovdb::math::Vec3 + apply(const InVec3T &xyz) const { static_assert(is_floating_point_or_half::value); const nanovdb::math::Vec3 pt(xyz[0], xyz[1], xyz[2]); return pt * mTransform.scale() + mTransform.translate(); } - - /// @brief Apply the inverse transformation (from ijk to xyz) to an input coordinate ijk /// @tparam ScalarT The scalar type of the input coordinate ijk /// @param ijk The input coordinate to apply the inverse transformation to /// @return The transformed coordinate T^-1(ijk) template - __hostdev__ nanovdb::math::Vec3 applyInv(const nanovdb::math::Vec3& ijk) const { + __hostdev__ nanovdb::math::Vec3 + applyInv(const nanovdb::math::Vec3 &ijk) const { static_assert(is_floating_point_or_half::value); return (ijk - mTransform.translate()) / mTransform.scale(); } @@ -128,31 +138,34 @@ struct VoxelCoordTransform { /// @param k The k component of the input coordinate to apply the inverse transformation to /// @return The transformed coordinate T^-1(i, j, k) template - __hostdev__ nanovdb::math::Vec3 applyInv(ScalarT i, ScalarT j, ScalarT k) const { + __hostdev__ nanovdb::math::Vec3 + applyInv(ScalarT i, ScalarT j, ScalarT k) const { static_assert(is_floating_point_or_half::value); const nanovdb::math::Vec3 ijk(i, j, k); return (ijk - mTransform.translate()) / mTransform.scale(); } - /// @brief Apply the inverse transformation (from ijk to xyz) to an input coordinate ijk which is of an indexable type + /// @brief Apply the inverse transformation (from ijk to xyz) to an input coordinate ijk which + /// is of an indexable type /// @tparam ScalarT The scalar type of the input coordinate ijk - /// @param ijk The input coordinate to apply the inverse transformation to (must support indexing with [0], [1], [2]) + /// @param ijk The input coordinate to apply the inverse transformation to (must support + /// indexing with [0], [1], [2]) /// @return The transformed coordinate T^-1(ijk) template - __hostdev__ nanovdb::math::Vec3 applyInv(const InVec3T& ijk) const { + __hostdev__ nanovdb::math::Vec3 + applyInv(const InVec3T &ijk) const { static_assert(is_floating_point_or_half::value); const nanovdb::math::Vec3 coord(ijk[0], ijk[1], ijk[2]); return (coord - mTransform.translate()) / mTransform.scale(); } - - /// @brief Apply the transformation (from xyz to ijk) to an input ray /// @tparam ScalarT The scalar type of the input ray /// @param ray The input ray to apply the transformation to /// @return The transformed ray T(ray) template - __hostdev__ nanovdb::math::Ray applyToRay(nanovdb::math::Ray ray) const { + __hostdev__ nanovdb::math::Ray + applyToRay(nanovdb::math::Ray ray) const { static_assert(is_floating_point_or_half::value); const nanovdb::math::Vec3 oVox = apply(ray.eye()); const nanovdb::math::Vec3 dVox = ray.dir() * mTransform.scale(); @@ -171,13 +184,14 @@ struct VoxelCoordTransform { /// @param t1 The maximum ray time parameter /// @return The transformed ray T(ray) template - __hostdev__ nanovdb::math::Ray applyToRay(ScalarT rayOx, ScalarT rayOy, ScalarT rayOz, - ScalarT rayDx, ScalarT rayDy, ScalarT rayDz, - ScalarT t0 = static_cast(0), - ScalarT t1 = std::numeric_limits::infinity()) const { + __hostdev__ nanovdb::math::Ray + applyToRay(ScalarT rayOx, ScalarT rayOy, ScalarT rayOz, ScalarT rayDx, ScalarT rayDy, + ScalarT rayDz, ScalarT t0 = static_cast(0), + ScalarT t1 = std::numeric_limits::infinity()) const { static_assert(is_floating_point_or_half::value); const nanovdb::math::Vec3 oVox = apply(rayOx, rayOy, rayOz); - const nanovdb::math::Vec3 dVox = nanovdb::math::Vec3(rayDx, rayDy, rayDz) * mTransform.scale(); + const nanovdb::math::Vec3 dVox = + nanovdb::math::Vec3(rayDx, rayDy, rayDz) * mTransform.scale(); return nanovdb::math::Ray(oVox, dVox, t0, t1); } @@ -185,7 +199,8 @@ struct VoxelCoordTransform { /// @tparam ScalarT The scalar type to return the scale in /// @return The scale component of this transformation template - __hostdev__ nanovdb::math::Vec3 scale() const { + __hostdev__ nanovdb::math::Vec3 + scale() const { return mTransform.scale(); } @@ -193,105 +208,124 @@ struct VoxelCoordTransform { /// @tparam ScalarT The scalar type to return the translation in /// @return The translation component of this transformation template - __hostdev__ nanovdb::math::Vec3 translate() const { + __hostdev__ nanovdb::math::Vec3 + translate() const { return mTransform.translate(); } - -private: + private: /// @brief A struct representing the transformation from world space (xyz) to voxel space (ijk) - /// in float16, float32, and float64. You can access the scale and translation in any of these - /// by calling methods with the appropriate template paramter + /// in float16, float32, and float64. You can access the scale and translation in any of + /// these by calling methods with the appropriate template paramter struct Transform { /// @brief Construct an identity transformation __hostdev__ Transform() {}; /// @brief Construct a transformation that scales and translates each input point - __hostdev__ Transform(nanovdb::Vec3d scale, const nanovdb::Vec3d& translate) : - mScaleh(nanovdb::math::Vec3(c10::Half(float(scale[0])), c10::Half(float(scale[1])), c10::Half(float(scale[2])))), - mTranslateh(nanovdb::math::Vec3(c10::Half(float(translate[0])), c10::Half(float(translate[1])), c10::Half(float(translate[2])))), - mScalef(nanovdb::Vec3f(scale[0], scale[1], scale[2])), - mTranslatef(nanovdb::Vec3f(translate[0], translate[1], translate[2])), - mScaled(scale), - mTranslated(translate) {} - - nanovdb::math::Vec3 mScaleh = nanovdb::math::Vec3(c10::Half(1.0f), c10::Half(1.0f), c10::Half(1.0f)); - nanovdb::math::Vec3 mTranslateh = nanovdb::math::Vec3(c10::Half(0.0f), c10::Half(0.0f), c10::Half(0.0f)); - nanovdb::Vec3f mScalef = nanovdb::Vec3f(1.0f, 1.0f, 1.0f); + __hostdev__ + Transform(nanovdb::Vec3d scale, const nanovdb::Vec3d &translate) + : mScaleh(nanovdb::math::Vec3(c10::Half(float(scale[0])), + c10::Half(float(scale[1])), + c10::Half(float(scale[2])))), + mTranslateh(nanovdb::math::Vec3(c10::Half(float(translate[0])), + c10::Half(float(translate[1])), + c10::Half(float(translate[2])))), + mScalef(nanovdb::Vec3f(scale[0], scale[1], scale[2])), + mTranslatef(nanovdb::Vec3f(translate[0], translate[1], translate[2])), mScaled(scale), + mTranslated(translate) {} + + nanovdb::math::Vec3 mScaleh = + nanovdb::math::Vec3(c10::Half(1.0f), c10::Half(1.0f), c10::Half(1.0f)); + nanovdb::math::Vec3 mTranslateh = + nanovdb::math::Vec3(c10::Half(0.0f), c10::Half(0.0f), c10::Half(0.0f)); + nanovdb::Vec3f mScalef = nanovdb::Vec3f(1.0f, 1.0f, 1.0f); nanovdb::Vec3f mTranslatef = nanovdb::Vec3f(0.0f, 0.0f, 0.0f); - nanovdb::Vec3d mScaled = nanovdb::Vec3d(1.0, 1.0, 1.0); + nanovdb::Vec3d mScaled = nanovdb::Vec3d(1.0, 1.0, 1.0); nanovdb::Vec3d mTranslated = nanovdb::Vec3d(0.0, 0.0, 0.0); /// @brief Get the scale component of this transformation /// @tparam T The scalar type to return the scale in /// @return The scale component of this transformation - template - __hostdev__ inline const nanovdb::math::Vec3& scale() const; + template __hostdev__ inline const nanovdb::math::Vec3 &scale() const; /// @brief Get the translation component of this transformation /// @tparam T The scalar type to return the translation in /// @return The translation component of this transformation - template - __hostdev__ inline const nanovdb::math::Vec3& translate() const; + template __hostdev__ inline const nanovdb::math::Vec3 &translate() const; } mTransform; }; // Template specializations to return the appropriate types template <> -__hostdev__ inline const nanovdb::math::Vec3& VoxelCoordTransform::Transform::scale() const { +__hostdev__ inline const nanovdb::math::Vec3 & +VoxelCoordTransform::Transform::scale() const { return mScaleh; } template <> -__hostdev__ inline const nanovdb::Vec3f& VoxelCoordTransform::Transform::scale() const { +__hostdev__ inline const nanovdb::Vec3f & +VoxelCoordTransform::Transform::scale() const { return mScalef; } template <> -__hostdev__ inline const nanovdb::Vec3d& VoxelCoordTransform::Transform::scale() const { +__hostdev__ inline const nanovdb::Vec3d & +VoxelCoordTransform::Transform::scale() const { return mScaled; } template <> -__hostdev__ inline const nanovdb::math::Vec3& VoxelCoordTransform::Transform::translate() const { +__hostdev__ inline const nanovdb::math::Vec3 & +VoxelCoordTransform::Transform::translate() const { return mTranslateh; } template <> -__hostdev__ inline const nanovdb::Vec3f& VoxelCoordTransform::Transform::translate() const { +__hostdev__ inline const nanovdb::Vec3f & +VoxelCoordTransform::Transform::translate() const { return mTranslatef; } template <> -__hostdev__ inline const nanovdb::Vec3d& VoxelCoordTransform::Transform::translate() const { +__hostdev__ inline const nanovdb::Vec3d & +VoxelCoordTransform::Transform::translate() const { return mTranslated; } -/// @brief Get a primal voxel transform given a voxel size and the coordinate of the [0, 0, 0] voxel center +/// @brief Get a primal voxel transform given a voxel size and the coordinate of the [0, 0, 0] voxel +/// center /// @param voxSize The size of each voxel in the grid /// @param voxOrigin The coordinate of the [0, 0, 0] voxel center /// @return The primal voxel transform -inline __hostdev__ VoxelCoordTransform primalVoxelTransformForSizeAndOrigin(const nanovdb::Vec3d& voxSize, const nanovdb::Vec3d& voxOrigin) { - // TORCH_CHECK_VALUE(voxSize[0] > 0.0 && voxSize[1] > 0.0 && voxSize[2] > 0.0, "voxel_size must be positive"); - const nanovdb::Vec3d& w = voxSize; - const nanovdb::Vec3d& tx = voxOrigin; - const nanovdb::Vec3d invW = nanovdb::Vec3d(1.0, 1.0, 1.0) / w; - const nanovdb::Vec3d half(0.5, 0.5, 0.5); +inline __hostdev__ VoxelCoordTransform +primalVoxelTransformForSizeAndOrigin(const nanovdb::Vec3d &voxSize, + const nanovdb::Vec3d &voxOrigin) { + // TORCH_CHECK_VALUE(voxSize[0] > 0.0 && voxSize[1] > 0.0 && voxSize[2] > 0.0, "voxel_size must + // be positive"); + const nanovdb::Vec3d &w = voxSize; + const nanovdb::Vec3d &tx = voxOrigin; + const nanovdb::Vec3d invW = nanovdb::Vec3d(1.0, 1.0, 1.0) / w; + const nanovdb::Vec3d half(0.5, 0.5, 0.5); return VoxelCoordTransform(invW, -tx / w); } -/// @brief Get the primal and dual transforms for a grid given a voxel size and the coordinate of the [0, 0, 0] voxel center +/// @brief Get the primal and dual transforms for a grid given a voxel size and the coordinate of +/// the [0, 0, 0] voxel center /// @param voxSize The size of each voxel in the grid /// @param voxOrigin The coordinate of the [0, 0, 0] voxel center /// @param outPrimal Output primal transform /// @param outDual Output dual transform -inline __hostdev__ void voxelTransformForSizeAndOrigin(const nanovdb::Vec3d& voxSize, const nanovdb::Vec3d& voxOrigin, - VoxelCoordTransform& outPrimal, VoxelCoordTransform& outDual) { - // TORCH_CHECK_VALUE(voxSize[0] > 0.0 && voxSize[1] > 0.0 && voxSize[2] > 0.0, "voxel_size must be positive"); - const nanovdb::Vec3d& w = voxSize; - const nanovdb::Vec3d& tx = voxOrigin; - const nanovdb::Vec3d invW = nanovdb::Vec3d(1.0, 1.0, 1.0) / w; - const nanovdb::Vec3d half(0.5, 0.5, 0.5); +inline __hostdev__ void +voxelTransformForSizeAndOrigin(const nanovdb::Vec3d &voxSize, const nanovdb::Vec3d &voxOrigin, + VoxelCoordTransform &outPrimal, VoxelCoordTransform &outDual) { + // TORCH_CHECK_VALUE(voxSize[0] > 0.0 && voxSize[1] > 0.0 && voxSize[2] > 0.0, "voxel_size must + // be positive"); + const nanovdb::Vec3d &w = voxSize; + const nanovdb::Vec3d &tx = voxOrigin; + const nanovdb::Vec3d invW = nanovdb::Vec3d(1.0, 1.0, 1.0) / w; + const nanovdb::Vec3d half(0.5, 0.5, 0.5); outPrimal = VoxelCoordTransform(invW, -tx / w); - outDual = VoxelCoordTransform(invW, -tx / w + half); + outDual = VoxelCoordTransform(invW, -tx / w + half); } } // namespace detail } // namespace fvdb + +#endif // FVDB_DETAIL_VOXELCOORDTRANSFORM_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/Attention.cpp b/fvdb/src/detail/autograd/Attention.cpp index 6f3f99ee77..a5dc8813d3 100644 --- a/fvdb/src/detail/autograd/Attention.cpp +++ b/fvdb/src/detail/autograd/Attention.cpp @@ -3,23 +3,21 @@ // #include "Attention.h" -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" +#include +#include namespace fvdb { namespace detail { namespace autograd { -Attention::variable_list Attention::forward(Attention::AutogradContext *ctx, - const Attention::Variable& query, - const Attention::Variable& key, - const Attention::Variable& value, - const Attention::Variable& qLengths, - const Attention::Variable& kvLengths, - float scale) { +Attention::variable_list +Attention::forward(Attention::AutogradContext *ctx, const Attention::Variable &query, + const Attention::Variable &key, const Attention::Variable &value, + const Attention::Variable &qLengths, const Attention::Variable &kvLengths, + float scale) { torch::Tensor out = FVDB_DISPATCH_KERNEL_DEVICE(query.device(), [&]() { - return ops::dispatchScaledDotProductAttention( - query, key, value, qLengths, kvLengths, true, scale); + return ops::dispatchScaledDotProductAttention(query, key, value, qLengths, + kvLengths, true, scale); }); // ctx->saved_data["tsmtThreshold"] = tsmtThreshold; @@ -29,15 +27,14 @@ Attention::variable_list Attention::forward(Attention::AutogradContext *ctx, // outOpacity, outDepth, outRgb, outWs // }); - return { out}; + return { out }; } -Attention::variable_list Attention::backward(Attention::AutogradContext *ctx, - Attention::variable_list grad_output) { +Attention::variable_list +Attention::backward(Attention::AutogradContext *ctx, Attention::variable_list grad_output) { TORCH_CHECK(false, "Not implemented"); } - } // namespace autograd } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/autograd/Attention.h b/fvdb/src/detail/autograd/Attention.h index 050ed643e3..9ee04a49ca 100644 --- a/fvdb/src/detail/autograd/Attention.h +++ b/fvdb/src/detail/autograd/Attention.h @@ -1,33 +1,29 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_ATTENTION_H +#define FVDB_DETAIL_AUTOGRAD_ATTENTION_H #include - namespace fvdb { namespace detail { namespace autograd { -struct Attention : public torch::autograd::Function -{ - using variable_list = torch::autograd::variable_list; +struct Attention : public torch::autograd::Function { + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - const Variable& query, - const Variable& key, - const Variable& value, - const Variable& qLengths, - const Variable& kvLengths, - float scale); + static variable_list forward(AutogradContext *ctx, const Variable &query, const Variable &key, + const Variable &value, const Variable &qLengths, + const Variable &kvLengths, float scale); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_ATTENTION_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/Autograd.h b/fvdb/src/detail/autograd/Autograd.h index 10873c9255..c63be9fac8 100644 --- a/fvdb/src/detail/autograd/Autograd.h +++ b/fvdb/src/detail/autograd/Autograd.h @@ -1,18 +1,23 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include "MaxPoolGrid.h" +#ifndef FVDB_DETAIL_AUTOGRAD_AUTOGRAD_H +#define FVDB_DETAIL_AUTOGRAD_AUTOGRAD_H + +#include "Attention.h" #include "AvgPoolGrid.h" +#include "FillToGrid.h" +#include "JaggedReduce.h" +#include "MaxPoolGrid.h" +#include "ReadFromDense.h" +#include "ReadIntoDense.h" #include "SampleGrid.h" +#include "SparseConvolutionHalo.h" +#include "SparseConvolutionImplicitGEMM.h" +#include "SparseConvolutionKernelMap.h" #include "SplatIntoGrid.h" -#include "UpsampleGrid.h" #include "TransformPoints.h" +#include "UpsampleGrid.h" #include "VolumeRender.h" -#include "SparseConvolutionKernelMap.h" -#include "SparseConvolutionHalo.h" -#include "SparseConvolutionImplicitGEMM.h" -#include "ReadIntoDense.h" -#include "ReadFromDense.h" -#include "FillToGrid.h" -#include "JaggedReduce.h" -#include "Attention.h" \ No newline at end of file + +#endif // FVDB_DETAIL_AUTOGRAD_AUTOGRAD_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/AvgPoolGrid.cpp b/fvdb/src/detail/autograd/AvgPoolGrid.cpp index 0078deb660..cf76ad295e 100644 --- a/fvdb/src/detail/autograd/AvgPoolGrid.cpp +++ b/fvdb/src/detail/autograd/AvgPoolGrid.cpp @@ -3,70 +3,60 @@ // #include "AvgPoolGrid.h" -#include - -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" +#include +#include +#include namespace fvdb { namespace detail { namespace autograd { -AvgPoolGrid::variable_list AvgPoolGrid::forward(AvgPoolGrid::AutogradContext *ctx, - c10::intrusive_ptr fineGrid, - c10::intrusive_ptr coarseGrid, - nanovdb::Coord poolingFactor, - nanovdb::Coord stride, - AvgPoolGrid::Variable fineData) { - +AvgPoolGrid::variable_list +AvgPoolGrid::forward(AvgPoolGrid::AutogradContext *ctx, c10::intrusive_ptr fineGrid, + c10::intrusive_ptr coarseGrid, nanovdb::Coord poolingFactor, + nanovdb::Coord stride, AvgPoolGrid::Variable fineData) { torch::Tensor outCoarseData = FVDB_DISPATCH_KERNEL_DEVICE(fineData.device(), [&]() { - return ops::dispatchDownsampleGridAvgPool( - *fineGrid, *coarseGrid, fineData, poolingFactor, stride); + return ops::dispatchDownsampleGridAvgPool(*fineGrid, *coarseGrid, fineData, + poolingFactor, stride); }); - ctx->save_for_backward({fineData}); - ctx->saved_data["fine_grid"] = fineGrid; - ctx->saved_data["coarse_grid"] = coarseGrid; - ctx->saved_data["pooling_factor_x"] = (int64_t) poolingFactor[0]; - ctx->saved_data["pooling_factor_y"] = (int64_t) poolingFactor[1]; - ctx->saved_data["pooling_factor_z"] = (int64_t) poolingFactor[2]; - ctx->saved_data["stride_x"] = (int64_t) stride[0]; - ctx->saved_data["stride_y"] = (int64_t) stride[1]; - ctx->saved_data["stride_z"] = (int64_t) stride[2]; + ctx->save_for_backward({ fineData }); + ctx->saved_data["fine_grid"] = fineGrid; + ctx->saved_data["coarse_grid"] = coarseGrid; + ctx->saved_data["pooling_factor_x"] = (int64_t)poolingFactor[0]; + ctx->saved_data["pooling_factor_y"] = (int64_t)poolingFactor[1]; + ctx->saved_data["pooling_factor_z"] = (int64_t)poolingFactor[2]; + ctx->saved_data["stride_x"] = (int64_t)stride[0]; + ctx->saved_data["stride_y"] = (int64_t)stride[1]; + ctx->saved_data["stride_z"] = (int64_t)stride[2]; - return variable_list({outCoarseData}); + return variable_list({ outCoarseData }); } -AvgPoolGrid::variable_list AvgPoolGrid::backward(AvgPoolGrid::AutogradContext *ctx, - AvgPoolGrid::variable_list grad_output) { - +AvgPoolGrid::variable_list +AvgPoolGrid::backward(AvgPoolGrid::AutogradContext *ctx, AvgPoolGrid::variable_list grad_output) { // Use data saved in forward - variable_list saved = ctx->get_saved_variables(); - Variable fineData = saved.at(0); - auto fineGrid = ctx->saved_data["fine_grid"].toCustomClass(); - auto coarseGrid = ctx->saved_data["coarse_grid"].toCustomClass(); - const int64_t poolingFactorX = ctx->saved_data["pooling_factor_x"].toInt(); - const int64_t poolingFactorY = ctx->saved_data["pooling_factor_y"].toInt(); - const int64_t poolingFactorZ = ctx->saved_data["pooling_factor_z"].toInt(); - const int64_t strideX = ctx->saved_data["stride_x"].toInt(); - const int64_t strideY = ctx->saved_data["stride_y"].toInt(); - const int64_t strideZ = ctx->saved_data["stride_z"].toInt(); + variable_list saved = ctx->get_saved_variables(); + Variable fineData = saved.at(0); + auto fineGrid = ctx->saved_data["fine_grid"].toCustomClass(); + auto coarseGrid = ctx->saved_data["coarse_grid"].toCustomClass(); + const int64_t poolingFactorX = ctx->saved_data["pooling_factor_x"].toInt(); + const int64_t poolingFactorY = ctx->saved_data["pooling_factor_y"].toInt(); + const int64_t poolingFactorZ = ctx->saved_data["pooling_factor_z"].toInt(); + const int64_t strideX = ctx->saved_data["stride_x"].toInt(); + const int64_t strideY = ctx->saved_data["stride_y"].toInt(); + const int64_t strideZ = ctx->saved_data["stride_z"].toInt(); const nanovdb::Coord poolingFactor(poolingFactorX, poolingFactorY, poolingFactorZ); const nanovdb::Coord stride(strideX, strideY, strideZ); - Variable gradOut = grad_output.at(0).contiguous(); // [#coarse_voxels | #coarse_corners, *] + Variable gradOut = grad_output.at(0).contiguous(); // [#coarse_voxels | #coarse_corners, *] Variable outGradIn = FVDB_DISPATCH_KERNEL_DEVICE(gradOut.device(), [&]() { return ops::dispatchDownsampleGridAvgPoolBackward( - *coarseGrid, *fineGrid, - fineData, - gradOut, - poolingFactor, - stride - ); + *coarseGrid, *fineGrid, fineData, gradOut, poolingFactor, stride); }); - return {torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), outGradIn}; + return { torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), outGradIn }; } } // namespace autograd diff --git a/fvdb/src/detail/autograd/AvgPoolGrid.h b/fvdb/src/detail/autograd/AvgPoolGrid.h index ec6211d7be..b8d6247baa 100644 --- a/fvdb/src/detail/autograd/AvgPoolGrid.h +++ b/fvdb/src/detail/autograd/AvgPoolGrid.h @@ -1,33 +1,32 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once - -#include +#ifndef FVDB_DETAIL_AUTOGRAD_AVGPOOLGRID_H +#define FVDB_DETAIL_AUTOGRAD_AVGPOOLGRID_H #include "detail/GridBatchImpl.h" +#include namespace fvdb { namespace detail { namespace autograd { struct AvgPoolGrid : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr fineGrid, + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr fineGrid, c10::intrusive_ptr coarseGrid, - nanovdb::Coord poolingFactor, - nanovdb::Coord stride, + nanovdb::Coord poolingFactor, nanovdb::Coord stride, Variable fineData); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_AVGPOOLGRID_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/FillToGrid.h b/fvdb/src/detail/autograd/FillToGrid.h index c1a43aa652..7af2a74730 100644 --- a/fvdb/src/detail/autograd/FillToGrid.h +++ b/fvdb/src/detail/autograd/FillToGrid.h @@ -1,78 +1,81 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_FILLTOGRID_H +#define FVDB_DETAIL_AUTOGRAD_FILLTOGRID_H -#include +#include +#include +#include +#include #include #include -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" - -#include "detail/GridBatchImpl.h" -#include "Types.h" - +#include namespace fvdb { namespace detail { namespace autograd { struct FillToGrid : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr fromGrid, - c10::intrusive_ptr toGrid, - Variable fromFeatures, - const int default_value=0.0) { - TORCH_CHECK_VALUE(fromFeatures.size(0) == fromGrid->totalVoxels(), "fromFeatures must conform to fromGrid"); - TORCH_CHECK_VALUE(fromGrid->batchSize() == toGrid->batchSize(), "fromGrid and toGrid must have the same batch size"); + static variable_list + forward(AutogradContext *ctx, c10::intrusive_ptr fromGrid, + c10::intrusive_ptr toGrid, Variable fromFeatures, + const int default_value = 0.0) { + TORCH_CHECK_VALUE(fromFeatures.size(0) == fromGrid->totalVoxels(), + "fromFeatures must conform to fromGrid"); + TORCH_CHECK_VALUE(fromGrid->batchSize() == toGrid->batchSize(), + "fromGrid and toGrid must have the same batch size"); torch::Tensor fromFeaturesReshape = featureCoalescedView(fromFeatures); - torch::Tensor ret = torch::full({toGrid->totalVoxels(), fromFeaturesReshape.size(1)}, + torch::Tensor ret = torch::full({ toGrid->totalVoxels(), fromFeaturesReshape.size(1) }, default_value, fromFeaturesReshape.options()); - auto outShape = spliceShape({toGrid->totalVoxels()}, fromFeatures, 1); // [B*M, *] + auto outShape = spliceShape({ toGrid->totalVoxels() }, fromFeatures, 1); // [B*M, *] // Dispatch to kernel. FVDB_DISPATCH_KERNEL_DEVICE(fromGrid->device(), [&]() { - ops::dispatchFillToGrid( - *fromGrid, *toGrid, fromFeaturesReshape, ret); + ops::dispatchFillToGrid(*fromGrid, *toGrid, fromFeaturesReshape, ret); }); ctx->saved_data["from_grid"] = fromGrid; - ctx->saved_data["to_grid"] = toGrid; + ctx->saved_data["to_grid"] = toGrid; - return variable_list({ret.reshape(outShape)}); + return variable_list({ ret.reshape(outShape) }); } - static variable_list backward(AutogradContext *ctx, - variable_list grad_output) { - torch::Tensor gradFeatures = grad_output[0]; + static variable_list + backward(AutogradContext *ctx, variable_list grad_output) { + torch::Tensor gradFeatures = grad_output[0]; torch::Tensor gradFeaturesReshape = featureCoalescedView(gradFeatures); auto fromGrid = ctx->saved_data["from_grid"].toCustomClass(); - auto toGrid = ctx->saved_data["to_grid"].toCustomClass(); - auto outShape = spliceShape({fromGrid->totalVoxels()}, gradFeatures, 1); // [B*M, *] + auto toGrid = ctx->saved_data["to_grid"].toCustomClass(); + auto outShape = spliceShape({ fromGrid->totalVoxels() }, gradFeatures, 1); // [B*M, *] - // The default grad_input is always 0.0, since gradient will only propagate for overlapped voxels. - torch::Tensor gradInput = torch::zeros({fromGrid->totalVoxels(), gradFeaturesReshape.size(1)}, - gradFeaturesReshape.options()); + // The default grad_input is always 0.0, since gradient will only propagate for overlapped + // voxels. + torch::Tensor gradInput = + torch::zeros({ fromGrid->totalVoxels(), gradFeaturesReshape.size(1) }, + gradFeaturesReshape.options()); // Dispatch same kernel but with to and from switched. FVDB_DISPATCH_KERNEL_DEVICE(fromGrid->device(), [&]() { - ops::dispatchFillToGrid( - *toGrid, *fromGrid, gradFeaturesReshape, gradInput); + ops::dispatchFillToGrid(*toGrid, *fromGrid, gradFeaturesReshape, gradInput); }); - return variable_list({torch::Tensor(), torch::Tensor(), gradInput.reshape(outShape), torch::Tensor()}); + return variable_list( + { torch::Tensor(), torch::Tensor(), gradInput.reshape(outShape), torch::Tensor() }); } }; } // namespace autograd } // namespace detail } // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_FILLTOGRID_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/JaggedReduce.cpp b/fvdb/src/detail/autograd/JaggedReduce.cpp index f2c0e99451..37f9f09b8c 100644 --- a/fvdb/src/detail/autograd/JaggedReduce.cpp +++ b/fvdb/src/detail/autograd/JaggedReduce.cpp @@ -3,17 +3,17 @@ // #include "JaggedReduce.h" -#include - -#include "detail/ops/jagged/JaggedOps.h" -#include "detail/utils/Utils.h" +#include +#include +#include namespace fvdb { namespace detail { namespace autograd { -static inline std::vector list2vec(const c10::List list) { +static inline std::vector +list2vec(const c10::List list) { std::vector result; result.reserve(list.size()); for (size_t i = 0; i < list.size(); i++) @@ -21,109 +21,107 @@ static inline std::vector list2vec(const c10::List list) { return result; } -JaggedSum::variable_list JaggedSum::forward(JaggedSum::AutogradContext *ctx, - JaggedSum::Variable jdata, - JaggedSum::Variable jidx, - JaggedSum::Variable joffsets, - int64_t dim_size) { +JaggedSum::variable_list +JaggedSum::forward(JaggedSum::AutogradContext *ctx, JaggedSum::Variable jdata, + JaggedSum::Variable jidx, JaggedSum::Variable joffsets, int64_t dim_size) { TORCH_CHECK_VALUE(jdata.device() == jidx.device(), "jdata and jidx must be on the same device"); - TORCH_CHECK_VALUE(jdata.device() == joffsets.device(), "jdata and joffsets must be on the same device"); + TORCH_CHECK_VALUE(jdata.device() == joffsets.device(), + "jdata and joffsets must be on the same device"); torch::Tensor outData = FVDB_DISPATCH_KERNEL_DEVICE(jdata.device(), [&]() { return ops::dispatchJaggedSum(jdata, jidx, joffsets, dim_size); }); - ctx->save_for_backward({jidx}); - return variable_list({outData}); + ctx->save_for_backward({ jidx }); + return variable_list({ outData }); } -JaggedSum::variable_list JaggedSum::backward(JaggedSum::AutogradContext *ctx, - JaggedSum::variable_list grad_output) { - variable_list saved = ctx->get_saved_variables(); - Variable jidx = saved.at(0); - Variable gradIn = grad_output.at(0).index({jidx.to(torch::kInt32)}); - return {gradIn, torch::Tensor(), torch::Tensor(), torch::Tensor()}; +JaggedSum::variable_list +JaggedSum::backward(JaggedSum::AutogradContext *ctx, JaggedSum::variable_list grad_output) { + variable_list saved = ctx->get_saved_variables(); + Variable jidx = saved.at(0); + Variable gradIn = grad_output.at(0).index({ jidx.to(torch::kInt32) }); + return { gradIn, torch::Tensor(), torch::Tensor(), torch::Tensor() }; } -JaggedMin::variable_list JaggedMin::forward(JaggedMin::AutogradContext *ctx, - JaggedMin::Variable jdata, - JaggedMin::Variable jidx, - JaggedMin::Variable joffsets, - int64_t dim_size) { +JaggedMin::variable_list +JaggedMin::forward(JaggedMin::AutogradContext *ctx, JaggedMin::Variable jdata, + JaggedMin::Variable jidx, JaggedMin::Variable joffsets, int64_t dim_size) { TORCH_CHECK_VALUE(jdata.device() == jidx.device(), "jdata and jidx must be on the same device"); - TORCH_CHECK_VALUE(jdata.device() == joffsets.device(), "jdata and joffsets must be on the same device"); + TORCH_CHECK_VALUE(jdata.device() == joffsets.device(), + "jdata and joffsets must be on the same device"); - auto minOut = FVDB_DISPATCH_KERNEL_DEVICE(jdata.device(), [&]() { + auto minOut = FVDB_DISPATCH_KERNEL_DEVICE(jdata.device(), [&]() { return ops::dispatchJaggedMin(jdata, jidx, joffsets, dim_size); }); torch::Tensor minData = minOut[0]; - torch::Tensor minIdx = minOut[1]; - ctx->save_for_backward({minIdx, joffsets}); + torch::Tensor minIdx = minOut[1]; + ctx->save_for_backward({ minIdx, joffsets }); ctx->saved_data["src_shape"] = jdata.sizes(); - return variable_list({minData, minIdx}); + return variable_list({ minData, minIdx }); } -JaggedMin::variable_list JaggedMin::backward(JaggedMin::AutogradContext *ctx, - JaggedMin::variable_list grad_output) { - variable_list saved = ctx->get_saved_variables(); - Variable gradOut = grad_output.at(0); - Variable minIdx = saved.at(0); - Variable joffsets0 = saved.at(1).index({torch::indexing::Slice(0, -1)}); +JaggedMin::variable_list +JaggedMin::backward(JaggedMin::AutogradContext *ctx, JaggedMin::variable_list grad_output) { + variable_list saved = ctx->get_saved_variables(); + Variable gradOut = grad_output.at(0); + Variable minIdx = saved.at(0); + Variable joffsets0 = saved.at(1).index({ torch::indexing::Slice(0, -1) }); for (int i = 0; i < minIdx.dim() - 1; i += 1) { joffsets0 = joffsets0.unsqueeze(-1); } auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); - // For output that receives no input, propagate to position -1 will result in memory out-of-bound error. + // For output that receives no input, propagate to position -1 will result in memory + // out-of-bound error. // Therefore, we need to add a dummy zero at the beginning of the index tensor. // src_shape[0] += 1; Variable gradIn = torch::zeros(src_shape, gradOut.options()); gradIn.scatter_(0, minIdx + joffsets0, gradOut); // gradIn = gradIn.narrow(0, 1, src_shape[0] - 1); - return {gradIn, torch::Tensor(), torch::Tensor(), torch::Tensor()}; + return { gradIn, torch::Tensor(), torch::Tensor(), torch::Tensor() }; } -JaggedMax::variable_list JaggedMax::forward(JaggedMax::AutogradContext *ctx, - JaggedMax::Variable jdata, - JaggedMax::Variable jidx, - JaggedMax::Variable joffsets, - int64_t dim_size) { +JaggedMax::variable_list +JaggedMax::forward(JaggedMax::AutogradContext *ctx, JaggedMax::Variable jdata, + JaggedMax::Variable jidx, JaggedMax::Variable joffsets, int64_t dim_size) { TORCH_CHECK_VALUE(jdata.device() == jidx.device(), "jdata and jidx must be on the same device"); - TORCH_CHECK_VALUE(jdata.device() == joffsets.device(), "jdata and joffsets must be on the same device"); + TORCH_CHECK_VALUE(jdata.device() == joffsets.device(), + "jdata and joffsets must be on the same device"); - auto maxOut = FVDB_DISPATCH_KERNEL_DEVICE(jdata.device(), [&]() { + auto maxOut = FVDB_DISPATCH_KERNEL_DEVICE(jdata.device(), [&]() { return ops::dispatchJaggedMax(jdata, jidx, joffsets, dim_size); }); torch::Tensor maxData = maxOut[0]; - torch::Tensor maxIdx = maxOut[1]; + torch::Tensor maxIdx = maxOut[1]; - ctx->save_for_backward({maxIdx, joffsets}); + ctx->save_for_backward({ maxIdx, joffsets }); ctx->saved_data["src_shape"] = jdata.sizes(); - return variable_list({maxData, maxIdx}); + return variable_list({ maxData, maxIdx }); } -JaggedMax::variable_list JaggedMax::backward(JaggedMax::AutogradContext *ctx, - JaggedMax::variable_list grad_output) { - variable_list saved = ctx->get_saved_variables(); - Variable gradOut = grad_output.at(0); - Variable maxIdx = saved.at(0); - Variable joffsets0 = saved.at(1).index({torch::indexing::Slice(0, -1)}); +JaggedMax::variable_list +JaggedMax::backward(JaggedMax::AutogradContext *ctx, JaggedMax::variable_list grad_output) { + variable_list saved = ctx->get_saved_variables(); + Variable gradOut = grad_output.at(0); + Variable maxIdx = saved.at(0); + Variable joffsets0 = saved.at(1).index({ torch::indexing::Slice(0, -1) }); for (int i = 0; i < maxIdx.dim() - 1; i += 1) { joffsets0 = joffsets0.unsqueeze(-1); } auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); - // For output that receives no input, propagate to position -1 will result in memory out-of-bound error. + // For output that receives no input, propagate to position -1 will result in memory + // out-of-bound error. // Therefore, we need to add a dummy zero at the beginning of the index tensor. // src_shape[0] += 1; Variable gradIn = torch::zeros(src_shape, gradOut.options()); gradIn.scatter_(0, maxIdx + joffsets0, gradOut); // gradIn = gradIn.narrow(0, 1, src_shape[0] - 1); - return {gradIn, torch::Tensor(), torch::Tensor(), torch::Tensor()}; + return { gradIn, torch::Tensor(), torch::Tensor(), torch::Tensor() }; } - } // namespace autograd } // namespace detail } // namespace fvdb \ No newline at end of file diff --git a/fvdb/src/detail/autograd/JaggedReduce.h b/fvdb/src/detail/autograd/JaggedReduce.h index 151a2cda81..e9e3c24107 100644 --- a/fvdb/src/detail/autograd/JaggedReduce.h +++ b/fvdb/src/detail/autograd/JaggedReduce.h @@ -1,56 +1,52 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once - -#include +#ifndef FVDB_DETAIL_AUTOGRAD_JAGGEDREDUCE_H +#define FVDB_DETAIL_AUTOGRAD_JAGGEDREDUCE_H #include "detail/GridBatchImpl.h" +#include namespace fvdb { namespace detail { namespace autograd { struct JaggedSum : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - Variable jdata, Variable jidx, + static variable_list forward(AutogradContext *ctx, Variable jdata, Variable jidx, Variable joffsets, int64_t dim_size); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; struct JaggedMin : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - Variable jdata, Variable jidx, + static variable_list forward(AutogradContext *ctx, Variable jdata, Variable jidx, Variable joffsets, int64_t dim_size); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; struct JaggedMax : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - Variable jdata, Variable jidx, + static variable_list forward(AutogradContext *ctx, Variable jdata, Variable jidx, Variable joffsets, int64_t dim_size); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_JAGGEDREDUCE_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/MaxPoolGrid.cpp b/fvdb/src/detail/autograd/MaxPoolGrid.cpp index f2f9c51a5f..38ef2c8218 100644 --- a/fvdb/src/detail/autograd/MaxPoolGrid.cpp +++ b/fvdb/src/detail/autograd/MaxPoolGrid.cpp @@ -3,71 +3,61 @@ // #include "MaxPoolGrid.h" -#include - -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" +#include +#include +#include namespace fvdb { namespace detail { namespace autograd { -MaxPoolGrid::variable_list MaxPoolGrid::forward(MaxPoolGrid::AutogradContext *ctx, - c10::intrusive_ptr fineGrid, - c10::intrusive_ptr coarseGrid, - nanovdb::Coord poolingFactor, - nanovdb::Coord stride, - MaxPoolGrid::Variable fineData) { - +MaxPoolGrid::variable_list +MaxPoolGrid::forward(MaxPoolGrid::AutogradContext *ctx, c10::intrusive_ptr fineGrid, + c10::intrusive_ptr coarseGrid, nanovdb::Coord poolingFactor, + nanovdb::Coord stride, MaxPoolGrid::Variable fineData) { torch::Tensor outCoarseData = FVDB_DISPATCH_KERNEL_DEVICE(fineData.device(), [&]() { - return ops::dispatchDownsampleGridMaxPool( - *fineGrid, *coarseGrid, fineData, poolingFactor, stride); + return ops::dispatchDownsampleGridMaxPool(*fineGrid, *coarseGrid, fineData, + poolingFactor, stride); }); - ctx->save_for_backward({fineData}); - ctx->saved_data["fine_grid"] = fineGrid; - ctx->saved_data["coarse_grid"] = coarseGrid; - ctx->saved_data["pooling_factor_x"] = (int64_t) poolingFactor[0]; - ctx->saved_data["pooling_factor_y"] = (int64_t) poolingFactor[1]; - ctx->saved_data["pooling_factor_z"] = (int64_t) poolingFactor[2]; - ctx->saved_data["stride_x"] = (int64_t) stride[0]; - ctx->saved_data["stride_y"] = (int64_t) stride[1]; - ctx->saved_data["stride_z"] = (int64_t) stride[2]; + ctx->save_for_backward({ fineData }); + ctx->saved_data["fine_grid"] = fineGrid; + ctx->saved_data["coarse_grid"] = coarseGrid; + ctx->saved_data["pooling_factor_x"] = (int64_t)poolingFactor[0]; + ctx->saved_data["pooling_factor_y"] = (int64_t)poolingFactor[1]; + ctx->saved_data["pooling_factor_z"] = (int64_t)poolingFactor[2]; + ctx->saved_data["stride_x"] = (int64_t)stride[0]; + ctx->saved_data["stride_y"] = (int64_t)stride[1]; + ctx->saved_data["stride_z"] = (int64_t)stride[2]; - return variable_list({outCoarseData}); + return variable_list({ outCoarseData }); } -MaxPoolGrid::variable_list MaxPoolGrid::backward(MaxPoolGrid::AutogradContext *ctx, - MaxPoolGrid::variable_list grad_output) { - +MaxPoolGrid::variable_list +MaxPoolGrid::backward(MaxPoolGrid::AutogradContext *ctx, MaxPoolGrid::variable_list grad_output) { // Use data saved in forward - variable_list saved = ctx->get_saved_variables(); - Variable fineData = saved.at(0); - auto fineGrid = ctx->saved_data["fine_grid"].toCustomClass(); - auto coarseGrid = ctx->saved_data["coarse_grid"].toCustomClass(); - const int64_t poolingFactorX = ctx->saved_data["pooling_factor_x"].toInt(); - const int64_t poolingFactorY = ctx->saved_data["pooling_factor_y"].toInt(); - const int64_t poolingFactorZ = ctx->saved_data["pooling_factor_z"].toInt(); - const int64_t strideX = ctx->saved_data["stride_x"].toInt(); - const int64_t strideY = ctx->saved_data["stride_y"].toInt(); - const int64_t strideZ = ctx->saved_data["stride_z"].toInt(); + variable_list saved = ctx->get_saved_variables(); + Variable fineData = saved.at(0); + auto fineGrid = ctx->saved_data["fine_grid"].toCustomClass(); + auto coarseGrid = ctx->saved_data["coarse_grid"].toCustomClass(); + const int64_t poolingFactorX = ctx->saved_data["pooling_factor_x"].toInt(); + const int64_t poolingFactorY = ctx->saved_data["pooling_factor_y"].toInt(); + const int64_t poolingFactorZ = ctx->saved_data["pooling_factor_z"].toInt(); + const int64_t strideX = ctx->saved_data["stride_x"].toInt(); + const int64_t strideY = ctx->saved_data["stride_y"].toInt(); + const int64_t strideZ = ctx->saved_data["stride_z"].toInt(); const nanovdb::Coord poolingFactor(poolingFactorX, poolingFactorY, poolingFactorZ); const nanovdb::Coord stride(strideX, strideY, strideZ); - Variable gradOut = grad_output.at(0).contiguous(); // [#coarse_voxels | #coarse_corners, *] + Variable gradOut = grad_output.at(0).contiguous(); // [#coarse_voxels | #coarse_corners, *] Variable outGradIn = FVDB_DISPATCH_KERNEL_DEVICE(gradOut.device(), [&]() { return ops::dispatchDownsampleGridMaxPoolBackward( - *coarseGrid, *fineGrid, - fineData, - gradOut, - poolingFactor, - stride - ); + *coarseGrid, *fineGrid, fineData, gradOut, poolingFactor, stride); }); - return {torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), outGradIn}; + return { torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), outGradIn }; } } // namespace autograd diff --git a/fvdb/src/detail/autograd/MaxPoolGrid.h b/fvdb/src/detail/autograd/MaxPoolGrid.h index 51bd617a90..4f02314c89 100644 --- a/fvdb/src/detail/autograd/MaxPoolGrid.h +++ b/fvdb/src/detail/autograd/MaxPoolGrid.h @@ -1,33 +1,32 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_MAXPOOLGRID_H +#define FVDB_DETAIL_AUTOGRAD_MAXPOOLGRID_H -#include - -#include "detail/GridBatchImpl.h" +#include +#include namespace fvdb { namespace detail { namespace autograd { struct MaxPoolGrid : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr fineGrid, + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr fineGrid, c10::intrusive_ptr coarseGrid, - nanovdb::Coord poolingFactor, - nanovdb::Coord stride, + nanovdb::Coord poolingFactor, nanovdb::Coord stride, Variable fineData); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_MAXPOOLGRID_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/ReadFromDense.h b/fvdb/src/detail/autograd/ReadFromDense.h index a4a8f81217..101fcceb6a 100644 --- a/fvdb/src/detail/autograd/ReadFromDense.h +++ b/fvdb/src/detail/autograd/ReadFromDense.h @@ -1,36 +1,35 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_READFROMDENSE_H +#define FVDB_DETAIL_AUTOGRAD_READFROMDENSE_H -#include +#include +#include +#include +#include #include #include -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" - -#include "detail/GridBatchImpl.h" -#include "Types.h" - +#include namespace fvdb { namespace detail { namespace autograd { struct ReadFromDense : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr grid, - Variable denseData, - const Vec3iBatch& denseOrigins) { + static variable_list + forward(AutogradContext *ctx, c10::intrusive_ptr grid, Variable denseData, + const Vec3iBatch &denseOrigins) { TORCH_CHECK_VALUE(denseData.dim() > 4, "dense data must have shape [B, W, H, D, *]"); - TORCH_CHECK_VALUE(denseData.size(0) == grid->batchSize(), "dense data must have shape [B, W, H, D, *]"); + TORCH_CHECK_VALUE(denseData.size(0) == grid->batchSize(), + "dense data must have shape [B, W, H, D, *]"); TORCH_CHECK_VALUE(denseData.is_contiguous(), "sparse_data must be contiguous"); grid->checkDevice(denseData); @@ -41,49 +40,55 @@ struct ReadFromDense : public torch::autograd::Function { torch::Tensor denseDataReshape = featureCoalescedView(denseData, 4); // [N, -1] - torch::Tensor ret = torch::zeros({grid->totalVoxels(), denseDataReshape.size(4)}, denseData.options()); + torch::Tensor ret = + torch::zeros({ grid->totalVoxels(), denseDataReshape.size(4) }, denseData.options()); // nanovdb::Coord denseOriginNvdb = tensorToCoord(denseOrigins); // NanoVDB coordinates are int32 - torch::Tensor denseOriginsI32 = denseOrigins.tensorValue(grid->batchSize(), false /*onlyPositive*/, "dense_origins").to(denseData.device()); + torch::Tensor denseOriginsI32 = + denseOrigins.tensorValue(grid->batchSize(), false /*onlyPositive*/, "dense_origins") + .to(denseData.device()); FVDB_DISPATCH_KERNEL_DEVICE(grid->device(), [&]() { - ops::dispatchReadFromDense( - *grid, denseDataReshape, denseOriginsI32, ret, false); + ops::dispatchReadFromDense(*grid, denseDataReshape, denseOriginsI32, ret, + false); }); // Reshape [B, N, -1] to [B, N, *] given [B, W, H, D, *] - torch::Tensor retReshape = ret.view( - spliceShape({grid->totalVoxels()}, denseData, 4)); + torch::Tensor retReshape = ret.view(spliceShape({ grid->totalVoxels() }, denseData, 4)); // Save shape information for backward ctx->saved_data["dense_origin"] = denseOriginsI32; - ctx->saved_data["grid_size"] = coordToTensor(nanovdb::Coord(denseData.size(1), denseData.size(2), denseData.size(3))); - ctx->saved_data["grid"] = grid; - ctx->saved_data["dummy_tensor"] = torch::empty({0}, denseData.options()); - torch::Tensor retShape = torch::empty({(int64_t) denseData.dim()}, torch::TensorOptions().dtype(torch::kLong)); + ctx->saved_data["grid_size"] = + coordToTensor(nanovdb::Coord(denseData.size(1), denseData.size(2), denseData.size(3))); + ctx->saved_data["grid"] = grid; + ctx->saved_data["dummy_tensor"] = torch::empty({ 0 }, denseData.options()); + torch::Tensor retShape = + torch::empty({ (int64_t)denseData.dim() }, torch::TensorOptions().dtype(torch::kLong)); auto acc = retShape.accessor(); for (int i = 0; i < denseData.dim(); i++) { acc[i] = denseData.size(i); } ctx->saved_data["final_shape"] = retShape; - return variable_list({retReshape}); // [N, *] + return variable_list({ retReshape }); // [N, *] } - static variable_list backward(AutogradContext *ctx, - variable_list grad_output) { - + static variable_list + backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward - torch::Tensor denseOrigins = ctx->saved_data["dense_origin"].toTensor(); // [B, 3] - nanovdb::Coord gridSize = tensorToCoord(ctx->saved_data["grid_size"].toTensor()); - auto grid = ctx->saved_data["grid"].toCustomClass(); + torch::Tensor denseOrigins = ctx->saved_data["dense_origin"].toTensor(); // [B, 3] + nanovdb::Coord gridSize = tensorToCoord(ctx->saved_data["grid_size"].toTensor()); + auto grid = ctx->saved_data["grid"].toCustomClass(); torch::TensorOptions denseDataOpts = ctx->saved_data["dummy_tensor"].toTensor().options(); - std::vector finalShapeTensor = intTensor1DToStdVector(ctx->saved_data["final_shape"].toTensor()); + std::vector finalShapeTensor = + intTensor1DToStdVector(ctx->saved_data["final_shape"].toTensor()); - Variable gradOut = grad_output.at(0); // [N, *] + Variable gradOut = grad_output.at(0); // [N, *] torch::Tensor gradOutReshape = featureCoalescedView(gradOut); // [N, -1] - torch::Tensor ret = torch::zeros({grid->batchSize(), gridSize[0], gridSize[1], gridSize[2], gradOutReshape.size(1)}, denseDataOpts); // [B, W, H, D, -1] + torch::Tensor ret = torch::zeros( + { grid->batchSize(), gridSize[0], gridSize[1], gridSize[2], gradOutReshape.size(1) }, + denseDataOpts); // [B, W, H, D, -1] FVDB_DISPATCH_KERNEL_DEVICE(grid->device(), [&]() { ops::dispatchReadIntoDense(*grid, gradOutReshape, denseOrigins, ret, false); @@ -91,10 +96,12 @@ struct ReadFromDense : public torch::autograd::Function { torch::Tensor retReshape = ret.view(finalShapeTensor); // [B, W, H, D, *] - return {torch::Tensor(), retReshape, torch::Tensor()}; + return { torch::Tensor(), retReshape, torch::Tensor() }; } }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_READFROMDENSE_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/ReadIntoDense.cpp b/fvdb/src/detail/autograd/ReadIntoDense.cpp index 5d05aef23d..9572107ce4 100644 --- a/fvdb/src/detail/autograd/ReadIntoDense.cpp +++ b/fvdb/src/detail/autograd/ReadIntoDense.cpp @@ -3,97 +3,113 @@ // #include "ReadIntoDense.h" -#include - -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" +#include +#include +#include namespace fvdb { namespace detail { namespace autograd { -ReadIntoDense::variable_list ReadIntoDense::forward(ReadIntoDense::AutogradContext *ctx, - c10::intrusive_ptr grid, - ReadIntoDense::Variable sparseData, - const torch::optional& maybeMinCoord, - const torch::optional& maybeGridSize) { +ReadIntoDense::variable_list +ReadIntoDense::forward(ReadIntoDense::AutogradContext *ctx, c10::intrusive_ptr grid, + ReadIntoDense::Variable sparseData, + const torch::optional &maybeMinCoord, + const torch::optional &maybeGridSize) { TORCH_CHECK_VALUE(sparseData.dim() > 1, "sparse_data must have shape [num_voxels, *]"); - TORCH_CHECK_VALUE(sparseData.size(0) == grid->totalVoxels(), "sparseData must have shape (num_voxels, *) where num_voxels = " + std::to_string(grid->totalVoxels())); + TORCH_CHECK_VALUE(sparseData.size(0) == grid->totalVoxels(), + "sparseData must have shape (num_voxels, *) where num_voxels = " + + std::to_string(grid->totalVoxels())); TORCH_CHECK_VALUE(sparseData.is_contiguous(), "sparse_data must be contiguous"); grid->checkDevice(sparseData); // Non empty grid->checkNonEmptyGrid(); - nanovdb::CoordBBox gridbb = grid->totalBBox(); // FIXME: Batched should use maximum bounding box which we need to compute + nanovdb::CoordBBox gridbb = grid->totalBBox(); // FIXME: Batched should use maximum bounding box + // which we need to compute - // Min coord is an integer tensor of shape [3,] or [B, 3] representing the minimum coordinate of the dense tensor + // Min coord is an integer tensor of shape [3,] or [B, 3] representing the minimum coordinate of + // the dense tensor torch::Tensor denseOrigins; if (maybeMinCoord.has_value()) { - denseOrigins = maybeMinCoord.value().tensorValue(grid->batchSize(), false /*onlyPositive*/, "min_coord").to(sparseData.device()); + denseOrigins = maybeMinCoord.value() + .tensorValue(grid->batchSize(), false /*onlyPositive*/, "min_coord") + .to(sparseData.device()); } else { - denseOrigins = coordToTensor(gridbb.min()).to(torch::kInt32).unsqueeze(0).repeat({grid->batchSize(), 1}).to(sparseData.device()); + denseOrigins = coordToTensor(gridbb.min()) + .to(torch::kInt32) + .unsqueeze(0) + .repeat({ grid->batchSize(), 1 }) + .to(sparseData.device()); } TORCH_CHECK_VALUE(denseOrigins.dim() == 2, "min_coord must have shape [3,] or [B, 3]"); - TORCH_CHECK_VALUE(denseOrigins.size(0) == grid->batchSize(), "min_coord must have shape [3,] or [B, 3]"); + TORCH_CHECK_VALUE(denseOrigins.size(0) == grid->batchSize(), + "min_coord must have shape [3,] or [B, 3]"); TORCH_CHECK_VALUE(denseOrigins.size(1) == 3, "min_coord must have shape [3,] or [B, 3]"); nanovdb::Coord gridSize = gridbb.dim(); if (maybeGridSize.has_value()) { gridSize = maybeGridSize.value().value(); } - TORCH_CHECK_VALUE(gridSize[0] >= 0 && gridSize[1] >= 0 && gridSize[2] >= 0, "grid_size must be non-negative"); + TORCH_CHECK_VALUE(gridSize[0] >= 0 && gridSize[1] >= 0 && gridSize[2] >= 0, + "grid_size must be non-negative"); - torch::Tensor sparseDataReshape = featureCoalescedView(sparseData); // [N, -1] + torch::Tensor sparseDataReshape = featureCoalescedView(sparseData); // [N, -1] TORCH_CHECK_VALUE(sparseDataReshape.is_contiguous(), "sparse_data must be contiguous"); - torch::Tensor ret = torch::zeros({grid->batchSize(), gridSize[0], gridSize[1], gridSize[2], sparseDataReshape.size(1)}, sparseData.options()); // [B, W, H, D, -1] + torch::Tensor ret = torch::zeros( + { grid->batchSize(), gridSize[0], gridSize[1], gridSize[2], sparseDataReshape.size(1) }, + sparseData.options()); // [B, W, H, D, -1] FVDB_DISPATCH_KERNEL_DEVICE(grid->device(), [&]() { ops::dispatchReadIntoDense(*grid, sparseDataReshape, denseOrigins, ret, false); }); - torch::Tensor retReshape = ret.view(spliceShape({grid->batchSize(), gridSize[0], gridSize[1], gridSize[2]}, sparseData)); + torch::Tensor retReshape = ret.view( + spliceShape({ grid->batchSize(), gridSize[0], gridSize[1], gridSize[2] }, sparseData)); TORCH_CHECK(retReshape.is_contiguous(), "retReshape must be contiguous"); // Save shape information for backward ctx->saved_data["dense_origins"] = denseOrigins; - ctx->saved_data["grid_size"] = coordToTensor(gridSize); - torch::Tensor retShape = torch::empty({(int64_t) sparseData.dim()}, torch::TensorOptions().dtype(torch::kLong)); + ctx->saved_data["grid_size"] = coordToTensor(gridSize); + torch::Tensor retShape = + torch::empty({ (int64_t)sparseData.dim() }, torch::TensorOptions().dtype(torch::kLong)); auto acc = retShape.accessor(); for (int i = 0; i < sparseData.dim(); i++) { acc[i] = sparseData.size(i); } - ctx->saved_data["final_shape"] = retShape; - ctx->saved_data["first_dim"] = sparseDataReshape.size(0); - ctx->saved_data["last_dim"] = sparseDataReshape.size(1); - ctx->saved_data["dummy_tensor"] = torch::empty({0}, sparseData.options()); - ctx->saved_data["grid"] = grid; + ctx->saved_data["final_shape"] = retShape; + ctx->saved_data["first_dim"] = sparseDataReshape.size(0); + ctx->saved_data["last_dim"] = sparseDataReshape.size(1); + ctx->saved_data["dummy_tensor"] = torch::empty({ 0 }, sparseData.options()); + ctx->saved_data["grid"] = grid; - return variable_list({retReshape}); + return variable_list({ retReshape }); } -ReadIntoDense::variable_list ReadIntoDense::backward(ReadIntoDense::AutogradContext *ctx, - ReadIntoDense::variable_list grad_output) { - +ReadIntoDense::variable_list +ReadIntoDense::backward(ReadIntoDense::AutogradContext *ctx, + ReadIntoDense::variable_list grad_output) { // Use data saved in forward - torch::Tensor denseOrigins = ctx->saved_data["dense_origins"].toTensor(); // [B, 3] - int64_t firstDim = ctx->saved_data["first_dim"].toInt(); - int64_t lastDim = ctx->saved_data["last_dim"].toInt(); - std::vector finalShapeTensor = intTensor1DToStdVector(ctx->saved_data["final_shape"].toTensor()); + torch::Tensor denseOrigins = ctx->saved_data["dense_origins"].toTensor(); // [B, 3] + int64_t firstDim = ctx->saved_data["first_dim"].toInt(); + int64_t lastDim = ctx->saved_data["last_dim"].toInt(); + std::vector finalShapeTensor = + intTensor1DToStdVector(ctx->saved_data["final_shape"].toTensor()); torch::TensorOptions sparseDataOpts = ctx->saved_data["dummy_tensor"].toTensor().options(); - auto grid = ctx->saved_data["grid"].toCustomClass(); - Variable gradOut = grad_output.at(0); // [B, W, H, D, *] + auto grid = ctx->saved_data["grid"].toCustomClass(); + Variable gradOut = grad_output.at(0); // [B, W, H, D, *] - torch::Tensor gradOutReshape = featureCoalescedView(gradOut, 4); // [B, W, H, D, -1] + torch::Tensor gradOutReshape = featureCoalescedView(gradOut, 4); // [B, W, H, D, -1] - torch::Tensor ret = torch::zeros({firstDim, lastDim}, sparseDataOpts); // [N, -1] + torch::Tensor ret = torch::zeros({ firstDim, lastDim }, sparseDataOpts); // [N, -1] FVDB_DISPATCH_KERNEL_DEVICE(grid->device(), [&]() { ops::dispatchReadFromDense(*grid, gradOutReshape, denseOrigins, ret, false); }); - torch::Tensor retReshape = ret.view(finalShapeTensor); // [N, *] + torch::Tensor retReshape = ret.view(finalShapeTensor); // [N, *] - return {torch::Tensor(), retReshape, torch::Tensor(), torch::Tensor()}; + return { torch::Tensor(), retReshape, torch::Tensor(), torch::Tensor() }; } } // namespace autograd diff --git a/fvdb/src/detail/autograd/ReadIntoDense.h b/fvdb/src/detail/autograd/ReadIntoDense.h index c2d8736e97..d6bed67feb 100644 --- a/fvdb/src/detail/autograd/ReadIntoDense.h +++ b/fvdb/src/detail/autograd/ReadIntoDense.h @@ -1,7 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_READINTODENSE_H +#define FVDB_DETAIL_AUTOGRAD_READINTODENSE_H #include @@ -14,20 +15,20 @@ namespace detail { namespace autograd { struct ReadIntoDense : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr grid, - Variable sparseData, - const torch::optional& maybeMinCoord, - const torch::optional& maybeGridSize); + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr grid, + Variable sparseData, + const torch::optional &maybeMinCoord, + const torch::optional &maybeGridSize); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_READINTODENSE_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/SampleGrid.cpp b/fvdb/src/detail/autograd/SampleGrid.cpp index 9cf5fd4c10..7a87a6b756 100644 --- a/fvdb/src/detail/autograd/SampleGrid.cpp +++ b/fvdb/src/detail/autograd/SampleGrid.cpp @@ -3,45 +3,44 @@ // #include "SampleGrid.h" -#include "detail/ops/Ops.h" +#include +#include -#include "detail/utils/Utils.h" - - - - -void checkForwardInputs(c10::intrusive_ptr grid, - fvdb::detail::autograd::SampleGridTrilinear::JaggedVariable points, - fvdb::detail::autograd::SampleGridTrilinear::Variable data, - bool returnGrad) { +void +checkForwardInputs(c10::intrusive_ptr grid, + fvdb::detail::autograd::SampleGridTrilinear::JaggedVariable points, + fvdb::detail::autograd::SampleGridTrilinear::Variable data, bool returnGrad) { grid->checkNonEmptyGrid(); - TORCH_CHECK_VALUE(points.device() == data.device(), "points and data must be on the same device"); + TORCH_CHECK_VALUE(points.device() == data.device(), + "points and data must be on the same device"); grid->checkDevice(points); grid->checkDevice(data); points.check_valid(); TORCH_CHECK_TYPE(points.is_floating_point(), "points must have a floating point type"); TORCH_CHECK_TYPE(points.dtype() == data.dtype(), "all tensors must have the same type"); - TORCH_CHECK_VALUE(points.rdim() == 2, "Expected points to have shape [B*M, 3] (wrong number of dimensions)"); + TORCH_CHECK_VALUE(points.rdim() == 2, + "Expected points to have shape [B*M, 3] (wrong number of dimensions)"); TORCH_CHECK(points.numel() > 0, "Empty tensor (points)"); TORCH_CHECK(points.rsize(1) == 3, "points must have shape [B, M, 3] (points must be 3D)"); TORCH_CHECK_TYPE(data.is_floating_point(), "data must have a floating point type"); - TORCH_CHECK_VALUE(data.dim() >= 2, "Expected data to have shape [N, *] (at least 2 dimensions)"); + TORCH_CHECK_VALUE(data.dim() >= 2, + "Expected data to have shape [N, *] (at least 2 dimensions)"); TORCH_CHECK(data.numel() > 0, "Empty tensor (data)"); - TORCH_CHECK(data.size(0) == grid->totalVoxels(), "grid_data must have one value per voxel (shape [N, *]) (wrong first dimension)"); + TORCH_CHECK(data.size(0) == grid->totalVoxels(), + "grid_data must have one value per voxel (shape [N, *]) (wrong first dimension)"); } - namespace fvdb { namespace detail { namespace autograd { -SampleGridTrilinear::variable_list SampleGridTrilinear::forward(SampleGridTrilinear::AutogradContext *ctx, - c10::intrusive_ptr grid, - SampleGridTrilinear::JaggedVariable points, - SampleGridTrilinear::Variable data, - bool returnGrad) { +SampleGridTrilinear::variable_list +SampleGridTrilinear::forward(SampleGridTrilinear::AutogradContext *ctx, + c10::intrusive_ptr grid, + SampleGridTrilinear::JaggedVariable points, + SampleGridTrilinear::Variable data, bool returnGrad) { checkForwardInputs(grid, points, data, returnGrad); auto ret = FVDB_DISPATCH_KERNEL_DEVICE(points.device(), [&]() { @@ -53,59 +52,54 @@ SampleGridTrilinear::variable_list SampleGridTrilinear::forward(SampleGridTrilin }); // Save data for backward in context - ctx->save_for_backward({data, points.jdata(), points.joffsets(), points.jlidx()}); - ctx->saved_data["grid"] = grid; + ctx->save_for_backward({ data, points.jdata(), points.joffsets(), points.jlidx() }); + ctx->saved_data["grid"] = grid; ctx->saved_data["return_grad"] = returnGrad; return ret; } - - - -SampleGridTrilinear::variable_list SampleGridTrilinear::backward(SampleGridTrilinear::AutogradContext *ctx, - SampleGridTrilinear::variable_list grad_output) { - +SampleGridTrilinear::variable_list +SampleGridTrilinear::backward(SampleGridTrilinear::AutogradContext *ctx, + SampleGridTrilinear::variable_list grad_output) { // Use data saved in forward variable_list saved = ctx->get_saved_variables(); - Variable data = saved.at(0); + Variable data = saved.at(0); - Variable pointCoords = saved.at(1); + Variable pointCoords = saved.at(1); Variable pointJOffsets = saved.at(2); - Variable pointsJLidx = saved.at(3); + Variable pointsJLidx = saved.at(3); - auto grid = ctx->saved_data["grid"].toCustomClass(); - bool returnGrad = ctx->saved_data["return_grad"].toBool(); - Variable gradOut = grad_output.at(0); // [B*M, *] + auto grid = ctx->saved_data["grid"].toCustomClass(); + bool returnGrad = ctx->saved_data["return_grad"].toBool(); + Variable gradOut = grad_output.at(0); // [B*M, *] torch::Tensor outGrad = FVDB_DISPATCH_KERNEL_DEVICE(gradOut.device(), [&]() { if (returnGrad) { - Variable gradPtsGrad = grad_output.at(1); // [B*M, -1, 3] + Variable gradPtsGrad = grad_output.at(1); // [B*M, -1, 3] return ops::dispatchSampleGridTrilinearWithGradBackward( - *grid, JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, pointsJLidx), data, gradOut, gradPtsGrad); + *grid, + JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, + pointsJLidx), + data, gradOut, gradPtsGrad); } else { return ops::dispatchSplatIntoGridTrilinear( - *grid, JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, pointsJLidx), gradOut); + *grid, + JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, + pointsJLidx), + gradOut); } }); - return {torch::Tensor(), torch::Tensor(), outGrad, torch::Tensor()}; + return { torch::Tensor(), torch::Tensor(), outGrad, torch::Tensor() }; } - - - - - - - -SampleGridBezier::variable_list SampleGridBezier::forward(SampleGridBezier::AutogradContext *ctx, - c10::intrusive_ptr grid, - SampleGridBezier::JaggedVariable points, - SampleGridBezier::Variable data, - bool returnGrad) { +SampleGridBezier::variable_list +SampleGridBezier::forward(SampleGridBezier::AutogradContext *ctx, + c10::intrusive_ptr grid, + SampleGridBezier::JaggedVariable points, SampleGridBezier::Variable data, + bool returnGrad) { checkForwardInputs(grid, points, data, returnGrad); - std::vector ret = FVDB_DISPATCH_KERNEL_DEVICE(points.device(), [&]() { if (returnGrad) { return ops::dispatchSampleGridBezierWithGrad(*grid, points, data); @@ -115,45 +109,48 @@ SampleGridBezier::variable_list SampleGridBezier::forward(SampleGridBezier::Auto }); // Save data for backward in context - ctx->save_for_backward({data, points.jdata(), points.joffsets(), points.jlidx()}); - ctx->saved_data["grid"] = grid; + ctx->save_for_backward({ data, points.jdata(), points.joffsets(), points.jlidx() }); + ctx->saved_data["grid"] = grid; ctx->saved_data["return_grad"] = returnGrad; return ret; } - -SampleGridBezier::variable_list SampleGridBezier::backward(SampleGridBezier::AutogradContext *ctx, - SampleGridBezier::variable_list grad_output) { - +SampleGridBezier::variable_list +SampleGridBezier::backward(SampleGridBezier::AutogradContext *ctx, + SampleGridBezier::variable_list grad_output) { // Use data saved in forward variable_list saved = ctx->get_saved_variables(); - Variable data = saved.at(0); + Variable data = saved.at(0); - Variable pointCoords = saved.at(1); + Variable pointCoords = saved.at(1); Variable pointJOffsets = saved.at(2); - Variable pointsJLidx = saved.at(3); + Variable pointsJLidx = saved.at(3); - auto grid = ctx->saved_data["grid"].toCustomClass(); - bool returnGrad = ctx->saved_data["return_grad"].toBool(); - Variable gradOut = grad_output.at(0); // [B*M, *] + auto grid = ctx->saved_data["grid"].toCustomClass(); + bool returnGrad = ctx->saved_data["return_grad"].toBool(); + Variable gradOut = grad_output.at(0); // [B*M, *] Variable outGrad = FVDB_DISPATCH_KERNEL_DEVICE(gradOut.device(), [&]() { if (returnGrad) { - Variable gradPtsGrad = grad_output.at(1); // [B*M, -1, 3] + Variable gradPtsGrad = grad_output.at(1); // [B*M, -1, 3] return ops::dispatchSampleGridBezierWithGradBackward( - *grid, JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, pointsJLidx), gradOut, gradPtsGrad, data); + *grid, + JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, + pointsJLidx), + gradOut, gradPtsGrad, data); } else { return ops::dispatchSplatIntoGridBezier( - *grid, JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, pointsJLidx), gradOut); + *grid, + JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, + pointsJLidx), + gradOut); } }); - return {torch::Tensor(), torch::Tensor(), outGrad, torch::Tensor()}; + return { torch::Tensor(), torch::Tensor(), outGrad, torch::Tensor() }; } - - -} // namespace autograd -} // namespace detail -} // namespace fvdb +} // namespace autograd +} // namespace detail +} // namespace fvdb diff --git a/fvdb/src/detail/autograd/SampleGrid.h b/fvdb/src/detail/autograd/SampleGrid.h index 424c6a3702..2e35dc4576 100644 --- a/fvdb/src/detail/autograd/SampleGrid.h +++ b/fvdb/src/detail/autograd/SampleGrid.h @@ -1,51 +1,43 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_SAMPLEGRID_H +#define FVDB_DETAIL_AUTOGRAD_SAMPLEGRID_H #include #include "detail/GridBatchImpl.h" - namespace fvdb { namespace detail { namespace autograd { - struct SampleGridTrilinear : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - using JaggedVariable = JaggedTensor; + using Variable = torch::autograd::Variable; + using JaggedVariable = JaggedTensor; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr grid, - JaggedTensor points, - Variable data, - bool returnGrad = false); + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr grid, + JaggedTensor points, Variable data, bool returnGrad = false); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; - struct SampleGridBezier : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - using JaggedVariable = JaggedTensor; + using Variable = torch::autograd::Variable; + using JaggedVariable = JaggedTensor; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr grid, - JaggedTensor points, - Variable data, - bool returnGrad = false); + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr grid, + JaggedTensor points, Variable data, bool returnGrad = false); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_SAMPLEGRID_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/SparseConvolutionHalo.cpp b/fvdb/src/detail/autograd/SparseConvolutionHalo.cpp index 323cb9e226..64db08db78 100644 --- a/fvdb/src/detail/autograd/SparseConvolutionHalo.cpp +++ b/fvdb/src/detail/autograd/SparseConvolutionHalo.cpp @@ -3,66 +3,71 @@ // #include "SparseConvolutionHalo.h" -#include "detail/ops/convolution/backend/ConvOps.h" - -#include "detail/utils/Utils.h" +#include +#include namespace fvdb { namespace detail { namespace autograd { - -SparseConvolutionHalo::variable_list SparseConvolutionHalo::forward(SparseConvolutionHalo::AutogradContext *ctx, - c10::intrusive_ptr grid, - SparseConvolutionHalo::Variable inFeatures, - SparseConvolutionHalo::Variable kernels, - int variant) { - +SparseConvolutionHalo::variable_list +SparseConvolutionHalo::forward(SparseConvolutionHalo::AutogradContext *ctx, + c10::intrusive_ptr grid, + SparseConvolutionHalo::Variable inFeatures, + SparseConvolutionHalo::Variable kernels, int variant) { // Check kernels TORCH_CHECK_TYPE(kernels.is_floating_point(), "kernels must have a floating point type"); - TORCH_CHECK_VALUE(kernels.dim() == 5, std::string("Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_NOT_IMPLEMENTED(kernels.size(2) == kernels.size(3) && kernels.size(3) == kernels.size(4) && kernels.size(2) == 3, - "sparse_conv_halo only supports kernels of size 3x3x3"); + TORCH_CHECK_VALUE( + kernels.dim() == 5, + std::string( + "Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + + std::to_string(kernels.dim()) + " dimensions"); + TORCH_CHECK_NOT_IMPLEMENTED(kernels.size(2) == kernels.size(3) && + kernels.size(3) == kernels.size(4) && kernels.size(2) == 3, + "sparse_conv_halo only supports kernels of size 3x3x3"); // Check features TORCH_CHECK_VALUE(inFeatures.is_contiguous(), "features must be contiguous"); TORCH_CHECK_TYPE(inFeatures.is_floating_point(), "features must have a floating point type"); - TORCH_CHECK_VALUE(inFeatures.dim() == 2, std::string("Expected features to have 2 dimensions (shape (n, nF)) but got ") + - std::to_string(inFeatures.dim()) + " dimensions"); - TORCH_CHECK_VALUE(kernels.size(1) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); + TORCH_CHECK_VALUE( + inFeatures.dim() == 2, + std::string("Expected features to have 2 dimensions (shape (n, nF)) but got ") + + std::to_string(inFeatures.dim()) + " dimensions"); + TORCH_CHECK_VALUE( + kernels.size(1) == inFeatures.size(1), + "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + + ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); // [O, I, 3, 3, 3] to [3, 3, 3, I, O] - kernels = kernels.permute({4, 3, 2, 1, 0}).contiguous(); + kernels = kernels.permute({ 4, 3, 2, 1, 0 }).contiguous(); torch::Tensor outFeatures = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { return ops::dispatchSparseConvolutionHalo(*grid, inFeatures, kernels, variant); }); // Save data for backward in context - ctx->save_for_backward({inFeatures, kernels}); - ctx->saved_data["grid"] = grid; + ctx->save_for_backward({ inFeatures, kernels }); + ctx->saved_data["grid"] = grid; ctx->saved_data["variant"] = variant; - return variable_list({outFeatures}); + return variable_list({ outFeatures }); } - -SparseConvolutionHalo::variable_list SparseConvolutionHalo::backward(AutogradContext *ctx, variable_list grad_output) { - +SparseConvolutionHalo::variable_list +SparseConvolutionHalo::backward(AutogradContext *ctx, variable_list grad_output) { variable_list saved = ctx->get_saved_variables(); - TORCH_CHECK(saved.size() > 0, "No backward context computed during forward. Please pass in training=True when calling kmap.build_implicit_gemm()"); - auto grid = ctx->saved_data["grid"].toCustomClass(); - int variant = ctx->saved_data["variant"].toInt(); + TORCH_CHECK( + saved.size() > 0, + "No backward context computed during forward. Please pass in training=True when calling kmap.build_implicit_gemm()"); + auto grid = ctx->saved_data["grid"].toCustomClass(); + int variant = ctx->saved_data["variant"].toInt(); Variable inFeatures = saved.at(0); - Variable kernels = saved.at(1); // [3, 3, 3, I, O] - Variable gradOut = grad_output.at(0); + Variable kernels = saved.at(1); // [3, 3, 3, I, O] + Variable gradOut = grad_output.at(0); - kernels = kernels.permute({0, 1, 2, 4, 3}).flip({0, 1, 2}).contiguous(); // [3, 3, 3, O, I] - torch::Tensor gradInput = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { + kernels = kernels.permute({ 0, 1, 2, 4, 3 }).flip({ 0, 1, 2 }).contiguous(); // [3, 3, 3, O, I] + torch::Tensor gradInput = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { return ops::dispatchSparseConvolutionHalo(*grid, gradOut, kernels, variant); }); torch::Tensor gradKernel = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { @@ -70,12 +75,11 @@ SparseConvolutionHalo::variable_list SparseConvolutionHalo::backward(AutogradCon }); // [3, 3, 3, I, O] to [O, I, 3, 3, 3] - gradKernel = gradKernel.permute({4, 3, 2, 1, 0}).contiguous(); + gradKernel = gradKernel.permute({ 4, 3, 2, 1, 0 }).contiguous(); - return {torch::Tensor(), gradInput, gradKernel, torch::Tensor()}; + return { torch::Tensor(), gradInput, gradKernel, torch::Tensor() }; } - } // namespace autograd } // namespace detail } // namespace fvdb \ No newline at end of file diff --git a/fvdb/src/detail/autograd/SparseConvolutionHalo.h b/fvdb/src/detail/autograd/SparseConvolutionHalo.h index 5d5cb2e1a5..eb577a900f 100644 --- a/fvdb/src/detail/autograd/SparseConvolutionHalo.h +++ b/fvdb/src/detail/autograd/SparseConvolutionHalo.h @@ -1,29 +1,25 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_SPARSECONVOLUTIONHALO_H +#define FVDB_DETAIL_AUTOGRAD_SPARSECONVOLUTIONHALO_H -#include - -#include "detail/ops/Ops.h" - -#include "SparseConvPackInfo.h" +#include +#include +#include namespace fvdb { namespace detail { namespace autograd { struct SparseConvolutionHalo : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr grid, - Variable inFeatures, - Variable kernels, - int variant); + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr grid, + Variable inFeatures, Variable kernels, int variant); static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; @@ -31,3 +27,5 @@ struct SparseConvolutionHalo : public torch::autograd::Function - -#include "detail/ops/Ops.h" - -#include "SparseConvPackInfo.h" +#include +#include +#include namespace fvdb { namespace detail { namespace autograd { -struct SparseConvolutionImplicitGEMM : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; +struct SparseConvolutionImplicitGEMM + : public torch::autograd::Function { + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - Variable inFeatures, Variable kernels, - SparseConvPackInfo packInfo, - bool transposed) { + static variable_list + forward(AutogradContext *ctx, Variable inFeatures, Variable kernels, + SparseConvPackInfo packInfo, bool transposed) { if (transposed) { packInfo = packInfo.transposed(); } @@ -30,155 +29,178 @@ struct SparseConvolutionImplicitGEMM : public torch::autograd::Function sizes = { (int) packInfo.sourceGrid().total_voxels(), (int) packInfo.targetGrid().total_voxels() }; - TORCH_CHECK(packInfo.sourceGrid().is_mutable() == packInfo.targetGrid().is_mutable(), "Source and target grids must both be mutable or immutable"); + const std::vector sizes = { (int)packInfo.sourceGrid().total_voxels(), + (int)packInfo.targetGrid().total_voxels() }; + TORCH_CHECK(packInfo.sourceGrid().is_mutable() == packInfo.targetGrid().is_mutable(), + "Source and target grids must both be mutable or immutable"); // Check features and kernels TORCH_CHECK_VALUE(inFeatures.is_contiguous(), "features must be contiguous"); - TORCH_CHECK_TYPE(inFeatures.is_floating_point(), "features must have a floating point type"); - TORCH_CHECK_VALUE(inFeatures.dim() == 2, std::string("Expected features to have 2 dimensions (shape (n, nF)) but got ") + - std::to_string(inFeatures.dim()) + " dimensions"); + TORCH_CHECK_TYPE(inFeatures.is_floating_point(), + "features must have a floating point type"); + TORCH_CHECK_VALUE( + inFeatures.dim() == 2, + std::string("Expected features to have 2 dimensions (shape (n, nF)) but got ") + + std::to_string(inFeatures.dim()) + " dimensions"); TORCH_CHECK_TYPE(kernels.is_floating_point(), "kernels must have a floating point type"); for (int i = 0; i < kernels.dim(); i += 1) { - TORCH_CHECK_VALUE(kernels.size(i) != 0, "kernels tensor has zero dimension (dim = " + std::to_string(i) + ")"); + TORCH_CHECK_VALUE(kernels.size(i) != 0, "kernels tensor has zero dimension (dim = " + + std::to_string(i) + ")"); } - auto opt = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); - torch::Tensor kWidth = torch::empty({3,}, opt); + auto opt = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + torch::Tensor kWidth = torch::empty( + { + 3, + }, + opt); int inC, outC; if (!transposed) { - TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[0], "The number of input features must match the number of voxels"); - TORCH_CHECK_VALUE(kernels.dim() == 5, std::string("Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_VALUE(kernels.size(1) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); - outC = kernels.size(0); inC = kernels.size(1); + TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[0], + "The number of input features must match the number of voxels"); + TORCH_CHECK_VALUE( + kernels.dim() == 5, + std::string( + "Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + + std::to_string(kernels.dim()) + " dimensions"); + TORCH_CHECK_VALUE( + kernels.size(1) == inFeatures.size(1), + "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + + ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); + outC = kernels.size(0); + inC = kernels.size(1); kWidth[0] = kernels.size(2); kWidth[1] = kernels.size(3); kWidth[2] = kernels.size(4); - kernels = kernels.permute({4, 3, 2, 1, 0}).reshape({-1, inC, outC}).contiguous(); + kernels = kernels.permute({ 4, 3, 2, 1, 0 }).reshape({ -1, inC, outC }).contiguous(); } else { - TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[1], "The number of input features must match the number of voxels"); - TORCH_CHECK_VALUE(kernels.dim() == 5, std::string("Expected kernels to have 5 dimensions (shape (in_ch, out_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_VALUE(kernels.size(0) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(0)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); - int inC = kernels.size(0); outC = kernels.size(1); + TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[1], + "The number of input features must match the number of voxels"); + TORCH_CHECK_VALUE( + kernels.dim() == 5, + std::string( + "Expected kernels to have 5 dimensions (shape (in_ch, out_ch, d, h, w)) but got ") + + std::to_string(kernels.dim()) + " dimensions"); + TORCH_CHECK_VALUE( + kernels.size(0) == inFeatures.size(1), + "Expected input channels of kernels (" + std::to_string(kernels.size(0)) + + ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); + int inC = kernels.size(0); + outC = kernels.size(1); kWidth[0] = kernels.size(2); kWidth[1] = kernels.size(3); kWidth[2] = kernels.size(4); - kernels = kernels.permute({4, 3, 2, 0, 1}).reshape({-1, inC, outC}).contiguous(); + kernels = kernels.permute({ 4, 3, 2, 0, 1 }).reshape({ -1, inC, outC }).contiguous(); } torch::Tensor output; if (packInfo.targetGrid().total_voxels() > 0) { int outFeats = transposed ? sizes[0] : sizes[1]; // Emprically larger kernels do not work right now, default to non-sorted version. - bool canSort = !transposed && (packInfo.kernelSize().value() < fvdb::Vec3iOrScalar(4).value()); + bool canSort = + !transposed && (packInfo.kernelSize().value() < fvdb::Vec3iOrScalar(4).value()); if (packInfo.reoderOutInMap().has_value() && canSort) { output = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { return ops::dispatchSparseConvolutionImplicitGEMMSorted( - inFeatures, kernels, - packInfo.reoderOutInMap().value(), - packInfo.reducedSortedMask().value(), - packInfo.reorderLoc().value(), - outFeats, outC, useTF32, true); + inFeatures, kernels, packInfo.reoderOutInMap().value(), + packInfo.reducedSortedMask().value(), packInfo.reorderLoc().value(), + outFeats, outC, useTF32, true); }); } else { output = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { return ops::dispatchSparseConvolutionImplicitGEMM( - inFeatures, kernels, outInMap, outFeats, outC, useTF32, true); + inFeatures, kernels, outInMap, outFeats, outC, useTF32, true); }); } } else { auto opt = torch::TensorOptions().dtype(inFeatures.dtype()).device(inFeatures.device()); - output = torch::empty({0, kernels.size(-1)}, opt); + output = torch::empty({ 0, kernels.size(-1) }, opt); } // Save for backward (for training mode) if (packInfo.outInMapBwd().has_value()) { - ctx->save_for_backward({inFeatures, kernels, - packInfo.outInMapBwd().value(), - packInfo.reorderOutInMapBwd().value(), - packInfo.sortedMaskBwdW().value(), - packInfo.sortedMaskBwdD().value(), - packInfo.reorderLocBwd().value()}); + ctx->save_for_backward( + { inFeatures, kernels, packInfo.outInMapBwd().value(), + packInfo.reorderOutInMapBwd().value(), packInfo.sortedMaskBwdW().value(), + packInfo.sortedMaskBwdD().value(), packInfo.reorderLocBwd().value() }); } - ctx->saved_data["use_tf32"] = useTF32; + ctx->saved_data["use_tf32"] = useTF32; ctx->saved_data["kernel_width"] = kWidth; - ctx->saved_data["transposed"] = transposed; + ctx->saved_data["transposed"] = transposed; - return {output}; + return { output }; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list + backward(AutogradContext *ctx, variable_list grad_output) { variable_list saved = ctx->get_saved_variables(); - TORCH_CHECK(saved.size() > 0, "No backward context computed during forward. Please pass in training=True when calling kmap.build_implicit_gemm()"); - - Variable inFeatures = saved.at(0); - Variable kernels = saved.at(1); - Variable outInMapBwd = saved.at(2); - Variable reorderOutInMapBwd = saved.at(3); - Variable sortedMaskBwdW = saved.at(4); - Variable sortedMaskBwdD = saved.at(5); - Variable reorderLocBwd = saved.at(6); - bool useTF32 = ctx->saved_data["use_tf32"].toBool(); - torch::Tensor kWidth = ctx->saved_data["kernel_width"].toTensor(); - bool transposed = ctx->saved_data["transposed"].toBool(); - - Variable gradOut = grad_output.at(0); + TORCH_CHECK( + saved.size() > 0, + "No backward context computed during forward. Please pass in training=True when calling kmap.build_implicit_gemm()"); + + Variable inFeatures = saved.at(0); + Variable kernels = saved.at(1); + Variable outInMapBwd = saved.at(2); + Variable reorderOutInMapBwd = saved.at(3); + Variable sortedMaskBwdW = saved.at(4); + Variable sortedMaskBwdD = saved.at(5); + Variable reorderLocBwd = saved.at(6); + bool useTF32 = ctx->saved_data["use_tf32"].toBool(); + torch::Tensor kWidth = ctx->saved_data["kernel_width"].toTensor(); + bool transposed = ctx->saved_data["transposed"].toBool(); + + Variable gradOut = grad_output.at(0); torch::Tensor gradInput, gradWeight; // Dispatching following torchsparse++ int kernelVolume = kernels.size(0); - int inC = kernels.size(1); - int outC = kernels.size(2); + int inC = kernels.size(1); + int outC = kernels.size(2); if (kernelVolume < 32) { - gradInput = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { + gradInput = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { return ops::dispatchSparseConvolutionImplicitGEMMSorted( - gradOut, kernels.transpose(2, 1).contiguous(), - reorderOutInMapBwd, - sortedMaskBwdD, - reorderLocBwd, - inFeatures.size(0), inC, useTF32, true); + gradOut, kernels.transpose(2, 1).contiguous(), reorderOutInMapBwd, + sortedMaskBwdD, reorderLocBwd, inFeatures.size(0), inC, useTF32, true); }); gradWeight = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { return ops::dispatchSparseConvolutionImplicitGEMMGradSorted( - gradOut, inFeatures, reorderOutInMapBwd, - sortedMaskBwdW, reorderLocBwd, 32, - useTF32, true); + gradOut, inFeatures, reorderOutInMapBwd, sortedMaskBwdW, reorderLocBwd, 32, + useTF32, true); }); } else { - gradInput = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { + gradInput = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { return ops::dispatchSparseConvolutionImplicitGEMM( - gradOut, kernels.transpose(2, 1).contiguous(), - outInMapBwd, - inFeatures.size(0), inC, useTF32, true); + gradOut, kernels.transpose(2, 1).contiguous(), outInMapBwd, inFeatures.size(0), + inC, useTF32, true); }); gradWeight = FVDB_DISPATCH_KERNEL_DEVICE(inFeatures.device(), [&]() { return ops::dispatchSparseConvolutionImplicitGEMMGrad( - gradOut, inFeatures, outInMapBwd, 32, - useTF32, true); + gradOut, inFeatures, outInMapBwd, 32, useTF32, true); }); } if (!transposed) { - gradWeight = gradWeight.reshape( - {kWidth[2].item(), kWidth[1].item(), kWidth[0].item(), outC, inC}).permute({3, 4, 2, 1, 0}); + gradWeight = gradWeight + .reshape({ kWidth[2].item(), kWidth[1].item(), + kWidth[0].item(), outC, inC }) + .permute({ 3, 4, 2, 1, 0 }); } else { - gradWeight = gradWeight.reshape( - {kWidth[2].item(), kWidth[1].item(), kWidth[0].item(), outC, inC}).permute({4, 3, 2, 1, 0}); + gradWeight = gradWeight + .reshape({ kWidth[2].item(), kWidth[1].item(), + kWidth[0].item(), outC, inC }) + .permute({ 4, 3, 2, 1, 0 }); } - return {gradInput, gradWeight, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()}; + return { gradInput, gradWeight, torch::Tensor(), + torch::Tensor(), torch::Tensor(), torch::Tensor() }; } }; } // namespace autograd } // namespace detail } // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_SPARSECONVOLUTIONIMPLICITGEMM_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/SparseConvolutionKernelMap.h b/fvdb/src/detail/autograd/SparseConvolutionKernelMap.h index c9303b467e..ae68311051 100644 --- a/fvdb/src/detail/autograd/SparseConvolutionKernelMap.h +++ b/fvdb/src/detail/autograd/SparseConvolutionKernelMap.h @@ -1,96 +1,116 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_SPARSECONVOLUTIONKERNELMAP_H +#define FVDB_DETAIL_AUTOGRAD_SPARSECONVOLUTIONKERNELMAP_H -#include - -#include "detail/ops/convolution/backend/ConvOps.h" - -#include "SparseConvPackInfo.h" +#include +#include +#include namespace fvdb { namespace detail { namespace autograd { struct SparseConvolutionKernelMap : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - - static variable_list forward(AutogradContext *ctx, - Variable inFeatures, Variable kernels, - const SparseConvPackInfo& packInfo, - bool transposed) { + using Variable = torch::autograd::Variable; + static variable_list + forward(AutogradContext *ctx, Variable inFeatures, Variable kernels, + const SparseConvPackInfo &packInfo, bool transposed) { TORCH_CHECK(packInfo.neighborMap().has_value() && packInfo.neighborSizes().has_value(), "Neighbor map must be built for sparse convolution"); - torch::Tensor nbmaps = packInfo.neighborMap().value(); - torch::Tensor nbsizes = packInfo.neighborSizes().value(); - const std::vector sizes = { (int) packInfo.sourceGrid().total_voxels(), (int) packInfo.targetGrid().total_voxels() }; - const bool middleAcceleration = !(packInfo.sourceGrid().is_mutable() && packInfo.targetGrid().is_mutable()) && \ - packInfo.stride().value() == Vec3iOrScalar(1).value(); + torch::Tensor nbmaps = packInfo.neighborMap().value(); + torch::Tensor nbsizes = packInfo.neighborSizes().value(); + const std::vector sizes = { (int)packInfo.sourceGrid().total_voxels(), + (int)packInfo.targetGrid().total_voxels() }; + const bool middleAcceleration = + !(packInfo.sourceGrid().is_mutable() && packInfo.targetGrid().is_mutable()) && + packInfo.stride().value() == Vec3iOrScalar(1).value(); - TORCH_CHECK(packInfo.sourceGrid().is_mutable() == packInfo.targetGrid().is_mutable(), "Source and target grids must both be mutable or immutable"); + TORCH_CHECK(packInfo.sourceGrid().is_mutable() == packInfo.targetGrid().is_mutable(), + "Source and target grids must both be mutable or immutable"); // Check features TORCH_CHECK_VALUE(inFeatures.is_contiguous(), "features must be contiguous"); - TORCH_CHECK_TYPE(inFeatures.is_floating_point(), "features must have a floating point type"); - TORCH_CHECK_VALUE(inFeatures.dim() == 2, std::string("Expected features to have 2 dimensions (shape (n, nF)) but got ") + - std::to_string(inFeatures.dim()) + " dimensions"); + TORCH_CHECK_TYPE(inFeatures.is_floating_point(), + "features must have a floating point type"); + TORCH_CHECK_VALUE( + inFeatures.dim() == 2, + std::string("Expected features to have 2 dimensions (shape (n, nF)) but got ") + + std::to_string(inFeatures.dim()) + " dimensions"); // Check kernels TORCH_CHECK_TYPE(kernels.is_floating_point(), "kernels must have a floating point type"); for (int i = 0; i < kernels.dim(); i += 1) { - TORCH_CHECK_VALUE(kernels.size(i) != 0, "kernels tensor has zero dimension (dim = " + std::to_string(i) + ")"); + TORCH_CHECK_VALUE(kernels.size(i) != 0, "kernels tensor has zero dimension (dim = " + + std::to_string(i) + ")"); } // Check pack info - TORCH_CHECK(nbmaps.is_contiguous() && nbmaps.scalar_type() == torch::kInt32, "nbmaps must be contiguous"); - TORCH_CHECK(nbsizes.is_contiguous() && nbsizes.scalar_type() == torch::kInt32, "nbsizes must be contiguous"); - - auto opt = torch::TensorOptions().dtype(torch::kInt32).device(inFeatures.device()); - torch::Tensor kWidth = torch::empty({3,}, opt); + TORCH_CHECK(nbmaps.is_contiguous() && nbmaps.scalar_type() == torch::kInt32, + "nbmaps must be contiguous"); + TORCH_CHECK(nbsizes.is_contiguous() && nbsizes.scalar_type() == torch::kInt32, + "nbsizes must be contiguous"); + + auto opt = torch::TensorOptions().dtype(torch::kInt32).device(inFeatures.device()); + torch::Tensor kWidth = torch::empty( + { + 3, + }, + opt); if (!transposed) { - TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[0], "The number of input features must match the number of voxels"); - TORCH_CHECK_VALUE(kernels.dim() == 5, std::string("Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_VALUE(kernels.size(1) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); + TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[0], + "The number of input features must match the number of voxels"); + TORCH_CHECK_VALUE( + kernels.dim() == 5, + std::string( + "Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + + std::to_string(kernels.dim()) + " dimensions"); + TORCH_CHECK_VALUE( + kernels.size(1) == inFeatures.size(1), + "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + + ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); const int outC = kernels.size(0), inC = kernels.size(1); kWidth[0] = kernels.size(2); kWidth[1] = kernels.size(3); kWidth[2] = kernels.size(4); - kernels = kernels.permute({4, 3, 2, 1, 0}).reshape({-1, inC, outC}).contiguous(); + kernels = kernels.permute({ 4, 3, 2, 1, 0 }).reshape({ -1, inC, outC }).contiguous(); } else { - TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[1], "The number of input features must match the number of voxels"); - TORCH_CHECK_VALUE(kernels.dim() == 5, std::string("Expected kernels to have 5 dimensions (shape (in_ch, out_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_VALUE(kernels.size(0) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(0)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); + TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[1], + "The number of input features must match the number of voxels"); + TORCH_CHECK_VALUE( + kernels.dim() == 5, + std::string( + "Expected kernels to have 5 dimensions (shape (in_ch, out_ch, d, h, w)) but got ") + + std::to_string(kernels.dim()) + " dimensions"); + TORCH_CHECK_VALUE( + kernels.size(0) == inFeatures.size(1), + "Expected input channels of kernels (" + std::to_string(kernels.size(0)) + + ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); const int inC = kernels.size(0), outC = kernels.size(1); kWidth[0] = kernels.size(2); kWidth[1] = kernels.size(3); kWidth[2] = kernels.size(4); - kernels = kernels.permute({4, 3, 2, 0, 1}).reshape({-1, inC, outC}).contiguous(); + kernels = kernels.permute({ 4, 3, 2, 0, 1 }).reshape({ -1, inC, outC }).contiguous(); } // Save for backward - ctx->save_for_backward({inFeatures, kernels, nbmaps, nbsizes}); - ctx->saved_data["transposed"] = transposed; + ctx->save_for_backward({ inFeatures, kernels, nbmaps, nbsizes }); + ctx->saved_data["transposed"] = transposed; ctx->saved_data["kernel_width"] = kWidth; - ctx->saved_data["use_me"] = packInfo.useME(); + ctx->saved_data["use_me"] = packInfo.useME(); torch::Tensor output; if (packInfo.targetGrid().total_voxels() > 0) { auto opt = torch::TensorOptions().dtype(inFeatures.dtype()).device(inFeatures.device()); if (!transposed) { - output = torch::zeros({sizes[1], kernels.size(-1)}, opt); + output = torch::zeros({ sizes[1], kernels.size(-1) }, opt); } else { - output = torch::zeros({sizes[0], kernels.size(-1)}, opt); + output = torch::zeros({ sizes[0], kernels.size(-1) }, opt); } // NOTE: Francis: We need .cpu().contiguous() here because we copied the convolution // implementation from torch_sparse which runs std::max_element on a pointer @@ -98,29 +118,30 @@ struct SparseConvolutionKernelMap : public torch::autograd::Function( - inFeatures, output, kernels, nbmaps, - nbsizes.cpu().contiguous(), transposed, middleAcceleration); + inFeatures, output, kernels, nbmaps, nbsizes.cpu().contiguous(), transposed, + middleAcceleration); }); } else { auto opt = torch::TensorOptions().dtype(inFeatures.dtype()).device(inFeatures.device()); - output = torch::empty({0, kernels.size(-1)}, opt); + output = torch::empty({ 0, kernels.size(-1) }, opt); } - return {output}; + return { output }; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list + backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward - variable_list saved = ctx->get_saved_variables(); - Variable inFeatures = saved.at(0); - Variable kernels = saved.at(1); - Variable nbmaps = saved.at(2); - Variable nbsizes = saved.at(3); - bool transposed = ctx->saved_data["transposed"].toBool(); - torch::Tensor kWidth = ctx->saved_data["kernel_width"].toTensor(); - bool use_me = ctx->saved_data["use_me"].toBool(); - - torch::Tensor gradInput = torch::zeros_like(inFeatures); + variable_list saved = ctx->get_saved_variables(); + Variable inFeatures = saved.at(0); + Variable kernels = saved.at(1); + Variable nbmaps = saved.at(2); + Variable nbsizes = saved.at(3); + bool transposed = ctx->saved_data["transposed"].toBool(); + torch::Tensor kWidth = ctx->saved_data["kernel_width"].toTensor(); + bool use_me = ctx->saved_data["use_me"].toBool(); + + torch::Tensor gradInput = torch::zeros_like(inFeatures); torch::Tensor gradWeight = torch::zeros_like(kernels); Variable gradOut = grad_output.at(0); @@ -141,14 +162,23 @@ struct SparseConvolutionKernelMap : public torch::autograd::Function(), kWidth[1].item(), kWidth[0].item(), inC, outC}).permute({4, 3, 2, 1, 0}); + gradWeight = gradWeight + .reshape({ kWidth[2].item(), kWidth[1].item(), + kWidth[0].item(), inC, outC }) + .permute({ 4, 3, 2, 1, 0 }); } else { - gradWeight = gradWeight.reshape({kWidth[2].item(), kWidth[1].item(), kWidth[0].item(), inC, outC}).permute({3, 4, 2, 1, 0}); + gradWeight = gradWeight + .reshape({ kWidth[2].item(), kWidth[1].item(), + kWidth[0].item(), inC, outC }) + .permute({ 3, 4, 2, 1, 0 }); } - return {gradInput, gradWeight, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()}; + return { gradInput, gradWeight, torch::Tensor(), + torch::Tensor(), torch::Tensor(), torch::Tensor() }; } }; } // namespace autograd } // namespace detail } // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_SPARSECONVOLUTIONKERNELMAP_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/SplatIntoGrid.cpp b/fvdb/src/detail/autograd/SplatIntoGrid.cpp index c73cd277a3..10b101de03 100644 --- a/fvdb/src/detail/autograd/SplatIntoGrid.cpp +++ b/fvdb/src/detail/autograd/SplatIntoGrid.cpp @@ -3,42 +3,45 @@ // #include "SplatIntoGrid.h" -#include "detail/ops/Ops.h" +#include +#include -#include "detail/utils/Utils.h" - - -void checkForwardInputs(c10::intrusive_ptr grid, - fvdb::detail::autograd::SplatIntoGridTrilinear::JaggedVariable points, - fvdb::detail::autograd::SplatIntoGridTrilinear::Variable data) { +void +checkForwardInputs(c10::intrusive_ptr grid, + fvdb::detail::autograd::SplatIntoGridTrilinear::JaggedVariable points, + fvdb::detail::autograd::SplatIntoGridTrilinear::Variable data) { grid->checkNonEmptyGrid(); - TORCH_CHECK_VALUE(points.device() == data.device(), "points and data must be on the same device"); + TORCH_CHECK_VALUE(points.device() == data.device(), + "points and data must be on the same device"); grid->checkDevice(points); grid->checkDevice(data); points.check_valid(); TORCH_CHECK_TYPE(points.is_floating_point(), "points must have a floating point type"); TORCH_CHECK_TYPE(points.dtype() == data.dtype(), "all tensors must have the same type"); - TORCH_CHECK_VALUE(points.rdim() == 2, "Expected points to have shape [B*M, 3] (wrong number of dimensions)"); + TORCH_CHECK_VALUE(points.rdim() == 2, + "Expected points to have shape [B*M, 3] (wrong number of dimensions)"); TORCH_CHECK(points.numel() > 0, "Empty tensor (points)"); TORCH_CHECK(points.rsize(1) == 3, "points must have shape [B*M, 3] (points must be 3D)"); TORCH_CHECK_TYPE(data.is_floating_point(), "point_data must have a floating point type"); - TORCH_CHECK_VALUE(data.dim() >= 2, "Expected data to have shape [B*M, *] (at least 3 dimensions)"); + TORCH_CHECK_VALUE(data.dim() >= 2, + "Expected data to have shape [B*M, *] (at least 3 dimensions)"); TORCH_CHECK(data.numel() > 0, "Empty tensor (data)"); - TORCH_CHECK(data.size(0) == points.rsize(0), "point_data must have one value per point (shape [B*M, *]) (incorrect first dimension must match number of points)"); + TORCH_CHECK( + data.size(0) == points.rsize(0), + "point_data must have one value per point (shape [B*M, *]) (incorrect first dimension must match number of points)"); } namespace fvdb { namespace detail { namespace autograd { - -SplatIntoGridTrilinear::variable_list SplatIntoGridTrilinear::forward(SplatIntoGridTrilinear::AutogradContext *ctx, - c10::intrusive_ptr grid, - SplatIntoGridTrilinear::JaggedVariable points, - SplatIntoGridTrilinear::Variable pointData) { - +SplatIntoGridTrilinear::variable_list +SplatIntoGridTrilinear::forward(SplatIntoGridTrilinear::AutogradContext *ctx, + c10::intrusive_ptr grid, + SplatIntoGridTrilinear::JaggedVariable points, + SplatIntoGridTrilinear::Variable pointData) { checkForwardInputs(grid, points, pointData); torch::Tensor outGridData = FVDB_DISPATCH_KERNEL_DEVICE(points.device(), [&]() { @@ -46,42 +49,41 @@ SplatIntoGridTrilinear::variable_list SplatIntoGridTrilinear::forward(SplatIntoG }); // Save data for backward in context - ctx->save_for_backward({pointData, points.jdata(), points.joffsets(), points.jlidx()}); + ctx->save_for_backward({ pointData, points.jdata(), points.joffsets(), points.jlidx() }); ctx->saved_data["grid"] = grid; // int64_t numOutputValues = grid->totalVoxels(); - return variable_list({outGridData}); + return variable_list({ outGridData }); } -SplatIntoGridTrilinear::variable_list SplatIntoGridTrilinear::backward(SplatIntoGridTrilinear::AutogradContext *ctx, - SplatIntoGridTrilinear::variable_list grad_output) { - +SplatIntoGridTrilinear::variable_list +SplatIntoGridTrilinear::backward(SplatIntoGridTrilinear::AutogradContext *ctx, + SplatIntoGridTrilinear::variable_list grad_output) { // Use data saved in forward - variable_list saved = ctx->get_saved_variables(); - Variable pointData = saved.at(0); // [B*M, *] + variable_list saved = ctx->get_saved_variables(); + Variable pointData = saved.at(0); // [B*M, *] - Variable pointCoords = saved.at(1); // [B*M, 3] - Variable pointJOffsets = saved.at(2); // [B,] - Variable pointsJLidx = saved.at(3); // [B,] - auto grid = ctx->saved_data["grid"].toCustomClass(); - Variable gradOut = grad_output.at(0); // [N, *] + Variable pointCoords = saved.at(1); // [B*M, 3] + Variable pointJOffsets = saved.at(2); // [B,] + Variable pointsJLidx = saved.at(3); // [B,] + auto grid = ctx->saved_data["grid"].toCustomClass(); + Variable gradOut = grad_output.at(0); // [N, *] auto ret = FVDB_DISPATCH_KERNEL_DEVICE(gradOut.device(), [&]() { return ops::dispatchSampleGridTrilinear( - *grid, JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, pointsJLidx), gradOut); + *grid, + JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, pointsJLidx), + gradOut); }); - return {torch::Tensor(), torch::Tensor(), ret[0]}; + return { torch::Tensor(), torch::Tensor(), ret[0] }; } - - - -SplatIntoGridBezier::variable_list SplatIntoGridBezier::forward(SplatIntoGridBezier::AutogradContext *ctx, - c10::intrusive_ptr grid, - SplatIntoGridBezier::JaggedVariable points, - SplatIntoGridBezier::Variable pointData) { - +SplatIntoGridBezier::variable_list +SplatIntoGridBezier::forward(SplatIntoGridBezier::AutogradContext *ctx, + c10::intrusive_ptr grid, + SplatIntoGridBezier::JaggedVariable points, + SplatIntoGridBezier::Variable pointData) { checkForwardInputs(grid, points, pointData); torch::Tensor outGridData = FVDB_DISPATCH_KERNEL_DEVICE(points.device(), [&]() { @@ -89,32 +91,34 @@ SplatIntoGridBezier::variable_list SplatIntoGridBezier::forward(SplatIntoGridBez }); // Save data for backward in context - ctx->save_for_backward({pointData, points.jdata(), points.joffsets(), points.jlidx()}); + ctx->save_for_backward({ pointData, points.jdata(), points.joffsets(), points.jlidx() }); ctx->saved_data["grid"] = grid; - return variable_list({outGridData}); + return variable_list({ outGridData }); } -SplatIntoGridBezier::variable_list SplatIntoGridBezier::backward(SplatIntoGridBezier::AutogradContext *ctx, - SplatIntoGridBezier::variable_list grad_output) { - +SplatIntoGridBezier::variable_list +SplatIntoGridBezier::backward(SplatIntoGridBezier::AutogradContext *ctx, + SplatIntoGridBezier::variable_list grad_output) { // Use data saved in forward - variable_list saved = ctx->get_saved_variables(); - Variable pointData = saved.at(0); // [B*M, *] + variable_list saved = ctx->get_saved_variables(); + Variable pointData = saved.at(0); // [B*M, *] - Variable pointCoords = saved.at(1); // [B*M, 3] + Variable pointCoords = saved.at(1); // [B*M, 3] Variable pointJOffsets = saved.at(2); // [B,] - Variable pointsJLidx = saved.at(3); // [B,] + Variable pointsJLidx = saved.at(3); // [B,] - auto grid = ctx->saved_data["grid"].toCustomClass(); + auto grid = ctx->saved_data["grid"].toCustomClass(); Variable gradOut = grad_output.at(0); // [N, *] torch::Tensor outGradIn = FVDB_DISPATCH_KERNEL_DEVICE(gradOut.device(), [&]() { return ops::dispatchSampleGridBezier( - *grid, JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, pointsJLidx), gradOut)[0]; + *grid, + JaggedTensor::from_data_offsets_and_list_ids(pointCoords, pointJOffsets, pointsJLidx), + gradOut)[0]; }); - return {torch::Tensor(), torch::Tensor(), outGradIn}; + return { torch::Tensor(), torch::Tensor(), outGradIn }; } } // namespace autograd diff --git a/fvdb/src/detail/autograd/SplatIntoGrid.h b/fvdb/src/detail/autograd/SplatIntoGrid.h index a0e183178e..7074f28031 100644 --- a/fvdb/src/detail/autograd/SplatIntoGrid.h +++ b/fvdb/src/detail/autograd/SplatIntoGrid.h @@ -1,48 +1,43 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_SPLATINTOGRID_H +#define FVDB_DETAIL_AUTOGRAD_SPLATINTOGRID_H #include #include "detail/GridBatchImpl.h" - namespace fvdb { namespace detail { namespace autograd { struct SplatIntoGridTrilinear : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - using JaggedVariable = JaggedTensor; + using Variable = torch::autograd::Variable; + using JaggedVariable = JaggedTensor; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr grid, - JaggedTensor points, - Variable pointData); + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr grid, + JaggedTensor points, Variable pointData); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; - struct SplatIntoGridBezier : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - using JaggedVariable = JaggedTensor; + using Variable = torch::autograd::Variable; + using JaggedVariable = JaggedTensor; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr grid, - JaggedVariable points, - Variable pointData); + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr grid, + JaggedVariable points, Variable pointData); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_SPLATINTOGRID_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/TransformPoints.cpp b/fvdb/src/detail/autograd/TransformPoints.cpp index 90ca2713ea..762936faa4 100644 --- a/fvdb/src/detail/autograd/TransformPoints.cpp +++ b/fvdb/src/detail/autograd/TransformPoints.cpp @@ -3,27 +3,22 @@ // #include "TransformPoints.h" -#include +#include +#include #include -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" - - +#include namespace fvdb { namespace detail { namespace autograd { - -TransformPoints::variable_list TransformPoints::forward(TransformPoints::AutogradContext *ctx, - c10::intrusive_ptr grid, - TransformPoints::JaggedVariable points, - Variable pointsData, - bool isInverse, - bool isDual) { - +TransformPoints::variable_list +TransformPoints::forward(TransformPoints::AutogradContext *ctx, + c10::intrusive_ptr grid, + TransformPoints::JaggedVariable points, Variable pointsData, + bool isInverse, bool isDual) { grid->checkDevice(points); TORCH_CHECK_VALUE(points.rdim() == 2, "points must have shape [B*N, 3]"); TORCH_CHECK_VALUE(points.rsize(-1) == 3, "points must have shape [B*N, 3]"); @@ -36,58 +31,59 @@ TransformPoints::variable_list TransformPoints::forward(TransformPoints::Autogra torch::Tensor outTxPoints; if (isInverse) { outTxPoints = FVDB_DISPATCH_KERNEL_DEVICE(points.device(), [&]() { - return ops::dispatchInvTransformPointsToGrid( - *grid, pointsWrap, !isDual); + return ops::dispatchInvTransformPointsToGrid(*grid, pointsWrap, !isDual); }); } else { outTxPoints = FVDB_DISPATCH_KERNEL_DEVICE(points.device(), [&]() { - return ops::dispatchTransformPointsToGrid( - *grid, pointsWrap, !isDual); + return ops::dispatchTransformPointsToGrid(*grid, pointsWrap, !isDual); }); } - ctx->save_for_backward({points.joffsets(), points.jlidx()}); + ctx->save_for_backward({ points.joffsets(), points.jlidx() }); - ctx->saved_data["grid"] = grid; - ctx->saved_data["isDual"] = isDual; + ctx->saved_data["grid"] = grid; + ctx->saved_data["isDual"] = isDual; ctx->saved_data["isInverse"] = isInverse; - return {outTxPoints}; // [B*N, 3] + return { outTxPoints }; // [B*N, 3] } - -TransformPoints::variable_list TransformPoints::backward(TransformPoints::AutogradContext *ctx, - TransformPoints::variable_list grad_output) { - +TransformPoints::variable_list +TransformPoints::backward(TransformPoints::AutogradContext *ctx, + TransformPoints::variable_list grad_output) { variable_list saved = ctx->get_saved_variables(); Variable pointsJOffsets = saved.at(0); - Variable pointsJLidx = saved.at(1); - Variable gradOut = grad_output.at(0); // [B*N, 3] + Variable pointsJLidx = saved.at(1); + Variable gradOut = grad_output.at(0); // [B*N, 3] // Use data saved in forward - auto grid = ctx->saved_data["grid"].toCustomClass(); - const bool isDual = ctx->saved_data["isDual"].toBool(); + auto grid = ctx->saved_data["grid"].toCustomClass(); + const bool isDual = ctx->saved_data["isDual"].toBool(); const bool isInverse = ctx->saved_data["isInverse"].toBool(); Variable outGradIn; // = torch::empty_like(gradOut); // [B*N, 3] if (isInverse) { outGradIn = FVDB_DISPATCH_KERNEL_DEVICE(gradOut.device(), [&]() { return ops::dispatchInvTransformPointsToGridBackward( - *grid, JaggedTensor::from_data_offsets_and_list_ids(gradOut, pointsJOffsets, pointsJLidx), !isDual); + *grid, + JaggedTensor::from_data_offsets_and_list_ids(gradOut, pointsJOffsets, pointsJLidx), + !isDual); }); } else { outGradIn = FVDB_DISPATCH_KERNEL_DEVICE(gradOut.device(), [&]() { return ops::dispatchTransformPointsToGridBackward( - *grid, JaggedTensor::from_data_offsets_and_list_ids(gradOut, pointsJOffsets, pointsJLidx), !isDual); + *grid, + JaggedTensor::from_data_offsets_and_list_ids(gradOut, pointsJOffsets, pointsJLidx), + !isDual); }); } - // Variable outGradIn = outGradInReshape.reshape(getShapeButReplaceFirstDim(fineData.size(0), gradOut)); - return {torch::Tensor(), torch::Tensor(), outGradIn, torch::Tensor(), torch::Tensor()}; + // Variable outGradIn = outGradInReshape.reshape(getShapeButReplaceFirstDim(fineData.size(0), + // gradOut)); + return { torch::Tensor(), torch::Tensor(), outGradIn, torch::Tensor(), torch::Tensor() }; } - } // namespace autograd } // namespace detail } // namespace fvdb \ No newline at end of file diff --git a/fvdb/src/detail/autograd/TransformPoints.h b/fvdb/src/detail/autograd/TransformPoints.h index 8d8d166b4b..47302e9085 100644 --- a/fvdb/src/detail/autograd/TransformPoints.h +++ b/fvdb/src/detail/autograd/TransformPoints.h @@ -1,34 +1,32 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_TRANSFORMPOINTS_H +#define FVDB_DETAIL_AUTOGRAD_TRANSFORMPOINTS_H #include #include "detail/GridBatchImpl.h" - namespace fvdb { namespace detail { namespace autograd { struct TransformPoints : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - using JaggedVariable = JaggedTensor; - - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr grid, - JaggedVariable points, - Variable pointsData, - bool isInverse, + using Variable = torch::autograd::Variable; + using JaggedVariable = JaggedTensor; + + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr grid, + JaggedVariable points, Variable pointsData, bool isInverse, bool isDual); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_TRANSFORMPOINTS_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/UpsampleGrid.cpp b/fvdb/src/detail/autograd/UpsampleGrid.cpp index 767e42176a..9d1f60037e 100644 --- a/fvdb/src/detail/autograd/UpsampleGrid.cpp +++ b/fvdb/src/detail/autograd/UpsampleGrid.cpp @@ -3,74 +3,70 @@ // #include "UpsampleGrid.h" -#include +#include +#include #include -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" +#include namespace fvdb { namespace detail { namespace autograd { -UpsampleGrid::variable_list UpsampleGrid::forward(UpsampleGrid::AutogradContext *ctx, - c10::intrusive_ptr coarseGrid, - c10::intrusive_ptr fineGrid, - nanovdb::Coord upsamplingFactor, - UpsampleGrid::Variable coarseData) { +UpsampleGrid::variable_list +UpsampleGrid::forward(UpsampleGrid::AutogradContext *ctx, + c10::intrusive_ptr coarseGrid, + c10::intrusive_ptr fineGrid, nanovdb::Coord upsamplingFactor, + UpsampleGrid::Variable coarseData) { // Save data for backward in context - ctx->save_for_backward({coarseData}); + ctx->save_for_backward({ coarseData }); - ctx->saved_data["coarse_grid"] = coarseGrid; - ctx->saved_data["fine_grid"] = fineGrid; - ctx->saved_data["upsampling_factor_x"] = (int64_t) upsamplingFactor[0]; - ctx->saved_data["upsampling_factor_y"] = (int64_t) upsamplingFactor[1]; - ctx->saved_data["upsampling_factor_z"] = (int64_t) upsamplingFactor[2]; + ctx->saved_data["coarse_grid"] = coarseGrid; + ctx->saved_data["fine_grid"] = fineGrid; + ctx->saved_data["upsampling_factor_x"] = (int64_t)upsamplingFactor[0]; + ctx->saved_data["upsampling_factor_y"] = (int64_t)upsamplingFactor[1]; + ctx->saved_data["upsampling_factor_z"] = (int64_t)upsamplingFactor[2]; if (fineGrid->totalVoxels() == 0) { - return variable_list({torch::empty({0, coarseData.size(1)}, coarseData.options())}); + return variable_list({ torch::empty({ 0, coarseData.size(1) }, coarseData.options()) }); } torch::Tensor ret = FVDB_DISPATCH_KERNEL_DEVICE(coarseData.device(), [&]() { - return ops::dispatchUpsampleGridNearest( - *coarseGrid, *fineGrid, coarseData, upsamplingFactor); + return ops::dispatchUpsampleGridNearest(*coarseGrid, *fineGrid, coarseData, + upsamplingFactor); }); - return variable_list({ret}); + return variable_list({ ret }); } -UpsampleGrid::variable_list UpsampleGrid::backward(UpsampleGrid::AutogradContext *ctx, - UpsampleGrid::variable_list grad_output) { - +UpsampleGrid::variable_list +UpsampleGrid::backward(UpsampleGrid::AutogradContext *ctx, + UpsampleGrid::variable_list grad_output) { // // Use data saved in forward - variable_list saved = ctx->get_saved_variables(); - Variable coarseData = saved.at(0); + variable_list saved = ctx->get_saved_variables(); + Variable coarseData = saved.at(0); - auto fineGrid = ctx->saved_data["fine_grid"].toCustomClass(); - auto coarseGrid = ctx->saved_data["coarse_grid"].toCustomClass(); - const int64_t upsamplingFactorX = ctx->saved_data["upsampling_factor_x"].toInt(); - const int64_t upsamplingFactorY = ctx->saved_data["upsampling_factor_y"].toInt(); - const int64_t upsamplingFactorZ = ctx->saved_data["upsampling_factor_z"].toInt(); + auto fineGrid = ctx->saved_data["fine_grid"].toCustomClass(); + auto coarseGrid = ctx->saved_data["coarse_grid"].toCustomClass(); + const int64_t upsamplingFactorX = ctx->saved_data["upsampling_factor_x"].toInt(); + const int64_t upsamplingFactorY = ctx->saved_data["upsampling_factor_y"].toInt(); + const int64_t upsamplingFactorZ = ctx->saved_data["upsampling_factor_z"].toInt(); const nanovdb::Coord upsamplingFactor(upsamplingFactorX, upsamplingFactorY, upsamplingFactorZ); - Variable gradOut = grad_output.at(0); // [#fine_voxels, *] + Variable gradOut = grad_output.at(0); // [#fine_voxels, *] if (fineGrid->totalVoxels() == 0) { auto ret = torch::zeros_like(coarseData); - return {torch::Tensor(), torch::Tensor(), torch::Tensor(), ret}; + return { torch::Tensor(), torch::Tensor(), torch::Tensor(), ret }; } torch::Tensor outGradIn = FVDB_DISPATCH_KERNEL_DEVICE(coarseData.device(), [&]() { - return ops::dispatchUpsampleGridNearestBackward( - *fineGrid, *coarseGrid, - gradOut, - coarseData, - upsamplingFactor - ); + return ops::dispatchUpsampleGridNearestBackward(*fineGrid, *coarseGrid, gradOut, + coarseData, upsamplingFactor); }); - return {torch::Tensor(), torch::Tensor(), torch::Tensor(), outGradIn}; + return { torch::Tensor(), torch::Tensor(), torch::Tensor(), outGradIn }; } -} // namespace autograd -} // namespace detail -} // namespace fvdb +} // namespace autograd +} // namespace detail +} // namespace fvdb diff --git a/fvdb/src/detail/autograd/UpsampleGrid.h b/fvdb/src/detail/autograd/UpsampleGrid.h index 5457f0e063..5d3cf84707 100644 --- a/fvdb/src/detail/autograd/UpsampleGrid.h +++ b/fvdb/src/detail/autograd/UpsampleGrid.h @@ -1,31 +1,31 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_UPSAMPLEGRID_H +#define FVDB_DETAIL_AUTOGRAD_UPSAMPLEGRID_H + #include #include "detail/GridBatchImpl.h" - namespace fvdb { namespace detail { namespace autograd { struct UpsampleGrid : public torch::autograd::Function { - using variable_list = torch::autograd::variable_list; + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - c10::intrusive_ptr coarseGrid, + static variable_list forward(AutogradContext *ctx, c10::intrusive_ptr coarseGrid, c10::intrusive_ptr fineGrid, - nanovdb::Coord upsamplingFactor, - Variable coarseData); + nanovdb::Coord upsamplingFactor, Variable coarseData); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_UPSAMPLEGRID_H \ No newline at end of file diff --git a/fvdb/src/detail/autograd/VolumeRender.cpp b/fvdb/src/detail/autograd/VolumeRender.cpp index c78b37f8e4..341ddf07f8 100644 --- a/fvdb/src/detail/autograd/VolumeRender.cpp +++ b/fvdb/src/detail/autograd/VolumeRender.cpp @@ -3,22 +3,20 @@ // #include "VolumeRender.h" -#include "detail/ops/Ops.h" -#include "detail/utils/Utils.h" +#include +#include namespace fvdb { namespace detail { namespace autograd { -VolumeRender::variable_list VolumeRender::forward(VolumeRender::AutogradContext *ctx, - const VolumeRender::Variable& sigmas, - const VolumeRender::Variable& rgbs, - const VolumeRender::Variable& deltaTs, - const VolumeRender::Variable& ts, - const VolumeRender::Variable& jOffsets, - double tsmtThreshold) { +VolumeRender::variable_list +VolumeRender::forward(VolumeRender::AutogradContext *ctx, const VolumeRender::Variable &sigmas, + const VolumeRender::Variable &rgbs, const VolumeRender::Variable &deltaTs, + const VolumeRender::Variable &ts, const VolumeRender::Variable &jOffsets, + double tsmtThreshold) { const int numRays = jOffsets.size(0) - 1; - const int N = sigmas.size(0); + const int N = sigmas.size(0); TORCH_CHECK(jOffsets.dim() == 1, "jOffsets must have shape (nRays+1,)"); TORCH_CHECK(sigmas.dim() == 1, "sigmas must have shape (nRays*nSamplesPerRay,)"); @@ -31,76 +29,82 @@ VolumeRender::variable_list VolumeRender::forward(VolumeRender::AutogradContext TORCH_CHECK(sigmas.device() == ts.device(), "All tensors must be on the same device"); TORCH_CHECK(sigmas.device() == jOffsets.device(), "All tensors must be on the same device"); - TORCH_CHECK(sigmas.dtype() == rgbs.dtype(), "All floating point tensors must be on the same dtype"); - TORCH_CHECK(sigmas.dtype() == deltaTs.dtype(), "All floating point tensors must be on the same dtype"); - TORCH_CHECK(sigmas.dtype() == ts.dtype(),"All floating point tensors must be on the same dtype"); - TORCH_CHECK(jOffsets.dtype() == torch::dtype(JOffsetsScalarType).dtype(), "jOffsets must be of type torch.int32"); - - TORCH_CHECK(sigmas.size(0) == rgbs.size(0), "sigmas and rgbs must have the same number of elements"); - TORCH_CHECK(sigmas.size(0) == deltaTs.size(0), "sigmas and deltaTs must have the same number of elements"); - TORCH_CHECK(sigmas.size(0) == ts.size(0), "sigmas and ts must have the same number of elements"); - torch::Tensor outOpacity = torch::zeros({numRays}, sigmas.options()); - torch::Tensor outDepth = torch::zeros({numRays}, sigmas.options()); + TORCH_CHECK(sigmas.dtype() == rgbs.dtype(), + "All floating point tensors must be on the same dtype"); + TORCH_CHECK(sigmas.dtype() == deltaTs.dtype(), + "All floating point tensors must be on the same dtype"); + TORCH_CHECK(sigmas.dtype() == ts.dtype(), + "All floating point tensors must be on the same dtype"); + TORCH_CHECK(jOffsets.dtype() == torch::dtype(JOffsetsScalarType).dtype(), + "jOffsets must be of type torch.int32"); + + TORCH_CHECK(sigmas.size(0) == rgbs.size(0), + "sigmas and rgbs must have the same number of elements"); + TORCH_CHECK(sigmas.size(0) == deltaTs.size(0), + "sigmas and deltaTs must have the same number of elements"); + TORCH_CHECK(sigmas.size(0) == ts.size(0), + "sigmas and ts must have the same number of elements"); + torch::Tensor outOpacity = torch::zeros({ numRays }, sigmas.options()); + torch::Tensor outDepth = torch::zeros({ numRays }, sigmas.options()); // torch::Tensor outDepthSq = torch::zeros({numRays}, sigmas.options()); - torch::Tensor outRgb = torch::zeros({numRays, 3}, sigmas.options()); - torch::Tensor outWs = torch::zeros({N}, sigmas.options()); - torch::Tensor outTotalSamples = torch::zeros({numRays}, torch::dtype(torch::kLong).device(sigmas.device())); + torch::Tensor outRgb = torch::zeros({ numRays, 3 }, sigmas.options()); + torch::Tensor outWs = torch::zeros({ N }, sigmas.options()); + torch::Tensor outTotalSamples = + torch::zeros({ numRays }, torch::dtype(torch::kLong).device(sigmas.device())); FVDB_DISPATCH_KERNEL_DEVICE(sigmas.device(), [&]() { - ops::dispatchVolumeRender( - sigmas, rgbs, deltaTs, ts, jOffsets, tsmtThreshold, - outOpacity, outDepth, outRgb, outWs, outTotalSamples); + ops::dispatchVolumeRender(sigmas, rgbs, deltaTs, ts, jOffsets, tsmtThreshold, + outOpacity, outDepth, outRgb, outWs, outTotalSamples); }); ctx->saved_data["tsmtThreshold"] = tsmtThreshold; - ctx->save_for_backward({ - sigmas, rgbs, deltaTs, ts, jOffsets, - outOpacity, outDepth, outRgb, outWs - }); + ctx->save_for_backward( + { sigmas, rgbs, deltaTs, ts, jOffsets, outOpacity, outDepth, outRgb, outWs }); return { outRgb, outDepth, outOpacity, outWs, outTotalSamples }; } -VolumeRender::variable_list VolumeRender::backward(VolumeRender::AutogradContext *ctx, - VolumeRender::variable_list grad_output) { - Variable dLdRgb = grad_output.at(0); - Variable dLdDepth = grad_output.at(1); +VolumeRender::variable_list +VolumeRender::backward(VolumeRender::AutogradContext *ctx, + VolumeRender::variable_list grad_output) { + Variable dLdRgb = grad_output.at(0); + Variable dLdDepth = grad_output.at(1); Variable dLdOpacity = grad_output.at(2); - Variable dLdWs = grad_output.at(3); + Variable dLdWs = grad_output.at(3); // Variable dLdDepthSq = grad_output.at(3); - variable_list saved = ctx->get_saved_variables(); - Variable sigmas = saved.at(0); - Variable rgbs = saved.at(1); - Variable deltaTs = saved.at(2); - Variable ts = saved.at(3); - Variable jOffsets = saved.at(4); + variable_list saved = ctx->get_saved_variables(); + Variable sigmas = saved.at(0); + Variable rgbs = saved.at(1); + Variable deltaTs = saved.at(2); + Variable ts = saved.at(3); + Variable jOffsets = saved.at(4); Variable outOpacity = saved.at(5); - Variable outDepth = saved.at(6); + Variable outDepth = saved.at(6); // Variable outDepthSq = ctx->saved_data["outDepthSq"].toTensor(); - Variable outRgb = saved.at(7); - Variable outWs = saved.at(8); + Variable outRgb = saved.at(7); + Variable outWs = saved.at(8); const double tsmtThreshold = ctx->saved_data["tsmtThreshold"].toDouble(); const int N = sigmas.size(0); - Variable dLdSigmas = torch::zeros({N}, sigmas.options()); - Variable dLdRgbs = torch::zeros({N, 3}, sigmas.options()); + Variable dLdSigmas = torch::zeros({ N }, sigmas.options()); + Variable dLdRgbs = torch::zeros({ N, 3 }, sigmas.options()); FVDB_DISPATCH_KERNEL_DEVICE(sigmas.device(), [&]() { ops::dispatchVolumeRenderBackward( - dLdOpacity, dLdDepth, /*dLdDepthSq,*/ dLdRgb, dLdWs, - sigmas, rgbs, outWs, deltaTs, ts, jOffsets, - outOpacity, outDepth, /*outDepthSq,*/ outRgb, tsmtThreshold, - dLdSigmas, dLdRgbs); + dLdOpacity, dLdDepth, /*dLdDepthSq,*/ dLdRgb, dLdWs, sigmas, rgbs, outWs, deltaTs, ts, + jOffsets, outOpacity, outDepth, /*outDepthSq,*/ outRgb, tsmtThreshold, dLdSigmas, + dLdRgbs); }); - return { dLdSigmas, dLdRgbs, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor() }; + return { + dLdSigmas, dLdRgbs, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor() + }; } - } // namespace autograd } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/autograd/VolumeRender.h b/fvdb/src/detail/autograd/VolumeRender.h index 010dff48ab..8c20722435 100644 --- a/fvdb/src/detail/autograd/VolumeRender.h +++ b/fvdb/src/detail/autograd/VolumeRender.h @@ -1,33 +1,29 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_AUTOGRAD_VOLUMERENDER_H +#define FVDB_DETAIL_AUTOGRAD_VOLUMERENDER_H #include - namespace fvdb { namespace detail { namespace autograd { -struct VolumeRender : public torch::autograd::Function -{ - using variable_list = torch::autograd::variable_list; +struct VolumeRender : public torch::autograd::Function { + using variable_list = torch::autograd::variable_list; using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; + using Variable = torch::autograd::Variable; - static variable_list forward(AutogradContext *ctx, - const Variable& sigmas, - const Variable& rgbs, - const Variable& deltaTs, - const Variable& ts, - const Variable& raysAcc, - double tsmtThreshold); + static variable_list forward(AutogradContext *ctx, const Variable &sigmas, const Variable &rgbs, + const Variable &deltaTs, const Variable &ts, + const Variable &raysAcc, double tsmtThreshold); - static variable_list backward(AutogradContext *ctx, - variable_list grad_output); + static variable_list backward(AutogradContext *ctx, variable_list grad_output); }; } // namespace autograd } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_AUTOGRAD_VOLUMERENDER_H \ No newline at end of file diff --git a/fvdb/src/detail/build/Build.h b/fvdb/src/detail/build/Build.h index f119f242ec..38a323e4c0 100644 --- a/fvdb/src/detail/build/Build.h +++ b/fvdb/src/detail/build/Build.h @@ -1,15 +1,14 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#pragma once +#ifndef FVDB_DETAIL_BUILD_BUILD_H +#define FVDB_DETAIL_BUILD_BUILD_H -#include - -#include "detail/VoxelCoordTransform.h" -#include "detail/GridBatchImpl.h" - -#include "detail/utils/Utils.h" +#include +#include +#include +#include namespace fvdb { namespace detail { @@ -27,37 +26,42 @@ nanovdb::GridHandle buildEmptyGrid(torch::Device device, bool /// @param batchSize The number of grids in the batch /// @param size The size of the grid in voxels /// @param ijkMin The coordinate of the bottom-back-left corner of the grid -/// @param mask An optional mask tensor that can be used to mask out some of the voxels (shape = size) +/// @param mask An optional mask tensor that can be used to mask out some of the voxels (shape = +/// size) /// @return A handle to the nanovdb grid nanovdb::GridHandle buildDenseGrid(torch::Device device, bool isMutable, - const uint32_t batchSize, - const nanovdb::Coord& size, - const nanovdb::Coord& ijkMin, - const torch::optional& mask); + const uint32_t batchSize, + const nanovdb::Coord &size, + const nanovdb::Coord &ijkMin, + const torch::optional &mask); /// @brief Build a NanoVDB grid representing the coarse grid of a given fine grid /// @param isMutable Whether the grid should be mutable or not /// @param fineGridHdl The handle to the fine grid -/// @param branchingFactor The coarsening factor from the fine grid to the coarse grid (i.e. N = [2, 2, 2] for a 2x2x2 coarsening) +/// @param branchingFactor The coarsening factor from the fine grid to the coarse grid (i.e. N = [2, +/// 2, 2] for a 2x2x2 coarsening) /// @return A handle to the nanovdb grid (the device will match fineGridHdl) -nanovdb::GridHandle buildCoarseGridFromFineGrid(bool isMutable, - const GridBatchImpl& fineGridHdl, - const nanovdb::Coord branchingFactor); +nanovdb::GridHandle +buildCoarseGridFromFineGrid(bool isMutable, const GridBatchImpl &fineGridHdl, + const nanovdb::Coord branchingFactor); /// @brief Build a NanoVDB grid representing the fine grid of a given coarse grid /// @param isMutable Whether the grid should be mutable or not /// @param coarseGridHdl The handle to the coarse grid -/// @param subdivMask An optional mask JaggedTensor that can be used to not refine certain voxels (shape = [B, -1] matching number of coarse voxels) -/// @param subdivisionFactor The refinement factor from the coarse grid to the fine grid (i.e. (2, 2, 2) for a 2x2x2 refinement) +/// @param subdivMask An optional mask JaggedTensor that can be used to not refine certain voxels +/// (shape = [B, -1] matching number of coarse voxels) +/// @param subdivisionFactor The refinement factor from the coarse grid to the fine grid (i.e. (2, +/// 2, 2) for a 2x2x2 refinement) /// @return A handle to the nanovdb grid (the device will match coarseGridHdl) -nanovdb::GridHandle buildFineGridFromCoarseGrid(bool isMutable, - const GridBatchImpl& coarseGridHdl, - const torch::optional& subdivMask, - const nanovdb::Coord subdivisionFactor); +nanovdb::GridHandle +buildFineGridFromCoarseGrid(bool isMutable, const GridBatchImpl &coarseGridHdl, + const torch::optional &subdivMask, + const nanovdb::Coord subdivisionFactor); -nanovdb::GridHandle buildConvGridFromGrid(bool isMutable, - const GridBatchImpl& baseGridHdl, - const nanovdb::Coord& kernelSize, const nanovdb::Coord& stride); +nanovdb::GridHandle buildConvGridFromGrid(bool isMutable, + const GridBatchImpl &baseGridHdl, + const nanovdb::Coord &kernelSize, + const nanovdb::Coord &stride); /// @brief Build a NanoVDB grid which is a padded version of the given grid /// @param isMutable Whether the grid should be mutable or not @@ -66,11 +70,13 @@ nanovdb::GridHandle buildConvGridFromGrid(bool isMutable, /// @param bmax The padding in the positive direction /// @param excludeBorder Whether to exclude the border voxels from padding /// @return A handle to the padded nanovdb grid (the device will match baseGridHdl) -nanovdb::GridHandle buildPaddedGridFromGrid(bool isMutable, - const GridBatchImpl& baseGridHdl, - int bmin, int bmax, bool excludeBorder); +nanovdb::GridHandle buildPaddedGridFromGrid(bool isMutable, + const GridBatchImpl &baseGridHdl, + int bmin, int bmax, + bool excludeBorder); -/// @brief Build a NanoVDB grid from a set of points and pad each voxel ijk which contains a point from ijk - bmin to ijk + bmax +/// @brief Build a NanoVDB grid from a set of points and pad each voxel ijk which contains a point +/// from ijk - bmin to ijk + bmax /// @param device The device on which the grid will be allocated /// @param isMutable Whether the grid should be mutable or not /// @param points The points to be encoded in the grid (JaggedTensor of shape = (B, -1, 3)) @@ -78,23 +84,24 @@ nanovdb::GridHandle buildPaddedGridFromGrid(bool isMutable, /// @param bmin The minimum padding (i.e. we pad ijk from ijk - bmin to ijk + bmax) /// @param bmax The maximum padding (i.e. we pad ijk from ijk - bmin to ijk + bmax) /// @return A handle to the nanovdb grid (the device will match points) -nanovdb::GridHandle buildPaddedGridFromPoints(bool isMutable, - const JaggedTensor& points, - const std::vector& tx, - const nanovdb::Coord& bmin, - const nanovdb::Coord& bmax); +nanovdb::GridHandle +buildPaddedGridFromPoints(bool isMutable, const JaggedTensor &points, + const std::vector &tx, const nanovdb::Coord &bmin, + const nanovdb::Coord &bmax); -/// @brief Build a NanoVDB grid from a set of points where the 8 nearest voxels to each point are added to the grid +/// @brief Build a NanoVDB grid from a set of points where the 8 nearest voxels to each point are +/// added to the grid /// @param device The device on which the grid will be allocated /// @param isMutable Whether the grid should be mutable or not /// @param points The points to be encoded in the grid (JaggedTensor of shape = (B, -1, 3)) /// @param tx Transform from world to voxel coordinates /// @return A handle to the nanovdb grid (the device will match points) -nanovdb::GridHandle buildNearestNeighborGridFromPoints(bool isMutable, - const JaggedTensor& points, - const std::vector& tx); +nanovdb::GridHandle +buildNearestNeighborGridFromPoints(bool isMutable, const JaggedTensor &points, + const std::vector &tx); -/// @brief Build a NanoVDB grid from a set of ijk coordinates pad each voxel from ijk - bmin to ijk + bmax +/// @brief Build a NanoVDB grid from a set of ijk coordinates pad each voxel from ijk - bmin to ijk +/// + bmax /// @param device The device on which the grid will be allocated /// @param isMutable Whether the grid should be mutable or not /// @param coords The ijk coordinates to be encoded in the grid (JaggedTensor of shape = (B, -1, 3)) @@ -102,22 +109,25 @@ nanovdb::GridHandle buildNearestNeighborGridFromPoints(bool i /// @param bmin The minimum padding (i.e. we pad ijk from ijk - bmin to ijk + bmax) /// @param bmax The maximum padding (i.e. we pad ijk from ijk - bmin to ijk + bmax) /// @return A handle to the nanovdb grid (the device will match coords) -nanovdb::GridHandle buildPaddedGridFromCoords(bool isMutable, - const JaggedTensor& coords, - const nanovdb::Coord& bmin, - const nanovdb::Coord& bmax); - +nanovdb::GridHandle buildPaddedGridFromCoords(bool isMutable, + const JaggedTensor &coords, + const nanovdb::Coord &bmin, + const nanovdb::Coord &bmax); -/// @brief Build a NanoVDB grid by voxelizing a mesh (i.e. each voxel in the ouput grid intersects the mesh) +/// @brief Build a NanoVDB grid by voxelizing a mesh (i.e. each voxel in the ouput grid intersects +/// the mesh) /// @param isMutable Whether the grid should be mutable or not -/// @param meshVertices A JaggedTensor of shape = (B, -1, 3) containing the vertices of each mesh to voxelize -/// @param meshFaces A JaggedTensor of shape = (B, -1, 3) containing the face indexes of each mesh to voxelize +/// @param meshVertices A JaggedTensor of shape = (B, -1, 3) containing the vertices of each mesh to +/// voxelize +/// @param meshFaces A JaggedTensor of shape = (B, -1, 3) containing the face indexes of each mesh +/// to voxelize /// @return A handle to the nanovdb grid (the device will match meshVertices and meshFaces) -nanovdb::GridHandle buildGridFromMesh(bool isMutable, - const JaggedTensor meshVertices, - const JaggedTensor meshFaces, - const std::vector& tx); +nanovdb::GridHandle +buildGridFromMesh(bool isMutable, const JaggedTensor meshVertices, const JaggedTensor meshFaces, + const std::vector &tx); } // namespace build } // namespace detail -} // namespace fvdb \ No newline at end of file +} // namespace fvdb + +#endif // FVDB_DETAIL_BUILD_BUILD_H \ No newline at end of file diff --git a/fvdb/src/detail/build/CoarseFromFine.cpp b/fvdb/src/detail/build/CoarseFromFine.cpp index 6a403165b1..d4b789d063 100644 --- a/fvdb/src/detail/build/CoarseFromFine.cpp +++ b/fvdb/src/detail/build/CoarseFromFine.cpp @@ -3,47 +3,47 @@ // #include "Build.h" +#include +#include + #include -#include #include - -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { - template -nanovdb::GridHandle buildCoarseGridFromFineGridCPU(const GridBatchImpl& fineBatchHdl, - const nanovdb::Coord branchingFactor) { - +nanovdb::GridHandle +buildCoarseGridFromFineGridCPU(const GridBatchImpl &fineBatchHdl, + const nanovdb::Coord branchingFactor) { using IndexTree = nanovdb::NanoTree; - const nanovdb::GridHandle& fineGridHdl = fineBatchHdl.nanoGridHandle(); + const nanovdb::GridHandle &fineGridHdl = fineBatchHdl.nanoGridHandle(); std::vector> batchHandles; batchHandles.reserve(fineGridHdl.gridCount()); for (uint32_t bidx = 0; bidx < fineGridHdl.gridCount(); bidx += 1) { - const nanovdb::NanoGrid* fineGrid = fineGridHdl.template grid(bidx); + const nanovdb::NanoGrid *fineGrid = fineGridHdl.template grid(bidx); if (!fineGrid) { throw std::runtime_error("Failed to get pointer to nanovdb index grid"); } - const IndexTree& fineTree = fineGrid->tree(); + const IndexTree &fineTree = fineGrid->tree(); - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(-1.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(-1.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); for (auto it = ActiveVoxelIterator(fineTree); it.isValid(); it++) { - const nanovdb::Coord coarseIjk = (it->first.asVec3d() / branchingFactor.asVec3d()).floor(); + const nanovdb::Coord coarseIjk = + (it->first.asVec3d() / branchingFactor.asVec3d()).floor(); proxyGridAccessor.setValue(coarseIjk, 1.0f); } proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); ret.buffer().setDevice(torch::kCPU, true); batchHandles.push_back(std::move(ret)); } @@ -55,12 +55,12 @@ nanovdb::GridHandle buildCoarseGridFromFineGridCPU(const Grid } } - -nanovdb::GridHandle buildCoarseGridFromFineGrid(bool isMutable, - const GridBatchImpl& fineBatchHdl, - const nanovdb::Coord branchingFactor) { +nanovdb::GridHandle +buildCoarseGridFromFineGrid(bool isMutable, const GridBatchImpl &fineBatchHdl, + const nanovdb::Coord branchingFactor) { if (fineBatchHdl.device().is_cuda()) { - JaggedTensor coords = ops::dispatchCoarseIJKForFineGrid(fineBatchHdl, branchingFactor); + JaggedTensor coords = + ops::dispatchCoarseIJKForFineGrid(fineBatchHdl, branchingFactor); return ops::dispatchCreateNanoGridFromIJK(coords, isMutable); } else { return FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { @@ -69,7 +69,6 @@ nanovdb::GridHandle buildCoarseGridFromFineGrid(bool isMutabl } } - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/build/ConvGrid.cpp b/fvdb/src/detail/build/ConvGrid.cpp index 455cfdf9a9..cfa2e56e8d 100644 --- a/fvdb/src/detail/build/ConvGrid.cpp +++ b/fvdb/src/detail/build/ConvGrid.cpp @@ -3,46 +3,47 @@ // #include "Build.h" +#include +#include + #include -#include #include - -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { template -nanovdb::GridHandle buildCoarseGridFromFineGridCPU(const GridBatchImpl& fineBatchHdl, - const nanovdb::Coord branchingFactor) { - +nanovdb::GridHandle +buildCoarseGridFromFineGridCPU(const GridBatchImpl &fineBatchHdl, + const nanovdb::Coord branchingFactor) { using IndexTree = nanovdb::NanoTree; - const nanovdb::GridHandle& fineGridHdl = fineBatchHdl.nanoGridHandle(); + const nanovdb::GridHandle &fineGridHdl = fineBatchHdl.nanoGridHandle(); std::vector> batchHandles; batchHandles.reserve(fineGridHdl.gridCount()); for (uint32_t bidx = 0; bidx < fineGridHdl.gridCount(); bidx += 1) { - const nanovdb::NanoGrid* fineGrid = fineGridHdl.template grid(bidx); + const nanovdb::NanoGrid *fineGrid = fineGridHdl.template grid(bidx); if (!fineGrid) { throw std::runtime_error("Failed to get pointer to nanovdb index grid"); } - const IndexTree& fineTree = fineGrid->tree(); + const IndexTree &fineTree = fineGrid->tree(); - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(-1.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(-1.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); for (auto it = ActiveVoxelIterator(fineTree); it.isValid(); it++) { - const nanovdb::Coord coarseIjk = (it->first.asVec3d() / branchingFactor.asVec3d()).floor(); + const nanovdb::Coord coarseIjk = + (it->first.asVec3d() / branchingFactor.asVec3d()).floor(); proxyGridAccessor.setValue(coarseIjk, 1.0f); } proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); ret.buffer().setDevice(torch::kCPU, true); batchHandles.push_back(std::move(ret)); } @@ -54,24 +55,23 @@ nanovdb::GridHandle buildCoarseGridFromFineGridCPU(const Grid } } - template -nanovdb::GridHandle buildConvGridFromGridCPU(const GridBatchImpl& baseBatchHdl, - const nanovdb::Coord& kernelSize, - const nanovdb::Coord& stride) { - +nanovdb::GridHandle +buildConvGridFromGridCPU(const GridBatchImpl &baseBatchHdl, const nanovdb::Coord &kernelSize, + const nanovdb::Coord &stride) { if (stride == nanovdb::Coord(1) || stride == kernelSize) { return buildCoarseGridFromFineGridCPU(baseBatchHdl, stride); } - const nanovdb::GridHandle& baseGridHdl = baseBatchHdl.nanoGridHandle(); + const nanovdb::GridHandle &baseGridHdl = baseBatchHdl.nanoGridHandle(); std::vector> batchHandles; batchHandles.reserve(baseGridHdl.gridCount()); int lower[3], upper[3]; for (int i = 0; i < 3; i += 1) { if (kernelSize[i] % 2 == 0) { - lower[i] = 0; upper[i] = kernelSize[i] - 1; + lower[i] = 0; + upper[i] = kernelSize[i] - 1; } else { lower[i] = -(kernelSize[i] - 1) / 2; upper[i] = (kernelSize[i] - 1) / 2; @@ -79,33 +79,37 @@ nanovdb::GridHandle buildConvGridFromGridCPU(const GridBatchI } for (uint32_t bidx = 0; bidx < baseGridHdl.gridCount(); bidx += 1) { - - const nanovdb::NanoGrid* baseGrid = baseGridHdl.template grid(bidx); + const nanovdb::NanoGrid *baseGrid = baseGridHdl.template grid(bidx); if (!baseGrid) { throw std::runtime_error("Failed to get pointer to nanovdb index grid"); } - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(-1.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(-1.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); for (auto it = ActiveVoxelIterator(baseGrid->tree()); it.isValid(); it++) { - const nanovdb::Coord& ijk0 = it->first; + const nanovdb::Coord &ijk0 = it->first; for (int di = lower[0]; di <= upper[0]; di += 1) { for (int dj = lower[1]; dj <= upper[1]; dj += 1) { for (int dk = lower[2]; dk <= upper[2]; dk += 1) { const nanovdb::Coord dstIjk = ijk0 + nanovdb::Coord(dk, dj, di); - if (dstIjk[0] % stride[2] != 0 || dstIjk[1] % stride[1] != 0 || dstIjk[2] % stride[0] != 0) continue; - proxyGridAccessor.setValue(nanovdb::Coord( - dstIjk[0] / stride[2], dstIjk[1] / stride[1], dstIjk[2] / stride[0]), 1.0f); + if (dstIjk[0] % stride[2] != 0 || dstIjk[1] % stride[1] != 0 || + dstIjk[2] % stride[0] != 0) + continue; + proxyGridAccessor.setValue(nanovdb::Coord(dstIjk[0] / stride[2], + dstIjk[1] / stride[1], + dstIjk[2] / stride[0]), + 1.0f); } } } } proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); ret.buffer().setDevice(torch::kCPU, true); batchHandles.push_back(std::move(ret)); } @@ -117,19 +121,18 @@ nanovdb::GridHandle buildConvGridFromGridCPU(const GridBatchI } } - -nanovdb::GridHandle buildConvGridFromGrid(bool isMutable, - const GridBatchImpl& baseGridHdl, - const nanovdb::Coord& kernelSize, - const nanovdb::Coord& stride) { +nanovdb::GridHandle +buildConvGridFromGrid(bool isMutable, const GridBatchImpl &baseGridHdl, + const nanovdb::Coord &kernelSize, const nanovdb::Coord &stride) { /** * Logic for building the conv grid is the same as torchsparse 2.0.0b. - * However, torchsparse has a bug that creates excessive voxels in the void space, it is fixed in a customized - * branch - hence the additional URL for pre-built wheels. + * However, torchsparse has a bug that creates excessive voxels in the void space, it is fixed + * in a customized branch - hence the additional URL for pre-built wheels. */ if (baseGridHdl.device().is_cuda()) { - JaggedTensor coords = ops::dispatchConvIJKForGrid(baseGridHdl, kernelSize, stride); + JaggedTensor coords = + ops::dispatchConvIJKForGrid(baseGridHdl, kernelSize, stride); return ops::dispatchCreateNanoGridFromIJK(coords, isMutable); } else { return FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { @@ -138,8 +141,6 @@ nanovdb::GridHandle buildConvGridFromGrid(bool isMutable, } } - - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/build/DenseGrid.cpp b/fvdb/src/detail/build/DenseGrid.cpp index 2a37717928..5b3b080fc0 100644 --- a/fvdb/src/detail/build/DenseGrid.cpp +++ b/fvdb/src/detail/build/DenseGrid.cpp @@ -3,32 +3,28 @@ // #include "Build.h" +#include +#include + #include -#include #include - -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { - template -nanovdb::GridHandle buildDenseGridCPU(const uint32_t batchSize, - const nanovdb::Coord& size, - const nanovdb::Coord& ijkMin, - torch::optional mask) { - +nanovdb::GridHandle +buildDenseGridCPU(const uint32_t batchSize, const nanovdb::Coord &size, + const nanovdb::Coord &ijkMin, torch::optional mask) { torch::TensorAccessor maskAccessor(nullptr, nullptr, nullptr); if (mask.has_value()) { maskAccessor = mask.value().accessor(); } - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(0.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(0.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); for (int32_t i = 0; i < size[0]; i += 1) { @@ -49,7 +45,9 @@ nanovdb::GridHandle buildDenseGridCPU(const uint32_t batchSiz } proxyGridAccessor.merge(); - nanovdb::GridHandle ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + nanovdb::GridHandle ret = + nanovdb::tools::createNanoGrid(*proxyGrid, 0u, + false, false); ret.buffer().setDevice(torch::kCPU, true /* sync */); TorchDeviceBuffer guide(0, nullptr); @@ -69,28 +67,29 @@ nanovdb::GridHandle buildDenseGridCPU(const uint32_t batchSiz } } - - -nanovdb::GridHandle buildDenseGrid(torch::Device device, bool isMutable, - const uint32_t batchSize, - const nanovdb::Coord& size, - const nanovdb::Coord& ijkMin, - const torch::optional& mask) { - - TORCH_CHECK(size[0] > 0 && size[1] > 0 && size[2] > 0, "Size must be greater than 0 in all dimensions"); - TORCH_CHECK((__uint128_t) size[0] * size[1] * size[2] <= std::numeric_limits::max(), - "Size of dense grid exceeds the number of voxels supported by a GridBatch"); - TORCH_CHECK((__uint128_t) size[0] * size[1] * size[2] * batchSize <= std::numeric_limits::max(), - "Size and batch size exceed the number of voxels supported by a GridBatch"); +nanovdb::GridHandle +buildDenseGrid(torch::Device device, bool isMutable, const uint32_t batchSize, + const nanovdb::Coord &size, const nanovdb::Coord &ijkMin, + const torch::optional &mask) { + TORCH_CHECK(size[0] > 0 && size[1] > 0 && size[2] > 0, + "Size must be greater than 0 in all dimensions"); + TORCH_CHECK((__uint128_t)size[0] * size[1] * size[2] <= std::numeric_limits::max(), + "Size of dense grid exceeds the number of voxels supported by a GridBatch"); + TORCH_CHECK((__uint128_t)size[0] * size[1] * size[2] * batchSize <= + std::numeric_limits::max(), + "Size and batch size exceed the number of voxels supported by a GridBatch"); if (mask.has_value()) { - TORCH_CHECK(mask.value().device() == device, "Mask device must match device of dense grid to build"); + TORCH_CHECK(mask.value().device() == device, + "Mask device must match device of dense grid to build"); TORCH_CHECK(mask.value().dtype() == torch::kBool, "Mask must be of type bool"); TORCH_CHECK(mask.value().dim() == 3, "Mask must be 3D"); - TORCH_CHECK(mask.value().size(0) == size[0] && mask.value().size(1) == size[1] && mask.value().size(2) == size[2], + TORCH_CHECK(mask.value().size(0) == size[0] && mask.value().size(1) == size[1] && + mask.value().size(2) == size[2], "Mask must have same size as dense grid to build"); } if (device.is_cuda()) { - return ops::dispatchCreateNanoGridFromDense(batchSize, ijkMin, size, isMutable, device, mask); + return ops::dispatchCreateNanoGridFromDense(batchSize, ijkMin, size, + isMutable, device, mask); } else { return FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { return buildDenseGridCPU(batchSize, size, ijkMin, mask); @@ -98,7 +97,6 @@ nanovdb::GridHandle buildDenseGrid(torch::Device device, bool } } - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/build/EmptyGrid.cpp b/fvdb/src/detail/build/EmptyGrid.cpp index b490125b84..eb07a527d5 100644 --- a/fvdb/src/detail/build/EmptyGrid.cpp +++ b/fvdb/src/detail/build/EmptyGrid.cpp @@ -3,32 +3,31 @@ // #include "Build.h" -#include "detail/utils/Utils.h" +#include #include -#include #include - +#include namespace fvdb { namespace detail { namespace build { - -nanovdb::GridHandle buildEmptyGrid(torch::Device device, bool isMutable) { +nanovdb::GridHandle +buildEmptyGrid(torch::Device device, bool isMutable) { return FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(0.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(0.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); ret.buffer().setDevice(device, true /* sync */); return ret; }); } - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/build/FineFromCoarse.cpp b/fvdb/src/detail/build/FineFromCoarse.cpp index efa6c41ff4..7deede22ed 100644 --- a/fvdb/src/detail/build/FineFromCoarse.cpp +++ b/fvdb/src/detail/build/FineFromCoarse.cpp @@ -3,40 +3,37 @@ // #include "Build.h" +#include +#include + #include -#include #include - -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { - template -nanovdb::GridHandle buildFineGridFromCoarseGridCPU(const GridBatchImpl& coarseBatchHdl, - const torch::Tensor& subdivMask, - const nanovdb::Coord subdivisionFactor) { - +nanovdb::GridHandle +buildFineGridFromCoarseGridCPU(const GridBatchImpl &coarseBatchHdl, const torch::Tensor &subdivMask, + const nanovdb::Coord subdivisionFactor) { using IndexTree = nanovdb::NanoTree; - const nanovdb::GridHandle& coarseGridHdl = coarseBatchHdl.nanoGridHandle(); - const torch::TensorAccessor& subdivMaskAcc = subdivMask.accessor(); + const nanovdb::GridHandle &coarseGridHdl = coarseBatchHdl.nanoGridHandle(); + const torch::TensorAccessor &subdivMaskAcc = subdivMask.accessor(); std::vector> batchHandles; batchHandles.reserve(coarseGridHdl.gridCount()); for (uint32_t bidx = 0; bidx < coarseGridHdl.gridCount(); bidx += 1) { - const nanovdb::NanoGrid* coarseGrid = coarseGridHdl.template grid(bidx); + const nanovdb::NanoGrid *coarseGrid = coarseGridHdl.template grid(bidx); if (!coarseGrid) { throw std::runtime_error("Failed to get pointer to nanovdb index grid"); } - const IndexTree& coarseTree = coarseGrid->tree(); + const IndexTree &coarseTree = coarseGrid->tree(); - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(-1.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(-1.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); for (auto it = ActiveVoxelIterator(coarseTree); it.isValid(); it++) { @@ -59,7 +56,8 @@ nanovdb::GridHandle buildFineGridFromCoarseGridCPU(const Grid } proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); ret.buffer().setDevice(torch::kCPU, true); batchHandles.push_back(std::move(ret)); } @@ -71,14 +69,13 @@ nanovdb::GridHandle buildFineGridFromCoarseGridCPU(const Grid } } - -nanovdb::GridHandle buildFineGridFromCoarseGrid(bool isMutable, - const GridBatchImpl& coarseBatchHdl, - const torch::optional& subdivMask, - const nanovdb::Coord subdivisionFactor) { - +nanovdb::GridHandle +buildFineGridFromCoarseGrid(bool isMutable, const GridBatchImpl &coarseBatchHdl, + const torch::optional &subdivMask, + const nanovdb::Coord subdivisionFactor) { if (coarseBatchHdl.device().is_cuda()) { - JaggedTensor coords = ops::dispatchFineIJKForCoarseGrid(coarseBatchHdl, subdivisionFactor, subdivMask); + JaggedTensor coords = ops::dispatchFineIJKForCoarseGrid( + coarseBatchHdl, subdivisionFactor, subdivMask); return ops::dispatchCreateNanoGridFromIJK(coords, isMutable); } else { torch::Tensor subdivMaskTensor; @@ -88,12 +85,12 @@ nanovdb::GridHandle buildFineGridFromCoarseGrid(bool isMutabl subdivMaskTensor = torch::zeros(0, torch::TensorOptions().dtype(torch::kBool)); } return FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { - return buildFineGridFromCoarseGridCPU(coarseBatchHdl, subdivMaskTensor, subdivisionFactor); + return buildFineGridFromCoarseGridCPU(coarseBatchHdl, subdivMaskTensor, + subdivisionFactor); }); } } - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/build/FromMesh.cpp b/fvdb/src/detail/build/FromMesh.cpp index cda2254ced..b0025b2419 100644 --- a/fvdb/src/detail/build/FromMesh.cpp +++ b/fvdb/src/detail/build/FromMesh.cpp @@ -3,54 +3,54 @@ // #include "Build.h" +#include +#include + #include -#include #include - -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { - template -nanovdb::GridHandle buildGridFromMeshCPU(const JaggedTensor& vertices, - const JaggedTensor& triangles, - const std::vector& tx) { - - using Vec3T = nanovdb::math::Vec3; +nanovdb::GridHandle +buildGridFromMeshCPU(const JaggedTensor &vertices, const JaggedTensor &triangles, + const std::vector &tx) { + using Vec3T = nanovdb::math::Vec3; using ProxyGridT = nanovdb::tools::build::Grid; - std::vector> batchHandles; batchHandles.reserve(vertices.num_outer_lists()); for (int64_t bidx = 0; bidx < vertices.num_outer_lists(); bidx += 1) { + const torch::Tensor ti = triangles.index({ bidx }).jdata(); + const torch::Tensor vi = vertices.index({ bidx }).jdata(); + const VoxelCoordTransform &txi = tx[bidx]; - const torch::Tensor ti = triangles.index({bidx}).jdata(); - const torch::Tensor vi = vertices.index({bidx}).jdata(); - const VoxelCoordTransform& txi = tx[bidx]; - - auto proxyGrid = std::make_shared(-1.0f); + auto proxyGrid = std::make_shared(-1.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); // int64_t numSearched = 0; // int64_t numFound = 0; // For eacjh face, compute thee min max voxels for (int faceId = 0; faceId < ti.size(0); faceId += 1) { - const torch::Tensor face = ti.index({faceId}); // 3 - const torch::Tensor faceVertices = vi.index({face}); // [3, 3] - torch::TensorAccessor faceVerticesAcc = faceVertices.accessor(); - const Vec3T v1 = txi.apply(Vec3T(faceVerticesAcc[0][0], faceVerticesAcc[0][1], faceVerticesAcc[0][2])); - const Vec3T v2 = txi.apply(Vec3T(faceVerticesAcc[1][0], faceVerticesAcc[1][1], faceVerticesAcc[1][2])); - const Vec3T v3 = txi.apply(Vec3T(faceVerticesAcc[2][0], faceVerticesAcc[2][1], faceVerticesAcc[2][2])); - - const Vec3T e1 = v2 - v1; - const Vec3T e2 = v3 - v1; - const ScalarType spacing = sqrt(3.0) / 3.0; // This is very conservative spacing but fine for now + const torch::Tensor face = ti.index({ faceId }); // 3 + const torch::Tensor faceVertices = vi.index({ face }); // [3, 3] + torch::TensorAccessor faceVerticesAcc = + faceVertices.accessor(); + const Vec3T v1 = txi.apply( + Vec3T(faceVerticesAcc[0][0], faceVerticesAcc[0][1], faceVerticesAcc[0][2])); + const Vec3T v2 = txi.apply( + Vec3T(faceVerticesAcc[1][0], faceVerticesAcc[1][1], faceVerticesAcc[1][2])); + const Vec3T v3 = txi.apply( + Vec3T(faceVerticesAcc[2][0], faceVerticesAcc[2][1], faceVerticesAcc[2][2])); + + const Vec3T e1 = v2 - v1; + const Vec3T e2 = v3 - v1; + const ScalarType spacing = + sqrt(3.0) / 3.0; // This is very conservative spacing but fine for now const int32_t numU = ceil((e1.length() + spacing) / spacing); const int32_t numV = ceil((e2.length() + spacing) / spacing); @@ -63,7 +63,7 @@ nanovdb::GridHandle buildGridFromMeshCPU(const JaggedTensor& u = 1.0 - u; v = 1.0 - v; } - const Vec3T p = v1 + e1 * u + e2 * v; + const Vec3T p = v1 + e1 * u + e2 * v; const nanovdb::Coord ijk = p.round(); proxyGridAccessor.setValue(ijk, 1.0f); @@ -75,7 +75,8 @@ nanovdb::GridHandle buildGridFromMeshCPU(const JaggedTensor& // std::cerr << "I searched over " << numSearched << " voxels" << std::endl; // std::cerr << "I found " << numFound << " voxels" << std::endl; proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); ret.buffer().setDevice(torch::kCPU, true); batchHandles.push_back(std::move(ret)); } @@ -87,24 +88,22 @@ nanovdb::GridHandle buildGridFromMeshCPU(const JaggedTensor& } } - -nanovdb::GridHandle buildGridFromMesh(bool isMutable, - const JaggedTensor meshVertices, - const JaggedTensor meshFaces, - const std::vector& tx) { +nanovdb::GridHandle +buildGridFromMesh(bool isMutable, const JaggedTensor meshVertices, const JaggedTensor meshFaces, + const std::vector &tx) { if (meshVertices.device().is_cuda()) { JaggedTensor coords = ops::dispatchIJKForMesh(meshVertices, meshFaces, tx); return ops::dispatchCreateNanoGridFromIJK(coords, isMutable); } else { return FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { - return AT_DISPATCH_FLOATING_TYPES(meshVertices.scalar_type(), "buildGridFromMeshCPU", [&]() { - return buildGridFromMeshCPU(meshVertices, meshFaces, tx); - }); + return AT_DISPATCH_FLOATING_TYPES( + meshVertices.scalar_type(), "buildGridFromMeshCPU", [&]() { + return buildGridFromMeshCPU(meshVertices, meshFaces, tx); + }); }); } } - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/build/NearestNeighborGridFromPoints.cpp b/fvdb/src/detail/build/NearestNeighborGridFromPoints.cpp index d763990513..805aa358a7 100644 --- a/fvdb/src/detail/build/NearestNeighborGridFromPoints.cpp +++ b/fvdb/src/detail/build/NearestNeighborGridFromPoints.cpp @@ -3,89 +3,90 @@ // #include "Build.h" +#include +#include + #include -#include #include - -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { - template -nanovdb::GridHandle buildNearestNeighborGridFromPointsCPU(const JaggedTensor& jaggedPoints, - const std::vector& txs) { - - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(jaggedPoints.scalar_type(), "buildNearestNeighborGridFromPoints", [&]() { - using ScalarT = scalar_t; - using MathT = typename at::opmath_type; - using Vec3T = typename nanovdb::math::Vec3; - using ProxyGridT = nanovdb::tools::build::Grid; - - static_assert(is_floating_point_or_half::value, "Invalid type for points, must be floating point"); - - jaggedPoints.check_valid(); - - const torch::TensorAccessor& pointsAcc = jaggedPoints.jdata().accessor(); - const torch::TensorAccessor& pointsBOffsetsAcc = jaggedPoints.joffsets().accessor(); - - std::vector> batchHandles; - batchHandles.reserve(pointsBOffsetsAcc.size(0) - 1); - for (int bi = 0; bi < (pointsBOffsetsAcc.size(0) - 1); bi += 1) { - - const VoxelCoordTransform& tx = txs[bi]; - - auto proxyGrid = std::make_shared(-1.0f); - auto proxyGridAccessor = proxyGrid->getWriteAccessor(); - - const int64_t start = pointsBOffsetsAcc[bi]; - const int64_t end = pointsBOffsetsAcc[bi+1]; - - for (int64_t pi = start; pi < end; pi += 1) { - Vec3T ijk0 = tx.apply(static_cast(pointsAcc[pi][0]), - static_cast(pointsAcc[pi][1]), - static_cast(pointsAcc[pi][2])); - nanovdb::Coord ijk000 = ijk0.floor(); - nanovdb::Coord ijk001 = ijk000 + nanovdb::Coord(0, 0, 1); - nanovdb::Coord ijk010 = ijk000 + nanovdb::Coord(0, 1, 0); - nanovdb::Coord ijk011 = ijk000 + nanovdb::Coord(0, 1, 1); - nanovdb::Coord ijk100 = ijk000 + nanovdb::Coord(1, 0, 0); - nanovdb::Coord ijk101 = ijk000 + nanovdb::Coord(1, 0, 1); - nanovdb::Coord ijk110 = ijk000 + nanovdb::Coord(1, 1, 0); - nanovdb::Coord ijk111 = ijk000 + nanovdb::Coord(1, 1, 1); - - proxyGridAccessor.setValue(ijk000, 11.0f); - proxyGridAccessor.setValue(ijk001, 11.0f); - proxyGridAccessor.setValue(ijk010, 11.0f); - proxyGridAccessor.setValue(ijk011, 11.0f); - proxyGridAccessor.setValue(ijk100, 11.0f); - proxyGridAccessor.setValue(ijk101, 11.0f); - proxyGridAccessor.setValue(ijk110, 11.0f); - proxyGridAccessor.setValue(ijk111, 11.0f); +nanovdb::GridHandle +buildNearestNeighborGridFromPointsCPU(const JaggedTensor &jaggedPoints, + const std::vector &txs) { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF( + jaggedPoints.scalar_type(), "buildNearestNeighborGridFromPoints", [&]() { + using ScalarT = scalar_t; + using MathT = typename at::opmath_type; + using Vec3T = typename nanovdb::math::Vec3; + using ProxyGridT = nanovdb::tools::build::Grid; + + static_assert(is_floating_point_or_half::value, + "Invalid type for points, must be floating point"); + + jaggedPoints.check_valid(); + + const torch::TensorAccessor &pointsAcc = + jaggedPoints.jdata().accessor(); + const torch::TensorAccessor &pointsBOffsetsAcc = + jaggedPoints.joffsets().accessor(); + + std::vector> batchHandles; + batchHandles.reserve(pointsBOffsetsAcc.size(0) - 1); + for (int bi = 0; bi < (pointsBOffsetsAcc.size(0) - 1); bi += 1) { + const VoxelCoordTransform &tx = txs[bi]; + + auto proxyGrid = std::make_shared(-1.0f); + auto proxyGridAccessor = proxyGrid->getWriteAccessor(); + + const int64_t start = pointsBOffsetsAcc[bi]; + const int64_t end = pointsBOffsetsAcc[bi + 1]; + + for (int64_t pi = start; pi < end; pi += 1) { + Vec3T ijk0 = tx.apply(static_cast(pointsAcc[pi][0]), + static_cast(pointsAcc[pi][1]), + static_cast(pointsAcc[pi][2])); + nanovdb::Coord ijk000 = ijk0.floor(); + nanovdb::Coord ijk001 = ijk000 + nanovdb::Coord(0, 0, 1); + nanovdb::Coord ijk010 = ijk000 + nanovdb::Coord(0, 1, 0); + nanovdb::Coord ijk011 = ijk000 + nanovdb::Coord(0, 1, 1); + nanovdb::Coord ijk100 = ijk000 + nanovdb::Coord(1, 0, 0); + nanovdb::Coord ijk101 = ijk000 + nanovdb::Coord(1, 0, 1); + nanovdb::Coord ijk110 = ijk000 + nanovdb::Coord(1, 1, 0); + nanovdb::Coord ijk111 = ijk000 + nanovdb::Coord(1, 1, 1); + + proxyGridAccessor.setValue(ijk000, 11.0f); + proxyGridAccessor.setValue(ijk001, 11.0f); + proxyGridAccessor.setValue(ijk010, 11.0f); + proxyGridAccessor.setValue(ijk011, 11.0f); + proxyGridAccessor.setValue(ijk100, 11.0f); + proxyGridAccessor.setValue(ijk101, 11.0f); + proxyGridAccessor.setValue(ijk110, 11.0f); + proxyGridAccessor.setValue(ijk111, 11.0f); + } + + proxyGridAccessor.merge(); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); + ret.buffer().setDevice(torch::kCPU, true); + batchHandles.push_back(std::move(ret)); } - proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); - ret.buffer().setDevice(torch::kCPU, true); - batchHandles.push_back(std::move(ret)); - } - - if (batchHandles.size() == 1) { - return std::move(batchHandles[0]); - } else { - return nanovdb::mergeGrids(batchHandles); - } - }); + if (batchHandles.size() == 1) { + return std::move(batchHandles[0]); + } else { + return nanovdb::mergeGrids(batchHandles); + } + }); } - -nanovdb::GridHandle buildNearestNeighborGridFromPoints(bool isMutable, - const JaggedTensor& points, - const std::vector& txs) { +nanovdb::GridHandle +buildNearestNeighborGridFromPoints(bool isMutable, const JaggedTensor &points, + const std::vector &txs) { if (points.device().is_cuda()) { JaggedTensor coords = ops::dispatchNearestNeighborIJKForPoints(points, txs); return ops::dispatchCreateNanoGridFromIJK(coords, isMutable); @@ -96,8 +97,6 @@ nanovdb::GridHandle buildNearestNeighborGridFromPoints(bool i } } - - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/build/PaddedGridFromCoords.cpp b/fvdb/src/detail/build/PaddedGridFromCoords.cpp index c45554d361..275c393b0d 100644 --- a/fvdb/src/detail/build/PaddedGridFromCoords.cpp +++ b/fvdb/src/detail/build/PaddedGridFromCoords.cpp @@ -3,82 +3,82 @@ // #include "Build.h" -#include +#include +#include #include -#include #include +#include -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { - template -nanovdb::GridHandle buildPaddedGridFromCoordsCPU(const JaggedTensor& jaggedCoords, - const nanovdb::Coord& bmin, - const nanovdb::Coord& bmax) { - - return AT_DISPATCH_INTEGRAL_TYPES(jaggedCoords.scalar_type(), "buildPaddedGridFromCoords", [&]() { - using ScalarT = scalar_t; - jaggedCoords.check_valid(); - - static_assert(std::is_integral::value, "Invalid type for coords, must be integral"); - - using ProxyGridT = nanovdb::tools::build::Grid; - - const torch::TensorAccessor& coordsAcc = jaggedCoords.jdata().accessor(); - const torch::TensorAccessor& coordsBOffsetsAcc = jaggedCoords.joffsets().accessor(); - - std::vector> batchHandles; - batchHandles.reserve(coordsBOffsetsAcc.size(0) - 1); - for (int bi = 0; bi < (coordsBOffsetsAcc.size(0) - 1); bi += 1) { - - auto proxyGrid = std::make_shared(-1.0f); - auto proxyGridAccessor = proxyGrid->getWriteAccessor(); - - const int64_t start = coordsBOffsetsAcc[bi]; - const int64_t end = coordsBOffsetsAcc[bi+1]; - - for (unsigned ci = start; ci < end; ci += 1) { - nanovdb::Coord ijk0(coordsAcc[ci][0], coordsAcc[ci][1], coordsAcc[ci][2]); - - // Splat the normal to the 8 neighboring voxels - for (int di = bmin[0]; di <= bmax[0]; di += 1) { - for (int dj = bmin[1]; dj <= bmax[1]; dj += 1) { - for (int dk = bmin[2]; dk <= bmax[2]; dk += 1) { - const nanovdb::Coord ijk = ijk0 + nanovdb::Coord(di, dj, dk); - proxyGridAccessor.setValue(ijk, 11); +nanovdb::GridHandle +buildPaddedGridFromCoordsCPU(const JaggedTensor &jaggedCoords, const nanovdb::Coord &bmin, + const nanovdb::Coord &bmax) { + return AT_DISPATCH_INTEGRAL_TYPES( + jaggedCoords.scalar_type(), "buildPaddedGridFromCoords", [&]() { + using ScalarT = scalar_t; + jaggedCoords.check_valid(); + + static_assert(std::is_integral::value, + "Invalid type for coords, must be integral"); + + using ProxyGridT = nanovdb::tools::build::Grid; + + const torch::TensorAccessor &coordsAcc = + jaggedCoords.jdata().accessor(); + const torch::TensorAccessor &coordsBOffsetsAcc = + jaggedCoords.joffsets().accessor(); + + std::vector> batchHandles; + batchHandles.reserve(coordsBOffsetsAcc.size(0) - 1); + for (int bi = 0; bi < (coordsBOffsetsAcc.size(0) - 1); bi += 1) { + auto proxyGrid = std::make_shared(-1.0f); + auto proxyGridAccessor = proxyGrid->getWriteAccessor(); + + const int64_t start = coordsBOffsetsAcc[bi]; + const int64_t end = coordsBOffsetsAcc[bi + 1]; + + for (unsigned ci = start; ci < end; ci += 1) { + nanovdb::Coord ijk0(coordsAcc[ci][0], coordsAcc[ci][1], coordsAcc[ci][2]); + + // Splat the normal to the 8 neighboring voxels + for (int di = bmin[0]; di <= bmax[0]; di += 1) { + for (int dj = bmin[1]; dj <= bmax[1]; dj += 1) { + for (int dk = bmin[2]; dk <= bmax[2]; dk += 1) { + const nanovdb::Coord ijk = ijk0 + nanovdb::Coord(di, dj, dk); + proxyGridAccessor.setValue(ijk, 11); + } } } } - } - proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); - ret.buffer().setDevice(torch::kCPU, true); - batchHandles.push_back(std::move(ret)); - } + proxyGridAccessor.merge(); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); + ret.buffer().setDevice(torch::kCPU, true); + batchHandles.push_back(std::move(ret)); + } - if (batchHandles.size() == 1) { - return std::move(batchHandles[0]); - } else { - return nanovdb::mergeGrids(batchHandles); - } - }); + if (batchHandles.size() == 1) { + return std::move(batchHandles[0]); + } else { + return nanovdb::mergeGrids(batchHandles); + } + }); } - -nanovdb::GridHandle buildPaddedGridFromCoords(bool isMutable, - const JaggedTensor& coords, - const nanovdb::Coord& bmin, - const nanovdb::Coord& bmax) { +nanovdb::GridHandle +buildPaddedGridFromCoords(bool isMutable, const JaggedTensor &coords, const nanovdb::Coord &bmin, + const nanovdb::Coord &bmax) { if (coords.device().is_cuda()) { - JaggedTensor buildCoords = ops::dispatchPaddedIJKForCoords(coords, bmin, bmax); + JaggedTensor buildCoords = + ops::dispatchPaddedIJKForCoords(coords, bmin, bmax); return ops::dispatchCreateNanoGridFromIJK(buildCoords, isMutable); } else { return FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { diff --git a/fvdb/src/detail/build/PaddedGridFromGrid.cpp b/fvdb/src/detail/build/PaddedGridFromGrid.cpp index 3108545806..4f78ee99f7 100644 --- a/fvdb/src/detail/build/PaddedGridFromGrid.cpp +++ b/fvdb/src/detail/build/PaddedGridFromGrid.cpp @@ -3,48 +3,47 @@ // #include "Build.h" +#include +#include + #include -#include #include - -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { - template -nanovdb::GridHandle buildPaddedGridFromGridWithoutBorderCPU(const GridBatchImpl& baseBatchHdl, int BMIN, int BMAX) { +nanovdb::GridHandle +buildPaddedGridFromGridWithoutBorderCPU(const GridBatchImpl &baseBatchHdl, int BMIN, int BMAX) { TORCH_CHECK(BMIN <= BMAX, "BMIN must be less than BMAX"); - const nanovdb::GridHandle& baseGridHdl = baseBatchHdl.nanoGridHandle(); + const nanovdb::GridHandle &baseGridHdl = baseBatchHdl.nanoGridHandle(); std::vector> batchHandles; batchHandles.reserve(baseGridHdl.gridCount()); for (uint32_t bidx = 0; bidx < baseGridHdl.gridCount(); bidx += 1) { - - const nanovdb::NanoGrid* baseGrid = baseGridHdl.template grid(bidx); + const nanovdb::NanoGrid *baseGrid = baseGridHdl.template grid(bidx); if (!baseGrid) { throw std::runtime_error("Failed to get pointer to nanovdb index grid"); } auto baseGridAccessor = baseGrid->getAccessor(); - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(-1.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(-1.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); for (auto it = ActiveVoxelIterator(baseGrid->tree()); it.isValid(); it++) { - nanovdb::Coord ijk0 = it->first; - bool active = true; + nanovdb::Coord ijk0 = it->first; + bool active = true; for (int di = BMIN; di <= BMAX && active; di += 1) { for (int dj = BMIN; dj <= BMAX && active; dj += 1) { for (int dk = BMIN; dk <= BMAX && active; dk += 1) { const nanovdb::Coord ijk = ijk0 + nanovdb::Coord(di, dj, dk); if (ijk != ijk0) { - active = active && baseGridAccessor.isActive(ijk); // if any surrounding is off, turn it off. + active = active && baseGridAccessor.isActive( + ijk); // if any surrounding is off, turn it off. } } } @@ -55,7 +54,8 @@ nanovdb::GridHandle buildPaddedGridFromGridWithoutBorderCPU(c } proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); ret.buffer().setDevice(torch::kCPU, true); batchHandles.push_back(std::move(ret)); } @@ -67,25 +67,23 @@ nanovdb::GridHandle buildPaddedGridFromGridWithoutBorderCPU(c } } - - template -nanovdb::GridHandle buildPaddedGridFromGridCPU(const GridBatchImpl& baseBatchHdl, int BMIN, int BMAX) { +nanovdb::GridHandle +buildPaddedGridFromGridCPU(const GridBatchImpl &baseBatchHdl, int BMIN, int BMAX) { TORCH_CHECK(BMIN <= BMAX, "BMIN must be less than BMAX"); - const nanovdb::GridHandle& baseGridHdl = baseBatchHdl.nanoGridHandle(); + const nanovdb::GridHandle &baseGridHdl = baseBatchHdl.nanoGridHandle(); std::vector> batchHandles; batchHandles.reserve(baseGridHdl.gridCount()); for (uint32_t bidx = 0; bidx < baseGridHdl.gridCount(); bidx += 1) { - - const nanovdb::NanoGrid* baseGrid = baseGridHdl.template grid(bidx); + const nanovdb::NanoGrid *baseGrid = baseGridHdl.template grid(bidx); if (!baseGrid) { throw std::runtime_error("Failed to get pointer to nanovdb index grid"); } - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(-1.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(-1.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); for (auto it = ActiveVoxelIterator(baseGrid->tree()); it.isValid(); it++) { @@ -101,7 +99,8 @@ nanovdb::GridHandle buildPaddedGridFromGridCPU(const GridBatc } proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); ret.buffer().setDevice(torch::kCPU, true); batchHandles.push_back(std::move(ret)); } @@ -113,16 +112,17 @@ nanovdb::GridHandle buildPaddedGridFromGridCPU(const GridBatc } } - -nanovdb::GridHandle buildPaddedGridFromGrid(bool isMutable, - const GridBatchImpl& baseBatchHdl, - int bmin, int bmax, bool excludeBorder) { +nanovdb::GridHandle +buildPaddedGridFromGrid(bool isMutable, const GridBatchImpl &baseBatchHdl, int bmin, int bmax, + bool excludeBorder) { if (baseBatchHdl.device().is_cuda()) { JaggedTensor coords; if (excludeBorder) { - coords = ops::dispatchPaddedIJKForGridWithoutBorder(baseBatchHdl, nanovdb::Coord(bmin), nanovdb::Coord(bmax)); + coords = ops::dispatchPaddedIJKForGridWithoutBorder( + baseBatchHdl, nanovdb::Coord(bmin), nanovdb::Coord(bmax)); } else { - coords = ops::dispatchPaddedIJKForGrid(baseBatchHdl, nanovdb::Coord(bmin), nanovdb::Coord(bmax)); + coords = ops::dispatchPaddedIJKForGrid(baseBatchHdl, nanovdb::Coord(bmin), + nanovdb::Coord(bmax)); } return ops::dispatchCreateNanoGridFromIJK(coords, isMutable); } else { @@ -136,9 +136,6 @@ nanovdb::GridHandle buildPaddedGridFromGrid(bool isMutable, } } - - - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/build/PaddedGridFromPoints.cpp b/fvdb/src/detail/build/PaddedGridFromPoints.cpp index 55a0871686..2c4219f185 100644 --- a/fvdb/src/detail/build/PaddedGridFromPoints.cpp +++ b/fvdb/src/detail/build/PaddedGridFromPoints.cpp @@ -3,97 +3,96 @@ // #include "Build.h" +#include +#include + #include -#include #include - -#include "detail/utils/Utils.h" -#include "detail/ops/Ops.h" - +#include namespace fvdb { namespace detail { namespace build { - template -nanovdb::GridHandle buildPaddedGridFromPointsCPU(const JaggedTensor& pointsJagged, - const std::vector& txs, - const nanovdb::Coord& bmin, - const nanovdb::Coord& bmax) { - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(pointsJagged.scalar_type(), "buildPaddedGridFromPoints", [&](){ - using ScalarT = scalar_t; - static_assert(is_floating_point_or_half::value, "Invalid type for points, must be floating point"); - using MathT = typename at::opmath_type; - using ProxyGridT = nanovdb::tools::build::Grid; - - pointsJagged.check_valid(); - - const torch::TensorAccessor& pointsAcc = pointsJagged.jdata().accessor(); - const torch::TensorAccessor& pointsBOffsetsAcc = pointsJagged.joffsets().accessor(); - - std::vector> batchHandles; - batchHandles.reserve(pointsBOffsetsAcc.size(0) - 1); - for (int bi = 0; bi < (pointsBOffsetsAcc.size(0) - 1); bi += 1) { - VoxelCoordTransform tx = txs[bi]; - - auto proxyGrid = std::make_shared(-1.0f); - auto proxyGridAccessor = proxyGrid->getWriteAccessor(); - - const int64_t start = pointsBOffsetsAcc[bi]; - const int64_t end = pointsBOffsetsAcc[bi+1]; - - for (int64_t pi = start; pi < end; pi += 1) { - - nanovdb::Coord ijk0 = tx.apply(static_cast(pointsAcc[pi][0]), - static_cast(pointsAcc[pi][1]), - static_cast(pointsAcc[pi][2])).round(); - - // Splat the normal to the 8 neighboring voxels - for (int di = bmin[0]; di <= bmax[0]; di += 1) { - for (int dj = bmin[1]; dj <= bmax[1]; dj += 1) { - for (int dk = bmin[2]; dk <= bmax[2]; dk += 1) { - const nanovdb::Coord ijk = ijk0 + nanovdb::Coord(di, dj, dk); - proxyGridAccessor.setValue(ijk, 1.0f); +nanovdb::GridHandle +buildPaddedGridFromPointsCPU(const JaggedTensor &pointsJagged, + const std::vector &txs, + const nanovdb::Coord &bmin, const nanovdb::Coord &bmax) { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF( + pointsJagged.scalar_type(), "buildPaddedGridFromPoints", [&]() { + using ScalarT = scalar_t; + static_assert(is_floating_point_or_half::value, + "Invalid type for points, must be floating point"); + using MathT = typename at::opmath_type; + using ProxyGridT = nanovdb::tools::build::Grid; + + pointsJagged.check_valid(); + + const torch::TensorAccessor &pointsAcc = + pointsJagged.jdata().accessor(); + const torch::TensorAccessor &pointsBOffsetsAcc = + pointsJagged.joffsets().accessor(); + + std::vector> batchHandles; + batchHandles.reserve(pointsBOffsetsAcc.size(0) - 1); + for (int bi = 0; bi < (pointsBOffsetsAcc.size(0) - 1); bi += 1) { + VoxelCoordTransform tx = txs[bi]; + + auto proxyGrid = std::make_shared(-1.0f); + auto proxyGridAccessor = proxyGrid->getWriteAccessor(); + + const int64_t start = pointsBOffsetsAcc[bi]; + const int64_t end = pointsBOffsetsAcc[bi + 1]; + + for (int64_t pi = start; pi < end; pi += 1) { + nanovdb::Coord ijk0 = tx.apply(static_cast(pointsAcc[pi][0]), + static_cast(pointsAcc[pi][1]), + static_cast(pointsAcc[pi][2])) + .round(); + + // Splat the normal to the 8 neighboring voxels + for (int di = bmin[0]; di <= bmax[0]; di += 1) { + for (int dj = bmin[1]; dj <= bmax[1]; dj += 1) { + for (int dk = bmin[2]; dk <= bmax[2]; dk += 1) { + const nanovdb::Coord ijk = ijk0 + nanovdb::Coord(di, dj, dk); + proxyGridAccessor.setValue(ijk, 1.0f); + } } } } + + proxyGridAccessor.merge(); + auto ret = nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); + ret.buffer().setDevice(torch::kCPU, true); + batchHandles.push_back(std::move(ret)); } - proxyGridAccessor.merge(); - auto ret = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); - ret.buffer().setDevice(torch::kCPU, true); - batchHandles.push_back(std::move(ret)); - } - - if (batchHandles.size() == 1) { - return std::move(batchHandles[0]); - } else { - return nanovdb::mergeGrids(batchHandles); - } - }); + if (batchHandles.size() == 1) { + return std::move(batchHandles[0]); + } else { + return nanovdb::mergeGrids(batchHandles); + } + }); } - -nanovdb::GridHandle buildPaddedGridFromPoints(bool isMutable, - const JaggedTensor& points, - const std::vector& txs, - const nanovdb::Coord& bmin, - const nanovdb::Coord& bmax) { +nanovdb::GridHandle +buildPaddedGridFromPoints(bool isMutable, const JaggedTensor &points, + const std::vector &txs, const nanovdb::Coord &bmin, + const nanovdb::Coord &bmax) { if (points.device().is_cuda()) { - JaggedTensor coords = ops::dispatchPaddedIJKForPoints(points, bmin, bmax, txs); + JaggedTensor coords = + ops::dispatchPaddedIJKForPoints(points, bmin, bmax, txs); return ops::dispatchCreateNanoGridFromIJK(coords, isMutable); } else { - return FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { return buildPaddedGridFromPointsCPU(points, txs, bmin, bmax); }); } } - - } // namespace build } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/io/IO.h b/fvdb/src/detail/io/IO.h index da2fd5fedc..64efa98854 100644 --- a/fvdb/src/detail/io/IO.h +++ b/fvdb/src/detail/io/IO.h @@ -1,42 +1,46 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include -#include +#ifndef FVDB_DETAIL_IO_IO_H +#define FVDB_DETAIL_IO_IO_H -#include "GridBatch.h" +#include +#include -#include "Types.h" +#include +#include namespace fvdb { namespace detail { namespace io { std::tuple> -fromNVDB(nanovdb::GridHandle& handle, - const torch::optional maybeDevice = torch::optional()); +fromNVDB(nanovdb::GridHandle &handle, + const torch::optional maybeDevice = + torch::optional()); std::tuple> -fromNVDB(const std::vector>& handles, - const torch::optional maybeDevice = torch::optional()); +fromNVDB(const std::vector> &handles, + const torch::optional maybeDevice = + torch::optional()); nanovdb::GridHandle -toNVDB(const GridBatch& gridBatch, - const torch::optional maybeData = torch::optional(), - const torch::optional maybeNames = torch::optional()); +toNVDB(const GridBatch &gridBatch, + const torch::optional maybeData = torch::optional(), + const torch::optional maybeNames = + torch::optional()); std::tuple> -loadNVDB(const std::string& path, - const fvdb::NanoVDBFileGridIdentifier& gridIdentifier, - fvdb::TorchDeviceOrString device, - bool verbose); - -void saveNVDB(const std::string& path, - const GridBatch& gridBatch, - const torch::optional maybeData, +loadNVDB(const std::string &path, const fvdb::NanoVDBFileGridIdentifier &gridIdentifier, + fvdb::TorchDeviceOrString device, bool verbose); + +void saveNVDB(const std::string &path, const GridBatch &gridBatch, + const torch::optional maybeData, const torch::optional maybeNames, bool compressed = false, bool verbose = false); -} // namespace io -} // namespace detail -} // namespace fvdb +} // namespace io +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_IO_IO_H \ No newline at end of file diff --git a/fvdb/src/detail/io/LoadNanovdb.cpp b/fvdb/src/detail/io/LoadNanovdb.cpp index e1dfe1017f..7e509aa605 100644 --- a/fvdb/src/detail/io/LoadNanovdb.cpp +++ b/fvdb/src/detail/io/LoadNanovdb.cpp @@ -1,95 +1,121 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include "detail/io/IO.h" +#include "IO.h" -#include +#include +#include +#include #include #include -#include #include +#include -#include "Types.h" -#include "detail/utils/Utils.h" -#include "detail/GridBatchImpl.h" - +#include namespace fvdb { namespace detail { namespace io { -/// @brief Get the gridId^th grid with build type SourceGrid in a grid handle and throw an exception if the grid is none +/// @brief Get the gridId^th grid with build type SourceGrid in a grid handle and throw an exception +/// if the grid is none /// @tparam GridType The build type of the grid to read /// @param handle The grid handle to read from /// @param gridId The index of the grid in the handle to read /// @param bi The batch index of the grid in the handle to read (this is only used for logging) /// @return A host pointer to the extracted grid template -const nanovdb::NanoGrid* getGrid(const nanovdb::GridHandle& handle, uint32_t gridId, uint32_t bi) { - const nanovdb::NanoGrid* grid = handle.grid(gridId); - char gridTypeStr[nanovdb::strlen()]; +const nanovdb::NanoGrid * +getGrid(const nanovdb::GridHandle &handle, uint32_t gridId, uint32_t bi) { + const nanovdb::NanoGrid *grid = handle.grid(gridId); + char gridTypeStr[nanovdb::strlen()]; nanovdb::toStr(gridTypeStr, handle.gridType(gridId)); char expectedGridTypeStr[nanovdb::strlen()]; nanovdb::toStr(expectedGridTypeStr, nanovdb::toGridType()); TORCH_CHECK(gridId < handle.gridCount(), - "Failed to load grid " + std::to_string(gridId) + " from handle at batch index " + std::to_string(bi) + - std::string(". Grid index out of bounds.")); - TORCH_CHECK(grid != nullptr, - "Failed to load grid " + std::to_string(gridId) + " from handle at batch index " + std::to_string(bi) + - std::string(". Grid has type ") + gridTypeStr + - std::string(", but expected ") + expectedGridTypeStr + "."); + "Failed to load grid " + std::to_string(gridId) + " from handle at batch index " + + std::to_string(bi) + std::string(". Grid index out of bounds.")); + TORCH_CHECK(grid != nullptr, "Failed to load grid " + std::to_string(gridId) + + " from handle at batch index " + std::to_string(bi) + + std::string(". Grid has type ") + gridTypeStr + + std::string(", but expected ") + expectedGridTypeStr + "."); return grid; } - /// @brief Set the (row) value at index rowIdx of a tensor with 2 dimensions. /// Specialized to accept useful nanovdb types (e.g. Vec3f, Vec4f, etc...) -/// @tparam TensorAccessorT The type of tensor accessor to use (e.g. torch::TensorAccessor, torch::PackedTensorAccessor) -/// @tparam ValueT The input type of the row to write to the tensor (e.g. float, nanovdb::Vec3f, nanovdb::Vec4f) +/// @tparam TensorAccessorT The type of tensor accessor to use (e.g. torch::TensorAccessor, +/// torch::PackedTensorAccessor) +/// @tparam ValueT The input type of the row to write to the tensor (e.g. float, nanovdb::Vec3f, +/// nanovdb::Vec4f) /// @param acc The accessor to the tensor (must refer to a 2D tensor) /// @param rowIdx The row to read from /// @return The rowIdx^th row of the tensor casted to ValueT template -inline void valueSetter(TensorAccessorT& acc, int idx, const ValueT& value) { +inline void +valueSetter(TensorAccessorT &acc, int idx, const ValueT &value) { acc[idx][0] = value; } template -inline void valueSetter(TensorAccessorT& acc, int idx, const nanovdb::Vec3f& value) { - acc[idx][0] = value[0]; acc[idx][1] = value[1]; acc[idx][2] = value[2]; +inline void +valueSetter(TensorAccessorT &acc, int idx, const nanovdb::Vec3f &value) { + acc[idx][0] = value[0]; + acc[idx][1] = value[1]; + acc[idx][2] = value[2]; } template -inline void valueSetter(TensorAccessorT& acc, int idx, const nanovdb::Vec3d& value) { - acc[idx][0] = value[0]; acc[idx][1] = value[1]; acc[idx][2] = value[2]; +inline void +valueSetter(TensorAccessorT &acc, int idx, const nanovdb::Vec3d &value) { + acc[idx][0] = value[0]; + acc[idx][1] = value[1]; + acc[idx][2] = value[2]; } template -inline void valueSetter(TensorAccessorT& acc, int idx, const nanovdb::Vec4f& value) { - acc[idx][0] = value[0]; acc[idx][1] = value[1]; acc[idx][2] = value[2]; acc[idx][3] = value[3]; +inline void +valueSetter(TensorAccessorT &acc, int idx, const nanovdb::Vec4f &value) { + acc[idx][0] = value[0]; + acc[idx][1] = value[1]; + acc[idx][2] = value[2]; + acc[idx][3] = value[3]; } template -inline void valueSetter(TensorAccessorT& acc, int idx, const nanovdb::Vec4d& value) { - acc[idx][0] = value[0]; acc[idx][1] = value[1]; acc[idx][2] = value[2]; acc[idx][3] = value[3]; +inline void +valueSetter(TensorAccessorT &acc, int idx, const nanovdb::Vec4d &value) { + acc[idx][0] = value[0]; + acc[idx][1] = value[1]; + acc[idx][2] = value[2]; + acc[idx][3] = value[3]; } template -inline void valueSetter(TensorAccessorT& acc, int idx, const nanovdb::math::Rgba8& value) { - acc[idx][0] = value.r(); acc[idx][1] = value.g(); acc[idx][2] = value.b(); acc[idx][3] = value.a(); +inline void +valueSetter(TensorAccessorT &acc, int idx, const nanovdb::math::Rgba8 &value) { + acc[idx][0] = value.r(); + acc[idx][1] = value.g(); + acc[idx][2] = value.b(); + acc[idx][3] = value.a(); } /// @brief Return whether a nanovdb blind metadata is a valid FVDB tensor grid blind metadata, /// and if so, what the dtype is (if any). -/// FVDB Blind data is named "fvdb_jdata" where dtype is an optional dtype name. If no dtype is specified, -/// then the blind data just records the size of the tensor, and the scalar type should be determinied from the -/// grid type itself (e.g. Vec3f grids will have a float32 scalar type). +/// FVDB Blind data is named "fvdb_jdata" where dtype is an optional dtype name. If no +/// dtype is specified, then the blind data just records the size of the tensor, and the +/// scalar type should be determinied from the grid type itself (e.g. Vec3f grids will have a +/// float32 scalar type). /// @param blindMetadata The blind metadata to check -/// @return A tuple containing whether the blind metadata is valid, and the dtype of the tensor (or None if no dtype is specified) -std::tuple> isFvdbBlindData(const nanovdb::GridBlindMetaData& blindMetadata) { - if(strncmp(blindMetadata.mName, "fvdb_jdata", 10) != 0) { +/// @return A tuple containing whether the blind metadata is valid, and the dtype of the tensor (or +/// None if no dtype is specified) +std::tuple> +isFvdbBlindData(const nanovdb::GridBlindMetaData &blindMetadata) { + if (strncmp(blindMetadata.mName, "fvdb_jdata", 10) != 0) { return std::make_tuple(false, torch::nullopt); } // Check if we load the dtype name, we won't overrun the buffer - const int64_t blindDataNameLen = strnlen(blindMetadata.mName, nanovdb::GridBlindMetaData::MaxNameSize); - TORCH_CHECK(blindDataNameLen < nanovdb::GridBlindMetaData::MaxNameSize, "Invalid blind metadata for nanovdb grid."); + const int64_t blindDataNameLen = + strnlen(blindMetadata.mName, nanovdb::GridBlindMetaData::MaxNameSize); + TORCH_CHECK(blindDataNameLen < nanovdb::GridBlindMetaData::MaxNameSize, + "Invalid blind metadata for nanovdb grid."); // There's no scalar type specified -- we're just storing a size of the tensor if (blindDataNameLen == 10) { @@ -97,192 +123,238 @@ std::tuple> isFvdbBlindData(const nanovdb::G } // Get the dtype of the blind data tensor - const std::string blindDtypeName = std::string(blindMetadata.mName + 10); - const torch::Dtype blindDtype = StringToTorchScalarType(blindDtypeName); + const std::string blindDtypeName = std::string(blindMetadata.mName + 10); + const torch::Dtype blindDtype = StringToTorchScalarType(blindDtypeName); return std::make_tuple(true, torch::optional(blindDtype)); } - -/// @brief Copy a source index grid (ValueIndex(Mask) or ValueOnIndex(Mask)) to a nanovdb::GridHandle. -/// If the source type is ValueIndex or ValueIndex mask it will be set to ValueOnIndex or ValueOnIndexMask respectively. -/// @tparam SourceGridType The type of the source grid (must be a nanovdb::ValueIndex or nanovdb::ValueIndexMask) +/// @brief Copy a source index grid (ValueIndex(Mask) or ValueOnIndex(Mask)) to a +/// nanovdb::GridHandle. +/// If the source type is ValueIndex or ValueIndex mask it will be set to ValueOnIndex or +/// ValueOnIndexMask respectively. +/// @tparam SourceGridType The type of the source grid (must be a nanovdb::ValueIndex or +/// nanovdb::ValueIndexMask) /// @tparam TargetGridType The type of the target grid (must be a form of index grid) /// @param sourceGrid A host pointer to the source grid to copy /// @return A handle to the copied grid template -nanovdb::GridHandle copyIndexGridToHandle(const nanovdb::NanoGrid* sourceGrid) { - constexpr bool isSrcValueOnIndex = nanovdb::util::is_same::value; - constexpr bool isSrcValueOnIndexMask = nanovdb::util::is_same::value; - constexpr bool isSrcValueIndex = nanovdb::util::is_same::value; - constexpr bool isSrcValueIndexMask = nanovdb::util::is_same::value; - constexpr bool isTgtValueOnIndex = nanovdb::util::is_same::value; - constexpr bool isTgtValueOnIndexMask = nanovdb::util::is_same::value; - - static_assert(isSrcValueOnIndex || isSrcValueOnIndexMask || isSrcValueIndex || isSrcValueIndexMask, +nanovdb::GridHandle +copyIndexGridToHandle(const nanovdb::NanoGrid *sourceGrid) { + constexpr bool isSrcValueOnIndex = + nanovdb::util::is_same::value; + constexpr bool isSrcValueOnIndexMask = + nanovdb::util::is_same::value; + constexpr bool isSrcValueIndex = + nanovdb::util::is_same::value; + constexpr bool isSrcValueIndexMask = + nanovdb::util::is_same::value; + constexpr bool isTgtValueOnIndex = + nanovdb::util::is_same::value; + constexpr bool isTgtValueOnIndexMask = + nanovdb::util::is_same::value; + + static_assert(isSrcValueOnIndex || isSrcValueOnIndexMask || isSrcValueIndex || + isSrcValueIndexMask, "Bad source type in copyIndexGridToHandle must be an Index grid type."); - static_assert(isTgtValueOnIndex || isTgtValueOnIndexMask, - "Bad target type in copyIndexGridToHandle must be ValueOnIndex or ValueOnIndexMask."); - static_assert((isTgtValueOnIndex && (isSrcValueIndex || isSrcValueOnIndex)) || - (isTgtValueOnIndexMask && (isSrcValueIndexMask || isSrcValueOnIndexMask)), - "Bad target grid type for given source grid type in copyIndexGridToHandle. If source is a masked grid, then target must also be a masked grid."); - - const ptrdiff_t gridSize = sourceGrid->blindDataCount() > 0 ? nanovdb::util::PtrDiff(&sourceGrid->blindMetaData(0), sourceGrid) : sourceGrid->gridSize(); + static_assert( + isTgtValueOnIndex || isTgtValueOnIndexMask, + "Bad target type in copyIndexGridToHandle must be ValueOnIndex or ValueOnIndexMask."); + static_assert( + (isTgtValueOnIndex && (isSrcValueIndex || isSrcValueOnIndex)) || + (isTgtValueOnIndexMask && (isSrcValueIndexMask || isSrcValueOnIndexMask)), + "Bad target grid type for given source grid type in copyIndexGridToHandle. If source is a masked grid, then target must also be a masked grid."); + + const ptrdiff_t gridSize = + sourceGrid->blindDataCount() > 0 + ? nanovdb::util::PtrDiff(&sourceGrid->blindMetaData(0), sourceGrid) + : sourceGrid->gridSize(); TorchDeviceBuffer buf(gridSize); memcpy(buf.data(), sourceGrid, gridSize); - nanovdb::GridData* data = reinterpret_cast(buf.data()); - data->mGridCount = 1; - data->mGridSize = gridSize; - data->mGridClass = nanovdb::GridClass::IndexGrid; - data->mGridType = nanovdb::toGridType(); + nanovdb::GridData *data = reinterpret_cast(buf.data()); + data->mGridCount = 1; + data->mGridSize = gridSize; + data->mGridClass = nanovdb::GridClass::IndexGrid; + data->mGridType = nanovdb::toGridType(); return nanovdb::GridHandle(std::move(buf)); } - -/// @brief Load a nanovdb ValueOnIndex or ValueOnIndexMask grid with tensor blind metatada (GridClass = TensorGrid) into +/// @brief Load a nanovdb ValueOnIndex or ValueOnIndexMask grid with tensor blind metatada +/// (GridClass = TensorGrid) into /// an index grid of the same type stored in a TorchDeviceBuffer) and a torch tensor of data /// (i.e. the standard grid format for FVDB). -/// @tparam SourceGridType The type of the source grid (must be a nanovdb::ValueOnIndex or nanovdb::ValueOnIndexMask) +/// @tparam SourceGridType The type of the source grid (must be a nanovdb::ValueOnIndex or +/// nanovdb::ValueOnIndexMask) /// @tparam TargetGridType The type of the target grid (must be a form of index grid) /// @param sourceGrid A host pointer to the source grid to load -/// @return A tuple containing the index grid, the name of the grid, the tensor of data, the voxel size, and the voxel origin +/// @return A tuple containing the index grid, the name of the grid, the tensor of data, the voxel +/// size, and the voxel origin template -std::tuple, std::string, torch::Tensor, nanovdb::Vec3d, nanovdb::Vec3d> -nanovdbTensorGridToFVDBGrid(const nanovdb::NanoGrid* sourceGrid) { - static_assert(nanovdb::util::is_same::value || - nanovdb::util::is_same::value, - "Bad source grid type in nanovdbTensorGridToFVDBGrid. Must be ValueOnIndex or ValueOnIndexMask."); - static_assert(nanovdb::util::is_same::value || - nanovdb::util::is_same::value, - "Bad target grid type in nanovdbTensorGridToFVDBGrid. Must be ValueOnIndex or ValueOnIndexMask."); - static_assert(nanovdb::util::is_same::value, - "Mismatched source and target grid types in nanovdbTensorGridToFVDBGrid. They must be identical."); - - TORCH_CHECK(sourceGrid->gridClass() == nanovdb::GridClass::TensorGrid, "Invalid grid class: Index grids which are not saved with fVDB are not yet supported."); +std::tuple, std::string, torch::Tensor, nanovdb::Vec3d, + nanovdb::Vec3d> +nanovdbTensorGridToFVDBGrid(const nanovdb::NanoGrid *sourceGrid) { + static_assert( + nanovdb::util::is_same::value || + nanovdb::util::is_same::value, + "Bad source grid type in nanovdbTensorGridToFVDBGrid. Must be ValueOnIndex or ValueOnIndexMask."); + static_assert( + nanovdb::util::is_same::value || + nanovdb::util::is_same::value, + "Bad target grid type in nanovdbTensorGridToFVDBGrid. Must be ValueOnIndex or ValueOnIndexMask."); + static_assert( + nanovdb::util::is_same::value, + "Mismatched source and target grid types in nanovdbTensorGridToFVDBGrid. They must be identical."); + + TORCH_CHECK( + sourceGrid->gridClass() == nanovdb::GridClass::TensorGrid, + "Invalid grid class: Index grids which are not saved with fVDB are not yet supported."); // Copy the index grid from the loaded buffer and update metadata to be consisten with FVDB - nanovdb::GridHandle retHandle = copyIndexGridToHandle(sourceGrid); + nanovdb::GridHandle retHandle = + copyIndexGridToHandle(sourceGrid); // Check if this grid has FVDB blind data attached to it - bool foundFVDB = false; + bool foundFVDB = false; torch::Dtype blindDtype; for (unsigned i = 0; i < sourceGrid->blindDataCount(); i += 1) { - const nanovdb::GridBlindMetaData& blindMetadata = sourceGrid->blindMetaData(i); + const nanovdb::GridBlindMetaData &blindMetadata = sourceGrid->blindMetaData(i); // Don't need to warn for grid name if (blindMetadata.mDataClass == nanovdb::GridBlindDataClass::GridName) { continue; } - std::tuple> isFvdb = isFvdbBlindData(sourceGrid->blindMetaData(0)); + std::tuple> isFvdb = + isFvdbBlindData(sourceGrid->blindMetaData(0)); if (std::get<0>(isFvdb)) { - TORCH_CHECK(!foundFVDB, "Internal Error: Grid has multiple FVDB blind data tensors. Only one is supported."); - TORCH_CHECK(std::get<1>(isFvdb).has_value(), "Invalid blind metadata for nanovdb Tensor grid."); - foundFVDB = true; + TORCH_CHECK( + !foundFVDB, + "Internal Error: Grid has multiple FVDB blind data tensors. Only one is supported."); + TORCH_CHECK(std::get<1>(isFvdb).has_value(), + "Invalid blind metadata for nanovdb Tensor grid."); + foundFVDB = true; blindDtype = std::get<1>(isFvdb).value(); } else { - TORCH_WARN("Grid has blind data, but it is not valid FVDB blind data. Blind data will be ignored."); + TORCH_WARN( + "Grid has blind data, but it is not valid FVDB blind data. Blind data will be ignored."); } } - // If there is no FVDB blind data, this is just an index grid, so just return an empty data tensor + // If there is no FVDB blind data, this is just an index grid, so just return an empty data + // tensor if (!foundFVDB) { - return std::make_tuple(std::move(retHandle), - sourceGrid->gridName(), - torch::empty({0}), + return std::make_tuple(std::move(retHandle), sourceGrid->gridName(), torch::empty({ 0 }), sourceGrid->data()->mVoxelSize, sourceGrid->data()->mMap.applyMap(nanovdb::Vec3d(0.0))); } // Pointer to actual blind data - uint8_t* readHead = (uint8_t*)(sourceGrid->blindMetaData(0).blindData()); + uint8_t *readHead = (uint8_t *)(sourceGrid->blindMetaData(0).blindData()); // Read the shape of the tensor - const int64_t ndim = *reinterpret_cast(readHead); + const int64_t ndim = *reinterpret_cast(readHead); readHead += sizeof(int64_t); std::vector blindDataShape; blindDataShape.reserve(ndim); for (int i = 0; i < ndim; i++) { - blindDataShape.push_back(*reinterpret_cast(readHead)); + blindDataShape.push_back(*reinterpret_cast(readHead)); readHead += sizeof(int64_t); } // Copy the blind data tensor - torch::Tensor retData = torch::from_blob(const_cast(readHead), blindDataShape, blindDtype).clone(); + torch::Tensor retData = + torch::from_blob(const_cast(readHead), blindDataShape, blindDtype).clone(); // Load the name and the transform - const std::string name = sourceGrid->gridName(); - const nanovdb::Vec3d voxSize = sourceGrid->mVoxelSize; + const std::string name = sourceGrid->gridName(); + const nanovdb::Vec3d voxSize = sourceGrid->mVoxelSize; const nanovdb::Vec3d voxOrigin = sourceGrid->mMap.applyMap(nanovdb::Vec3d(0.0)); return std::make_tuple(std::move(retHandle), name, retData, voxSize, voxOrigin); } -/// @brief Load a nanovdb index grid (ValueOnIndex(Mask) or ValueIndex(Mask)) into an ValueOnIndex or ValueIndex grid -/// (stored in a TorchDeviceBuffer) and an empty tensor of data (i.e. the standard grid format for FVDB). +/// @brief Load a nanovdb index grid (ValueOnIndex(Mask) or ValueIndex(Mask)) into an ValueOnIndex +/// or ValueIndex grid +/// (stored in a TorchDeviceBuffer) and an empty tensor of data (i.e. the standard grid +/// format for FVDB). /// @tparam SourceGridType The type of the source grid (must not be an index grid) -/// @tparam TargetGridType The type of the target grid (must be a nanovdb::ValueOnIndex or nanovdb::ValueOnIndexMask) +/// @tparam TargetGridType The type of the target grid (must be a nanovdb::ValueOnIndex or +/// nanovdb::ValueOnIndexMask) /// @param sourceGrid A host pointer to the source grid to load -/// @return A tuple containing the index grid, the name of the grid, the empty tensor of data, the voxel size, and the voxel origin +/// @return A tuple containing the index grid, the name of the grid, the empty tensor of data, the +/// voxel size, and the voxel origin template -std::tuple, std::string, torch::Tensor, nanovdb::Vec3d, nanovdb::Vec3d> -nanovdbIndexGridToFVDBGrid(const nanovdb::NanoGrid* sourceGrid) { - nanovdb::GridHandle retHandle = copyIndexGridToHandle(sourceGrid); - const std::string name = sourceGrid->gridName(); - const nanovdb::Vec3d voxSize = sourceGrid->data()->mVoxelSize; +std::tuple, std::string, torch::Tensor, nanovdb::Vec3d, + nanovdb::Vec3d> +nanovdbIndexGridToFVDBGrid(const nanovdb::NanoGrid *sourceGrid) { + nanovdb::GridHandle retHandle = + copyIndexGridToHandle(sourceGrid); + const std::string name = sourceGrid->gridName(); + const nanovdb::Vec3d voxSize = sourceGrid->data()->mVoxelSize; const nanovdb::Vec3d voxOrigin = sourceGrid->data()->mMap.applyMap(nanovdb::Vec3d(0.0)); - return std::make_tuple(std::move(retHandle), name, torch::empty({0}), voxSize, voxOrigin); + return std::make_tuple(std::move(retHandle), name, torch::empty({ 0 }), voxSize, voxOrigin); } - -/// @brief Load a nanovdb grid with scalar or vector data stored in the leaves into a ValueOnIndex grid -/// (stored in a TorchDeviceBuffer) and a tensor of data (i.e. the standard grid format for FVDB). +/// @brief Load a nanovdb grid with scalar or vector data stored in the leaves into a ValueOnIndex +/// grid +/// (stored in a TorchDeviceBuffer) and a tensor of data (i.e. the standard grid format for +/// FVDB). /// @tparam SourceGridType The type of the source grid (must not be an index grid) -/// @tparam TargetGridType The type of the target grid (must be a nanovdb::ValueOnIndex or nanovdb::ValueOnIndexMask) +/// @tparam TargetGridType The type of the target grid (must be a nanovdb::ValueOnIndex or +/// nanovdb::ValueOnIndexMask) /// @tparam ScalarType The scalar type of data stored in the source grid /// @tparam DataDim The dimension of the data stored in the source grid /// @param sourceGrid A host pointer to the source grid to load -/// @return A tuple containing the index grid, the name of the grid, the tensor of data, the voxel size, and the voxel origin +/// @return A tuple containing the index grid, the name of the grid, the tensor of data, the voxel +/// size, and the voxel origin template -std::tuple, std::string, torch::Tensor, nanovdb::Vec3d, nanovdb::Vec3d> -nanovdbGridToFvdbGrid(const nanovdb::NanoGrid* sourceGrid) { +std::tuple, std::string, torch::Tensor, nanovdb::Vec3d, + nanovdb::Vec3d> +nanovdbGridToFvdbGrid(const nanovdb::NanoGrid *sourceGrid) { static_assert(nanovdb::util::is_same::value, "Bad target type in copyIndexGridToHandle must be ValueOnIndex."); - static_assert(!nanovdb::util::is_same::value && - !nanovdb::util::is_same::value && - !nanovdb::util::is_same::value && - !nanovdb::util::is_same::value, - "Bad source type in nanovdbGridToIndexGridAndData must NOT be an Index grid type."); + static_assert( + !nanovdb::util::is_same::value && + !nanovdb::util::is_same::value && + !nanovdb::util::is_same::value && + !nanovdb::util::is_same::value, + "Bad source type in nanovdbGridToIndexGridAndData must NOT be an Index grid type."); // Create the index grid for the loaded grid - using ProxyGridT = nanovdb::tools::build::Grid; - auto proxyGrid = std::make_shared(0.0f); + using ProxyGridT = nanovdb::tools::build::Grid; + auto proxyGrid = std::make_shared(0.0f); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); - for (auto it = ActiveVoxelIteratorIJKOnly(sourceGrid->tree()); it.isValid(); it++) { + for (auto it = ActiveVoxelIteratorIJKOnly(sourceGrid->tree()); it.isValid(); + it++) { proxyGridAccessor.setValue(*it, 1.0f); } proxyGridAccessor.merge(); - nanovdb::GridHandle retHandle = nanovdb::tools::createNanoGrid(*proxyGrid, 0u, false, false); - nanovdb::NanoGrid* outGrid = retHandle.template grid(); + nanovdb::GridHandle retHandle = + nanovdb::tools::createNanoGrid( + *proxyGrid, 0u, false, false); + nanovdb::NanoGrid *outGrid = retHandle.template grid(); TORCH_CHECK(outGrid != nullptr, "Internal error: failed to get outGrid."); - TORCH_CHECK(outGrid->gridClass() == nanovdb::GridClass::IndexGrid, "Internal error: outGrid is not an index grid."); - TORCH_CHECK(outGrid->gridType() == nanovdb::GridType::OnIndex || outGrid->gridType() == nanovdb::GridType::OnIndexMask, + TORCH_CHECK(outGrid->gridClass() == nanovdb::GridClass::IndexGrid, + "Internal error: outGrid is not an index grid."); + TORCH_CHECK(outGrid->gridType() == nanovdb::GridType::OnIndex || + outGrid->gridType() == nanovdb::GridType::OnIndexMask, "Internal error: outGrid is not an index grid."); // Load data at the voxels into a tensor - int64_t numVox = outGrid->activeVoxelCount(); - int64_t dim = DataDim; - torch::TensorOptions opts = torch::TensorOptions().device(torch::kCPU).dtype(); - torch::Tensor outData = torch::empty({numVox, dim}, opts); - auto outDataAcc = outData.accessor(); - auto sourceGridAccessor = sourceGrid->getAccessor(); + int64_t numVox = outGrid->activeVoxelCount(); + int64_t dim = DataDim; + torch::TensorOptions opts = torch::TensorOptions().device(torch::kCPU).dtype(); + torch::Tensor outData = torch::empty({ numVox, dim }, opts); + auto outDataAcc = outData.accessor(); + auto sourceGridAccessor = sourceGrid->getAccessor(); for (auto it = ActiveVoxelIterator(outGrid->tree()); it.isValid(); it++) { valueSetter(outDataAcc, it->second, sourceGridAccessor.getValue(it->first)); } // If there's extra blind data we need to load, check if any of it is FVDB blind data. - // We use FVDB blind data in save to store the shape of the tensor so we can load it back in the same shape - // the user saved it in. This lets us handle saving (N, 1), (1, N, 1), (N, )... shaped tensors properly. + // We use FVDB blind data in save to store the shape of the tensor so we can load it back in the + // same shape the user saved it in. This lets us handle saving (N, 1), (1, N, 1), (N, )... + // shaped tensors properly. bool foundFVDB = false; for (unsigned i = 0; i < sourceGrid->blindDataCount(); i += 1) { - const nanovdb::GridBlindMetaData& blindMetadata = sourceGrid->blindMetaData(i); + const nanovdb::GridBlindMetaData &blindMetadata = sourceGrid->blindMetaData(i); // Don't need to warn for grid name if (blindMetadata.mDataClass == nanovdb::GridBlindDataClass::GridName) { @@ -292,25 +364,30 @@ nanovdbGridToFvdbGrid(const nanovdb::NanoGrid* sourceGrid) { // Otherwise, check if this is an FVDB blind data tensor std::tuple> isFvdb = isFvdbBlindData(blindMetadata); if (!std::get<0>(isFvdb)) { - TORCH_WARN("Grid has blind data, but it is not valid FVDB blind data. Blind data will be ignored."); + TORCH_WARN( + "Grid has blind data, but it is not valid FVDB blind data. Blind data will be ignored."); } else { - TORCH_CHECK(!foundFVDB, "Internal Error: Grid has multiple FVDB blind data tensors. Only one is supported."); + TORCH_CHECK( + !foundFVDB, + "Internal Error: Grid has multiple FVDB blind data tensors. Only one is supported."); foundFVDB = true; - TORCH_CHECK(!std::get<1>(isFvdb).has_value(), - "Invalid FVDB blind metadata for nanovdb grid. Should not have extra type."); + TORCH_CHECK( + !std::get<1>(isFvdb).has_value(), + "Invalid FVDB blind metadata for nanovdb grid. Should not have extra type."); // Pointer to actual blind data - uint8_t* readHead = (uint8_t*)(sourceGrid->blindMetaData(0).blindData()); + uint8_t *readHead = (uint8_t *)(sourceGrid->blindMetaData(0).blindData()); // Read the shape of the tensor - const int64_t ndim = *reinterpret_cast(readHead); - TORCH_CHECK(sourceGrid->blindMetaData(0).blindDataSize() == nanovdb::math::AlignUp<32U>(sizeof(int64_t) * (ndim + 1)), + const int64_t ndim = *reinterpret_cast(readHead); + TORCH_CHECK(sourceGrid->blindMetaData(0).blindDataSize() == + nanovdb::math::AlignUp<32U>(sizeof(int64_t) * (ndim + 1)), "Invalid FVDB blind data for nanovdb grid. Unexpected size."); readHead += sizeof(int64_t); std::vector blindDataShape; blindDataShape.reserve(ndim); for (int i = 0; i < ndim; i++) { - blindDataShape.push_back(*reinterpret_cast(readHead)); + blindDataShape.push_back(*reinterpret_cast(readHead)); readHead += sizeof(int64_t); } @@ -319,147 +396,157 @@ nanovdbGridToFvdbGrid(const nanovdb::NanoGrid* sourceGrid) { } // Load the name and the transform - const std::string name = sourceGrid->gridName(); - const nanovdb::Vec3d voxSize = sourceGrid->data()->mVoxelSize; + const std::string name = sourceGrid->gridName(); + const nanovdb::Vec3d voxSize = sourceGrid->data()->mVoxelSize; const nanovdb::Vec3d voxOrigin = sourceGrid->data()->mMap.applyMap(nanovdb::Vec3d(0.0)); return std::make_tuple(std::move(retHandle), name, outData, voxSize, voxOrigin); } - -/// @brief Load a single nanovdb grid in a nanovdb::GridHandle into an ValueOnIndex or ValueOnIndexMask grid -/// stored in a nanovdb::GridHandle as well as torch::Tensor encoding the data at the voxels -/// (i.e. the standard format for FVDB). -/// There are 3 cases: +/// @brief Load a single nanovdb grid in a nanovdb::GridHandle into an +/// ValueOnIndex or ValueOnIndexMask grid +/// stored in a nanovdb::GridHandle as well as torch::Tensor encoding the +/// data at the voxels (i.e. the standard format for FVDB). There are 3 cases: /// 1. The input grid has scalar or vector values at the leaves: /// - Load a ValueOnIndex grid and torch::Tensor of values -/// 2. The input grid is a ValueOnIndex or ValueOnIndexMask and has its grid class set to TensorGrid: -/// - Load a matching ValueOnIndex or ValueOnIndexMask grid and torch::Tensor of values corresponding to +/// 2. The input grid is a ValueOnIndex or ValueOnIndexMask and has its grid class set to +/// TensorGrid: +/// - Load a matching ValueOnIndex or ValueOnIndexMask grid and torch::Tensor of values +/// corresponding to /// the blind data (if it is present) -/// 3. The input grid is an index grid (ValueIndex(Mask) or ValueOnIndex(Mask)) but doesn't have a TensorGrid class set: -/// - Load a ValueOnIndex or ValueOnIndexMask grid (depending if the input type has a mask or not) and an empty torch::Tensor of values +/// 3. The input grid is an index grid (ValueIndex(Mask) or ValueOnIndex(Mask)) but doesn't +/// have a TensorGrid class set: +/// - Load a ValueOnIndex or ValueOnIndexMask grid (depending if the input type has a +/// mask or not) and an empty torch::Tensor of values /// /// @param handle The grid handle to read from /// @param gridId The index of the grid in the handle to read /// @param bi The batch index of the grid in the handle to read (this is only used for logging) -/// @return A tuple containing the loaded index grid, the name of the grid, the tensor of data, the voxel size, and the voxel origin -std::tuple, std::string, torch::Tensor, nanovdb::Vec3d, nanovdb::Vec3d> -loadOneGrid(const nanovdb::GridHandle& handle, uint32_t gridId, uint32_t bi) { - +/// @return A tuple containing the loaded index grid, the name of the grid, the tensor of data, the +/// voxel size, and the voxel origin +std::tuple, std::string, torch::Tensor, nanovdb::Vec3d, + nanovdb::Vec3d> +loadOneGrid(const nanovdb::GridHandle &handle, uint32_t gridId, uint32_t bi) { if (handle.gridMetaData()->gridClass() == nanovdb::GridClass::TensorGrid) { - TORCH_CHECK(handle.gridType() == nanovdb::GridType::OnIndex || handle.gridType() == nanovdb::GridType::OnIndexMask, - "Invalid grid type: Tensor grids which are not saved with fVDB are not yet supported."); + TORCH_CHECK( + handle.gridType() == nanovdb::GridType::OnIndex || + handle.gridType() == nanovdb::GridType::OnIndexMask, + "Invalid grid type: Tensor grids which are not saved with fVDB are not yet supported."); if (handle.gridType() == nanovdb::GridType::OnIndex) { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbTensorGridToFVDBGrid(sourceGrid); + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbTensorGridToFVDBGrid( + sourceGrid); } else if (handle.gridType() == nanovdb::GridType::OnIndexMask) { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbTensorGridToFVDBGrid(sourceGrid); + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbTensorGridToFVDBGrid(sourceGrid); } } switch (handle.gridType()) { - case nanovdb::GridType::Float: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Double: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Int32: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Int64: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Mask: - case nanovdb::GridType::Boolean: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Vec3f: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Vec3d: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::RGBA8: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Vec4f: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Vec4d: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Fp16: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbGridToFvdbGrid(sourceGrid); - } - case nanovdb::GridType::Index: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbIndexGridToFVDBGrid(sourceGrid); - } - case nanovdb::GridType::IndexMask: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbIndexGridToFVDBGrid(sourceGrid); - } - case nanovdb::GridType::OnIndex: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbIndexGridToFVDBGrid(sourceGrid); - } - case nanovdb::GridType::OnIndexMask: - { - const nanovdb::NanoGrid* sourceGrid = getGrid(handle, gridId, bi); - return nanovdbIndexGridToFVDBGrid(sourceGrid); - } - default: - // Unhandled cases include: Int16, UInt32, Fp4, Fp8, FpN - char gridTypeStr[nanovdb::strlen()]; - nanovdb::toStr(gridTypeStr, handle.gridType()); - throw std::runtime_error( - std::string("Grid type not supported: ") + gridTypeStr); + case nanovdb::GridType::Float: { + const nanovdb::NanoGrid *sourceGrid = getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid(sourceGrid); + } + case nanovdb::GridType::Double: { + const nanovdb::NanoGrid *sourceGrid = getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid(sourceGrid); + } + case nanovdb::GridType::Int32: { + const nanovdb::NanoGrid *sourceGrid = getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid(sourceGrid); + } + case nanovdb::GridType::Int64: { + const nanovdb::NanoGrid *sourceGrid = getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid(sourceGrid); + } + case nanovdb::GridType::Mask: + case nanovdb::GridType::Boolean: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid( + sourceGrid); + } + case nanovdb::GridType::Vec3f: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid(sourceGrid); + } + case nanovdb::GridType::Vec3d: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid(sourceGrid); + } + case nanovdb::GridType::RGBA8: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid( + sourceGrid); + } + case nanovdb::GridType::Vec4f: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid(sourceGrid); + } + case nanovdb::GridType::Vec4d: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid(sourceGrid); + } + case nanovdb::GridType::Fp16: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbGridToFvdbGrid( + sourceGrid); + } + case nanovdb::GridType::Index: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbIndexGridToFVDBGrid(sourceGrid); + } + case nanovdb::GridType::IndexMask: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbIndexGridToFVDBGrid( + sourceGrid); + } + case nanovdb::GridType::OnIndex: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbIndexGridToFVDBGrid(sourceGrid); + } + case nanovdb::GridType::OnIndexMask: { + const nanovdb::NanoGrid *sourceGrid = + getGrid(handle, gridId, bi); + return nanovdbIndexGridToFVDBGrid( + sourceGrid); + } + default: + // Unhandled cases include: Int16, UInt32, Fp4, Fp8, FpN + char gridTypeStr[nanovdb::strlen()]; + nanovdb::toStr(gridTypeStr, handle.gridType()); + throw std::runtime_error(std::string("Grid type not supported: ") + gridTypeStr); } } std::tuple> -fromNVDB(nanovdb::GridHandle& handle, +fromNVDB(nanovdb::GridHandle &handle, const torch::optional maybeDevice) { - return fromNVDB({handle}, maybeDevice); + return fromNVDB({ handle }, maybeDevice); } std::tuple> -fromNVDB(const std::vector>& handles, - const torch::optional maybeDevice) { +fromNVDB(const std::vector> &handles, + const torch::optional maybeDevice) { // Load the grids, data, names, voxel origins, and sizes - std::vector data; + std::vector data; std::vector> grids; - std::vector voxSizes, voxOrigins; - std::vector names; - uint32_t bi = 0; - nanovdb::GridType lastGridType = nanovdb::GridType::Unknown; + std::vector voxSizes, voxOrigins; + std::vector names; + uint32_t bi = 0; + nanovdb::GridType lastGridType = nanovdb::GridType::Unknown; for (size_t handleId = 0; handleId < handles.size(); handleId += 1) { for (size_t gridId = 0; gridId < handles[handleId].gridCount(); gridId += 1) { auto gridData = loadOneGrid(handles[handleId], gridId, bi); @@ -474,11 +561,12 @@ fromNVDB(const std::vector>& handles, // In all but two cases, we load a ValueOnIndex grid and a tensor of data: // 1. When the user saved a mutable Tensor grid with save // 2. When the user loaded a batch with a ValueOnIndexMask grid - // If the file the list of grids the user loaded contains a mix of ValueOnIndex and ValueOnIndexMask grids, - // then it's unclear what to do, so throw an exception. + // If the file the list of grids the user loaded contains a mix of ValueOnIndex and + // ValueOnIndexMask grids, then it's unclear what to do, so throw an exception. if (bi > 0) { - TORCH_CHECK(lastGridType == grids.back().gridData()->mGridType, - "All grids in a batch must have the same mutability (i.e. all ValueOnIndex or all ValueOnIndexMask)."); + TORCH_CHECK( + lastGridType == grids.back().gridData()->mGridType, + "All grids in a batch must have the same mutability (i.e. all ValueOnIndex or all ValueOnIndexMask)."); } lastGridType = grids.back().gridData()->mGridType; @@ -487,9 +575,11 @@ fromNVDB(const std::vector>& handles, } // Merge all the loaded grids into a single handle - TORCH_CHECK_VALUE(grids.size() <= fvdb::GridBatch::MAX_GRIDS_PER_BATCH, "Cannot load more than ", fvdb::GridBatch::MAX_GRIDS_PER_BATCH, " grids."); + TORCH_CHECK_VALUE(grids.size() <= fvdb::GridBatch::MAX_GRIDS_PER_BATCH, + "Cannot load more than ", fvdb::GridBatch::MAX_GRIDS_PER_BATCH, " grids."); nanovdb::GridHandle resCpu = nanovdb::mergeGrids(grids); - c10::intrusive_ptr ret = c10::make_intrusive(std::move(resCpu), voxSizes, voxOrigins); + c10::intrusive_ptr ret = + c10::make_intrusive(std::move(resCpu), voxSizes, voxOrigins); // Merge loaded data Tensors into a JaggedTensor JaggedTensor dataJagged(data); @@ -498,7 +588,7 @@ fromNVDB(const std::vector>& handles, if (maybeDevice.has_value()) { torch::Device toDevice = maybeDevice.value().value(); if (toDevice != ret->device()) { - ret = ret->clone(toDevice); + ret = ret->clone(toDevice); dataJagged = dataJagged.to(toDevice); } } @@ -507,28 +597,28 @@ fromNVDB(const std::vector>& handles, } std::tuple> -loadNVDB(const std::string& path, - const NanoVDBFileGridIdentifier& gridIdentifier, - TorchDeviceOrString device, - bool verbose) { - +loadNVDB(const std::string &path, const NanoVDBFileGridIdentifier &gridIdentifier, + TorchDeviceOrString device, bool verbose) { // Load a std::vector of grid handles each containing a one grid to load // If the user specified specific indices or names of grid to load, use that as a filter std::vector> sourceHandles; if (gridIdentifier.specifiesIndices()) { - for (uint64_t index : gridIdentifier.indicesValue()) { + for (uint64_t index: gridIdentifier.indicesValue()) { try { - sourceHandles.emplace_back(nanovdb::io::readGrid(path, index, verbose)); - } catch(std::runtime_error& e) { + sourceHandles.emplace_back( + nanovdb::io::readGrid(path, index, verbose)); + } catch (std::runtime_error &e) { TORCH_CHECK_INDEX(false, "Grid id ", index, " is out of range."); } } } else if (gridIdentifier.specifiesNames()) { - for (const std::string& name : gridIdentifier.namesValue()) { + for (const std::string &name: gridIdentifier.namesValue()) { try { - sourceHandles.emplace_back(nanovdb::io::readGrid(path, name, verbose)); - } catch(std::runtime_error& e) { - TORCH_CHECK_INDEX(false, "Grid with name '", name, "' not found in file '", path, "'."); + sourceHandles.emplace_back( + nanovdb::io::readGrid(path, name, verbose)); + } catch (std::runtime_error &e) { + TORCH_CHECK_INDEX(false, "Grid with name '", name, "' not found in file '", path, + "'."); } } } else { @@ -538,8 +628,6 @@ loadNVDB(const std::string& path, return fromNVDB(sourceHandles, device); } - - } // namespace io } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/io/SaveNanoVDB.cpp b/fvdb/src/detail/io/SaveNanoVDB.cpp index 9d9eb41aec..0d5c5f8ca9 100644 --- a/fvdb/src/detail/io/SaveNanoVDB.cpp +++ b/fvdb/src/detail/io/SaveNanoVDB.cpp @@ -1,39 +1,42 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include "detail/io/IO.h" +#include "IO.h" -#include "detail/utils/Utils.h" - -#include -#include - -#include +#include #include #include -#include #include +#include + +#include #include +#include +#include namespace fvdb { namespace detail { namespace io { -/// @brief Copy a std::string to a char buffer with a fixed size and throw an exception if the string is too long +/// @brief Copy a std::string to a char buffer with a fixed size and throw an exception if the +/// string is too long /// @param targetBuf A pointer to the buffer to write the string to /// @param maxSize The maximum size of the target buffer /// @param sourceSting The source string to copy /// @param bufName A name for this string to use when throwing an exception (default is "String") -void setFixedSizeStringBuf(char* targetBuf, size_t maxSize, std::string sourceSting, std::string bufName = "String") { +void +setFixedSizeStringBuf(char *targetBuf, size_t maxSize, std::string sourceSting, + std::string bufName = "String") { memset(targetBuf, 0, maxSize); - TORCH_CHECK_VALUE(sourceSting.size() < maxSize, bufName + " exceeds maximum character length of " + std::to_string(maxSize) + "."); + TORCH_CHECK_VALUE(sourceSting.size() < maxSize, bufName + + " exceeds maximum character length of " + + std::to_string(maxSize) + "."); strncpy(targetBuf, sourceSting.c_str(), maxSize); } - /// @brief Get the (row) value at index rowIdx from a tensor with 2 dimensions. /// Specialized to return useful nanovdb types (e.g. Vec3f, Vec4f, etc...) /// @tparam ScalarT The scalar type of values @@ -42,56 +45,66 @@ void setFixedSizeStringBuf(char* targetBuf, size_t maxSize, std::string sourceSt /// @param rowIdx The row to read from /// @return The rowIdx^th row of the tensor casted to ValueT template -inline ValueT valueGetter(torch::TensorAccessor& acc, int rowIdx) { +inline ValueT +valueGetter(torch::TensorAccessor &acc, int rowIdx) { return acc[rowIdx][0]; } template <> -inline nanovdb::Vec3f valueGetter(torch::TensorAccessor& acc, int rowIdx) { - return {acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2]}; +inline nanovdb::Vec3f +valueGetter(torch::TensorAccessor &acc, int rowIdx) { + return { acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2] }; } template <> -inline nanovdb::Vec4f valueGetter(torch::TensorAccessor& acc, int rowIdx) { - return {acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2], acc[rowIdx][3]}; +inline nanovdb::Vec4f +valueGetter(torch::TensorAccessor &acc, int rowIdx) { + return { acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2], acc[rowIdx][3] }; } template <> -inline nanovdb::Vec3d valueGetter(torch::TensorAccessor& acc, int rowIdx) { - return {acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2]}; +inline nanovdb::Vec3d +valueGetter(torch::TensorAccessor &acc, int rowIdx) { + return { acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2] }; } template <> -inline nanovdb::Vec4d valueGetter(torch::TensorAccessor& acc, int rowIdx) { - return {acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2], acc[rowIdx][3]}; +inline nanovdb::Vec4d +valueGetter(torch::TensorAccessor &acc, int rowIdx) { + return { acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2], acc[rowIdx][3] }; } template <> -inline nanovdb::Vec3i valueGetter(torch::TensorAccessor& acc, int rowIdx) { - return {acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2]}; +inline nanovdb::Vec3i +valueGetter(torch::TensorAccessor &acc, int rowIdx) { + return { acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2] }; } template <> -inline nanovdb::math::Rgba8 valueGetter(torch::TensorAccessor& acc, int rowIdx) { - return {acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2], acc[rowIdx][3]}; +inline nanovdb::math::Rgba8 +valueGetter(torch::TensorAccessor &acc, int rowIdx) { + return { acc[rowIdx][0], acc[rowIdx][1], acc[rowIdx][2], acc[rowIdx][3] }; } - -/// @brief Helper function to copy an index grid with a corresponding JaggedTensor of values to a nanovdb grid -/// with values stored directly in the leaves. This will only work for values which correspond to valid nanovdb -/// grid types (e.g. Vec3f, Vec4f, Vec3d, Vec4d, etc...) -/// @tparam OutGridType The type of data to store in the returned grid (e.g. float, nanovdb::Vec3f, etc...) -/// @tparam InScalarType The scalar type of the input jagged tensor (e.g float, double, int32_t, etc...) +/// @brief Helper function to copy an index grid with a corresponding JaggedTensor of values to a +/// nanovdb grid +/// with values stored directly in the leaves. This will only work for values which +/// correspond to valid nanovdb grid types (e.g. Vec3f, Vec4f, Vec3d, Vec4d, etc...) +/// @tparam OutGridType The type of data to store in the returned grid (e.g. float, nanovdb::Vec3f, +/// etc...) +/// @tparam InScalarType The scalar type of the input jagged tensor (e.g float, double, int32_t, +/// etc...) /// @param gridBatch The batch of index grids to copy /// @param data The JaggedTensor of data to copy /// @param names The names of the grids in the batch to write to the copied output (optional) /// @return A nanovdb grid handle with the copied data stored in the leaves template -nanovdb::GridHandle fvdbToNanovdbGridWithValues(const GridBatch& gridBatch, - const JaggedTensor& data, - const std::vector& names) { - - TORCH_CHECK(names.size() == 0 || names.size() == (size_t) gridBatch.grid_count(), - "Invalid parameter for names, must be empty or a list of the same length as the batch size. Got " - + std::to_string(names.size()) + " names for batch size " + std::to_string(gridBatch.grid_count())); +nanovdb::GridHandle +fvdbToNanovdbGridWithValues(const GridBatch &gridBatch, const JaggedTensor &data, + const std::vector &names) { + TORCH_CHECK( + names.size() == 0 || names.size() == (size_t)gridBatch.grid_count(), + "Invalid parameter for names, must be empty or a list of the same length as the batch size. Got " + + std::to_string(names.size()) + " names for batch size " + + std::to_string(gridBatch.grid_count())); TORCH_CHECK(!gridBatch.is_mutable(), "Need to use indexing with mutable grids!"); - using ProxyGridT = nanovdb::tools::build::Grid; - using GridValueT = typename ProxyGridT::ValueType; + using ProxyGridT = nanovdb::tools::build::Grid; + using GridValueT = typename ProxyGridT::ValueType; using HostGridHandle = nanovdb::GridHandle; // We'll build each grid from the ijk values and data, so get accessors for these @@ -105,32 +118,39 @@ nanovdb::GridHandle fvdbToNanovdbGridWithValues(const GridB if (jdataCpu.ndimension() == 1) { jdataCpu = jdataCpu.unsqueeze(1); } - TORCH_CHECK(jdataCpu.size(0) == gridBatch.total_voxels(), "Invalid data tensor size. Must match number of voxels in grid batch."); + TORCH_CHECK(jdataCpu.size(0) == gridBatch.total_voxels(), + "Invalid data tensor size. Must match number of voxels in grid batch."); - auto ijkAccessor = ijkValues.jdata().accessor(); + auto ijkAccessor = ijkValues.jdata().accessor(); auto jdataAccessor = jdataCpu.accessor(); // Populate a vector of host buffers for each grid in the batch std::vector buffers(gridBatch.grid_count()); for (int64_t bi = 0; bi < gridBatch.grid_count(); bi += 1) { const std::string name = names.size() > 0 ? names[bi] : ""; - TORCH_CHECK_VALUE(name.size() < nanovdb::GridData::MaxNameSize, "Grid name " + name + " exceeds maximum character length of " + std::to_string(nanovdb::GridData::MaxNameSize) + "."); + TORCH_CHECK_VALUE(name.size() < nanovdb::GridData::MaxNameSize, + "Grid name " + name + " exceeds maximum character length of " + + std::to_string(nanovdb::GridData::MaxNameSize) + "."); - auto proxyGrid = std::make_shared(GridValueT(0), name); + auto proxyGrid = std::make_shared(GridValueT(0), name); auto proxyGridAccessor = proxyGrid->getWriteAccessor(); - const int start = ijkValues.joffsets()[bi].item(); - const int end = ijkValues.joffsets()[bi+1].item(); + const int start = ijkValues.joffsets()[bi].item(); + const int end = ijkValues.joffsets()[bi + 1].item(); const int64_t numVoxels = end - start; - const int64_t numData = data.joffsets()[bi+1].item() - data.joffsets()[bi].item(); + const int64_t numData = + data.joffsets()[bi + 1].item() - data.joffsets()[bi].item(); TORCH_CHECK_VALUE(numData == gridBatch.num_voxels_at(bi), - "Invalid number of voxels in jagged tensor at index " + std::to_string(bi) + - ". Expected it to match the number of voxels at grid index " + std::to_string(bi) + ". " + - "Got " + std::to_string(numVoxels) + " but expected " + - std::to_string(gridBatch.num_voxels_at(bi)) + "."); + "Invalid number of voxels in jagged tensor at index " + + std::to_string(bi) + + ". Expected it to match the number of voxels at grid index " + + std::to_string(bi) + ". " + "Got " + std::to_string(numVoxels) + + " but expected " + std::to_string(gridBatch.num_voxels_at(bi)) + "."); for (int i = 0; i < numVoxels; i += 1) { - const GridValueT& value = valueGetter(jdataAccessor, start + i); - const nanovdb::Coord ijk(ijkAccessor[start + i][0], ijkAccessor[start + i][1], ijkAccessor[start + i][2]); + const GridValueT &value = + valueGetter(jdataAccessor, start + i); + const nanovdb::Coord ijk(ijkAccessor[start + i][0], ijkAccessor[start + i][1], + ijkAccessor[start + i][2]); proxyGridAccessor.setValue(ijk, value); } proxyGridAccessor.merge(); @@ -138,22 +158,21 @@ nanovdb::GridHandle fvdbToNanovdbGridWithValues(const GridB // Write shape of tensor to blind data so we can load it back with the same shape // This lets us handle things like (N, 1, 3) tensors which we can save as Vec3f grids nanovdb::tools::CreateNanoGrid converter(*proxyGrid); - converter.addBlindData("fvdb_jdata", - nanovdb::GridBlindDataSemantic::Unknown, - nanovdb::GridBlindDataClass::Unknown, - nanovdb::GridType::Unknown, - data.jdata().dim() + 1, - sizeof(int64_t)); - buffers.push_back(converter.template getHandle(nanovdb::HostBuffer())); + converter.addBlindData("fvdb_jdata", nanovdb::GridBlindDataSemantic::Unknown, + nanovdb::GridBlindDataClass::Unknown, nanovdb::GridType::Unknown, + data.jdata().dim() + 1, sizeof(int64_t)); + buffers.push_back( + converter.template getHandle(nanovdb::HostBuffer())); TORCH_CHECK(buffers.back().gridCount() == 1, "Internal error. Invalid grid count."); - nanovdb::NanoGrid* nanoGrid = buffers.back().grid(); - TORCH_CHECK(nanoGrid->blindDataCount() == 1, "Internal error. Invalid blind metadata count."); - int64_t* writeHead = (int64_t*) nanoGrid->blindMetaData(0).blindData(); - JaggedTensor dataBi = data.index({bi}); - *writeHead = (int64_t) dataBi.jdata().dim(); + nanovdb::NanoGrid *nanoGrid = buffers.back().grid(); + TORCH_CHECK(nanoGrid->blindDataCount() == 1, + "Internal error. Invalid blind metadata count."); + int64_t *writeHead = (int64_t *)nanoGrid->blindMetaData(0).blindData(); + JaggedTensor dataBi = data.index({ bi }); + *writeHead = (int64_t)dataBi.jdata().dim(); writeHead += 1; for (int di = 0; di < dataBi.jdata().dim(); di += 1) { - *writeHead = (int64_t) dataBi.jdata().size(di); + *writeHead = (int64_t)dataBi.jdata().size(di); writeHead += 1; } } @@ -166,11 +185,11 @@ nanovdb::GridHandle fvdbToNanovdbGridWithValues(const GridB } } -nanovdb::GridHandle maybeConvertToStandardNanovdbGrid(const fvdb::GridBatch& gridBatch, - const fvdb::JaggedTensor data, - const std::vector names) -{ - // We can't convert mutable grids to a standard format because we don't know what do with disabled voxels +nanovdb::GridHandle +maybeConvertToStandardNanovdbGrid(const fvdb::GridBatch &gridBatch, const fvdb::JaggedTensor data, + const std::vector names) { + // We can't convert mutable grids to a standard format because we don't know what do with + // disabled voxels if (gridBatch.is_mutable()) { return nanovdb::GridHandle(); } @@ -178,9 +197,11 @@ nanovdb::GridHandle maybeConvertToStandardNanovdbGrid(const // Get a squeezed view of the tensor so we can save data with redundant dimensions // (e.g. shape (N, 1, 3) can get saved as a Vec3f grid) torch::Tensor jdataSqueezed = data.jdata().squeeze(); - if (jdataSqueezed.numel() == 1 && jdataSqueezed.dim() == 0) { // Make sure we have at least 1 dimension + if (jdataSqueezed.numel() == 1 && + jdataSqueezed.dim() == 0) { // Make sure we have at least 1 dimension jdataSqueezed = jdataSqueezed.unsqueeze(0); - TORCH_CHECK(jdataSqueezed.ndimension() == 1, "Internal error: Invalid jdata shape when saving grid."); + TORCH_CHECK(jdataSqueezed.ndimension() == 1, + "Internal error: Invalid jdata shape when saving grid."); } if (data.dtype() == torch::kHalf) { if (jdataSqueezed.dim() == 1 || (jdataSqueezed.dim() == 2 && jdataSqueezed.size(1) == 1)) { @@ -232,16 +253,13 @@ nanovdb::GridHandle maybeConvertToStandardNanovdbGrid(const return nanovdb::GridHandle(); } -bool maybeSaveStandardNanovdbGrid(const std::string& path, - const GridBatch& gridBatch, - const JaggedTensor data, - const std::vector names, - nanovdb::io::Codec codec, - bool verbose) { - - nanovdb::GridHandle gridHandle = maybeConvertToStandardNanovdbGrid(gridBatch, data, names); - if (gridHandle.isEmpty()) - { +bool +maybeSaveStandardNanovdbGrid(const std::string &path, const GridBatch &gridBatch, + const JaggedTensor data, const std::vector names, + nanovdb::io::Codec codec, bool verbose) { + nanovdb::GridHandle gridHandle = + maybeConvertToStandardNanovdbGrid(gridBatch, data, names); + if (gridHandle.isEmpty()) { return false; } @@ -249,48 +267,54 @@ bool maybeSaveStandardNanovdbGrid(const std::string& path, return true; } -nanovdb::GridHandle getIndexGrid(const GridBatch& gridBatch, - const std::vector names = {}) { - - const nanovdb::GridHandle& nanoGridHdl = gridBatch.nanovdb_grid_handle(); +nanovdb::GridHandle +getIndexGrid(const GridBatch &gridBatch, const std::vector names = {}) { + const nanovdb::GridHandle &nanoGridHdl = gridBatch.nanovdb_grid_handle(); // Allocate memory and get pointer to host grid buffer nanovdb::HostBuffer writeBuf(nanoGridHdl.buffer().size()); - void* writeHead = writeBuf.data(); + void *writeHead = writeBuf.data(); // Get pointer to grid read from (possibly on the device) - const bool isCuda = nanoGridHdl.buffer().device().is_cuda(); - void* readHead = isCuda ? nanoGridHdl.buffer().deviceData() : nanoGridHdl.buffer().data(); + const bool isCuda = nanoGridHdl.buffer().device().is_cuda(); + void *readHead = isCuda ? nanoGridHdl.buffer().deviceData() : nanoGridHdl.buffer().data(); const size_t sourceGridByteSize = nanoGridHdl.buffer().size(); // Write out the full grid to the buffer if (isCuda) { - at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(gridBatch.device().index()); - cudaMemcpyAsync(writeHead, readHead, sourceGridByteSize, cudaMemcpyDeviceToHost, defaultStream.stream()); + at::cuda::CUDAStream defaultStream = + at::cuda::getCurrentCUDAStream(gridBatch.device().index()); + cudaMemcpyAsync(writeHead, readHead, sourceGridByteSize, cudaMemcpyDeviceToHost, + defaultStream.stream()); cudaStreamSynchronize(defaultStream.stream()); } else { memcpy(writeHead, readHead, sourceGridByteSize); } - nanovdb::GridHandle retHandle = nanovdb::GridHandle (std::move(writeBuf)); + nanovdb::GridHandle retHandle = + nanovdb::GridHandle(std::move(writeBuf)); // Write voxelSize and origin information to the output buffer - for (int64_t bi = 0; bi < gridBatch.grid_count(); bi += 1) - { + for (int64_t bi = 0; bi < gridBatch.grid_count(); bi += 1) { nanovdb::GridData *retGridData = (nanovdb::GridData *)(retHandle.gridData(bi)); - torch::Tensor voxelSize = gridBatch.voxel_size_at(bi, torch::kFloat64); - torch::Tensor origin = gridBatch.origin_at(bi, torch::kFloat64); - retGridData->mVoxelSize = {voxelSize[0].item(), voxelSize[1].item(), voxelSize[2].item()}; - retGridData->mMap = nanovdb::Map(voxelSize[0].item(), {origin[0].item(), origin[1].item(), origin[2].item()}); + torch::Tensor voxelSize = gridBatch.voxel_size_at(bi, torch::kFloat64); + torch::Tensor origin = gridBatch.origin_at(bi, torch::kFloat64); + retGridData->mVoxelSize = { voxelSize[0].item(), voxelSize[1].item(), + voxelSize[2].item() }; + retGridData->mMap = nanovdb::Map( + voxelSize[0].item(), + { origin[0].item(), origin[1].item(), origin[2].item() }); } // If you passed in grid names, write them to the output buffer if (names.size() > 0) { for (int64_t bi = 0; bi < gridBatch.grid_count(); bi += 1) { const std::string name = names.size() > 0 ? names[bi] : ""; - TORCH_CHECK_VALUE(name.size() < nanovdb::GridData::MaxNameSize, "Grid name " + name + " exceeds maximum character length of " + std::to_string(nanovdb::GridData::MaxNameSize) + "."); - nanovdb::GridData* retGridData = (nanovdb::GridData*) (retHandle.gridData(bi)); - #pragma GCC diagnostic ignored "-Wstringop-truncation" + TORCH_CHECK_VALUE(name.size() < nanovdb::GridData::MaxNameSize, + "Grid name " + name + " exceeds maximum character length of " + + std::to_string(nanovdb::GridData::MaxNameSize) + "."); + nanovdb::GridData *retGridData = (nanovdb::GridData *)(retHandle.gridData(bi)); +#pragma GCC diagnostic ignored "-Wstringop-truncation" strncpy(retGridData->mGridName, names[bi].c_str(), nanovdb::GridData::MaxNameSize); } } @@ -299,12 +323,9 @@ nanovdb::GridHandle getIndexGrid(const GridBatch& gridBatch return retHandle; } -void saveIndexGrid(const std::string& path, - const GridBatch& gridBatch, - const std::vector names, - nanovdb::io::Codec codec, - bool verbose) { - +void +saveIndexGrid(const std::string &path, const GridBatch &gridBatch, + const std::vector names, nanovdb::io::Codec codec, bool verbose) { // If you don't pass in data, then we just write the grid nanovdb::GridHandle writeHandle = getIndexGrid(gridBatch, names); @@ -312,32 +333,31 @@ void saveIndexGrid(const std::string& path, nanovdb::io::writeGrid(path, writeHandle, codec, verbose); } -void saveIndexGridWithBlindData(const std::string& path, - const GridBatch& gridBatch, - const JaggedTensor data, - const std::vector names, - nanovdb::io::Codec codec, - bool verbose) { - - const nanovdb::GridHandle& nanoGridHdl = gridBatch.nanovdb_grid_handle(); +void +saveIndexGridWithBlindData(const std::string &path, const GridBatch &gridBatch, + const JaggedTensor data, const std::vector names, + nanovdb::io::Codec codec, bool verbose) { + const nanovdb::GridHandle &nanoGridHdl = gridBatch.nanovdb_grid_handle(); // Make a (possible) cpu copy of the data jagged tensor JaggedTensor cpuData = data.cpu().contiguous(); // Compute blind data sizes padded to be 32 byte aligned - std::vector blindDataPadding; // Size of each blind data padded to 32 bytes - std::vector paddedBlindDataSizes; // The amount of padding added to each blind data to achieve 32 byte alignment + std::vector blindDataPadding; // Size of each blind data padded to 32 bytes + std::vector paddedBlindDataSizes; // The amount of padding added to each blind data to + // achieve 32 byte alignment uint64_t totalBlindDataSize = 0; for (int bi = 0; bi < gridBatch.grid_count(); bi += 1) { - JaggedTensor dataBi = cpuData.index({bi}); - const int64_t numVoxelsBi = gridBatch.num_voxels_at(bi); + JaggedTensor dataBi = cpuData.index({ bi }); + const int64_t numVoxelsBi = gridBatch.num_voxels_at(bi); const int64_t jdataBytesBi = dataBi.jdata().numel() * dataBi.jdata().element_size(); - TORCH_CHECK_VALUE(numVoxelsBi == dataBi.rsize(0), - "Invalid number of voxels in jagged tensor at index " + std::to_string(bi) + - ". Expected it to match the number of voxels at grid index " + std::to_string(bi) + ". " + - "Got " + std::to_string(dataBi.jdata().size(0)) + " but expected " + - std::to_string(gridBatch.num_voxels_at(bi)) + "."); - const uint64_t blindDataSizeBi = jdataBytesBi + sizeof(int64_t) * (dataBi.rdim() + 1); + TORCH_CHECK_VALUE( + numVoxelsBi == dataBi.rsize(0), + "Invalid number of voxels in jagged tensor at index " + std::to_string(bi) + + ". Expected it to match the number of voxels at grid index " + std::to_string(bi) + + ". " + "Got " + std::to_string(dataBi.jdata().size(0)) + " but expected " + + std::to_string(gridBatch.num_voxels_at(bi)) + "."); + const uint64_t blindDataSizeBi = jdataBytesBi + sizeof(int64_t) * (dataBi.rdim() + 1); const uint64_t paddedBlindDataSizeBi = nanovdb::math::AlignUp<32UL>(blindDataSizeBi); blindDataPadding.push_back(paddedBlindDataSizeBi - blindDataSizeBi); paddedBlindDataSizes.push_back(paddedBlindDataSizeBi); @@ -345,61 +365,71 @@ void saveIndexGridWithBlindData(const std::string& path, } // Allocate a big enough buffer to allocate the index grid and blind data - const size_t allocSize = nanoGridHdl.buffer().size() + // Grids (32B aligned) - sizeof(nanovdb::GridBlindMetaData) * gridBatch.grid_count() + // Blind metadata (32B aligned) - totalBlindDataSize; // Blind data (32B aligned) + const size_t allocSize = nanoGridHdl.buffer().size() + // Grids (32B aligned) + sizeof(nanovdb::GridBlindMetaData) * + gridBatch.grid_count() + // Blind metadata (32B aligned) + totalBlindDataSize; // Blind data (32B aligned) nanovdb::HostBuffer writeBuf(allocSize); // Get pointer to read (possibly on the device) and write pointers - const bool isCuda = nanoGridHdl.buffer().device().is_cuda(); - uint8_t* writeHead = static_cast(writeBuf.data()); - uint8_t* readHead = static_cast(isCuda ? nanoGridHdl.buffer().deviceData() : nanoGridHdl.buffer().data()); + const bool isCuda = nanoGridHdl.buffer().device().is_cuda(); + uint8_t *writeHead = static_cast(writeBuf.data()); + uint8_t *readHead = static_cast(isCuda ? nanoGridHdl.buffer().deviceData() + : nanoGridHdl.buffer().data()); // Copy each grid and each entry in the jagged tensor for (int bi = 0; bi < gridBatch.grid_count(); bi += 1) { - // Copy the full bi^th index grid to the buffer const size_t sourceGridByteSize = nanoGridHdl.gridSize(bi); if (isCuda) { - at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(gridBatch.device().index()); - cudaMemcpyAsync((void*) writeHead, (void*) readHead, sourceGridByteSize, cudaMemcpyDeviceToHost, defaultStream.stream()); + at::cuda::CUDAStream defaultStream = + at::cuda::getCurrentCUDAStream(gridBatch.device().index()); + cudaMemcpyAsync((void *)writeHead, (void *)readHead, sourceGridByteSize, + cudaMemcpyDeviceToHost, defaultStream.stream()); } else { - memcpy((void*) writeHead, (void*) readHead, sourceGridByteSize); + memcpy((void *)writeHead, (void *)readHead, sourceGridByteSize); } // Update the metadata for the copied grid in the buffer to be a tensor grid with blind data - nanovdb::GridData* writeGridData = reinterpret_cast(writeHead); - writeGridData->mGridClass = nanovdb::GridClass::TensorGrid; - writeGridData->mGridType = gridBatch.is_mutable() ? nanovdb::GridType::OnIndexMask : nanovdb::GridType::OnIndex; - writeGridData->mBlindMetadataCount = 1; + nanovdb::GridData *writeGridData = reinterpret_cast(writeHead); + writeGridData->mGridClass = nanovdb::GridClass::TensorGrid; + writeGridData->mGridType = + gridBatch.is_mutable() ? nanovdb::GridType::OnIndexMask : nanovdb::GridType::OnIndex; + writeGridData->mBlindMetadataCount = 1; writeGridData->mBlindMetadataOffset = sourceGridByteSize; - const std::string name = names.size() > 0 ? names[bi] : ""; - setFixedSizeStringBuf(writeGridData->mGridName, nanovdb::GridData::MaxNameSize, name, "Grid name " + name); - writeGridData->mGridSize = sourceGridByteSize + sizeof(nanovdb::GridBlindMetaData) + paddedBlindDataSizes[bi]; + const std::string name = names.size() > 0 ? names[bi] : ""; + setFixedSizeStringBuf(writeGridData->mGridName, nanovdb::GridData::MaxNameSize, name, + "Grid name " + name); + writeGridData->mGridSize = + sourceGridByteSize + sizeof(nanovdb::GridBlindMetaData) + paddedBlindDataSizes[bi]; readHead += sourceGridByteSize; writeHead += sourceGridByteSize; // Write out blind metadata to the end of the grid - nanovdb::GridBlindMetaData* blindMetadata = reinterpret_cast(writeHead); + nanovdb::GridBlindMetaData *blindMetadata = + reinterpret_cast(writeHead); blindMetadata->mDataOffset = int64_t(sizeof(nanovdb::GridBlindMetaData)); blindMetadata->mValueCount = paddedBlindDataSizes[bi]; // Number of bytes - blindMetadata->mValueSize = 1; // 1 byte per value - blindMetadata->mSemantic = nanovdb::GridBlindDataSemantic::Unknown; - blindMetadata->mDataClass = nanovdb::GridBlindDataClass::Unknown; - blindMetadata->mDataType = nanovdb::GridType::Unknown; - const std::string fvdbBlindName = "fvdb_jdata" + TorchScalarTypeToStr(cpuData.scalar_type()); - setFixedSizeStringBuf(blindMetadata->mName, nanovdb::GridBlindMetaData::MaxNameSize, fvdbBlindName, "blind metadata name"); + blindMetadata->mValueSize = 1; // 1 byte per value + blindMetadata->mSemantic = nanovdb::GridBlindDataSemantic::Unknown; + blindMetadata->mDataClass = nanovdb::GridBlindDataClass::Unknown; + blindMetadata->mDataType = nanovdb::GridType::Unknown; + const std::string fvdbBlindName = + "fvdb_jdata" + TorchScalarTypeToStr(cpuData.scalar_type()); + setFixedSizeStringBuf(blindMetadata->mName, nanovdb::GridBlindMetaData::MaxNameSize, + fvdbBlindName, "blind metadata name"); TORCH_CHECK(blindMetadata->isValid(), "Invalid blind metadata"); writeHead += sizeof(nanovdb::GridBlindMetaData); // i^th jdata entry in the jagged tensor - JaggedTensor dataBi = cpuData.index({bi}); + JaggedTensor dataBi = cpuData.index({ bi }); TORCH_CHECK(dataBi.is_contiguous(), "Jagged tensor must be contiguous"); - // Write the shape of bi^th jdata tensor so we can load it with the same shape it was saved with - *reinterpret_cast(writeHead) = (int64_t) dataBi.rdim(); + // Write the shape of bi^th jdata tensor so we can load it with the same shape it was saved + // with + *reinterpret_cast(writeHead) = (int64_t)dataBi.rdim(); writeHead += sizeof(int64_t); for (int di = 0; di < dataBi.rdim(); di += 1) { - *reinterpret_cast(writeHead) = (int64_t) dataBi.rsize(di); + *reinterpret_cast(writeHead) = (int64_t)dataBi.rsize(di); writeHead += sizeof(int64_t); } @@ -407,14 +437,15 @@ void saveIndexGridWithBlindData(const std::string& path, const int64_t jdataSize = dataBi.jdata().numel() * dataBi.jdata().element_size(); TORCH_CHECK(dataBi.jdata().is_contiguous(), "Jagged tensor must be contiguous"); TORCH_CHECK(dataBi.device().is_cpu(), "Jagged tensor must be on CPU"); - memcpy((void*) writeHead, (void*) dataBi.jdata().data_ptr(), jdataSize); + memcpy((void *)writeHead, (void *)dataBi.jdata().data_ptr(), jdataSize); writeHead += jdataSize; - writeHead += blindDataPadding[bi]; // Add padding to make sure we're 32 byte aligned + writeHead += blindDataPadding[bi]; // Add padding to make sure we're 32 byte aligned } // Synchronize cuda stream if we just did a bunch of GPU -> CPU transfers if (isCuda) { - at::cuda::CUDAStream defaultStream = at::cuda::getCurrentCUDAStream(gridBatch.device().index()); + at::cuda::CUDAStream defaultStream = + at::cuda::getCurrentCUDAStream(gridBatch.device().index()); cudaStreamSynchronize(defaultStream.stream()); } @@ -424,36 +455,30 @@ void saveIndexGridWithBlindData(const std::string& path, } nanovdb::GridHandle -toNVDB(const GridBatch& gridBatch, - const torch::optional maybeData, +toNVDB(const GridBatch &gridBatch, const torch::optional maybeData, const torch::optional maybeNames) { - // Get optional names std::vector names; - if (maybeNames.has_value()) - { + if (maybeNames.has_value()) { names = maybeNames.value().value(); - TORCH_CHECK_VALUE(names.size() == 0 || names.size() == (size_t)gridBatch.grid_count(), - "Invalid parameter for names, must be empty or a list of the same length as the batch size. Got " + std::to_string(names.size()) + " names for batch size " + std::to_string(gridBatch.grid_count())); + TORCH_CHECK_VALUE( + names.size() == 0 || names.size() == (size_t)gridBatch.grid_count(), + "Invalid parameter for names, must be empty or a list of the same length as the batch size. Got " + + std::to_string(names.size()) + " names for batch size " + + std::to_string(gridBatch.grid_count())); } - if (maybeData.has_value()) - { + if (maybeData.has_value()) { return maybeConvertToStandardNanovdbGrid(gridBatch, maybeData.value(), names); - } - else - { + } else { return getIndexGrid(gridBatch, names); } } -void saveNVDB(const std::string& path, - const GridBatch& gridBatch, - const torch::optional maybeData, - const torch::optional maybeNames, - bool compressed, - bool verbose) { - +void +saveNVDB(const std::string &path, const GridBatch &gridBatch, + const torch::optional maybeData, + const torch::optional maybeNames, bool compressed, bool verbose) { // Which Codec to use for saving nanovdb::io::Codec codec = compressed ? nanovdb::io::Codec::BLOSC : nanovdb::io::Codec::NONE; @@ -461,9 +486,11 @@ void saveNVDB(const std::string& path, std::vector names; if (maybeNames.has_value()) { names = maybeNames.value().value(); - TORCH_CHECK_VALUE(names.size() == 0 || names.size() == (size_t) gridBatch.grid_count(), - "Invalid parameter for names, must be empty or a list of the same length as the batch size. Got " - + std::to_string(names.size()) + " names for batch size " + std::to_string(gridBatch.grid_count())); + TORCH_CHECK_VALUE( + names.size() == 0 || names.size() == (size_t)gridBatch.grid_count(), + "Invalid parameter for names, must be empty or a list of the same length as the batch size. Got " + + std::to_string(names.size()) + " names for batch size " + + std::to_string(gridBatch.grid_count())); } JaggedTensor data; @@ -475,20 +502,23 @@ void saveNVDB(const std::string& path, } TORCH_CHECK_VALUE(data.jdata().ndimension() >= 1, "Invalid jagged data shape in save_nvdb"); - TORCH_CHECK_VALUE(gridBatch.total_voxels() == data.jdata().size(0), "Invalid jagged data shape in save_nvdb. Must match number of voxels"); - TORCH_CHECK_VALUE(gridBatch.device() == data.device(), "Device should match between grid batch and data"); - - // Heuristically determine if we can use a standard nanovdb grid (e.g. vec3f, float, vec3i, etc...) to store the data - // If so, we save such a grid -- otherwise we save an index grid with custom blind data + TORCH_CHECK_VALUE(gridBatch.total_voxels() == data.jdata().size(0), + "Invalid jagged data shape in save_nvdb. Must match number of voxels"); + TORCH_CHECK_VALUE(gridBatch.device() == data.device(), + "Device should match between grid batch and data"); + + // Heuristically determine if we can use a standard nanovdb grid (e.g. vec3f, float, vec3i, + // etc...) to store the data If so, we save such a grid -- otherwise we save an index grid with + // custom blind data if (maybeSaveStandardNanovdbGrid(path, gridBatch, data, names, codec, verbose)) { return; } else { - // If we didn't manage to save a standard nanovdb grid, just save a tensor grid with blind data + // If we didn't manage to save a standard nanovdb grid, just save a tensor grid with blind + // data saveIndexGridWithBlindData(path, gridBatch, data, names, codec, verbose); } } - } // namespace io } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/ops/ActiveGridGoords.cu b/fvdb/src/detail/ops/ActiveGridGoords.cu index 598a2525d0..bf02e49593 100644 --- a/fvdb/src/detail/ops/ActiveGridGoords.cu +++ b/fvdb/src/detail/ops/ActiveGridGoords.cu @@ -1,117 +1,131 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include -#include +#include -#include "detail/utils/cuda/Utils.cuh" +#include +#include namespace fvdb { namespace detail { namespace ops { - /// @brief Per-voxel callback for getting the enabled grid coordinates in a batch of grids template typename TorchAccessor> -__hostdev__ inline void enabledGridCoordsVoxelCallback(int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, - GridBatchImpl::Accessor gridAccessor, - TorchAccessor leafBaseOffset, - TorchAccessor outGridCoords) { - const nanovdb::NanoGrid* grid = gridAccessor.grid(batchIdx); - const typename nanovdb::NanoGrid::LeafNodeType& leaf = grid->tree().template getFirstNode<0>()[leafIdx]; - const nanovdb::Coord& ijk = leaf.offsetToGlobalCoord(voxelIdx); - const int64_t outIdx = leafBaseOffset[leafIdx] + leaf.template get>(voxelIdx); +__hostdev__ inline void +enabledGridCoordsVoxelCallback(int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, + GridBatchImpl::Accessor gridAccessor, + TorchAccessor leafBaseOffset, + TorchAccessor outGridCoords) { + const nanovdb::NanoGrid *grid = gridAccessor.grid(batchIdx); + const typename nanovdb::NanoGrid::LeafNodeType &leaf = + grid->tree().template getFirstNode<0>()[leafIdx]; + const nanovdb::Coord &ijk = leaf.offsetToGlobalCoord(voxelIdx); + const int64_t outIdx = + leafBaseOffset[leafIdx] + leaf.template get>(voxelIdx); if (leaf.template get>(voxelIdx)) { outGridCoords[outIdx][0] = ijk[0]; outGridCoords[outIdx][1] = ijk[1]; outGridCoords[outIdx][2] = ijk[2]; } - } - /// @brief Per-voxel callback which computes the active grid coordinates for a batch of grids template typename TorchAccessor> -__hostdev__ inline void activeGridCoordsVoxelCallback(int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, - GridBatchImpl::Accessor gridAccessor, - TorchAccessor outGridCoords) { - - const nanovdb::NanoGrid* grid = gridAccessor.grid(batchIdx); - const typename nanovdb::NanoGrid::LeafNodeType& leaf = grid->tree().template getFirstNode<0>()[leafIdx]; +__hostdev__ inline void +activeGridCoordsVoxelCallback(int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, + GridBatchImpl::Accessor gridAccessor, + TorchAccessor outGridCoords) { + const nanovdb::NanoGrid *grid = gridAccessor.grid(batchIdx); + const typename nanovdb::NanoGrid::LeafNodeType &leaf = + grid->tree().template getFirstNode<0>()[leafIdx]; const int64_t baseOffset = gridAccessor.voxelOffset(batchIdx); - - const nanovdb::Coord& ijk = leaf.offsetToGlobalCoord(voxelIdx); + const nanovdb::Coord &ijk = leaf.offsetToGlobalCoord(voxelIdx); if (leaf.isActive(voxelIdx)) { - const int64_t idx = baseOffset + (int64_t) leaf.getValue(voxelIdx) - 1; + const int64_t idx = baseOffset + (int64_t)leaf.getValue(voxelIdx) - 1; outGridCoords[idx][0] = ijk[0]; outGridCoords[idx][1] = ijk[1]; outGridCoords[idx][2] = ijk[2]; } } - /// @brief Get the enabled grid coordinates for a batch of grids (ignoring disabled voxels) /// @param gridBatch The batch of grids (must be mutable) /// @param outGridCoords Tensor which will contain the output grid coordinates template -void GetEnabledGridCoords(const GridBatchImpl& gridBatch, torch::Tensor& outGridCoords) { +void +GetEnabledGridCoords(const GridBatchImpl &gridBatch, torch::Tensor &outGridCoords) { using GridType = nanovdb::ValueOnIndexMask; // Compute a prefix sum of the unmasked voxels per leaf - const torch::Tensor leafBaseOffset = countEnabledPerLeafShiftedByOne(gridBatch).cumsum(0, torch::kInt64); + const torch::Tensor leafBaseOffset = + countEnabledPerLeafShiftedByOne(gridBatch).cumsum(0, torch::kInt64); // Get the unmasked grid coordinates auto leafBaseOffsetAcc = tensorAccessor(leafBaseOffset); - auto outCoordsAcc = tensorAccessor(outGridCoords); + auto outCoordsAcc = tensorAccessor(outGridCoords); if constexpr (DeviceTag == torch::kCUDA) { - auto cb = [=] __device__ (int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, int64_t, GridBatchImpl::Accessor gridAccessor) { - enabledGridCoordsVoxelCallback(batchIdx, leafIdx, voxelIdx, gridAccessor, leafBaseOffsetAcc, outCoordsAcc); + auto cb = [=] __device__(int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, int64_t, + GridBatchImpl::Accessor gridAccessor) { + enabledGridCoordsVoxelCallback( + batchIdx, leafIdx, voxelIdx, gridAccessor, leafBaseOffsetAcc, outCoordsAcc); }; forEachVoxelCUDA(1024, 1, gridBatch, cb); } else { - auto cb = [=] (int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, int64_t, GridBatchImpl::Accessor gridAccessor) { - enabledGridCoordsVoxelCallback(batchIdx, leafIdx, voxelIdx, gridAccessor, leafBaseOffsetAcc, outCoordsAcc); + auto cb = [=](int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, int64_t, + GridBatchImpl::Accessor gridAccessor) { + enabledGridCoordsVoxelCallback( + batchIdx, leafIdx, voxelIdx, gridAccessor, leafBaseOffsetAcc, outCoordsAcc); }; forEachVoxelCPU(1, gridBatch, cb); } } - -/// @brief Get the active grid coordinates for a batch of grids (including disabled coordinates in mutable grids) +/// @brief Get the active grid coordinates for a batch of grids (including disabled coordinates in +/// mutable grids) /// @tparam GridType The type of the grid (one of ValueOnIndex, ValueOnIndexMask) /// @param gridBatch The batch of grids /// @param outGridCoords Tensor which will contain the output grid coordinates template -void GetActiveGridCoords(const GridBatchImpl& gridBatch, torch::Tensor& outGridCoords) { +void +GetActiveGridCoords(const GridBatchImpl &gridBatch, torch::Tensor &outGridCoords) { auto outCoordsAcc = tensorAccessor(outGridCoords); if constexpr (DeviceTag == torch::kCUDA) { - auto cb = [=] __device__ (int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, int64_t, GridBatchImpl::Accessor gridAccessor) { - activeGridCoordsVoxelCallback(batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc); + auto cb = [=] __device__(int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, int64_t, + GridBatchImpl::Accessor gridAccessor) { + activeGridCoordsVoxelCallback(batchIdx, leafIdx, voxelIdx, + gridAccessor, outCoordsAcc); }; forEachVoxelCUDA(1024, 1, gridBatch, cb); } else { - auto cb = [=] (int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, int64_t, GridBatchImpl::Accessor gridAccessor) { - activeGridCoordsVoxelCallback(batchIdx, leafIdx, voxelIdx, gridAccessor, outCoordsAcc); + auto cb = [=](int64_t batchIdx, int64_t leafIdx, int64_t voxelIdx, int64_t, + GridBatchImpl::Accessor gridAccessor) { + activeGridCoordsVoxelCallback(batchIdx, leafIdx, voxelIdx, + gridAccessor, outCoordsAcc); }; forEachVoxelCPU(1, gridBatch, cb); } } - -/// @brief Get the number of active (or enabled for mutable grids) ijk coordiantes in a batch of grids +/// @brief Get the number of active (or enabled for mutable grids) ijk coordiantes in a batch of +/// grids /// @tparam DeviceTag Which device to run on /// @param gridBatch The batch of grids to get the active coordinates for -/// @param ignoreDisabledVoxels If set to true, and the grid batch is mutable, also return coordinates that are disabled +/// @param ignoreDisabledVoxels If set to true, and the grid batch is mutable, also return +/// coordinates that are disabled /// @return A JaggedTensor or shape [B, -1, 3] of active/enabled IJK coordinates template -JaggedTensor ActiveGridCoords(const GridBatchImpl& gridBatch, bool ignoreDisabledVoxels) { +JaggedTensor +ActiveGridCoords(const GridBatchImpl &gridBatch, bool ignoreDisabledVoxels) { gridBatch.checkNonEmptyGrid(); - auto opts = torch::TensorOptions().dtype(torch::kInt32).device(gridBatch.device()); - torch::Tensor outGridCoords = torch::empty({gridBatch.totalEnabledVoxels(ignoreDisabledVoxels), 3}, opts); + auto opts = torch::TensorOptions().dtype(torch::kInt32).device(gridBatch.device()); + torch::Tensor outGridCoords = + torch::empty({ gridBatch.totalEnabledVoxels(ignoreDisabledVoxels), 3 }, opts); FVDB_DISPATCH_GRID_TYPES(gridBatch, [&]() { - if (ignoreDisabledVoxels || nanovdb::util::is_same::value) { + if (ignoreDisabledVoxels || + nanovdb::util::is_same::value) { GetActiveGridCoords(gridBatch, outGridCoords); } else if (nanovdb::util::is_same::value) { TORCH_CHECK(!ignoreDisabledVoxels, "This should never happen"); @@ -121,19 +135,18 @@ JaggedTensor ActiveGridCoords(const GridBatchImpl& gridBatch, bool ignoreDisable return gridBatch.jaggedTensor(outGridCoords, ignoreDisabledVoxels); } - - template <> -JaggedTensor dispatchActiveGridCoords(const GridBatchImpl& gridBatch, bool ignoreMasked) { +JaggedTensor +dispatchActiveGridCoords(const GridBatchImpl &gridBatch, bool ignoreMasked) { return ActiveGridCoords(gridBatch, ignoreMasked); } template <> -JaggedTensor dispatchActiveGridCoords(const GridBatchImpl& gridBatch, bool ignoreMasked) { +JaggedTensor +dispatchActiveGridCoords(const GridBatchImpl &gridBatch, bool ignoreMasked) { return ActiveGridCoords(gridBatch, ignoreMasked); } - } // namespace ops } // namespace detail } // namespace fvdb \ No newline at end of file diff --git a/fvdb/src/detail/ops/ActiveVoxelsInBoundsMask.cu b/fvdb/src/detail/ops/ActiveVoxelsInBoundsMask.cu index d6527d9395..8b646394de 100644 --- a/fvdb/src/detail/ops/ActiveVoxelsInBoundsMask.cu +++ b/fvdb/src/detail/ops/ActiveVoxelsInBoundsMask.cu @@ -1,133 +1,153 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include +#include +#include -#include "detail/utils/cuda/Utils.cuh" -#include "detail/utils/nanovdb/CustomAccessors.h" +#include namespace fvdb { namespace detail { namespace ops { - -/// @brief Per-voxel callback to compute a mask of the enabled voxels in a bounding box for a batch of grids +/// @brief Per-voxel callback to compute a mask of the enabled voxels in a bounding box for a batch +/// of grids template typename TorchAccessor> -__hostdev__ inline void enabledGridVoxelInBoundsMaskCallback(int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, - GridBatchImpl::Accessor gridAccessor, - TorchAccessor leafBaseOffset, - TorchAccessor bboxes, - TorchAccessor outGridBoundsMask) { - const nanovdb::CoordBBox maskBbox(nanovdb::Coord(bboxes[batchIdx][0][0], bboxes[batchIdx][0][1], bboxes[batchIdx][0][2]), - nanovdb::Coord(bboxes[batchIdx][1][0], bboxes[batchIdx][1][1], bboxes[batchIdx][1][2])); - - const nanovdb::NanoGrid *grid = gridAccessor.grid(batchIdx); - const typename nanovdb::NanoGrid::LeafNodeType& leaf = grid->tree().template getFirstNode<0>()[leafIdx]; +__hostdev__ inline void +enabledGridVoxelInBoundsMaskCallback(int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, + GridBatchImpl::Accessor gridAccessor, + TorchAccessor leafBaseOffset, + TorchAccessor bboxes, + TorchAccessor outGridBoundsMask) { + const nanovdb::CoordBBox maskBbox( + nanovdb::Coord(bboxes[batchIdx][0][0], bboxes[batchIdx][0][1], bboxes[batchIdx][0][2]), + nanovdb::Coord(bboxes[batchIdx][1][0], bboxes[batchIdx][1][1], bboxes[batchIdx][1][2])); + + const nanovdb::NanoGrid *grid = gridAccessor.grid(batchIdx); + const typename nanovdb::NanoGrid::LeafNodeType &leaf = + grid->tree().template getFirstNode<0>()[leafIdx]; if (maskBbox.hasOverlap(leaf.bbox())) { const nanovdb::Coord ijk = leaf.offsetToGlobalCoord(voxelIdx); if (leaf.template get>(voxelIdx) && maskBbox.isInside(ijk)) { - const int64_t outIdx = leafBaseOffset[leafIdx] + leaf.template get>(voxelIdx); + const int64_t outIdx = + leafBaseOffset[leafIdx] + leaf.template get>(voxelIdx); outGridBoundsMask[outIdx] = true; } } } -/// @brief Per-voxel callback to compute a mask of the active grid voxels in a bounding box for a batch of grids +/// @brief Per-voxel callback to compute a mask of the active grid voxels in a bounding box for a +/// batch of grids template typename TorchAccessor> -__hostdev__ inline void activeGridVoxelInBoundsMaskCallback(int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, - GridBatchImpl::Accessor gridAccessor, - TorchAccessor bboxes, - TorchAccessor outGridBoundsMask) { - - const nanovdb::CoordBBox maskBbox(nanovdb::Coord(bboxes[batchIdx][0][0], bboxes[batchIdx][0][1], bboxes[batchIdx][0][2]), - nanovdb::Coord(bboxes[batchIdx][1][0], bboxes[batchIdx][1][1], bboxes[batchIdx][1][2])); - - const nanovdb::NanoGrid *grid = gridAccessor.grid(batchIdx); - const typename nanovdb::NanoGrid::LeafNodeType& leaf = grid->tree().template getFirstNode<0>()[leafIdx]; +__hostdev__ inline void +activeGridVoxelInBoundsMaskCallback(int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, + GridBatchImpl::Accessor gridAccessor, + TorchAccessor bboxes, + TorchAccessor outGridBoundsMask) { + const nanovdb::CoordBBox maskBbox( + nanovdb::Coord(bboxes[batchIdx][0][0], bboxes[batchIdx][0][1], bboxes[batchIdx][0][2]), + nanovdb::Coord(bboxes[batchIdx][1][0], bboxes[batchIdx][1][1], bboxes[batchIdx][1][2])); + + const nanovdb::NanoGrid *grid = gridAccessor.grid(batchIdx); + const typename nanovdb::NanoGrid::LeafNodeType &leaf = + grid->tree().template getFirstNode<0>()[leafIdx]; if (maskBbox.hasOverlap(leaf.bbox())) { const nanovdb::Coord ijk = leaf.offsetToGlobalCoord(voxelIdx); if (leaf.isActive(voxelIdx) && maskBbox.isInside(ijk)) { const int64_t baseOffset = gridAccessor.voxelOffset(batchIdx); - const int64_t idx = baseOffset + (int64_t)leaf.getValue(voxelIdx) - 1; - outGridBoundsMask[idx] = true; + const int64_t idx = baseOffset + (int64_t)leaf.getValue(voxelIdx) - 1; + outGridBoundsMask[idx] = true; } } } -/// @brief Get a boolean mask of the enabled grid voxels for a batch of grids (ignoring disabled voxels) +/// @brief Get a boolean mask of the enabled grid voxels for a batch of grids (ignoring disabled +/// voxels) /// @param gridBatch The batch of grids (must be mutable) /// @param batchBboxes The batch of bounding boxes /// @param outGridCoords Tensor which will contain the output grid coordinates template -void GetEnabledVoxelsInBoundsMask(const GridBatchImpl& gridBatch, - torch::Tensor& batchBboxes, - torch::Tensor& outGridBoundsMask) { +void +GetEnabledVoxelsInBoundsMask(const GridBatchImpl &gridBatch, torch::Tensor &batchBboxes, + torch::Tensor &outGridBoundsMask) { using GridType = nanovdb::ValueOnIndexMask; // Compute a prefix sum of the unmasked voxels per leaf - const torch::Tensor leafBaseOffset = countEnabledPerLeafShiftedByOne(gridBatch).cumsum(0, torch::kInt64); + const torch::Tensor leafBaseOffset = + countEnabledPerLeafShiftedByOne(gridBatch).cumsum(0, torch::kInt64); // Get the unmasked grid coordinates auto leafBaseOffsetAcc = tensorAccessor(leafBaseOffset); - auto outMaskAcc = tensorAccessor(outGridBoundsMask); - auto bboxAcc = tensorAccessor(batchBboxes); + auto outMaskAcc = tensorAccessor(outGridBoundsMask); + auto bboxAcc = tensorAccessor(batchBboxes); if constexpr (DeviceTag == torch::kCUDA) { - auto cb = [=] __device__(int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, int32_t, GridBatchImpl::Accessor gridAccessor) { - enabledGridVoxelInBoundsMaskCallback(batchIdx, leafIdx, voxelIdx, gridAccessor, leafBaseOffsetAcc, bboxAcc, outMaskAcc); + auto cb = [=] __device__(int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, int32_t, + GridBatchImpl::Accessor gridAccessor) { + enabledGridVoxelInBoundsMaskCallback( + batchIdx, leafIdx, voxelIdx, gridAccessor, leafBaseOffsetAcc, bboxAcc, outMaskAcc); }; forEachVoxelCUDA(1024, 1, gridBatch, cb); } else { - auto cb = [=](int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, int32_t, GridBatchImpl::Accessor gridAccessor) { - enabledGridVoxelInBoundsMaskCallback(batchIdx, leafIdx, voxelIdx, gridAccessor, leafBaseOffsetAcc, bboxAcc, outMaskAcc); + auto cb = [=](int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, int32_t, + GridBatchImpl::Accessor gridAccessor) { + enabledGridVoxelInBoundsMaskCallback( + batchIdx, leafIdx, voxelIdx, gridAccessor, leafBaseOffsetAcc, bboxAcc, outMaskAcc); }; forEachVoxelCPU(1, gridBatch, cb); } } -/// @brief Get a boolean mask of the active grid voxels for a batch of grids (including disabled coordinates in mutable grids) +/// @brief Get a boolean mask of the active grid voxels for a batch of grids (including disabled +/// coordinates in mutable grids) /// @tparam GridType The type of the grid (one of ValueOnIndex, ValueOnIndexMask) /// @param gridBatch The batch of grids /// @param batchBboxes The batch of bounding boxes /// @param outGridCoords Tensor which will contain the output grid coordinates template -void GetActiveVoxelsInBoundsMask(const GridBatchImpl& gridBatch, - torch::Tensor& batchBboxes, - torch::Tensor& outGridBoundsMask) { +void +GetActiveVoxelsInBoundsMask(const GridBatchImpl &gridBatch, torch::Tensor &batchBboxes, + torch::Tensor &outGridBoundsMask) { auto outMaskAcc = tensorAccessor(outGridBoundsMask); - auto bboxAcc = tensorAccessor(batchBboxes); + auto bboxAcc = tensorAccessor(batchBboxes); if constexpr (DeviceTag == torch::kCUDA) { - auto cb = [=] __device__(int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, int32_t, GridBatchImpl::Accessor gridAccessor) { - activeGridVoxelInBoundsMaskCallback(batchIdx, leafIdx, voxelIdx, gridAccessor, bboxAcc, outMaskAcc); + auto cb = [=] __device__(int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, int32_t, + GridBatchImpl::Accessor gridAccessor) { + activeGridVoxelInBoundsMaskCallback( + batchIdx, leafIdx, voxelIdx, gridAccessor, bboxAcc, outMaskAcc); }; forEachVoxelCUDA(1024, 1, gridBatch, cb); } else { - auto cb = [=](int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, int32_t, GridBatchImpl::Accessor gridAccessor) { - activeGridVoxelInBoundsMaskCallback(batchIdx, leafIdx, voxelIdx, gridAccessor, bboxAcc, outMaskAcc); + auto cb = [=](int32_t batchIdx, int32_t leafIdx, int32_t voxelIdx, int32_t, + GridBatchImpl::Accessor gridAccessor) { + activeGridVoxelInBoundsMaskCallback( + batchIdx, leafIdx, voxelIdx, gridAccessor, bboxAcc, outMaskAcc); }; forEachVoxelCPU(1, gridBatch, cb); } } template -JaggedTensor ActiveVoxelsInBoundsMask(const GridBatchImpl& batchHdl, - const Vec3iBatch& ijkMin, - const Vec3iBatch& ijkMax, - bool ignoreDisabledVoxels) { - +JaggedTensor +ActiveVoxelsInBoundsMask(const GridBatchImpl &batchHdl, const Vec3iBatch &ijkMin, + const Vec3iBatch &ijkMax, bool ignoreDisabledVoxels) { batchHdl.checkNonEmptyGrid(); // output storage - auto opts = torch::TensorOptions().dtype(torch::kBool).device(batchHdl.device()); - torch::Tensor outGridBoundsMask = torch::zeros({batchHdl.totalEnabledVoxels(ignoreDisabledVoxels)}, opts); + auto opts = torch::TensorOptions().dtype(torch::kBool).device(batchHdl.device()); + torch::Tensor outGridBoundsMask = + torch::zeros({ batchHdl.totalEnabledVoxels(ignoreDisabledVoxels) }, opts); // bbox to tensor storage - const std::vector& bboxMins = ijkMin.value(batchHdl.batchSize(), false, "ijk_min"); - const std::vector& bboxMaxs = ijkMax.value(batchHdl.batchSize(), false, "ijk_max"); + const std::vector &bboxMins = + ijkMin.value(batchHdl.batchSize(), false, "ijk_min"); + const std::vector &bboxMaxs = + ijkMax.value(batchHdl.batchSize(), false, "ijk_max"); - torch::Tensor batchBboxes = torch::empty({batchHdl.batchSize(), 2, 3}, - torch::TensorOptions().dtype(torch::kInt32).device(batchHdl.device())); + torch::Tensor batchBboxes = + torch::empty({ batchHdl.batchSize(), 2, 3 }, + torch::TensorOptions().dtype(torch::kInt32).device(batchHdl.device())); for (size_t batchIdx = 0; batchIdx < batchHdl.batchSize(); batchIdx++) { for (size_t dimIdx = 0; dimIdx < 3; dimIdx++) { @@ -138,8 +158,10 @@ JaggedTensor ActiveVoxelsInBoundsMask(const GridBatchImpl& batchHdl, // create boolean mask of active voxels FVDB_DISPATCH_GRID_TYPES(batchHdl, [&]() { - if (ignoreDisabledVoxels || nanovdb::util::is_same::value) { - GetActiveVoxelsInBoundsMask(batchHdl, batchBboxes, outGridBoundsMask); + if (ignoreDisabledVoxels || + nanovdb::util::is_same::value) { + GetActiveVoxelsInBoundsMask(batchHdl, batchBboxes, + outGridBoundsMask); } else if (nanovdb::util::is_same::value) { TORCH_CHECK(!ignoreDisabledVoxels, "This should never happen"); GetEnabledVoxelsInBoundsMask(batchHdl, batchBboxes, outGridBoundsMask); @@ -149,21 +171,24 @@ JaggedTensor ActiveVoxelsInBoundsMask(const GridBatchImpl& batchHdl, return batchHdl.jaggedTensor(outGridBoundsMask, ignoreDisabledVoxels); } - template <> -JaggedTensor dispatchActiveVoxelsInBoundsMask(const GridBatchImpl& batchHdl, - const Vec3iBatch& boundsMinIjk, - const Vec3iBatch& boundsMaxIjk, - bool ignoreDisabledVoxels) { - return ActiveVoxelsInBoundsMask(batchHdl, boundsMinIjk, boundsMaxIjk, ignoreDisabledVoxels); +JaggedTensor +dispatchActiveVoxelsInBoundsMask(const GridBatchImpl &batchHdl, + const Vec3iBatch &boundsMinIjk, + const Vec3iBatch &boundsMaxIjk, + bool ignoreDisabledVoxels) { + return ActiveVoxelsInBoundsMask(batchHdl, boundsMinIjk, boundsMaxIjk, + ignoreDisabledVoxels); } template <> -JaggedTensor dispatchActiveVoxelsInBoundsMask(const GridBatchImpl& batchHdl, - const Vec3iBatch& boundsMinIjk, - const Vec3iBatch& boundsMaxIjk, - bool ignoreDisabledVoxels) { - return ActiveVoxelsInBoundsMask(batchHdl, boundsMinIjk, boundsMaxIjk, ignoreDisabledVoxels); +JaggedTensor +dispatchActiveVoxelsInBoundsMask(const GridBatchImpl &batchHdl, + const Vec3iBatch &boundsMinIjk, + const Vec3iBatch &boundsMaxIjk, + bool ignoreDisabledVoxels) { + return ActiveVoxelsInBoundsMask(batchHdl, boundsMinIjk, boundsMaxIjk, + ignoreDisabledVoxels); } } // namespace ops diff --git a/fvdb/src/detail/ops/BuildDeviceGrid.cu b/fvdb/src/detail/ops/BuildDeviceGrid.cu index 790875d8ee..11288be6ac 100644 --- a/fvdb/src/detail/ops/BuildDeviceGrid.cu +++ b/fvdb/src/detail/ops/BuildDeviceGrid.cu @@ -1,37 +1,32 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include "detail/utils/Utils.h" -#include "detail/utils/cuda/Utils.cuh" +#include +#include -#include "detail/GridBatchImpl.h" -#include "detail/build/Build.h" +#include +#include + +#include -#include #include #include +#include #include -#include - - namespace fvdb { namespace detail { namespace ops { - template typename TensorAccessorT> -__hostdev__ void populateGridMetadataKernel( - uint32_t numGrids, - const nanovdb::NanoGrid* grids, - const nanovdb::Vec3d* voxelSizes, - const nanovdb::Vec3d* voxelOrigins, - TensorAccessorT gridOffsets, - GridBatchImpl::GridMetadata* perGridMetadata, - GridBatchImpl::GridBatchMetadata* batchMetadata) { - - batchMetadata->mMaxVoxels = 0; +__hostdev__ void +populateGridMetadataKernel(uint32_t numGrids, const nanovdb::NanoGrid *grids, + const nanovdb::Vec3d *voxelSizes, const nanovdb::Vec3d *voxelOrigins, + TensorAccessorT gridOffsets, + GridBatchImpl::GridMetadata *perGridMetadata, + GridBatchImpl::GridBatchMetadata *batchMetadata) { + batchMetadata->mMaxVoxels = 0; batchMetadata->mMaxLeafCount = 0; batchMetadata->mIsMutable = nanovdb::util::is_same::value; @@ -39,90 +34,91 @@ __hostdev__ void populateGridMetadataKernel( nanovdb::Coord bbMin = nanovdb::Coord::max(); nanovdb::Coord bbMax = nanovdb::Coord::min(); - nanovdb::NanoGrid* currentGrid = (nanovdb::NanoGrid*) &grids[0]; - uint32_t i = 0; - uint64_t byteCount = 0; + nanovdb::NanoGrid *currentGrid = (nanovdb::NanoGrid *)&grids[0]; + uint32_t i = 0; + uint64_t byteCount = 0; perGridMetadata[i].mCumVoxels = 0; - perGridMetadata[i].mCumBytes = 0; + perGridMetadata[i].mCumBytes = 0; perGridMetadata[i].mCumLeaves = 0; gridOffsets[i] = 0; while (i < numGrids - 1) { - byteCount = currentGrid->gridSize(); - const uint32_t leafCount = currentGrid->tree().nodeCount(0); + byteCount = currentGrid->gridSize(); + const uint32_t leafCount = currentGrid->tree().nodeCount(0); const uint64_t voxelCount = currentGrid->tree().activeVoxelCount(); - GridBatchImpl::GridMetadata& metaCur = perGridMetadata[i]; - GridBatchImpl::GridMetadata& metaNext = perGridMetadata[i + 1]; + GridBatchImpl::GridMetadata &metaCur = perGridMetadata[i]; + GridBatchImpl::GridMetadata &metaNext = perGridMetadata[i + 1]; metaCur.setTransform(voxelSizes[i], voxelOrigins[i]); metaCur.mNumVoxels = voxelCount; - metaCur.mNumBytes = byteCount; + metaCur.mNumBytes = byteCount; metaCur.mNumLeaves = leafCount; - metaCur.mBBox = currentGrid->tree().bbox(); + metaCur.mBBox = currentGrid->tree().bbox(); metaNext.mCumVoxels = metaCur.mCumVoxels + voxelCount; - metaNext.mCumBytes = metaCur.mCumBytes + byteCount; + metaNext.mCumBytes = metaCur.mCumBytes + byteCount; metaNext.mCumLeaves = metaCur.mCumLeaves + leafCount; - gridOffsets[i+1] = metaCur.mCumVoxels + metaCur.mNumVoxels; + gridOffsets[i + 1] = metaCur.mCumVoxels + metaCur.mNumVoxels; // number of voxels exceeds maximum indexable value assert(voxelCount <= std::numeric_limits::max()); - batchMetadata->mMaxVoxels = max(batchMetadata->mMaxVoxels, static_cast(voxelCount)); + batchMetadata->mMaxVoxels = + max(batchMetadata->mMaxVoxels, static_cast(voxelCount)); batchMetadata->mMaxLeafCount = max(batchMetadata->mMaxLeafCount, leafCount); - bbMin = bbMin.minComponent(currentGrid->tree().bbox().min()); - bbMax = bbMax.maxComponent(currentGrid->tree().bbox().max()); - currentGrid = (nanovdb::NanoGrid*) (((uint8_t*) currentGrid) + byteCount); + bbMin = bbMin.minComponent(currentGrid->tree().bbox().min()); + bbMax = bbMax.maxComponent(currentGrid->tree().bbox().max()); + currentGrid = (nanovdb::NanoGrid *)(((uint8_t *)currentGrid) + byteCount); i += 1; } perGridMetadata[i].setTransform(voxelSizes[i], voxelOrigins[i]); perGridMetadata[i].mNumVoxels = currentGrid->tree().activeVoxelCount(); - perGridMetadata[i].mNumBytes = currentGrid->gridSize(); + perGridMetadata[i].mNumBytes = currentGrid->gridSize(); perGridMetadata[i].mNumLeaves = currentGrid->tree().nodeCount(0); - perGridMetadata[i].mBBox = currentGrid->tree().bbox(); + perGridMetadata[i].mBBox = currentGrid->tree().bbox(); - gridOffsets[i+1] = perGridMetadata[i].mCumVoxels + perGridMetadata[i].mNumVoxels; + gridOffsets[i + 1] = perGridMetadata[i].mCumVoxels + perGridMetadata[i].mNumVoxels; - batchMetadata->mMaxVoxels = max(batchMetadata->mMaxVoxels, perGridMetadata[i].mNumVoxels); + batchMetadata->mMaxVoxels = max(batchMetadata->mMaxVoxels, perGridMetadata[i].mNumVoxels); batchMetadata->mMaxLeafCount = max(batchMetadata->mMaxLeafCount, perGridMetadata[i].mNumLeaves); // number of voxels exceeds maximum indexable value - assert(perGridMetadata[i].mCumVoxels + perGridMetadata[i].mNumVoxels <= std::numeric_limits::max()); + assert(perGridMetadata[i].mCumVoxels + perGridMetadata[i].mNumVoxels <= + std::numeric_limits::max()); batchMetadata->mTotalVoxels = perGridMetadata[i].mCumVoxels + perGridMetadata[i].mNumVoxels; // number of grid leaf nodes exceeds maximum indexable value - assert(perGridMetadata[i].mCumLeaves + perGridMetadata[i].mNumLeaves <= std::numeric_limits::max()); + assert(perGridMetadata[i].mCumLeaves + perGridMetadata[i].mNumLeaves <= + std::numeric_limits::max()); batchMetadata->mTotalLeaves = perGridMetadata[i].mCumLeaves + perGridMetadata[i].mNumLeaves; - bbMin = bbMin.minComponent(currentGrid->tree().bbox().min()); - bbMax = bbMax.maxComponent(currentGrid->tree().bbox().max()); + bbMin = bbMin.minComponent(currentGrid->tree().bbox().min()); + bbMax = bbMax.maxComponent(currentGrid->tree().bbox().max()); batchMetadata->mTotalBBox = nanovdb::CoordBBox(bbMin, bbMax); batchMetadata->mIsMutable = nanovdb::util::is_same::value; } - template typename TensorAccessorT> -__global__ void populateGridMetadataCUDA( - uint32_t numGrids, - const nanovdb::NanoGrid* grids, - const nanovdb::Vec3d* voxelSizes, - const nanovdb::Vec3d* voxelOrigins, - TensorAccessorT outBatchOffsets, - GridBatchImpl::GridMetadata* perGridMetadata, - GridBatchImpl::GridBatchMetadata* batchMetadata) { - - populateGridMetadataKernel(numGrids, grids, voxelSizes, voxelOrigins, outBatchOffsets, perGridMetadata, batchMetadata); +__global__ void +populateGridMetadataCUDA(uint32_t numGrids, const nanovdb::NanoGrid *grids, + const nanovdb::Vec3d *voxelSizes, const nanovdb::Vec3d *voxelOrigins, + TensorAccessorT outBatchOffsets, + GridBatchImpl::GridMetadata *perGridMetadata, + GridBatchImpl::GridBatchMetadata *batchMetadata) { + populateGridMetadataKernel( + numGrids, grids, voxelSizes, voxelOrigins, outBatchOffsets, perGridMetadata, batchMetadata); } - -__global__ void ijkForDense(nanovdb::Coord origin, nanovdb::Coord size, TorchRAcc32 outIJKAccessor) { - const int32_t w = size[0], h = size[1], d = size[2]; - const uint64_t tid = (static_cast(blockIdx.x) * blockDim.x) + threadIdx.x; // = x * (h * d) + y * d + z) +__global__ void +ijkForDense(nanovdb::Coord origin, nanovdb::Coord size, TorchRAcc32 outIJKAccessor) { + const int32_t w = size[0], h = size[1], d = size[2]; + const uint64_t tid = (static_cast(blockIdx.x) * blockDim.x) + + threadIdx.x; // = x * (h * d) + y * d + z) if (tid >= outIJKAccessor.size(0)) { return; @@ -137,34 +133,35 @@ __global__ void ijkForDense(nanovdb::Coord origin, nanovdb::Coord size, TorchRAc outIJKAccessor[tid][2] = zi + origin[2]; } - struct NanoVDBGridBuilderTorchAllocator { - std::set mAllocatedData; + std::set mAllocatedData; - cudaError_t DeviceAllocate(void** ptr, size_t size, cudaStream_t stream) { + cudaError_t + DeviceAllocate(void **ptr, size_t size, cudaStream_t stream) { *ptr = c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(size, stream); mAllocatedData.insert(*ptr); - return (cudaError_t) CUDA_SUCCESS; + return (cudaError_t)CUDA_SUCCESS; } - cudaError_t DeviceFree(void* ptr) { + cudaError_t + DeviceFree(void *ptr) { c10::cuda::CUDACachingAllocator::raw_delete(ptr); mAllocatedData.erase(ptr); - return (cudaError_t) CUDA_SUCCESS; + return (cudaError_t)CUDA_SUCCESS; } - void FreeAllCached() { - for (void* ptr : mAllocatedData) { + void + FreeAllCached() { + for (void *ptr: mAllocatedData) { c10::cuda::CUDACachingAllocator::raw_delete(ptr); } mAllocatedData.clear(); } }; - template <> -nanovdb::GridHandle dispatchCreateNanoGridFromIJK( - const JaggedTensor& ijk, bool isMutable) { +nanovdb::GridHandle +dispatchCreateNanoGridFromIJK(const JaggedTensor &ijk, bool isMutable) { TORCH_CHECK(ijk.is_contiguous(), "ijk must be contiguous"); TORCH_CHECK(ijk.device().is_cuda(), "device must be cuda"); TORCH_CHECK(ijk.device().has_index(), "device must have index"); @@ -175,15 +172,17 @@ nanovdb::GridHandle dispatchCreateNanoGridFromIJK ret = FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { - // This guide buffer is a hack to pass in a device with an index to the cudaCreateNanoGrid function. We can't pass in a device directly - // but we can pass in a buffer which gets passed to TorchDeviceBuffer::create. The guide buffer holds the device and - // effectively passes it to the created buffer. + // This guide buffer is a hack to pass in a device with an index to the cudaCreateNanoGrid + // function. We can't pass in a device directly but we can pass in a buffer which gets + // passed to TorchDeviceBuffer::create. The guide buffer holds the device and effectively + // passes it to the created buffer. TorchDeviceBuffer guide(0, nullptr, false, ijk.device().index()); - // FIXME: This is slow because we have to copy this data to the host and then build the grids. Ideally we want to do this in a single invocation. + // FIXME: This is slow because we have to copy this data to the host and then build the + // grids. Ideally we want to do this in a single invocation. torch::Tensor ijkBOffsetTensor = ijk.joffsets().cpu(); - auto ijkBOffset = ijkBOffsetTensor.accessor(); - torch::Tensor ijkData = ijk.jdata(); + auto ijkBOffset = ijkBOffsetTensor.accessor(); + torch::Tensor ijkData = ijk.jdata(); TORCH_CHECK(ijkData.is_contiguous(), "ijk must be contiguous"); TORCH_CHECK(ijkData.dim() == 2, "ijk must have shape (N, 3)"); TORCH_CHECK(ijkData.size(1) == 3, "ijk must have shape (N, 3)"); @@ -192,13 +191,16 @@ nanovdb::GridHandle dispatchCreateNanoGridFromIJK> handles; for (int i = 0; i < (ijkBOffset.size(0) - 1); i += 1) { const int64_t startIdx = ijkBOffset[i]; - const int64_t nVoxels = ijkBOffset[i+1] - startIdx; + const int64_t nVoxels = ijkBOffset[i + 1] - startIdx; // torch::Tensor ijkDataSlice = ijkData.narrow(0, startIdx, nVoxels); - const int32_t* dataPtr = ijkData.data_ptr() + 3 * startIdx; - - handles.push_back(nVoxels == 0 ? build::buildEmptyGrid(guide.device(), isMutable) : - nanovdb::tools::cuda::voxelsToGrid( - (nanovdb::Coord*) dataPtr, nVoxels, 1.0, guide)); + const int32_t *dataPtr = ijkData.data_ptr() + 3 * startIdx; + + handles.push_back( + nVoxels == 0 ? build::buildEmptyGrid(guide.device(), isMutable) + : nanovdb::tools::cuda::voxelsToGrid( + (nanovdb::Coord *)dataPtr, nVoxels, 1.0, guide)); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -206,7 +208,8 @@ nanovdb::GridHandle dispatchCreateNanoGridFromIJK dispatchCreateNanoGridFromIJK -nanovdb::GridHandle dispatchCreateNanoGridFromDense(uint32_t batchSize, - nanovdb::Coord origin, - nanovdb::Coord size, - bool isMutable, - torch::Device device, - const torch::optional& maybeMask) { +nanovdb::GridHandle +dispatchCreateNanoGridFromDense(uint32_t batchSize, nanovdb::Coord origin, + nanovdb::Coord size, bool isMutable, + torch::Device device, + const torch::optional &maybeMask) { TORCH_CHECK(device.is_cuda(), "device must be cuda"); TORCH_CHECK(device.has_index(), "device must have index"); @@ -230,29 +231,28 @@ nanovdb::GridHandle dispatchCreateNanoGridFromDense(size[0]) * size[1] * size[2]; constexpr int NUM_THREADS = 1024; - const int64_t NUM_BLOCKS = GET_BLOCKS(gridVolume, NUM_THREADS); - + const int64_t NUM_BLOCKS = GET_BLOCKS(gridVolume, NUM_THREADS); - const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kInt32).device(device); - torch::Tensor ijkData = torch::empty({gridVolume, 3}, opts); + const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kInt32).device(device); + torch::Tensor ijkData = torch::empty({ gridVolume, 3 }, opts); if (NUM_BLOCKS > 0) { ijkForDense<<>>( - origin, size, - ijkData.packed_accessor32()); + origin, size, ijkData.packed_accessor32()); C10_CUDA_KERNEL_LAUNCH_CHECK(); } if (maybeMask.has_value()) { - torch::Tensor mask = maybeMask.value().view({-1}); + torch::Tensor mask = maybeMask.value().view({ -1 }); TORCH_CHECK(mask.device() == device, "mask must be on same device as ijkData"); - ijkData = ijkData.index({mask}); + ijkData = ijkData.index({ mask }); } nanovdb::GridHandle ret = FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { - // This guide buffer is a hack to pass in a device with an index to the cudaCreateNanoGrid function. We can't pass in a device directly - // but we can pass in a buffer which gets passed to TorchDeviceBuffer::create. The guide buffer holds the device and - // effectively passes it to the created buffer. + // This guide buffer is a hack to pass in a device with an index to the cudaCreateNanoGrid + // function. We can't pass in a device directly but we can pass in a buffer which gets + // passed to TorchDeviceBuffer::create. The guide buffer holds the device and effectively + // passes it to the created buffer. TorchDeviceBuffer guide(0, nullptr, false, device.index()); TORCH_CHECK(ijkData.is_contiguous(), "ijkData must be contiguous"); @@ -261,9 +261,12 @@ nanovdb::GridHandle dispatchCreateNanoGridFromDense> handles; for (int i = 0; i < batchSize; i += 1) { const int64_t nVoxels = ijkData.size(0); - handles.push_back(nVoxels == 0 ? build::buildEmptyGrid(guide.device(), isMutable) : - nanovdb::tools::cuda::voxelsToGrid( - (nanovdb::Coord*) ijkData.data_ptr(), nVoxels, 1.0, guide)); + handles.push_back( + nVoxels == 0 ? build::buildEmptyGrid(guide.device(), isMutable) + : nanovdb::tools::cuda::voxelsToGrid( + (nanovdb::Coord *)ijkData.data_ptr(), nVoxels, 1.0, guide)); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -271,7 +274,8 @@ nanovdb::GridHandle dispatchCreateNanoGridFromDense dispatchCreateNanoGridFromDense -void dispatchPopulateGridMetadata(const nanovdb::GridHandle& gridHdl, - const std::vector& voxelSizes, - const std::vector& voxelOrigins, - const bool isMutable, - torch::Tensor& outBatchOffsets, - GridBatchImpl::GridMetadata* outPerGridMetadataHost, - GridBatchImpl::GridMetadata* outPerGridMetadataDevice, - GridBatchImpl::GridBatchMetadata* outBatchMetadataHost, - GridBatchImpl::GridBatchMetadata* outBatchMetadataDevice) { +void +dispatchPopulateGridMetadata( + const nanovdb::GridHandle &gridHdl, + const std::vector &voxelSizes, const std::vector &voxelOrigins, + const bool isMutable, torch::Tensor &outBatchOffsets, + GridBatchImpl::GridMetadata *outPerGridMetadataHost, + GridBatchImpl::GridMetadata *outPerGridMetadataDevice, + GridBatchImpl::GridBatchMetadata *outBatchMetadataHost, + GridBatchImpl::GridBatchMetadata *outBatchMetadataDevice) { c10::cuda::CUDAGuard deviceGuard(gridHdl.buffer().device()); // Copy sizes and origins to device buffers - RAIIRawDeviceBuffer deviceVoxSizes(voxelSizes.size(), gridHdl.buffer().device()); - deviceVoxSizes.setData((nanovdb::Vec3d*) voxelSizes.data(), true /* blocking */); - const nanovdb::Vec3d* deviceVoxSizesPtr = deviceVoxSizes.devicePtr; + RAIIRawDeviceBuffer deviceVoxSizes(voxelSizes.size(), + gridHdl.buffer().device()); + deviceVoxSizes.setData((nanovdb::Vec3d *)voxelSizes.data(), true /* blocking */); + const nanovdb::Vec3d *deviceVoxSizesPtr = deviceVoxSizes.devicePtr; - RAIIRawDeviceBuffer deviceVoxOrigins(voxelOrigins.size(), gridHdl.buffer().device()); - deviceVoxOrigins.setData((nanovdb::Vec3d*) voxelOrigins.data(), true /* blocking */); - const nanovdb::Vec3d* deviceVoxOriginsPtr = deviceVoxOrigins.devicePtr; + RAIIRawDeviceBuffer deviceVoxOrigins(voxelOrigins.size(), + gridHdl.buffer().device()); + deviceVoxOrigins.setData((nanovdb::Vec3d *)voxelOrigins.data(), true /* blocking */); + const nanovdb::Vec3d *deviceVoxOriginsPtr = deviceVoxOrigins.devicePtr; - outBatchOffsets = torch::empty({(fvdb::JOffsetsType) (voxelOrigins.size() + 1)}, torch::TensorOptions().dtype(fvdb::JOffsetsScalarType).device(gridHdl.buffer().device())); + outBatchOffsets = torch::empty( + { (fvdb::JOffsetsType)(voxelOrigins.size() + 1) }, + torch::TensorOptions().dtype(fvdb::JOffsetsScalarType).device(gridHdl.buffer().device())); // Read metadata into device buffers FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { TORCH_CHECK(gridHdl.deviceData() != nullptr, "GridHandle is empty"); - const nanovdb::NanoGrid* grids = (nanovdb::NanoGrid*) gridHdl.deviceData(); + const nanovdb::NanoGrid *grids = + (nanovdb::NanoGrid *)gridHdl.deviceData(); populateGridMetadataCUDA<<<1, 1>>>( - gridHdl.gridCount(), grids, - (const nanovdb::Vec3d*) deviceVoxSizesPtr, - (const nanovdb::Vec3d*) deviceVoxOriginsPtr, + gridHdl.gridCount(), grids, (const nanovdb::Vec3d *)deviceVoxSizesPtr, + (const nanovdb::Vec3d *)deviceVoxOriginsPtr, outBatchOffsets.packed_accessor32(), - outPerGridMetadataDevice, - outBatchMetadataDevice); + outPerGridMetadataDevice, outBatchMetadataDevice); }); C10_CUDA_KERNEL_LAUNCH_CHECK(); const size_t metaDataByteSize = sizeof(GridBatchImpl::GridMetadata) * gridHdl.gridCount(); - cudaMemcpy(outPerGridMetadataHost, outPerGridMetadataDevice, metaDataByteSize, cudaMemcpyDeviceToHost); - cudaMemcpy(outBatchMetadataHost, outBatchMetadataDevice, sizeof(GridBatchImpl::GridBatchMetadata), cudaMemcpyDeviceToHost); + cudaMemcpy(outPerGridMetadataHost, outPerGridMetadataDevice, metaDataByteSize, + cudaMemcpyDeviceToHost); + cudaMemcpy(outBatchMetadataHost, outBatchMetadataDevice, + sizeof(GridBatchImpl::GridBatchMetadata), cudaMemcpyDeviceToHost); } template <> -void dispatchPopulateGridMetadata(const nanovdb::GridHandle& gridHdl, - const std::vector& voxelSizes, - const std::vector& voxelOrigins, - const bool isMutable, - torch::Tensor& outBatchOffsets, - GridBatchImpl::GridMetadata* outPerGridMetadataHost, - GridBatchImpl::GridMetadata* outPerGridMetadataDevice, - GridBatchImpl::GridBatchMetadata* outBatchMetadataHost, - GridBatchImpl::GridBatchMetadata* outBatchMetadataDevice) { - - outBatchOffsets = torch::empty({(fvdb::JOffsetsType) (voxelOrigins.size() + 1)}, torch::TensorOptions().dtype(fvdb::JOffsetsScalarType).device(gridHdl.buffer().device())); +void +dispatchPopulateGridMetadata( + const nanovdb::GridHandle &gridHdl, + const std::vector &voxelSizes, const std::vector &voxelOrigins, + const bool isMutable, torch::Tensor &outBatchOffsets, + GridBatchImpl::GridMetadata *outPerGridMetadataHost, + GridBatchImpl::GridMetadata *outPerGridMetadataDevice, + GridBatchImpl::GridBatchMetadata *outBatchMetadataHost, + GridBatchImpl::GridBatchMetadata *outBatchMetadataDevice) { + outBatchOffsets = torch::empty( + { (fvdb::JOffsetsType)(voxelOrigins.size() + 1) }, + torch::TensorOptions().dtype(fvdb::JOffsetsScalarType).device(gridHdl.buffer().device())); FVDB_DISPATCH_GRID_TYPES_MUTABLE(isMutable, [&]() { TORCH_CHECK(gridHdl.data() != nullptr, "GridHandle is empty"); - const nanovdb::NanoGrid* grids = (nanovdb::NanoGrid*) gridHdl.data(); + const nanovdb::NanoGrid *grids = (nanovdb::NanoGrid *)gridHdl.data(); populateGridMetadataKernel( - gridHdl.gridCount(), grids, voxelSizes.data(), voxelOrigins.data(), outBatchOffsets.accessor(), - outPerGridMetadataHost, outBatchMetadataHost); + gridHdl.gridCount(), grids, voxelSizes.data(), voxelOrigins.data(), + outBatchOffsets.accessor(), outPerGridMetadataHost, + outBatchMetadataHost); }); } - } // namespace ops } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/ops/CoordsInGrid.cu b/fvdb/src/detail/ops/CoordsInGrid.cu index 1816cb9b1c..62e1cae463 100644 --- a/fvdb/src/detail/ops/CoordsInGrid.cu +++ b/fvdb/src/detail/ops/CoordsInGrid.cu @@ -1,57 +1,63 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include - -#include "detail/utils/cuda/Utils.cuh" +#include +#include namespace fvdb { namespace detail { namespace ops { -template typename JaggedAccessor, template typename TensorAccessor> -__hostdev__ inline void coordsInGridCallback(int32_t bidx, int32_t eidx, - JaggedAccessor ijk, - TensorAccessor outMask, - BatchGridAccessor batchAccessor, - bool ignoreMasked) { - const auto* gpuGrid = batchAccessor.grid(bidx); - auto primalAcc = gpuGrid->getAccessor(); - - const auto& ijkCoord = ijk.data()[eidx]; +template typename JaggedAccessor, + template typename TensorAccessor> +__hostdev__ inline void +coordsInGridCallback(int32_t bidx, int32_t eidx, JaggedAccessor ijk, + TensorAccessor outMask, BatchGridAccessor batchAccessor, + bool ignoreMasked) { + const auto *gpuGrid = batchAccessor.grid(bidx); + auto primalAcc = gpuGrid->getAccessor(); + + const auto &ijkCoord = ijk.data()[eidx]; const nanovdb::Coord vox(ijkCoord[0], ijkCoord[1], ijkCoord[2]); - const bool isActive = ignoreMasked ? primalAcc.isActive(vox) : primalAcc.template get>(vox); - outMask[eidx] = isActive; + const bool isActive = ignoreMasked ? primalAcc.isActive(vox) + : primalAcc.template get>(vox); + outMask[eidx] = isActive; } - template -JaggedTensor CoordsInGrid(const GridBatchImpl& batchHdl, const JaggedTensor& ijk, bool ignoreMasked) { - +JaggedTensor +CoordsInGrid(const GridBatchImpl &batchHdl, const JaggedTensor &ijk, bool ignoreMasked) { batchHdl.checkNonEmptyGrid(); batchHdl.checkDevice(ijk); TORCH_CHECK_TYPE(!ijk.is_floating_point(), "ijk must have an integeral type"); - TORCH_CHECK(ijk.rdim() == 2, std::string("Expected ijk to have 2 dimensions (shape (n, 3)) but got ") + std::to_string(ijk.rdim()) + " dimensions"); + TORCH_CHECK(ijk.rdim() == 2, + std::string("Expected ijk to have 2 dimensions (shape (n, 3)) but got ") + + std::to_string(ijk.rdim()) + " dimensions"); TORCH_CHECK(ijk.rsize(0) > 0, "Empty tensor (ijk)"); - TORCH_CHECK(ijk.rsize(1) == 3, "Expected 3 dimensional ijk but got ijk.shape[1] = " + std::to_string(ijk.rsize(1))); + TORCH_CHECK(ijk.rsize(1) == 3, "Expected 3 dimensional ijk but got ijk.shape[1] = " + + std::to_string(ijk.rsize(1))); - auto opts = torch::TensorOptions().dtype(torch::kBool).device(ijk.device()); - torch::Tensor outMask = torch::empty({ijk.rsize(0)}, opts); + auto opts = torch::TensorOptions().dtype(torch::kBool).device(ijk.device()); + torch::Tensor outMask = torch::empty({ ijk.rsize(0) }, opts); FVDB_DISPATCH_GRID_TYPES(batchHdl, [&]() { AT_DISPATCH_INTEGRAL_TYPES(ijk.scalar_type(), "CoordsInGrid", [&]() { - - auto batchAcc = gridBatchAccessor(batchHdl); + auto batchAcc = gridBatchAccessor(batchHdl); auto outMaskAccessor = tensorAccessor(outMask); if constexpr (DeviceTag == torch::kCUDA) { - auto cb = [=] __device__ (int32_t bidx, int32_t eidx, int32_t cidx, JaggedRAcc32 ijkAcc) { - coordsInGridCallback(bidx, eidx, ijkAcc, outMaskAccessor, batchAcc, ignoreMasked); + auto cb = [=] __device__(int32_t bidx, int32_t eidx, int32_t cidx, + JaggedRAcc32 ijkAcc) { + coordsInGridCallback( + bidx, eidx, ijkAcc, outMaskAccessor, batchAcc, ignoreMasked); }; forEachJaggedElementChannelCUDA(1024, 1, ijk, cb); } else { - auto cb = [=] (int32_t bidx, int32_t eidx, int32_t cidx, JaggedAcc ijkAcc) { - coordsInGridCallback(bidx, eidx, ijkAcc, outMaskAccessor, batchAcc, ignoreMasked); + auto cb = [=](int32_t bidx, int32_t eidx, int32_t cidx, + JaggedAcc ijkAcc) { + coordsInGridCallback( + bidx, eidx, ijkAcc, outMaskAccessor, batchAcc, ignoreMasked); }; forEachJaggedElementChannelCPU(1, ijk, cb); } @@ -61,18 +67,20 @@ JaggedTensor CoordsInGrid(const GridBatchImpl& batchHdl, const JaggedTensor& ijk return ijk.jagged_like(outMask); } - template <> -JaggedTensor dispatchCoordsInGrid(const GridBatchImpl& batchHdl, const JaggedTensor& coords, bool ignoreMasked) { +JaggedTensor +dispatchCoordsInGrid(const GridBatchImpl &batchHdl, const JaggedTensor &coords, + bool ignoreMasked) { return CoordsInGrid(batchHdl, coords, ignoreMasked); } template <> -JaggedTensor dispatchCoordsInGrid(const GridBatchImpl& batchHdl, const JaggedTensor& coords, bool ignoreMasked) { +JaggedTensor +dispatchCoordsInGrid(const GridBatchImpl &batchHdl, const JaggedTensor &coords, + bool ignoreMasked) { return CoordsInGrid(batchHdl, coords, ignoreMasked); } - } // namespace ops } // namespace detail } // namespace fvdb diff --git a/fvdb/src/detail/ops/CountEnabledVoxels.cu b/fvdb/src/detail/ops/CountEnabledVoxels.cu index 3a8fc22e38..3706c6eba7 100644 --- a/fvdb/src/detail/ops/CountEnabledVoxels.cu +++ b/fvdb/src/detail/ops/CountEnabledVoxels.cu @@ -1,11 +1,10 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: MPL-2.0 // -#include - -#include "detail/utils/cuda/Utils.cuh" -#include "detail/utils/nanovdb/CustomAccessors.h" +#include +#include +#include namespace fvdb { namespace detail { @@ -16,12 +15,14 @@ namespace ops { /// @param li the index of the leaf to process /// @param outUnmaskedPerLeaf the output tensor storing the number of unmasked voxels in each leaf template