Skip to content

Commit e305516

Browse files
committed
fix tests on GPU
1 parent 547759a commit e305516

File tree

5 files changed

+47
-38
lines changed

5 files changed

+47
-38
lines changed

csrc/cuda/knn_cuda.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,15 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
9090
CHECK_CUDA(ptr_x.value());
9191
CHECK_INPUT(ptr_x.value().dim() == 1);
9292
} else {
93-
ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong));
93+
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
94+
x.options().dtype(torch::kLong));
9495
}
9596
if (ptr_y.has_value()) {
9697
CHECK_CUDA(ptr_y.value());
9798
CHECK_INPUT(ptr_y.value().dim() == 1);
9899
} else {
99-
ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong));
100+
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
101+
y.options().dtype(torch::kLong));
100102
}
101103
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
102104

csrc/cuda/radius_cuda.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,15 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
5858
CHECK_CUDA(ptr_x.value());
5959
CHECK_INPUT(ptr_x.value().dim() == 1);
6060
} else {
61-
ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong));
61+
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
62+
x.options().dtype(torch::kLong));
6263
}
6364
if (ptr_y.has_value()) {
6465
CHECK_CUDA(ptr_y.value());
6566
CHECK_INPUT(ptr_y.value().dim() == 1);
6667
} else {
67-
ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong));
68+
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
69+
y.options().dtype(torch::kLong));
6870
}
6971
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
7072

csrc/knn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ torch::Tensor knn(torch::Tensor x, torch::Tensor y,
1717
int64_t num_workers) {
1818
if (x.device().is_cuda()) {
1919
#ifdef WITH_CUDA
20-
return knn_cuda(x, y, ptr_x, ptr_x, k, cosine);
20+
return knn_cuda(x, y, ptr_x, ptr_y, k, cosine);
2121
#else
2222
AT_ERROR("Not compiled with CUDA support");
2323
#endif

test/test_knn.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from .utils import grad_dtypes, devices, tensor
99

1010

11+
def to_set(edge_index):
12+
return set([(i, j) for i, j in edge_index.t().tolist()])
13+
14+
1115
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1216
def test_knn(dtype, device):
1317
x = tensor([
@@ -28,18 +32,15 @@ def test_knn(dtype, device):
2832
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
2933
batch_y = tensor([0, 1], torch.long, device)
3034

31-
row, col = knn(x, y, 2)
32-
assert row.tolist() == [0, 0, 1, 1]
33-
assert col.tolist() == [2, 3, 0, 1]
35+
edge_index = knn(x, y, 2)
36+
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
3437

35-
row, col = knn(x, y, 2, batch_x, batch_y)
36-
assert row.tolist() == [0, 0, 1, 1]
37-
assert col.tolist() == [2, 3, 4, 5]
38+
edge_index = knn(x, y, 2, batch_x, batch_y)
39+
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
3840

3941
if x.is_cuda:
40-
row, col = knn(x, y, 2, batch_x, batch_y, cosine=True)
41-
assert row.tolist() == [0, 0, 1, 1]
42-
assert col.tolist() == [0, 1, 4, 5]
42+
edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True)
43+
assert to_set(edge_index) == set([(0, 0), (0, 1), (1, 4), (1, 5)])
4344

4445

4546
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
@@ -51,25 +52,24 @@ def test_knn_graph(dtype, device):
5152
[+1, -1],
5253
], dtype, device)
5354

54-
row, col = knn_graph(x, k=2, flow='target_to_source')
55-
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
56-
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
55+
edge_index = knn_graph(x, k=2, flow='target_to_source')
56+
assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
57+
(2, 3), (3, 0), (3, 2)])
5758

58-
row, col = knn_graph(x, k=2, flow='source_to_target')
59-
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
60-
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
59+
edge_index = knn_graph(x, k=2, flow='source_to_target')
60+
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
61+
(3, 2), (0, 3), (2, 3)])
6162

6263

6364
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
6465
def test_knn_graph_large(dtype, device):
6566
x = torch.randn(1000, 3)
6667

67-
row, col = knn_graph(x, k=5, flow='target_to_source', loop=True,
68-
num_workers=6)
69-
pred = set([(i, j) for i, j in zip(row.tolist(), col.tolist())])
68+
edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True,
69+
num_workers=6)
7070

7171
tree = scipy.spatial.cKDTree(x.numpy())
7272
_, col = tree.query(x.cpu(), k=5)
7373
truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
7474

75-
assert pred == truth
75+
assert to_set(edge_index) == truth

test/test_radius.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from .utils import grad_dtypes, devices, tensor
99

1010

11+
def to_set(edge_index):
12+
return set([(i, j) for i, j in edge_index.t().tolist()])
13+
14+
1115
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1216
def test_radius(dtype, device):
1317
x = tensor([
@@ -28,11 +32,13 @@ def test_radius(dtype, device):
2832
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
2933
batch_y = tensor([0, 1], torch.long, device)
3034

31-
out = radius(x, y, 2, max_num_neighbors=4)
32-
assert out.tolist() == [[0, 0, 0, 0, 1, 1, 1, 1], [0, 1, 2, 3, 1, 2, 5, 6]]
35+
edge_index = radius(x, y, 2, max_num_neighbors=4)
36+
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
37+
(1, 2), (1, 5), (1, 6)])
3338

34-
out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
35-
assert out.tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]
39+
edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
40+
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
41+
(1, 6)])
3642

3743

3844
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
@@ -44,25 +50,24 @@ def test_radius_graph(dtype, device):
4450
[+1, -1],
4551
], dtype, device)
4652

47-
row, col = radius_graph(x, r=2, flow='target_to_source')
48-
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
49-
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
53+
edge_index = radius_graph(x, r=2, flow='target_to_source')
54+
assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
55+
(2, 3), (3, 0), (3, 2)])
5056

51-
row, col = radius_graph(x, r=2, flow='source_to_target')
52-
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
53-
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
57+
edge_index = radius_graph(x, r=2, flow='source_to_target')
58+
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
59+
(3, 2), (0, 3), (2, 3)])
5460

5561

5662
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
5763
def test_radius_graph_large(dtype, device):
5864
x = torch.randn(1000, 3)
5965

60-
row, col = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
61-
max_num_neighbors=1000, num_workers=6)
62-
pred = set([(i, j) for i, j in zip(row.tolist(), col.tolist())])
66+
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
67+
max_num_neighbors=1000, num_workers=6)
6368

6469
tree = scipy.spatial.cKDTree(x.numpy())
6570
col = tree.query_ball_point(x.cpu(), r=0.5)
6671
truth = set([(i, j) for i, ns in enumerate(col) for j in ns])
6772

68-
assert pred == truth
73+
assert to_set(edge_index) == truth

0 commit comments

Comments
 (0)