Skip to content

Commit

Permalink
Merge pull request #1869 from swahtz/feature/fvdb
Browse files Browse the repository at this point in the history
Updates to fVDB
  • Loading branch information
apradhana authored Aug 14, 2024
2 parents 63f7ab0 + 9357034 commit c99b17c
Show file tree
Hide file tree
Showing 64 changed files with 3,751 additions and 216 deletions.
3 changes: 1 addition & 2 deletions fvdb/.gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
build/*
build/
/build/*
*.creator
*.includes
*.files
Expand Down
4 changes: 2 additions & 2 deletions fvdb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pip install .

To make sure that everything works by running tests:
```shell
python setup.py test
pytest tests/unit
```

### Building Documentation
Expand All @@ -118,7 +118,7 @@ docker run -it --gpus all --rm \
--mount type=bind,source="$HOME/.ssh",target=/root/.ssh \
--mount type=bind,source="$(pwd)",target=/fvdb \
fvdb-dev:latest \
conda run -n fvdb_test --no-capture-output python setup.py test
conda run -n fvdb_test --no-capture-output python setup.py develop
```


Expand Down
12 changes: 7 additions & 5 deletions fvdb/env/build_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@ name: fvdb_build
channels:
- nvidia/label/cuda-12.1.0
- pytorch
- conda-forge
dependencies:
- python=3.10
- pytorch=2.2
- pytorch-cuda=12.1
- pytorch::pytorch=2.2
- pytorch::pytorch-cuda=12.1
- git
- gitpython
- ca-certificates
- certifi
- openssl
- cuda
- cuda-nvcc
- nvidia/label/cuda-12.1.0::cuda
- nvidia/label/cuda-12.1.0::cuda-tools
- nvidia/label/cuda-12.1.0::cuda-nvcc
- nvidia/label/cuda-12.1.0::cuda-cccl
- nvidia/label/cuda-12.1.0::cuda-libraries-static
- gcc_linux-64=11
- gxx_linux-64=11
- setuptools
Expand Down
4 changes: 2 additions & 2 deletions fvdb/env/cutlass.patch
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ index e22c8be3..a29e6067 100644
storage = reinterpret_cast<uint16_t const &>(x);
#else
- __half_raw raw(x);
+ __half_raw raw(*(reinterpret_cast<const unsigned short*>(&x)));
+ __half_raw raw(*(reinterpret_cast<const __half_raw*>(&x)));
std::memcpy(&storage, &raw.x, sizeof(storage));
#endif
}
Expand All @@ -16,7 +16,7 @@ index e22c8be3..a29e6067 100644
storage = reinterpret_cast<uint16_t const &>(x);
#else
- __half_raw raw(x);
+ __half_raw raw(*(reinterpret_cast<const unsigned short*>(&x)));
+ __half_raw raw(*(reinterpret_cast<const __half_raw*>(&x)));
std::memcpy(&storage, &raw.x, sizeof(storage));
#endif
return *this;
Expand Down
12 changes: 7 additions & 5 deletions fvdb/env/test_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@ channels:
- pyg
- nvidia/label/cuda-12.1.0
- pytorch
- conda-forge
dependencies:
- python=3.10
- pytorch=2.2
- pytorch-cuda=12.1
- pytorch::pytorch=2.2
- pytorch::pytorch-cuda=12.1
- tensorboard
- pip
- git
- gitpython
- ca-certificates
- certifi
- openssl
- cuda
- cuda-nvcc
- nvidia/label/cuda-12.1.0::cuda
- nvidia/label/cuda-12.1.0::cuda-tools
- nvidia/label/cuda-12.1.0::cuda-nvcc
- nvidia/label/cuda-12.1.0::cuda-cccl
- nvidia/label/cuda-12.1.0::cuda-libraries-static
- parameterized
- gcc_linux-64=11
- gxx_linux-64=11
Expand Down
23 changes: 19 additions & 4 deletions fvdb/fvdb/_Cpp.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ class JaggedTensor:
@property
def dtype(self) -> torch.dtype: ...
@property
def jidx(self) -> torch.ShortTensor: ...
def jidx(self) -> torch.IntTensor: ...
@property
def jlidx(self) -> torch.IntTensor: ...
@property
def joffsets(self) -> torch.LongTensor: ...
@property
Expand All @@ -134,6 +136,19 @@ class JaggedTensor:
@property
def requires_grad(self) -> bool: ...

@staticmethod
def from_data_and_indices(data: torch.Tensor, indices: torch.Tensor, num_tensors: int) -> JaggedTensor: ...

@staticmethod
def from_data_indices_and_list_ids(data: torch.Tensor, indices: torch.Tensor, list_ids: torch.Tensor, num_tensors: int) -> JaggedTensor: ...

@staticmethod
def from_data_and_offsets(data: torch.Tensor, offsets: torch.Tensor) -> JaggedTensor: ...

@staticmethod
def from_data_offsets_and_list_ids(data: torch.Tensor, offsets: torch.Tensor, list_ids: torch.Tensor) -> JaggedTensor: ...


JaggedTensorOrTensor = Union[torch.Tensor, JaggedTensor]

class GridBatch:
Expand Down Expand Up @@ -243,8 +258,8 @@ class GridBatch:
def cubes_in_grid(self, cube_centers: JaggedTensorOrTensor, cube_min: Vec3dOrScalar = 0.0, cube_max: Vec3dOrScalar = 0.0, ignore_disabled: bool = False) -> JaggedTensor: ...
def cubes_intersect_grid(self, cube_centers: JaggedTensorOrTensor, cube_min: Vec3dOrScalar = 0.0, cube_max: Vec3dOrScalar = 0.0, ignore_disabled: bool = False) -> JaggedTensor: ...

def ijk_to_index(self, ijk: JaggedTensorOrTensor) -> JaggedTensor: ...
def ijk_to_inv_index(self, ijk: JaggedTensorOrTensor) -> JaggedTensor: ...
def ijk_to_index(self, ijk: JaggedTensorOrTensor, cumulative: bool = False) -> JaggedTensor: ...
def ijk_to_inv_index(self, ijk: JaggedTensorOrTensor, cumulative: bool = False) -> JaggedTensor: ...
def neighbor_indexes(self, ijk: JaggedTensorOrTensor, extent: int, bitshift: int = 0) -> JaggedTensor: ...

def splat_bezier(self, points: JaggedTensorOrTensor, points_data: JaggedTensorOrTensor) -> JaggedTensor: ...
Expand All @@ -256,7 +271,7 @@ class GridBatch:


def segments_along_rays(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, max_segments: int, eps: float = 0.0, ignore_masked: bool = False) -> JaggedTensor: ...
def voxels_along_rays(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, max_voxels: int, eps: float = 0.0, return_ijk: bool = True) -> Tuple[JaggedTensor, JaggedTensor]: ...
def voxels_along_rays(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, max_voxels: int, eps: float = 0.0, return_ijk: bool = True, cumulative: bool = False) -> Tuple[JaggedTensor, JaggedTensor]: ...
def uniform_ray_samples(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, t_min: JaggedTensorOrTensor, t_max: JaggedTensorOrTensor, step_size: float, cone_angle: float = 0.0, include_end_segments : bool = True, return_midpoints: bool = False, eps: float = 0.0) -> JaggedTensor: ...
def ray_implicit_intersection(self, ray_origins: JaggedTensorOrTensor, ray_directions: JaggedTensorOrTensor, grid_scalars: JaggedTensorOrTensor, eps: float = 0.0) -> JaggedTensor: ...

Expand Down
2 changes: 1 addition & 1 deletion fvdb/fvdb/nn/vdbtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _feature_ops(op, other: List[Union["VDBTensor", JaggedTensor, Any]]):
raw_features.append(o.feature.jdata)
elif isinstance(o, JaggedTensor):
assert pivot_tensor.total_voxels == o.jdata.size(0), "All tensors should have the same voxels"
assert pivot_tensor.grid.grid_count == len(o.joffsets), "All tensors should have the same batch size"
assert pivot_tensor.grid.grid_count == o.num_tensors, "All tensors should have the same batch size"
raw_features.append(o.jdata)
else:
raw_features.append(o)
Expand Down
Loading

0 comments on commit c99b17c

Please sign in to comment.