forked from carlinds/splatad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_strategy.py
67 lines (55 loc) · 1.97 KB
/
test_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Tests for the functions in the CUDA extension.
Usage:
```bash
pytest <THIS_PY_FILE> -s
```
"""
import pytest
import torch
device = torch.device("cuda:0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
def test_strategy():
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy
torch.manual_seed(42)
# Prepare Gaussians
N = 100
params = torch.nn.ParameterDict(
{
"means": torch.randn(N, 3),
"scales": torch.rand(N, 3),
"quats": torch.randn(N, 4),
"opacities": torch.rand(N),
"colors": torch.rand(N, 3),
}
).to(device)
optimizers = {k: torch.optim.Adam([v], lr=1e-3) for k, v in params.items()}
# A dummy rendering call
render_colors, render_alphas, info = rasterization(
means=params["means"],
quats=params["quats"], # F.normalize is fused into the kernel
scales=torch.exp(params["scales"]),
opacities=torch.sigmoid(params["opacities"]),
colors=params["colors"],
velocities=None,
viewmats=torch.eye(4).unsqueeze(0).to(device),
Ks=torch.eye(3).unsqueeze(0).to(device),
width=10,
height=10,
packed=False,
)
# Test DefaultStrategy
strategy = DefaultStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
strategy.step_pre_backward(params, optimizers, state, step=600, info=info)
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, state, step=600, info=info)
# Test MCMCStrategy
strategy = MCMCStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)
if __name__ == "__main__":
test_strategy()