Skip to content

Commit e31d6d2

Browse files
committed
Functions working properly
1 parent a5332f9 commit e31d6d2

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

nnMorpho/functions.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ def forward(ctx, *args):
3434
elif input_tensor.ndim - strel_tensor.ndim == 1:
3535
output_tensor, indexes = morphology_cuda.erosion_batched_forward(input_pad, strel_tensor, BLOCK_SHAPE)
3636
elif input_tensor.ndim - strel_tensor.ndim == 2:
37-
raise NotImplementedError("ToDo")
37+
batch_channel_dim = input_pad.shape[0] * input_pad.shape[1]
38+
input_height = input_pad.shape[2]
39+
input_width = input_pad.shape[3]
40+
input_view = input_pad.view(batch_channel_dim, input_height, input_width)
41+
output_tensor, indexes = morphology_cuda.erosion_batched_forward(input_view, strel_tensor, BLOCK_SHAPE)
42+
output_tensor = output_tensor.view(*input_tensor.shape)
43+
indexes = indexes.view(*input_tensor.shape, 2)
3844
else:
3945
raise NotImplementedError("Currently, nnMorpho only supports as input:\n"
4046
"- 2D tensors of the form (H, W)\n"
@@ -57,7 +63,12 @@ def backward(ctx, *grad_outputs):
5763
elif grad_output.ndim - len(strel_shape) == 1:
5864
result = morphology_cuda.erosion_batched_backward(grad_output, indexes, strel_shape, BLOCK_SHAPE)
5965
elif grad_output.ndim - len(strel_shape) == 2:
60-
raise NotImplementedError("ToDo")
66+
batch_channel_dim = grad_output.shape[0] * grad_output.shape[1]
67+
input_height = grad_output.shape[2]
68+
input_width = grad_output.shape[3]
69+
grad_output_view = grad_output.view(batch_channel_dim, input_height, input_width)
70+
indexes_view = indexes.view(batch_channel_dim, input_height, input_width, 2)
71+
result = morphology_cuda.erosion_batched_backward(grad_output_view, indexes_view, strel_shape, BLOCK_SHAPE)
6172
else:
6273
raise NotImplementedError("Currently, nnMorpho only supports as input:\n"
6374
"- 2D tensors of the form (H, W)\n"
@@ -100,7 +111,13 @@ def forward(ctx, *args):
100111
elif input_tensor.ndim - strel_tensor.ndim == 1:
101112
output_tensor, indexes = morphology_cuda.dilation_batched_forward(input_pad, strel_tensor, BLOCK_SHAPE)
102113
elif input_tensor.ndim - strel_tensor.ndim == 2:
103-
raise NotImplementedError("ToDo")
114+
batch_channel_dim = input_pad.shape[0] * input_pad.shape[1]
115+
input_height = input_pad.shape[2]
116+
input_width = input_pad.shape[3]
117+
input_view = input_pad.view(batch_channel_dim, input_height, input_width)
118+
output_tensor, indexes = morphology_cuda.dilation_batched_forward(input_view, strel_tensor, BLOCK_SHAPE)
119+
output_tensor = output_tensor.view(*input_tensor.shape)
120+
indexes = indexes.view(*input_tensor.shape, 2)
104121
else:
105122
raise NotImplementedError("Currently, nnMorpho only supports as input:\n"
106123
"- 2D tensors of the form (H, W)\n"
@@ -123,7 +140,12 @@ def backward(ctx, *grad_outputs):
123140
elif grad_output.ndim - len(strel_shape) == 1:
124141
result = morphology_cuda.dilation_batched_backward(grad_output, indexes, strel_shape, BLOCK_SHAPE)
125142
elif grad_output.ndim - len(strel_shape) == 2:
126-
raise NotImplementedError("ToDo")
143+
batch_channel_dim = grad_output.shape[0] * grad_output.shape[1]
144+
input_height = grad_output.shape[2]
145+
input_width = grad_output.shape[3]
146+
grad_output_view = grad_output.view(batch_channel_dim, input_height, input_width)
147+
indexes_view = indexes.view(batch_channel_dim, input_height, input_width, 2)
148+
result = morphology_cuda.dilation_batched_backward(grad_output_view, indexes_view, strel_shape, BLOCK_SHAPE)
127149
else:
128150
raise NotImplementedError("Currently, nnMorpho only supports as input:\n"
129151
"- 2D tensors of the form (H, W)\n"

tests/test_functions_gradients.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
from nnMorpho.operations import erosion, dilation
33
from nnMorpho.functions import ErosionFunction, DilationFunction
44

5-
image_size = (5, 8)
5+
color = True
6+
7+
if color:
8+
image_size = (3, 5, 8)
9+
else:
10+
image_size = (5, 8)
11+
612
strel_size = (3, 3)
713
origin = (2, 2)
814

9-
rand = True
10-
1115
image_1 = torch.rand(image_size, device='cuda:0')
1216
image_2 = torch.rand(image_size, device='cuda:0')
1317
image_batch = torch.stack((image_1, image_2), dim=0)
@@ -58,6 +62,7 @@
5862
print("Gradients for erosion")
5963
print("Gradient batch:\n", strel_batch_erosion.grad)
6064
print("Gradient stack:\n", strel_stack_erosion.grad)
65+
print("Error between erosion:", torch.norm(eroded_batch - eroded_stack, 1).item())
6166

6267
# Dilation
6368

@@ -81,3 +86,4 @@
8186
print("Gradients for dilation")
8287
print("Gradient batch:\n", strel_batch_dilation.grad)
8388
print("Gradient stack:\n", strel_stack_dilation.grad)
89+
print("Error between dilation:", torch.norm(dilated_batch - dilated_stack, 1).item())

0 commit comments

Comments
 (0)