You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When running sparse conv's forward pass in fp16, everything seems to work fine. The backward pass throws a runtime error.
Minimal reproducible script using the provided example
from datetime import datetime
import numpy as np
import torch
import torch.cuda
import torch.nn as nn
import torch.optim
import torchsparse.nn as spnn
from torchsparse import SparseTensor
from torchsparse.utils.collate import sparse_collate_fn
from torchsparse.utils.quantize import sparse_quantize
def generate_random_point_cloud(size=100000, voxel_size=0.2):
pc = np.random.randn(size, 4)
pc[:, :3] = pc[:, :3] * 10
labels = np.random.choice(10, size)
coords, feats = pc[:, :3], pc
coords -= np.min(coords, axis=0, keepdims=True)
coords, indices = sparse_quantize(coords, voxel_size, return_index=True)
coords = torch.tensor(coords, dtype=torch.int)
feats = torch.tensor(feats[indices], dtype=torch.float)
labels = torch.tensor(labels[indices], dtype=torch.long)
input = SparseTensor(coords=coords, feats=feats)
label = SparseTensor(coords=coords, feats=labels)
feed_dict = {"input": input, "label": label}
return feed_dict
def generate_batched_random_point_clouds(size=100000, voxel_size=0.2, batch_size=2):
batch = []
for _ in range(batch_size):
batch.append(generate_random_point_cloud(size, voxel_size))
return sparse_collate_fn(batch)
def dummy_train_3x3(device):
model = nn.Sequential(
spnn.Conv3d(4, 32, kernel_size=3, stride=1),
spnn.Conv3d(32, 10, kernel_size=3, stride=1, transposed=True),
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(device)
print("Starting dummy_train_3x3...")
time = datetime.now()
for i in range(10):
feed_dict = generate_batched_random_point_clouds()
inputs = feed_dict["input"].to(device)
inputs = inputs.half() # convert to fp16
targets = feed_dict["label"].F.to(device).long()
outputs = model(inputs)
optimizer.zero_grad()
loss = criterion(outputs.F, targets)
loss.backward() # throws error
optimizer.step()
print('[step %d] loss = %f.'%(i, loss.item()))
time = datetime.now() - time
print("Finished dummy_train_3x3 in ", time)
dummy_train_3x3("cuda")
Stack trace
File ".../torchsparse_py/torchsparse/nn/functional/conv/func/implicit_gemm.py", line 224, in backward
grad_input, grad_weight = backward(
File ".../torchsparse_py/torchsparse/nn/functional/conv/func/implicit_gemm.py", line 152, in backward
torch.ops.torchsparse_ops.conv_backward_wgrad_implicit_gemm_sorted_cuda(
File ".../pip_torch/site-packages/torch/_ops.py", line 854, in __call__
return self_._op(*args, **(kwargs or {}))
RuntimeError: expected scalar type Float but found Half
Expected Behavior
I expect no errors to be thrown on the backward pass. I would also love to have compatibility with autocast, for which I think the ImplicitGEMMConvolutionFuntion functions need to be annotated with @custom_fwd and @custom_bwd.
Environment
TorchSparse: 2.1.0
Anything else?
No response
The text was updated successfully, but these errors were encountered:
Is there an existing issue for this?
Current Behavior
When running sparse conv's forward pass in fp16, everything seems to work fine. The backward pass throws a runtime error.
Minimal reproducible script using the provided example
Stack trace
Expected Behavior
I expect no errors to be thrown on the backward pass. I would also love to have compatibility with autocast, for which I think the ImplicitGEMMConvolutionFuntion functions need to be annotated with
@custom_fwd
and@custom_bwd
.Environment
Anything else?
No response
The text was updated successfully, but these errors were encountered: