Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpherePad for equi2cube + cache for create_equi_grid and create_normalized_grid #15

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions equilib/cube2equi/torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import math
from functools import cache

from typing import Dict, List, Union

Expand Down Expand Up @@ -169,6 +170,7 @@ def _equirect_facetype(h: int, w: int) -> torch.Tensor:
return tp.type(int_dtype)


@cache
def create_equi_grid(
h_out: int,
w_out: int,
Expand Down
3 changes: 2 additions & 1 deletion equilib/equi2cube/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def run(
z_down: bool,
mode: str,
override_func: Optional[Callable[[], Any]] = None,
pad: float = 0.,
) -> Union[np.ndarray, List[List[np.ndarray]], List[Dict[str, np.ndarray]]]:
"""Call Equi2Cube

Expand Down Expand Up @@ -179,7 +180,7 @@ def run(
out = np.empty((bs, c, w_face, w_face * 6), dtype=dtype)

# create grid
xyz = create_xyz_grid(w_face=w_face, batch=bs, dtype=dtype)
xyz = create_xyz_grid(w_face=w_face, pad=pad, batch=bs, dtype=dtype)
xyz = xyz[..., np.newaxis]

# FIXME: not sure why, but z-axis is facing the opposite
Expand Down
3 changes: 2 additions & 1 deletion equilib/equi2cube/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def run(
z_down: bool,
mode: str,
backend: str = "native",
pad: float = 0.,
) -> Union[torch.Tensor, List[torch.Tensor], List[Dict[str, torch.Tensor]]]:
"""Run Equi2Cube

Expand Down Expand Up @@ -201,7 +202,7 @@ def run(

# create grid
xyz = create_xyz_grid(
w_face=w_face, batch=bs, dtype=tmp_dtype, device=tmp_device
w_face=w_face, pad=pad, batch=bs, dtype=tmp_dtype, device=tmp_device
)
xyz = xyz.unsqueeze(-1)

Expand Down
79 changes: 64 additions & 15 deletions equilib/equi2equi/torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

from typing import Dict, List, Optional
from functools import cache
from typing import Dict, List, Optional, Tuple, Iterable

import torch

Expand Down Expand Up @@ -53,6 +53,58 @@ def convert_grid(

return grid

@cache
def _get_grid(
height: int,
width: int,
bs: int,
tmp_dtype: torch.dtype,
tmp_device: torch.device,
rolls: Tuple[float],
pitchs: Tuple[float],
yaws: Tuple[float],
z_down: float,
h_equi: int,
w_equi: int,
) -> torch.Tensor:
m = create_normalized_grid(
height=height, width=width, batch=bs, dtype=tmp_dtype, device=tmp_device
)
m = m.unsqueeze(-1)

rots = [
{'roll': roll, 'pitch': pitch, 'yaw': yaw}
for (roll, pitch, yaw) in zip(rolls, pitchs, yaws)
]

# create batched rotation matrices
R = create_rotation_matrices(
rots=rots, z_down=z_down, dtype=tmp_dtype, device=tmp_device
)

# rotate the grid
M = matmul(m, R)

grid = convert_grid(M=M, h_equi=h_equi, w_equi=w_equi, method="robust")

return grid


def get_grid(
rots: Iterable[Dict[str, float]],
**kwargs
) -> torch.Tensor:
rolls = tuple([rot['roll'] for rot in rots])
pitchs = tuple([rot['pitch'] for rot in rots])
yaws = tuple([rot['yaw'] for rot in rots])

return _get_grid(
rolls=rolls,
pitchs=pitchs,
yaws=yaws,
**kwargs
)


def run(
src: torch.Tensor,
Expand Down Expand Up @@ -158,21 +210,18 @@ def run(
else:
tmp_dtype = dtype

m = create_normalized_grid(
height=height, width=width, batch=bs, dtype=tmp_dtype, device=tmp_device
)
m = m.unsqueeze(-1)

# create batched rotation matrices
R = create_rotation_matrices(
rots=rots, z_down=z_down, dtype=tmp_dtype, device=tmp_device
grid = get_grid(
height=height,
width=width,
bs=bs,
tmp_dtype=tmp_dtype,
tmp_device=tmp_device,
rots=rots,
z_down=z_down,
h_equi=h_equi,
w_equi=w_equi
)

# rotate the grid
M = matmul(m, R)

grid = convert_grid(M=M, h_equi=h_equi, w_equi=w_equi, method="robust")

# FIXME: putting `grid` to device since `pure`'s bilinear interpolation requires it
# FIXME: better way of forcing `grid` to be the same dtype?
if src.dtype != grid.dtype:
Expand Down
3 changes: 2 additions & 1 deletion equilib/numpy_utils/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def create_normalized_grid(

def create_xyz_grid(
w_face: int,
pad: float = 0.,
batch: Optional[int] = None,
dtype: np.dtype = np.dtype(np.float32),
) -> np.ndarray:
Expand All @@ -107,7 +108,7 @@ def create_xyz_grid(
ratio = (w_face - 1) / w_face

out = np.zeros((w_face, w_face * 6, 3), dtype=dtype)
rng = np.linspace(-0.5 * ratio, 0.5 * ratio, num=w_face, dtype=dtype)
rng = np.linspace(-(0.5 + pad) * ratio, (0.5 + pad) * ratio, num=w_face, dtype=dtype)

# Front face (x = 0.5)
out[:, 0 * w_face : 1 * w_face, [1, 2]] = np.stack(
Expand Down
3 changes: 2 additions & 1 deletion equilib/torch_utils/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def create_normalized_grid(

def create_xyz_grid(
w_face: int,
pad: float = 0.,
batch: Optional[int] = None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
Expand All @@ -128,7 +129,7 @@ def create_xyz_grid(

out = torch.zeros((w_face, w_face * 6, 3), dtype=dtype, device=device)
rng = torch.linspace(
-0.5 * ratio, 0.5 * ratio, w_face, dtype=dtype, device=device
-(0.5 + pad) * ratio, (0.5 + pad) * ratio, w_face, dtype=dtype, device=device
)

# NOTE: https://github.com/pytorch/pytorch/issues/15301
Expand Down
Loading