@@ -34,7 +34,13 @@ def forward(ctx, *args):
34
34
elif input_tensor .ndim - strel_tensor .ndim == 1 :
35
35
output_tensor , indexes = morphology_cuda .erosion_batched_forward (input_pad , strel_tensor , BLOCK_SHAPE )
36
36
elif input_tensor .ndim - strel_tensor .ndim == 2 :
37
- raise NotImplementedError ("ToDo" )
37
+ batch_channel_dim = input_pad .shape [0 ] * input_pad .shape [1 ]
38
+ input_height = input_pad .shape [2 ]
39
+ input_width = input_pad .shape [3 ]
40
+ input_view = input_pad .view (batch_channel_dim , input_height , input_width )
41
+ output_tensor , indexes = morphology_cuda .erosion_batched_forward (input_view , strel_tensor , BLOCK_SHAPE )
42
+ output_tensor = output_tensor .view (* input_tensor .shape )
43
+ indexes = indexes .view (* input_tensor .shape , 2 )
38
44
else :
39
45
raise NotImplementedError ("Currently, nnMorpho only supports as input:\n "
40
46
"- 2D tensors of the form (H, W)\n "
@@ -57,7 +63,12 @@ def backward(ctx, *grad_outputs):
57
63
elif grad_output .ndim - len (strel_shape ) == 1 :
58
64
result = morphology_cuda .erosion_batched_backward (grad_output , indexes , strel_shape , BLOCK_SHAPE )
59
65
elif grad_output .ndim - len (strel_shape ) == 2 :
60
- raise NotImplementedError ("ToDo" )
66
+ batch_channel_dim = grad_output .shape [0 ] * grad_output .shape [1 ]
67
+ input_height = grad_output .shape [2 ]
68
+ input_width = grad_output .shape [3 ]
69
+ grad_output_view = grad_output .view (batch_channel_dim , input_height , input_width )
70
+ indexes_view = indexes .view (batch_channel_dim , input_height , input_width , 2 )
71
+ result = morphology_cuda .erosion_batched_backward (grad_output_view , indexes_view , strel_shape , BLOCK_SHAPE )
61
72
else :
62
73
raise NotImplementedError ("Currently, nnMorpho only supports as input:\n "
63
74
"- 2D tensors of the form (H, W)\n "
@@ -100,7 +111,13 @@ def forward(ctx, *args):
100
111
elif input_tensor .ndim - strel_tensor .ndim == 1 :
101
112
output_tensor , indexes = morphology_cuda .dilation_batched_forward (input_pad , strel_tensor , BLOCK_SHAPE )
102
113
elif input_tensor .ndim - strel_tensor .ndim == 2 :
103
- raise NotImplementedError ("ToDo" )
114
+ batch_channel_dim = input_pad .shape [0 ] * input_pad .shape [1 ]
115
+ input_height = input_pad .shape [2 ]
116
+ input_width = input_pad .shape [3 ]
117
+ input_view = input_pad .view (batch_channel_dim , input_height , input_width )
118
+ output_tensor , indexes = morphology_cuda .dilation_batched_forward (input_view , strel_tensor , BLOCK_SHAPE )
119
+ output_tensor = output_tensor .view (* input_tensor .shape )
120
+ indexes = indexes .view (* input_tensor .shape , 2 )
104
121
else :
105
122
raise NotImplementedError ("Currently, nnMorpho only supports as input:\n "
106
123
"- 2D tensors of the form (H, W)\n "
@@ -123,7 +140,12 @@ def backward(ctx, *grad_outputs):
123
140
elif grad_output .ndim - len (strel_shape ) == 1 :
124
141
result = morphology_cuda .dilation_batched_backward (grad_output , indexes , strel_shape , BLOCK_SHAPE )
125
142
elif grad_output .ndim - len (strel_shape ) == 2 :
126
- raise NotImplementedError ("ToDo" )
143
+ batch_channel_dim = grad_output .shape [0 ] * grad_output .shape [1 ]
144
+ input_height = grad_output .shape [2 ]
145
+ input_width = grad_output .shape [3 ]
146
+ grad_output_view = grad_output .view (batch_channel_dim , input_height , input_width )
147
+ indexes_view = indexes .view (batch_channel_dim , input_height , input_width , 2 )
148
+ result = morphology_cuda .dilation_batched_backward (grad_output_view , indexes_view , strel_shape , BLOCK_SHAPE )
127
149
else :
128
150
raise NotImplementedError ("Currently, nnMorpho only supports as input:\n "
129
151
"- 2D tensors of the form (H, W)\n "
0 commit comments