Skip to content

Commit

Permalink
Merge pull request #676 from bioimage-io/update_unet_example
Browse files Browse the repository at this point in the history
use sigmoid postprocessing in unet2d example
  • Loading branch information
FynnBe authored Dec 10, 2024
2 parents 9983d3e + 6cc7343 commit 3196742
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ outputs:
sample_tensor:
source: test_output.png
sha256: 7bce8b53bcd0a12487a61f953aafe0f3700652848980d1083964c5bcb9555eec
postprocessing:
- id: sigmoid
- id: ensure_dtype
kwargs:
dtype: float32

weights:
pytorch_state_dict:
Expand All @@ -105,7 +110,7 @@ weights:
architecture:
callable: UNet2d
source: unet2d.py
sha256: 7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1
sha256: 589f0c9e60fa00f015213cd251541bcbf9582644f3ecffb2e6f3a30d2af1931a
kwargs: { input_channels: 1, output_channels: 1 }
dependencies:
source: environment.yaml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ outputs:
reference_tensor: raw
scale: [1.0, 0.0, null, 1.0, 1.0]
offset: [0.0, 0.5, 0.5, 0.0, 0.0]
postprocessing:
- name: sigmoid

test_inputs: [test_input.npy]
test_outputs: [test_output_expanded.npy]
Expand All @@ -67,7 +69,7 @@ weights:
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
source: https://zenodo.org/records/3446812/files/unet2d_weights.torch
architecture: unet2d_expand_output_shape.py:UNet2d
architecture_sha256: 80a886acc734f848a8e018d8063cfd7e003d7e20076583b28326bfdd6136be32
architecture_sha256: 1441e8cfaf387a98a1c0bb937d59a2e9d6c311a8912cd88b39c11ecff503ccfe
kwargs: { input_channels: 1, output_channels: 1 }
dependencies: conda:environment.yaml

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ outputs:
data_range: [-.inf, .inf]
halo: [0, 0, 32, 32]
shape: [1, 1, 512, 512]
postprocessing:
- name: sigmoid

dependencies: conda:environment.yaml

Expand All @@ -63,7 +65,7 @@ weights:
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
source: https://zenodo.org/records/3446812/files/unet2d_weights.torch
architecture: unet2d.py:UNet2d
architecture_sha256: 7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1
architecture_sha256: 589f0c9e60fa00f015213cd251541bcbf9582644f3ecffb2e6f3a30d2af1931a
kwargs: { input_channels: 1, output_channels: 1 }
onnx:
sha256: f1f086d5e340f9d4d7001a1b62a2b835f9b87a2fb5452c4fe7d8cc821bdf539c
Expand Down
8 changes: 1 addition & 7 deletions example_descriptions/models/unet2d_nuclei_broad/unet2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def forward(self, input):


class UNet2d(nn.Module):
def __init__(self, input_channels, output_channels, training=False):
def __init__(self, input_channels, output_channels):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
Expand All @@ -41,7 +41,6 @@ def __init__(self, input_channels, output_channels, training=False):
)

self.output = nn.Conv2d(16, self.output_channels, 1)
self.training = training

def conv_layer(self, in_channels, out_channels):
kernel_size = 3
Expand Down Expand Up @@ -78,9 +77,4 @@ def forward(self, input):

x = self.output(x)

# apply a sigmoid directly if we are in inference mode
if not self.training:
# postprocessing
x = torch.sigmoid(x)

return x
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def forward(self, input):


class UNet2d(nn.Module):
def __init__(self, input_channels, output_channels, training=False):
def __init__(self, input_channels, output_channels):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
Expand All @@ -41,7 +41,6 @@ def __init__(self, input_channels, output_channels, training=False):
)

self.output = nn.Conv2d(16, self.output_channels, 1)
self.training = training

def conv_layer(self, in_channels, out_channels):
kernel_size = 3
Expand Down Expand Up @@ -78,11 +77,6 @@ def forward(self, input):

x = self.output(x)

# apply a sigmoid directly if we are in inference mode
if not self.training:
# postprocessing
x = torch.sigmoid(x)

# expand the shape across z
out_shape = tuple(x.shape)
expanded_shape = out_shape[:2] + (1,) + out_shape[2:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ outputs:
reference_tensor: raw
scale: [1.0, 0.0, 1.0, 1.0]
offset: [0.0, 0.5, 0.0, 0.0]
postprocessing:
- name: sigmoid

dependencies: conda:environment.yaml

Expand All @@ -72,7 +74,7 @@ weights:
source: https://zenodo.org/records/3446812/files/unet2d_weights.torch
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
architecture: unet2d.py:UNet2d
architecture_sha256: 7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1
architecture_sha256: 589f0c9e60fa00f015213cd251541bcbf9582644f3ecffb2e6f3a30d2af1931a
kwargs: { input_channels: 1, output_channels: 1 }
onnx:
source: weights.onnx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ outputs:
reference_tensor: raw
scale: [1.0, 0.0, 1.0, 1.0]
offset: [0.0, 0.5, 0.0, 0.0]
postprocessing:
- name: sigmoid

test_inputs: [test_input.npy]
test_outputs: [test_output.npy]
Expand All @@ -69,7 +71,7 @@ weights:
source: https://zenodo.org/records/3446812/files/unet2d_weights.torch
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
architecture: unet2d.py:UNet2d
architecture_sha256: 7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1
architecture_sha256: 589f0c9e60fa00f015213cd251541bcbf9582644f3ecffb2e6f3a30d2af1931a
kwargs: { input_channels: 1, output_channels: 1 }
dependencies: conda:environment.yaml
onnx:
Expand Down

0 comments on commit 3196742

Please sign in to comment.