-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support pytorch acceleration on M1 mac hardware #14
base: main
Are you sure you want to change the base?
Support pytorch acceleration on M1 mac hardware #14
Conversation
Ryan suggests (and I agree) that we turn this code block into a utility function, since that pattern is repeated so many times. I haven't done that already because we're not sure just yet where this utility function should live. # check whether GPU or M1 Mac hardware is available
if torch.cuda.is_available():
device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
model.to(device) |
Problem: The models empanada uses are not completely compatible with the new mps backend (makes sense I guess, it's a very new pytorch feature) 226 coarse_sem_seg_logits,
227 self.train_num_points,
228 self.oversample_ratio,
229 self.importance_sample_ratio,
230 )
232 # sample points at coarse and fine resolutions
233 coarse_sem_seg_points = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)
NotImplementedError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py", line 87, in get_uncertain_point_coords_with_randomness
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
~~~~~~~~~~~~ <--- HERE
point_uncertainties = calculate_uncertainty(point_logits)
File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py", line 55, in point_sample
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(features, 2.0 * point_coords - 1.0, mode=mode, align_corners=align_corners)
~~~~~~~~~~~~~ <--- HERE
if add_dim:
File "/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/functional.py", line 4221, in grid_sample
align_corners = False
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: The operator 'aten::grid_sampler_2d' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS. ...the code just hangs at this point until you press Control+C in the terminal to kill it. Full traceback (click to expand): (conv): Identity()
)
(3): Resample2d(
(conv): Identity()
)
)
(resize_down): Resize2d(
(resample): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
(after_combines): ModuleList(
(0): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(1): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(2): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(3): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
)
)
)
(2): BiFPNLayer(
(top_down_fpn): TopDownFPN(
(resamplings): ModuleList(
(0): Resample2d(
(conv): Identity()
)
(1): Resample2d(
(conv): Identity()
)
(2): Resample2d(
(conv): Identity()
)
(3): Resample2d(
(conv): Identity()
)
)
(resize_up): Resize2d(
(resample): Interpolate2d()
)
(after_combines): ModuleList(
(0): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(1): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(2): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(3): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
)
)
(bottom_up_fpn): BottomUpFPN(
(resamplings): ModuleList(
(0): Resample2d(
(conv): Identity()
)
(1): Resample2d(
(conv): Identity()
)
(2): Resample2d(
(conv): Identity()
)
(3): Resample2d(
(conv): Identity()
)
)
(resize_down): Resize2d(
(resample): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
(after_combines): ModuleList(
(0): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(1): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(2): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
(3): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): SiLU(inplace=True)
)
)
)
)
)
)
(semantic_decoder): BiFPNDecoder(
(upsamplings): ModuleList(
(0): Sequential(
(0): ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(2): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(3): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(4): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(fusion): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256, bias=False)
(1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(semantic_head): PanopticDeepLabHead(
(head): Sequential(
(0): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
)
)
(ins_center): PanopticDeepLabHead(
(head): Sequential(
(0): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
)
)
(ins_xy): PanopticDeepLabHead(
(head): Sequential(
(0): Sequential(
(0): SeparableConv2d(
(sepconv): Sequential(
(0): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=128, bias=False)
(1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
(interpolate): Interpolate2d()
(semantic_pr): PointRendSemSegHead(
(point_head): StandardPointHead(
(fc_layers): ModuleList(
(0): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
(2): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
)
(predictor): Conv1d(129, 1, kernel_size=(1,), stride=(1,))
)
(interpolate): Interpolate2d()
)
)
159 if self.training:
160 # interpolate to original resolution (4x)
161 heads_out['sem_logits'] = self.interpolate(pr_out['sem_seg_logits'])
File ~/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/modules/module.py:1131, in Module._call_impl(self=PointRendSemSegHead(
(point_head): StandardPoi...ride=(1,))
)
(interpolate): Interpolate2d()
), *input=(tensor([[[[0.0088, 0.0088, 0.0088, ..., 0.0089,... device='mps:0', grad_fn=<ConvolutionBackward0>), tensor([[[[0.0000e+00, 7.4089e-02, 8.8945e-01, ...+00]]]], device='mps:0', grad_fn=<ReluBackward0>)), **kwargs={})
1127 # If we don't have any hooks, we want to skip the rest of the logic in
1128 # this function, and just call forward.
1129 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1130 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1131 return forward_call(*input, **kwargs)
forward_call = <bound method PointRendSemSegHead.forward of PointRendSemSegHead(
(point_head): StandardPointHead(
(fc_layers): ModuleList(
(0): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
(2): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
)
(predictor): Conv1d(129, 1, kernel_size=(1,), stride=(1,))
)
(interpolate): Interpolate2d()
)>
input = (tensor([[[[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0087, 0.0088, ..., 0.0090, 0.0088, 0.0089],
[0.0088, 0.0088, 0.0089, ..., 0.0089, 0.0088, 0.0089],
...,
[0.0088, 0.0088, 0.0089, ..., 0.0088, 0.0088, 0.0089],
[0.0088, 0.0089, 0.0088, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0088, ..., 0.0088, 0.0089, 0.0089]]],
[[[0.0088, 0.0089, 0.0088, ..., 0.0090, 0.0089, 0.0089],
[0.0087, 0.0088, 0.0089, ..., 0.0090, 0.0089, 0.0089],
[0.0088, 0.0089, 0.0089, ..., 0.0089, 0.0088, 0.0090],
...,
[0.0088, 0.0088, 0.0088, ..., 0.0090, 0.0088, 0.0089],
[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0089, 0.0089],
[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0089, 0.0089]]],
[[[0.0088, 0.0089, 0.0088, ..., 0.0088, 0.0088, 0.0089],
[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0089, 0.0088, 0.0088],
...,
[0.0089, 0.0088, 0.0089, ..., 0.0088, 0.0087, 0.0089],
[0.0088, 0.0089, 0.0088, ..., 0.0089, 0.0088, 0.0089],
[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0088, 0.0089]]],
...,
[[[0.0089, 0.0088, 0.0088, ..., 0.0088, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0089, 0.0088, 0.0088],
[0.0090, 0.0089, 0.0087, ..., 0.0090, 0.0089, 0.0089],
...,
[0.0088, 0.0088, 0.0088, ..., 0.0088, 0.0088, 0.0089],
[0.0089, 0.0089, 0.0088, ..., 0.0089, 0.0088, 0.0089],
[0.0090, 0.0088, 0.0088, ..., 0.0088, 0.0089, 0.0089]]],
[[[0.0088, 0.0088, 0.0088, ..., 0.0088, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0089, 0.0088, ..., 0.0089, 0.0089, 0.0089],
...,
[0.0088, 0.0087, 0.0088, ..., 0.0088, 0.0089, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0088, 0.0088, 0.0088],
[0.0089, 0.0088, 0.0088, ..., 0.0089, 0.0089, 0.0088]]],
[[[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0089, 0.0089, ..., 0.0088, 0.0088, 0.0088],
[0.0089, 0.0089, 0.0089, ..., 0.0090, 0.0089, 0.0089],
...,
[0.0088, 0.0088, 0.0088, ..., 0.0088, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0088, 0.0088, 0.0088],
[0.0089, 0.0088, 0.0088, ..., 0.0089, 0.0089, 0.0089]]]],
device='mps:0', grad_fn=<ConvolutionBackward0>), tensor([[[[0.0000e+00, 7.4089e-02, 8.8945e-01, ..., 2.1623e+00,
1.0930e+00, 5.5897e-01],
[0.0000e+00, 0.0000e+00, 5.0784e-01, ..., 1.9066e+00,
1.6825e-02, 0.0000e+00],
[0.0000e+00, 9.8485e-01, 0.0000e+00, ..., 1.9920e+00,
1.9017e+00, 5.9719e-01],
...,
[0.0000e+00, 0.0000e+00, 2.8657e-01, ..., 0.0000e+00,
0.0000e+00, 3.7702e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
3.5290e-03, 8.8577e-01],
[0.0000e+00, 2.2756e-01, 7.9611e-01, ..., 0.0000e+00,
0.0000e+00, 5.4088e-01]],
[[7.4642e-01, 9.0818e-01, 4.7434e-01, ..., 1.3986e+00,
0.0000e+00, 3.6472e-01],
[1.2746e+00, 4.2758e-01, 5.6983e-01, ..., 5.9831e-01,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 5.1684e-01, 6.1667e-02, ..., 4.2643e-01,
0.0000e+00, 0.0000e+00],
...,
[1.0358e+00, 9.9598e-01, 6.1302e-01, ..., 2.4554e-01,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 2.4817e-01, 1.2009e-01, ..., 6.6150e-01,
2.0307e-01, 0.0000e+00],
[5.2141e-02, 1.2608e-01, 0.0000e+00, ..., 9.7955e-01,
0.0000e+00, 0.0000e+00]],
[[8.2630e-01, 4.0873e-01, 0.0000e+00, ..., 0.0000e+00,
3.8406e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 3.0580e-03, ..., 1.2760e+00,
1.1749e+00, 2.5873e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 6.7121e-01,
7.1096e-01, 2.5795e-02],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[6.0656e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 7.5030e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
1.6559e-01, 0.0000e+00],
[1.8934e-01, 0.0000e+00, 0.0000e+00, ..., 1.0735e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 6.3372e-01, 1.5395e+00, ..., 9.7852e-01,
7.1311e-01, 2.0924e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.5243e-01,
0.0000e+00, 0.0000e+00],
[4.3880e-01, 7.3612e-01, 0.0000e+00, ..., 7.6816e-01,
0.0000e+00, 2.0065e-01]],
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[9.8340e-01, 0.0000e+00, 1.1624e-01, ..., 0.0000e+00,
1.3949e+00, 0.0000e+00],
[4.4687e-01, 5.6389e-01, 8.3710e-01, ..., 1.3881e+00,
5.1339e-01, 0.0000e+00],
...,
[2.6702e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 6.0383e-01, ..., 0.0000e+00,
3.9091e-01, 0.0000e+00],
[0.0000e+00, 7.8590e-03, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 1.0503e+00, 9.1893e-01, ..., 7.4280e-01,
0.0000e+00, 1.1529e+00],
[0.0000e+00, 5.0458e-01, 0.0000e+00, ..., 2.9885e-03,
2.9599e-02, 5.8381e-01],
[0.0000e+00, 0.0000e+00, 4.0211e-01, ..., 1.3079e-01,
4.2846e-01, 4.0586e-01],
...,
[1.4783e+00, 7.7416e-01, 4.6123e-01, ..., 0.0000e+00,
0.0000e+00, 1.8699e-02],
[1.5407e-01, 7.8448e-01, 0.0000e+00, ..., 0.0000e+00,
4.0555e-01, 4.0005e-01],
[9.8552e-02, 0.0000e+00, 9.7077e-01, ..., 1.1479e+00,
1.3398e+00, 3.9441e-01]]],
[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 2.8061e-01, 4.1951e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[3.1790e-01, 0.0000e+00, 1.0249e+00, ..., 1.1682e+00,
7.6773e-02, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 1.8090e-01, ..., 8.6559e-01,
0.0000e+00, 1.4326e+00],
[0.0000e+00, 6.2360e-01, 5.4497e-01, ..., 3.9382e-01,
0.0000e+00, 0.0000e+00]],
[[1.1216e+00, 4.3993e-01, 4.4988e-01, ..., 6.7284e-02,
0.0000e+00, 0.0000e+00],
[4.0034e-01, 1.2157e+00, 1.2398e+00, ..., 0.0000e+00,
2.3730e-01, 0.0000e+00],
[4.4041e-01, 0.0000e+00, 5.0124e-01, ..., 0.0000e+00,
5.1770e-02, 2.0369e-01],
...,
[1.2808e+00, 1.1491e+00, 5.6730e-01, ..., 0.0000e+00,
1.8608e+00, 0.0000e+00],
[7.6671e-01, 6.2083e-01, 0.0000e+00, ..., 0.0000e+00,
6.7525e-01, 4.5533e-01],
[3.3561e-01, 5.5205e-01, 0.0000e+00, ..., 0.0000e+00,
7.2797e-01, 0.0000e+00]],
[[1.3734e-01, 0.0000e+00, 8.5182e-01, ..., 2.6845e+00,
2.5839e+00, 1.7602e+00],
[2.8231e-01, 0.0000e+00, 9.1393e-01, ..., 2.4226e+00,
1.9898e+00, 1.5943e+00],
[0.0000e+00, 0.0000e+00, 5.9704e-01, ..., 3.1030e+00,
2.2420e+00, 1.2189e+00],
...,
[0.0000e+00, 2.6186e-01, 0.0000e+00, ..., 0.0000e+00,
7.6884e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.9080e-01,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 1.4086e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 5.3720e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 4.0644e-01, 1.6045e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[3.2313e-01, 6.0408e-01, 5.6011e-01, ..., 2.3999e+00,
2.1060e+00, 3.4051e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 9.4139e-01,
8.4063e-01, 8.6017e-01],
[0.0000e+00, 2.0202e-02, 1.0373e+00, ..., 2.1997e-01,
1.9469e+00, 6.8659e-01]],
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
9.1034e-02, 1.4219e+00],
[2.2594e-01, 6.9682e-01, 0.0000e+00, ..., 0.0000e+00,
1.0161e+00, 3.6442e+00],
...,
[0.0000e+00, 6.0895e-01, 0.0000e+00, ..., 0.0000e+00,
1.0281e+00, 2.0246e+00],
[0.0000e+00, 3.7160e-02, 0.0000e+00, ..., 0.0000e+00,
1.3466e-01, 1.0907e+00],
[0.0000e+00, 0.0000e+00, 1.9310e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[1.5424e+00, 1.2874e+00, 0.0000e+00, ..., 0.0000e+00,
1.2070e-01, 5.3125e-02],
[2.9930e-01, 1.4011e+00, 0.0000e+00, ..., 1.0405e+00,
1.0802e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 1.4425e+00, ..., 0.0000e+00,
1.8737e+00, 0.0000e+00],
...,
[4.0814e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
1.3249e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 2.3509e-01, ..., 0.0000e+00,
1.3229e+00, 7.5914e-01],
[0.0000e+00, 1.1340e+00, 6.6128e-01, ..., 5.5020e-01,
1.2950e+00, 9.2822e-01]]],
[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
4.6020e-01, 5.4386e-01],
[0.0000e+00, 1.2850e+00, 1.0804e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[4.8737e-01, 0.0000e+00, 0.0000e+00, ..., 7.4040e-01,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 4.8643e-01,
0.0000e+00, 4.1738e-01],
[1.8504e-01, 8.0441e-01, 1.1717e+00, ..., 6.5370e-01,
0.0000e+00, 6.4115e-01],
[0.0000e+00, 1.8474e-01, 0.0000e+00, ..., 5.2801e-01,
0.0000e+00, 0.0000e+00]],
[[2.5379e-01, 6.9640e-01, 1.3286e+00, ..., 7.4304e-01,
4.0592e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 1.3103e+00, ..., 7.5919e-01,
1.0384e+00, 0.0000e+00],
[1.2091e+00, 1.0910e+00, 1.0791e+00, ..., 1.0534e+00,
2.6434e-01, 0.0000e+00],
...,
[5.7020e-01, 1.5868e+00, 1.1818e+00, ..., 1.4721e+00,
1.4554e+00, 0.0000e+00],
[1.9463e-01, 9.0874e-01, 0.0000e+00, ..., 8.8098e-01,
3.5288e-01, 0.0000e+00],
[4.9881e-01, 0.0000e+00, 0.0000e+00, ..., 1.2071e+00,
1.1014e+00, 0.0000e+00]],
[[4.8452e-01, 5.1103e-01, 2.0605e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 9.0667e-01, ..., 2.8976e-01,
0.0000e+00, 3.9946e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 6.7167e-01,
0.0000e+00, 1.5029e-01],
...,
[0.0000e+00, 0.0000e+00, 2.7325e-01, ..., 5.0072e-01,
3.7166e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 1.5733e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 6.8057e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 1.2125e-01, 1.0418e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 6.3101e-01, 1.5323e+00, ..., 3.9036e-01,
2.3309e-01, 0.0000e+00],
[0.0000e+00, 2.1632e-01, 4.5374e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 1.4529e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[2.1779e-01, 5.8201e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 8.4552e-01, 5.3630e-01, ..., 0.0000e+00,
0.0000e+00, 2.6251e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.4750e+00,
2.9286e-01, 0.0000e+00],
...,
[9.4575e-02, 0.0000e+00, 1.6777e+00, ..., 1.0457e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 4.1746e-01,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.4591e+00,
5.0983e-02, 0.0000e+00]],
[[0.0000e+00, 5.1716e-01, 6.2236e-02, ..., 1.0030e+00,
0.0000e+00, 2.0162e-01],
[1.3101e+00, 1.2060e+00, 2.6098e+00, ..., 4.0982e-01,
4.1046e-01, 3.8965e-01],
[0.0000e+00, 5.1605e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 1.3625e+00, 0.0000e+00, ..., 0.0000e+00,
8.9353e-01, 3.1565e-01],
[0.0000e+00, 0.0000e+00, 2.3708e-01, ..., 0.0000e+00,
5.3425e-01, 4.1147e-01],
[0.0000e+00, 0.0000e+00, 7.1393e-01, ..., 9.8011e-01,
1.9881e-01, 5.7963e-01]]],
...,
[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
4.3204e-01, 8.7892e-01],
[0.0000e+00, 1.7989e+00, 2.9548e+00, ..., 1.6490e+00,
1.9274e-01, 1.7261e+00],
[8.3221e-01, 0.0000e+00, 7.9636e-01, ..., 1.7252e+00,
1.9930e-01, 1.1918e+00],
...,
[0.0000e+00, 0.0000e+00, 3.6827e+00, ..., 5.3516e-01,
6.7847e-01, 2.1678e+00],
[0.0000e+00, 0.0000e+00, 2.7483e+00, ..., 1.3287e+00,
5.4538e-02, 1.0777e+00],
[0.0000e+00, 0.0000e+00, 1.9590e+00, ..., 5.5577e-01,
0.0000e+00, 6.2696e-01]],
[[1.8083e-01, 1.8776e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 8.3434e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 2.2160e-01, 5.0457e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 7.1693e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 1.8506e+00, 0.0000e+00, ..., 0.0000e+00,
5.1740e-01, 0.0000e+00],
[0.0000e+00, 2.4024e+00, 1.3335e-01, ..., 4.0303e-01,
7.4837e-01, 4.9973e-01]],
[[0.0000e+00, 5.3818e-01, 1.1546e+00, ..., 0.0000e+00,
0.0000e+00, 6.6571e-02],
[3.7659e-01, 1.5354e+00, 1.8003e+00, ..., 3.0831e-01,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 6.5659e-01, ..., 2.7220e-01,
3.7993e-01, 0.0000e+00],
...,
[0.0000e+00, 1.8084e+00, 0.0000e+00, ..., 1.5273e+00,
0.0000e+00, 6.1346e-01],
[0.0000e+00, 1.4131e-01, 0.0000e+00, ..., 0.0000e+00,
1.1623e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.4635e-01,
0.0000e+00, 0.0000e+00]],
...,
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 4.6800e-01],
[0.0000e+00, 5.5588e-01, 0.0000e+00, ..., 1.2340e-01,
5.4171e-02, 2.2038e-01],
[0.0000e+00, 0.0000e+00, 5.7616e-01, ..., 0.0000e+00,
0.0000e+00, 9.3509e-01],
...,
[0.0000e+00, 5.0771e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 2.3833e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 2.3509e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.4870e-01,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 1.5589e-02],
[2.0271e-01, 6.5152e-01, 5.0315e-01, ..., 1.2720e+00,
1.0532e-02, 2.0511e-02],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.2694e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 8.6492e-01,
6.2874e-01, 1.1512e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 7.2908e-01,
1.4395e+00, 4.7596e-02],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 7.8735e-01, 0.0000e+00, ..., 1.1731e+00,
1.5951e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.2886e+00,
2.5239e+00, 0.0000e+00],
[1.8061e+00, 4.2019e+00, 1.7260e+00, ..., 3.2810e+00,
2.4222e+00, 0.0000e+00],
...,
[1.8344e+00, 2.0851e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.3024e+00, 2.2217e+00, 0.0000e+00, ..., 0.0000e+00,
3.0012e-01, 0.0000e+00],
[6.3968e-01, 1.8986e+00, 0.0000e+00, ..., 4.2057e-01,
0.0000e+00, 0.0000e+00]]],
[[[0.0000e+00, 0.0000e+00, 2.4025e-01, ..., 0.0000e+00,
1.3822e-01, 5.1537e-01],
[0.0000e+00, 0.0000e+00, 3.4819e-01, ..., 1.4473e-01,
0.0000e+00, 8.3819e-01],
[1.8107e-01, 4.5902e-02, 1.8605e+00, ..., 9.5985e-01,
1.5111e-01, 6.0978e-01],
...,
[4.1359e-01, 1.0578e-01, 1.6421e+00, ..., 2.8347e-03,
0.0000e+00, 2.4505e-01],
[3.3733e-01, 4.0608e-01, 1.8206e+00, ..., 1.8545e-01,
0.0000e+00, 0.0000e+00],
[5.0725e-01, 0.0000e+00, 9.6854e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[1.2885e-01, 1.4179e+00, 0.0000e+00, ..., 3.9084e-01,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 7.5183e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 1.5614e-01, 0.0000e+00, ..., 0.0000e+00,
4.5392e-01, 0.0000e+00],
...,
[0.0000e+00, 7.1066e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 1.7085e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[4.4407e-01, 8.6787e-01, 0.0000e+00, ..., 2.6774e-01,
1.1174e+00, 5.8951e-01]],
[[0.0000e+00, 8.2225e-02, 2.7425e-01, ..., 0.0000e+00,
1.7316e-01, 3.5523e-01],
[0.0000e+00, 3.0014e-01, 3.7909e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
2.4758e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
3.3341e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[0.0000e+00, 3.6318e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 2.8388e-01],
[0.0000e+00, 3.5971e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 8.2956e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 6.8531e-02],
...,
[0.0000e+00, 4.9203e-01, 0.0000e+00, ..., 2.7257e-01,
0.0000e+00, 5.1685e-01],
[0.0000e+00, 6.0935e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 2.0687e-01],
[0.0000e+00, 5.5951e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.9856e-01,
0.0000e+00, 1.9472e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[7.9620e-01, 1.2160e-01, 2.1168e-01, ..., 1.3841e+00,
0.0000e+00, 0.0000e+00],
...,
[2.7774e-01, 1.0537e+00, 1.5652e+00, ..., 1.2279e+00,
6.4342e-02, 0.0000e+00],
[0.0000e+00, 5.9066e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 3.4963e-01, 6.6080e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[1.3104e-01, 8.6589e-01, 0.0000e+00, ..., 1.1465e+00,
1.5749e-01, 1.8180e-01],
[3.0964e-01, 6.1976e-01, 2.7942e-01, ..., 1.5026e+00,
1.4885e+00, 6.7010e-03],
[5.3211e-01, 4.9338e-01, 0.0000e+00, ..., 4.0494e-01,
1.3100e+00, 0.0000e+00],
...,
[0.0000e+00, 5.1027e-01, 0.0000e+00, ..., 0.0000e+00,
1.0122e+00, 0.0000e+00],
[5.1633e-03, 1.4668e+00, 6.6400e-01, ..., 8.0338e-01,
1.3246e+00, 1.9260e-01],
[7.0187e-02, 1.1304e+00, 7.0275e-01, ..., 9.5207e-01,
1.1219e+00, 0.0000e+00]]],
[[[0.0000e+00, 0.0000e+00, 4.6005e-01, ..., 0.0000e+00,
6.9761e-02, 5.6667e-01],
[0.0000e+00, 0.0000e+00, 3.2581e-01, ..., 3.6165e-01,
0.0000e+00, 9.5707e-01],
[0.0000e+00, 0.0000e+00, 2.1238e+00, ..., 5.9505e-01,
1.5425e-01, 4.0622e-01],
...,
[2.4131e-01, 0.0000e+00, 2.5841e+00, ..., 4.9629e-02,
1.4180e-01, 6.3241e-02],
[0.0000e+00, 1.5117e-01, 2.1144e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.9965e-01, 0.0000e+00, 8.6311e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 2.1975e+00, 0.0000e+00, ..., 4.4226e-01,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 1.1389e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 9.7757e-01, 0.0000e+00, ..., 6.1028e-02,
9.4302e-01, 0.0000e+00],
...,
[0.0000e+00, 6.7376e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 5.8245e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[6.1490e-01, 9.7687e-01, 0.0000e+00, ..., 5.5859e-01,
1.4821e+00, 5.2390e-01]],
[[0.0000e+00, 3.7543e-02, 2.7192e-01, ..., 0.0000e+00,
1.7679e-01, 2.0123e-01],
[0.0000e+00, 7.9174e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 4.1048e-02],
[0.0000e+00, 6.7443e-01, 1.6642e-02, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 3.0269e-01, 0.0000e+00, ..., 0.0000e+00,
1.1899e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
5.5099e-01, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
...,
[[0.0000e+00, 7.4942e-01, 4.6713e-01, ..., 2.2149e-01,
0.0000e+00, 3.2799e-01],
[0.0000e+00, 8.9299e-01, 1.3694e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 1.2217e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 3.1208e-01],
...,
[0.0000e+00, 1.3503e+00, 0.0000e+00, ..., 8.0093e-02,
0.0000e+00, 7.0693e-01],
[0.0000e+00, 1.2291e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 2.7166e-01],
[0.0000e+00, 7.2925e-01, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.6526e-01,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[2.9020e-01, 0.0000e+00, 3.7589e-01, ..., 1.2548e+00,
0.0000e+00, 0.0000e+00],
...,
[0.0000e+00, 1.0710e+00, 1.2292e+00, ..., 7.4159e-01,
1.7804e-01, 0.0000e+00],
[0.0000e+00, 7.8056e-01, 4.1537e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[0.0000e+00, 4.1615e-01, 6.2823e-01, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[0.0000e+00, 1.3807e+00, 0.0000e+00, ..., 7.0208e-01,
0.0000e+00, 2.5674e-01],
[6.8659e-01, 7.8201e-01, 2.9957e-01, ..., 1.4700e+00,
1.8378e+00, 0.0000e+00],
[2.6428e-01, 1.9601e+00, 0.0000e+00, ..., 6.3278e-01,
1.5842e+00, 0.0000e+00],
...,
[0.0000e+00, 1.3229e+00, 0.0000e+00, ..., 0.0000e+00,
1.2955e+00, 0.0000e+00],
[0.0000e+00, 1.7383e+00, 0.0000e+00, ..., 6.2228e-01,
1.0436e+00, 0.0000e+00],
[0.0000e+00, 1.6015e+00, 4.6638e-01, ..., 9.1203e-01,
1.2354e+00, 0.0000e+00]]]], device='mps:0', grad_fn=<ReluBackward0>))
kwargs = {}
1132 # Do not call functions when jit is used
1133 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py:225, in PointRendSemSegHead.forward(self=PointRendSemSegHead(
(point_head): StandardPoi...ride=(1,))
)
(interpolate): Interpolate2d()
), coarse_sem_seg_logits=tensor([[[[0.0088, 0.0088, 0.0088, ..., 0.0089,... device='mps:0', grad_fn=<ConvolutionBackward0>), features=tensor([[[[0.0000e+00, 7.4089e-02, 8.8945e-01, ...+00]]]], device='mps:0', grad_fn=<ReluBackward0>))
222 if self.training:
223 # pick the points to apply point rend
224 with torch.no_grad():
--> 225 point_coords = get_uncertain_point_coords_with_randomness(
get_uncertain_point_coords_with_randomness = <torch.jit.ScriptFunction object at 0x28b911c20>
coarse_sem_seg_logits = tensor([[[[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0087, 0.0088, ..., 0.0090, 0.0088, 0.0089],
[0.0088, 0.0088, 0.0089, ..., 0.0089, 0.0088, 0.0089],
...,
[0.0088, 0.0088, 0.0089, ..., 0.0088, 0.0088, 0.0089],
[0.0088, 0.0089, 0.0088, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0088, ..., 0.0088, 0.0089, 0.0089]]],
[[[0.0088, 0.0089, 0.0088, ..., 0.0090, 0.0089, 0.0089],
[0.0087, 0.0088, 0.0089, ..., 0.0090, 0.0089, 0.0089],
[0.0088, 0.0089, 0.0089, ..., 0.0089, 0.0088, 0.0090],
...,
[0.0088, 0.0088, 0.0088, ..., 0.0090, 0.0088, 0.0089],
[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0089, 0.0089],
[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0089, 0.0089]]],
[[[0.0088, 0.0089, 0.0088, ..., 0.0088, 0.0088, 0.0089],
[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0089, 0.0088, 0.0088],
...,
[0.0089, 0.0088, 0.0089, ..., 0.0088, 0.0087, 0.0089],
[0.0088, 0.0089, 0.0088, ..., 0.0089, 0.0088, 0.0089],
[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0088, 0.0089]]],
...,
[[[0.0089, 0.0088, 0.0088, ..., 0.0088, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0089, 0.0088, 0.0088],
[0.0090, 0.0089, 0.0087, ..., 0.0090, 0.0089, 0.0089],
...,
[0.0088, 0.0088, 0.0088, ..., 0.0088, 0.0088, 0.0089],
[0.0089, 0.0089, 0.0088, ..., 0.0089, 0.0088, 0.0089],
[0.0090, 0.0088, 0.0088, ..., 0.0088, 0.0089, 0.0089]]],
[[[0.0088, 0.0088, 0.0088, ..., 0.0088, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0089, 0.0088, ..., 0.0089, 0.0089, 0.0089],
...,
[0.0088, 0.0087, 0.0088, ..., 0.0088, 0.0089, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0088, 0.0088, 0.0088],
[0.0089, 0.0088, 0.0088, ..., 0.0089, 0.0089, 0.0088]]],
[[[0.0088, 0.0088, 0.0088, ..., 0.0089, 0.0088, 0.0088],
[0.0088, 0.0089, 0.0089, ..., 0.0088, 0.0088, 0.0088],
[0.0089, 0.0089, 0.0089, ..., 0.0090, 0.0089, 0.0089],
...,
[0.0088, 0.0088, 0.0088, ..., 0.0088, 0.0088, 0.0088],
[0.0088, 0.0088, 0.0089, ..., 0.0088, 0.0088, 0.0088],
[0.0089, 0.0088, 0.0088, ..., 0.0089, 0.0089, 0.0089]]]],
device='mps:0', grad_fn=<ConvolutionBackward0>)
self = PointRendSemSegHead(
(point_head): StandardPointHead(
(fc_layers): ModuleList(
(0): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
(1): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
(2): Sequential(
(0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
(1): ReLU(inplace=True)
)
)
(predictor): Conv1d(129, 1, kernel_size=(1,), stride=(1,))
)
(interpolate): Interpolate2d()
)
self.train_num_points = 1024
self.oversample_ratio = 3
self.importance_sample_ratio = 0.75
226 coarse_sem_seg_logits,
227 self.train_num_points,
228 self.oversample_ratio,
229 self.importance_sample_ratio,
230 )
232 # sample points at coarse and fine resolutions
233 coarse_sem_seg_points = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)
NotImplementedError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py", line 87, in get_uncertain_point_coords_with_randomness
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
~~~~~~~~~~~~ <--- HERE
point_uncertainties = calculate_uncertainty(point_logits)
File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py", line 55, in point_sample
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(features, 2.0 * point_coords - 1.0, mode=mode, align_corners=align_corners)
~~~~~~~~~~~~~ <--- HERE
if add_dim:
File "/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/functional.py", line 4221, in grid_sample
align_corners = False
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: The operator 'aten::grid_sampler_2d' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS. |
Ryan says the model incompatibility is only in the last layer of the model, so we could potentially remove it and replace it with a simpler upsampling. He already has code for that in empanada, but empanada-napari will need to be edited so it doesn't cause other errors there (mostly in the inference part of the empanada-napari plugin). |
Results of experiment using script to run on M1 hardware tl:dr - not great
Config file `mps_config.yml`, modelled on `GitHub/empanada/projects/mitonet/configs/mmm_panoptic_deeplab_pointrend.yaml` (click to expand):DATASET:
dataset_name: "Baselines"
class_names: { 1: "membrane" }
labels: [ 1 ]
thing_list: [ ]
norms: { mean: 0.508979, std: 0.148561 }
MODEL:
arch: "PanopticDeepLab"
encoder: "resnet50"
num_classes: 1
stage4_stride: 16
decoder_channels: 256
low_level_stages: [ 1 ]
low_level_channels_project: [ 32 ]
atrous_rates: [ 2, 4, 6 ]
aspp_channels: null
aspp_dropout: 0.5
ins_decoder: False
ins_ratio: 0.5
TRAIN:
run_name: "Panoptic DeepLab Baseline"
# image and model directories
train_dir: "/Users/genevieb/Documents/Projects/empanada/training-nuclei-membrane-patches/kidney-nuclei-membrane-patches/"
additional_train_dirs: null
model_dir: "/Users/genevieb/Documents/Projects/empanada/nuclei-membrane-model/"
save_freq: 1
# path to .pth file for resuming training
resume: null
# pretraining parameters
encoder_pretraining: "/Users/genevieb/.empanada/checkpoints/cem1.5m_swav_resnet50_200ep_balanced.pth.tar"
whole_pretraining: null
finetune_layer: "none"
# set the lr schedule
lr_schedule: "OneCycleLR"
schedule_params:
max_lr: 0.003
epochs: 30
steps_per_epoch: -1
pct_start: 0.3
# setup the optimizer
amp: True # automatic mixed precision
optimizer: "AdamW"
optimizer_params:
weight_decay: 0.1
# criterion parameters
criterion: "PanopticLoss"
criterion_params:
ce_weight: 1
mse_weight: 200
l1_weight: 0.01
top_k_percent: 0.2
pr_weight: 1
# performance metrics
print_freq: 50
metrics:
- { metric: "IoU", name: "semantic_iou", labels: [ 1 ], output_key: "sem_logits", target_key: "sem"}
# dataset parameters
batch_size: 16
workers: 0
dataset_class: "SingleClassInstanceDataset"
dataset_params:
weight_gamma: 0.7
augmentations:
- { aug: "RandomScale", scale_limit: [ -0.9, 1 ]}
- { aug: "PadIfNeeded", min_height: 256, min_width: 256, border_mode: 0 }
- { aug: "RandomCrop", height: 256, width: 256}
- { aug: "Rotate", limit: 180, border_mode: 0 }
- { aug: "RandomBrightnessContrast", brightness_limit: 0.3, contrast_limit: 0.3 }
- { aug: "HorizontalFlip" }
- { aug: "VerticalFlip" }
# distributed training parameters
multiprocessing_distributed: False
gpu: null
mps: 'mps'
world_size: 1
rank: 0
dist_url: "tcp://localhost:10001"
dist_backend: "nccl"
EVAL:
eval_dir: "/Users/genevieb/Documents/Projects/empanada/training-nuclei-membrane-patches/kidney-nuclei-membrane-patches/"
eval_track_indices: null # from test_eval
eval_track_freq: 1
epochs_per_eval: 1
# parameters needed for eval_metrics
metrics:
- { metric: "IoU", name: "semantic_iou", labels: [ 1 ], output_key: "sem_logits", target_key: "sem"}
# parameters needed for inference
engine: "PanopticDeepLabEngine"
engine_params:
thing_list: [ ]
label_divisor: 1000
stuff_area: 64
void_label: 0
nms_threshold: 0.1
nms_kernel: 7
confidence_thr: 0.5 Contents of `mps_train.py script, modelled on `GitHub/empanda/scripts/train.py` script (click to expand):import os
import time
import yaml
import argparse
import mlflow
import random
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from skimage import io
from skimage import measure
from matplotlib import pyplot as plt
from empanada import losses
from empanada import data
from empanada import metrics
from empanada import models
from empanada.inference import engines
from empanada.config_loaders import load_config
from empanada.data.utils.sampler import DistributedWeightedSampler
from empanada.data.utils.transforms import FactorPad
archs = sorted(name for name in models.__dict__
if callable(models.__dict__[name])
)
schedules = sorted(name for name in lr_scheduler.__dict__
if callable(lr_scheduler.__dict__[name]) and not name.startswith('__')
and name[0].isupper()
)
optimizers = sorted(name for name in optim.__dict__
if callable(optim.__dict__[name]) and not name.startswith('__')
and name[0].isupper()
)
augmentations = sorted(name for name in A.__dict__
if callable(A.__dict__[name]) and not name.startswith('__')
and name[0].isupper()
)
datasets = sorted(name for name in data.__dict__
if callable(data.__dict__[name])
)
engine_names = sorted(name for name in engines.__dict__
if callable(engines.__dict__[name])
)
loss_names = sorted(name for name in losses.__dict__
if callable(losses.__dict__[name])
)
def parse_args():
parser = argparse.ArgumentParser(description='Runs panoptic deeplab training')
parser.add_argument('config', type=str, metavar='config', help='Path to a config yaml file')
return parser.parse_args()
def main():
args = parse_args()
# read the config file
config = load_config(args.config)
config['config_file'] = args.config
config['config_name'] = os.path.basename(args.config).split('.yaml')[0]
# create model directory if None
if not os.path.isdir(config['TRAIN']['model_dir']):
os.mkdir(config['TRAIN']['model_dir'])
# validate parameters
assert config['MODEL']['arch'] in archs
assert config['TRAIN']['lr_schedule'] in schedules
assert config['TRAIN']['optimizer'] in optimizers
assert config['TRAIN']['criterion'] in loss_names
assert config['EVAL']['engine'] in engine_names
if config['TRAIN']['dist_url'] == "env://" and config['TRAIN']['world_size'] == -1:
config['TRAIN']['world_size'] = int(os.environ["WORLD_SIZE"])
ngpus_per_node = torch.cuda.device_count()
if config['TRAIN']['multiprocessing_distributed']:
# Since we have ngpus_per_node processes per node, the total world_size
# needs to be adjusted accordingly
config['TRAIN']['world_size'] = ngpus_per_node * config['TRAIN']['world_size']
# Use torch.multiprocessing.spawn to launch distributed processes: the
# main_worker process function
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config))
else:
# Simply call main_worker function
main_worker(config['TRAIN']['gpu'], ngpus_per_node, config)
def main_worker(gpu, ngpus_per_node, config):
config['gpu'] = gpu
if config['gpu'] is not None:
print(f"Use GPU: {gpu} for training")
if config['TRAIN']['multiprocessing_distributed']:
if config['TRAIN']['dist_url'] == "env://" and config['TRAIN']['rank'] == -1:
config['TRAIN']['rank'] = int(os.environ["RANK"])
if config['TRAIN']['multiprocessing_distributed']:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
config['TRAIN']['rank'] = config['TRAIN']['rank'] * ngpus_per_node + config['gpu']
dist.init_process_group(backend=config['TRAIN']['dist_backend'], init_method=config['TRAIN']['dist_url'],
world_size=config['TRAIN']['world_size'], rank=config['TRAIN']['rank'])
# setup the model and pick dataset class
model_arch = config['MODEL']['arch']
model = models.__dict__[model_arch](**config['MODEL'])
dataset_class_name = config['TRAIN']['dataset_class']
data_cls = data.__dict__[dataset_class_name]
# load pre-trained weights, if using
if config['TRAIN']['whole_pretraining'] is not None:
state = torch.load(config['TRAIN']['whole_pretraining'], map_location='cpu')
state_dict = state['state_dict']
# remove the prefix 'module' from all of the keys
for k in list(state_dict.keys()):
if k.startswith('module'):
state_dict[k[len('module.'):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state['state_dict'], strict=True)
norms = state['norms']
elif config['TRAIN']['encoder_pretraining'] is not None:
state = torch.load(config['TRAIN']['encoder_pretraining'], map_location='cpu')
state_dict = state['state_dict']
# add the prefix 'encoder' to all of the keys
for k in list(state_dict.keys()):
if not k.startswith('fc'):
state_dict['encoder.' + k] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state['state_dict'], strict=False)
norms = {}
norms['mean'] = state['norms'][0]
norms['std'] = state['norms'][1]
else:
norms = config['DATASET']['norms']
finetune_layer = config['TRAIN']['finetune_layer']
# start by freezing all encoder parameters
for pname, param in model.named_parameters():
if 'encoder' in pname:
param.requires_grad = False
# freeze encoder layers
if finetune_layer == 'none':
# leave all encoder layers frozen
pass
elif finetune_layer == 'all':
# unfreeze all encoder parameters
for pname, param in model.named_parameters():
if 'encoder' in pname:
param.requires_grad = True
else:
valid_layers = ['stage1', 'stage2', 'stage3', 'stage4']
assert finetune_layer in valid_layers
# unfreeze all layers from finetune_layer onward
for layer_name in valid_layers[valid_layers.index(finetune_layer):]:
# freeze all encoder parameters
for pname, param in model.named_parameters():
if f'encoder.{layer_name}' in pname:
param.requires_grad = True
num_trainable = sum(p[1].numel() for p in model.named_parameters() if p[1].requires_grad)
print(f'Model with {num_trainable} trainable parameters.')
# CPU
if not torch.cuda.is_available() and not torch.backends.mps.is_available():
print('Using CPU, this will be slow')
# M1 Mac
if torch.backends.mps.is_available():
model = model.to('mps')
# GPU
elif config['TRAIN']['multiprocessing_distributed']:
# use Synced batchnorm
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if config['gpu'] is not None:
torch.cuda.set_device(config['gpu'])
model.cuda(config['gpu'])
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
config['TRAIN']['batch_size'] = int(config['TRAIN']['batch_size'] / ngpus_per_node)
config['TRAIN']['workers'] = int((config['TRAIN']['workers'] + ngpus_per_node - 1) / ngpus_per_node)
model = DDP(model, device_ids=[config['gpu']])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = DDP(model)
elif config['gpu'] is not None:
torch.cuda.set_device(config['gpu'])
model = model.cuda(config['gpu'])
else:
# script the model
#model = torch.jit.script(model, optimize=True)
model = torch.nn.DataParallel(model).cuda()
#raise Exception
cudnn.benchmark = True
# set the training image augmentations
config['aug_string'] = []
dataset_augs = []
for aug_params in config['TRAIN']['augmentations']:
aug_name = aug_params['aug']
assert aug_name in augmentations or aug_name == 'CopyPaste', \
f'{aug_name} is not a valid albumentations augmentation!'
config['aug_string'].append(aug_params['aug'])
del aug_params['aug']
if aug_name == 'CopyPaste':
dataset_augs.append(CopyPaste(**aug_params))
else:
dataset_augs.append(A.__dict__[aug_name](**aug_params))
config['aug_string'] = ','.join(config['aug_string'])
tfs = A.Compose([
*dataset_augs,
A.Normalize(**norms),
ToTensorV2()
])
# create training dataset and loader
train_dataset = data_cls(config['TRAIN']['train_dir'], transforms=tfs, **config['TRAIN']['dataset_params'])
if config['TRAIN']['additional_train_dirs'] is not None:
for train_dir in config['TRAIN']['additional_train_dirs']:
add_dataset = data_cls(train_dir, transforms=tfs, **config['TRAIN']['dataset_params'])
train_dataset = train_dataset + add_dataset
if config['TRAIN']['multiprocessing_distributed']:
if config['TRAIN']['dataset_params']['weight_gamma'] is None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = DistributedWeightedSampler(train_dataset, train_dataset.weights)
elif config['TRAIN']['dataset_params']['weight_gamma'] is not None:
train_sampler = WeightedRandomSampler(train_dataset.weights, len(train_dataset))
else:
train_sampler = None
# num workers always less than number of batches in train dataset
num_workers = min(config['TRAIN']['workers'], len(train_dataset) // config['TRAIN']['batch_size'])
train_loader = DataLoader(
train_dataset, batch_size=config['TRAIN']['batch_size'], shuffle=(train_sampler is None),
num_workers=config['TRAIN']['workers'], pin_memory=torch.cuda.is_available(), sampler=train_sampler,
drop_last=True
)
if config['EVAL']['eval_dir'] is not None:
eval_tfs = A.Compose([
FactorPad(128), # pad image to be divisible by 128
A.Normalize(**norms),
ToTensorV2()
])
eval_dataset = data_cls(config['EVAL']['eval_dir'], transforms=eval_tfs, **config['TRAIN']['dataset_params'])
# evaluation runs on a single gpu
eval_loader = DataLoader(eval_dataset, batch_size=1, shuffle=False,
pin_memory=torch.cuda.is_available(),
num_workers=config['TRAIN']['workers'])
# pick images to track during validation
if config['EVAL']['eval_track_indices'] is None:
# randomly pick 8 examples from eval set to track
config['EVAL']['eval_track_indices'] = [random.randint(0, len(eval_dataset)) for _ in range(8)]
else:
eval_loader = None
# set criterion
criterion_name = config['TRAIN']['criterion']
if config['gpu'] is not None:
criterion = losses.__dict__[criterion_name](**config['TRAIN']['criterion_params']).cuda(config['gpu'])
else:
criterion = losses.__dict__[criterion_name](**config['TRAIN']['criterion_params']).cuda()
# set optimizer and lr scheduler
opt_name = config['TRAIN']['optimizer']
opt_params = config['TRAIN']['optimizer_params']
optimizer = configure_optimizer(model, opt_name, **opt_params)
schedule_name = config['TRAIN']['lr_schedule']
schedule_params = config['TRAIN']['schedule_params']
if 'steps_per_epoch' in schedule_params:
n_steps = schedule_params['steps_per_epoch']
if n_steps != len(train_loader):
schedule_params['steps_per_epoch'] = len(train_loader)
print(f'Steps per epoch adjusted from {n_steps} to {len(train_loader)}')
scheduler = lr_scheduler.__dict__[schedule_name](optimizer, **schedule_params)
scaler = GradScaler() if config['TRAIN']['amp'] else None
# optionally resume from a checkpoint
config['run_id'] = None
config['start_epoch'] = 0
if config['TRAIN']['resume'] is not None:
if os.path.isfile(config['TRAIN']['resume']):
print("=> loading checkpoint")
if config['gpu'] is None:
checkpoint = torch.load(config['TRAIN']['resume'])
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(config['gpu'])
checkpoint = torch.load(config['TRAIN']['resume'], map_location=loc)
config['start_epoch'] = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
if scaler is not None:
scaler.load_state_dict(checkpoint['scaler'])
# use the saved norms
norms = checkpoint['norms']
config['run_id'] = checkpoint['run_id']
print("=> loaded checkpoint '{}' (epoch {})"
.format(config['TRAIN']['resume'], checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(config['TRAIN']['resume']))
# training and evaluation loop
if 'epochs' in config['TRAIN']['schedule_params']:
epochs = config['TRAIN']['schedule_params']['epochs']
elif 'epochs' in config['TRAIN']:
epochs = config['TRAIN']['epochs']
else:
raise Exception('Number of training epochs not defined!')
config['TRAIN']['epochs'] = epochs
# log important parameters and start/resume mlflow run
prepare_logging(config)
for epoch in range(config['start_epoch'], epochs):
if config['TRAIN']['multiprocessing_distributed']:
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer,
scheduler, scaler, epoch, config)
is_distributed = config['TRAIN']['multiprocessing_distributed']
gpu_rank = config['TRAIN']['rank'] % ngpus_per_node
# evaluate on validation set, does not support multiGPU
if eval_loader is not None and (epoch + 1) % config['EVAL']['epochs_per_eval'] == 0:
if not is_distributed or (is_distributed and gpu_rank == 0):
validate(eval_loader, model, criterion, epoch, config)
save_now = (epoch + 1) % config['TRAIN']['save_freq'] == 0
if save_now and not is_distributed or (is_distributed and gpu_rank == 0):
save_checkpoint({
'epoch': epoch + 1,
'arch': config['MODEL']['arch'],
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'scaler': scaler.state_dict(),
'run_id': mlflow.active_run().info.run_id,
'norms': norms
}, os.path.join(config['TRAIN']['model_dir'], f"{config['config_name']}_checkpoint.pth.tar"))
def save_checkpoint(state, filename='checkpoint.pth.tar'):
torch.save(state, filename)
def prepare_logging(config):
# log parameters for run, or resume existing run
if config['run_id'] is None and config['TRAIN']['rank'] == 0:
# log parameters in mlflow
mlflow.end_run()
mlflow.set_experiment(config['DATASET']['dataset_name'])
# log the full config file after inheritance
artifact_path = 'mlruns/' + mlflow.get_artifact_uri().split('/mlruns/')[-1]
config_fp = os.path.join(artifact_path, os.path.basename(config['config_file']))
with open(config_fp, mode='w') as f:
yaml.dump(config, f)
#we don't want to add everything in the config
#to mlflow parameters, we'll just add the most
#likely to change parameters
mlflow.log_param('run_name', config['TRAIN']['run_name'])
mlflow.log_param('architecture', config['MODEL']['arch'])
mlflow.log_param('encoder_pretraining', config['TRAIN']['encoder_pretraining'])
mlflow.log_param('whole_pretraining', config['TRAIN']['whole_pretraining'])
mlflow.log_param('epochs', config['TRAIN']['epochs'])
mlflow.log_param('batch_size', config['TRAIN']['batch_size'])
mlflow.log_param('lr_schedule', config['TRAIN']['lr_schedule'])
mlflow.log_param('optimizer', config['TRAIN']['optimizer'])
aug_names = config['aug_string']
mlflow.log_param('augmentations', aug_names)
else:
# resume existing run
mlflow.start_run(run_id=config['run_id'])
def log_metrics(progress, meters, epoch, dataset='Train'):
# log all the losses from progress
for meter in progress.meters:
mlflow.log_metric(f'{dataset}_{meter.name}', meter.avg, step=epoch)
for metric_name, values in meters.history.items():
mlflow.log_metric(f'{dataset}_{metric_name}', values[-1], step=epoch)
def configure_optimizer(model, opt_name, **opt_params):
"""
Takes an optimizer and separates parameters into two groups
that either use weight decay or are exempt.
Only BatchNorm parameters and biases are excluded.
"""
# easy if there's no weight_decay
if 'weight_decay' not in opt_params:
return optim.__dict__[opt_name](model.parameters(), **opt_params)
elif opt_params['weight_decay'] == 0:
return optim.__dict__[opt_name](model.parameters(), **opt_params)
# otherwise separate parameters into two groups
decay = set()
no_decay = set()
blacklist = (torch.nn.BatchNorm2d,)
for mn, m in model.named_modules():
for pn, p in m.named_parameters(recurse=False):
full_name = '%s.%s' % (mn, pn) if mn else pn
if full_name.endswith('bias'):
no_decay.add(full_name)
elif full_name.endswith('weight') and isinstance(m, blacklist):
no_decay.add(full_name)
else:
decay.add(full_name)
param_dict = {pn: p for pn, p in model.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert(len(inter_params) == 0), "Overlapping decay and no decay"
assert(len(param_dict.keys() - union_params) == 0), "Missing decay parameters"
decay_params = [param_dict[pn] for pn in sorted(list(decay))]
no_decay_params = [param_dict[pn] for pn in sorted(list(no_decay))]
param_groups = [
{"params": decay_params, **opt_params},
{"params": no_decay_params, **opt_params}
]
param_groups[1]['weight_decay'] = 0 # overwrite default to 0 for no_decay group
return optim.__dict__[opt_name](param_groups, **opt_params)
def train(
train_loader,
model,
criterion,
optimizer,
scheduler,
scaler,
epoch,
config
):
# generic progress
batch_time = ProgressAverageMeter('Time', ':6.3f')
data_time = ProgressAverageMeter('Data', ':6.3f')
loss_meters = None
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time],
prefix="Epoch: [{}]".format(epoch)
)
# end of epoch metrics
class_names = config['DATASET']['class_names']
metric_dict = {}
for metric_params in config['TRAIN']['metrics']:
reg_name = metric_params['name']
metric_name = metric_params['metric']
metric_params = {k: v for k,v in metric_params.items() if k not in ['name', 'metric']}
metric_dict[reg_name] = metrics.__dict__[metric_name](metrics.EMAMeter, **metric_params)
meters = metrics.ComposeMetrics(metric_dict, class_names)
# switch to train mode
model.train()
end = time.time()
for i, batch in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
images = batch['image']
target = {k: v for k,v in batch.items() if k not in ['image', 'fname']}
if config.get('gpu') is not None and torch.cuda.is_available():
images = images.to(config['gpu'], non_blocking=True)
target = {k: tensor.to(config['gpu'], non_blocking=True)
for k,tensor in target.items()}
if config['TRAIN'].get('mps') is not None and torch.backends.mps.is_available():
images = images.to('mps', non_blocking=True)
target = {k: tensor.to('mps', non_blocking=True)
for k,tensor in target.items()}
# zero grad before running
optimizer.zero_grad()
# compute output
if scaler is not None:
with autocast():
output = model(images)
loss, aux_loss = criterion(output, target) # output and target are both dicts
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
output = model(images)
loss, aux_loss = criterion(output, target)
loss.backward()
optimizer.step()
# update the LR
scheduler.step()
# record losses
if loss_meters is None:
loss_meters = {}
for k,v in aux_loss.items():
loss_meters[k] = ProgressEMAMeter(k, ':.4e')
loss_meters[k].update(v)
# add to progress
progress.meters.append(loss_meters[k])
else:
for k,v in aux_loss.items():
loss_meters[k].update(v)
# calculate human-readable per epoch metrics
with torch.no_grad():
meters.evaluate(output, target)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % config['TRAIN']['print_freq'] == 0:
progress.display(i)
# end of epoch print evaluation metrics
print('\n')
print(f'Epoch {epoch} training metrics:')
meters.display()
log_metrics(progress, meters, epoch, dataset='Train')
def validate(
eval_loader,
model,
criterion,
epoch,
config
):
# validation metrics to track
class_names = config['DATASET']['class_names']
metric_dict = {}
for metric_params in config['EVAL']['metrics']:
reg_name = metric_params['name']
metric_name = metric_params['metric']
metric_params = {k: v for k,v in metric_params.items() if k not in ['name', 'metric']}
metric_dict[reg_name] = metrics.__dict__[metric_name](metrics.AverageMeter, **metric_params)
meters = metrics.ComposeMetrics(metric_dict, class_names)
# validation tracking
batch_time = ProgressAverageMeter('Time', ':6.3f')
loss_meters = None
progress = ProgressMeter(
len(eval_loader),
[batch_time],
prefix='Validation: '
)
# create the Inference Engine
engine_name = config['EVAL']['engine']
engine = engines.__dict__[engine_name](model, **config['EVAL']['engine_params'])
for i, batch in enumerate(eval_loader):
end = time.time()
images = batch['image']
target = {k: v for k,v in batch.items() if k not in ['image', 'fname']}
if config['gpu'] is not None:
images = images.cuda(config['gpu'], non_blocking=True)
if torch.cuda.is_available():
target = {k: tensor.cuda(config['gpu'], non_blocking=True)
for k,tensor in target.items()}
# compute panoptic segmentations
# from prediction and ground truth
output = engine.infer(images)
semantic = engine._harden_seg(output['sem'])
output['pan_seg'] = engine.postprocess(
semantic, output['ctr_hmp'], output['offsets']
)
target['pan_seg'] = engine.postprocess(
target['sem'].unsqueeze(1), target['ctr_hmp'], target['offsets']
)
loss, aux_loss = criterion(output, target)
# record losses
if loss_meters is None:
loss_meters = {}
for k,v in aux_loss.items():
loss_meters[k] = ProgressAverageMeter(k, ':.4e')
loss_meters[k].update(v)
# add to progress
progress.meters.append(loss_meters[k])
else:
for k,v in aux_loss.items():
loss_meters[k].update(v)
# compute metrics
with torch.no_grad():
meters.evaluate(output, target)
batch_time.update(time.time() - end)
if i % config['TRAIN']['print_freq'] == 0:
progress.display(i)
if i in config['EVAL']['eval_track_indices'] and (epoch + 1) % config['EVAL']['eval_track_freq'] == 0:
impath = batch['fname'][0]
fname = '.'.join(os.path.basename(impath).split('.')[:-1])
image = io.imread(impath)
# gt and prediction
h, w = image.shape
gt = measure.label(target['pan_seg'].squeeze().cpu().numpy()[:h, :w])
pred = measure.label(output['pan_seg'].squeeze().cpu().numpy()[:h, :w])
artifact_path = 'mlruns/' + mlflow.get_artifact_uri().split('/mlruns/')[-1]
f, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(image, cmap='gray')
ax[1].imshow(gt, cmap='plasma')
ax[2].imshow(pred, cmap='plasma')
plt.savefig(os.path.join(artifact_path, f'{fname}_epoch{epoch}.png'))
plt.clf()
# end of epoch print evaluation metrics
print('\n')
print(f'Validation results:')
meters.display()
log_metrics(progress, meters, epoch, dataset='Eval')
class ProgressAverageMeter(metrics.AverageMeter):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
super().__init__()
def __str__(self):
fmtstr = '{name} {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class ProgressEMAMeter(metrics.EMAMeter):
"""Computes and stores the exponential moving average and current value"""
def __init__(self, name, fmt=':f', momentum=0.98):
self.name = name
self.fmt = fmt
super().__init__(momentum)
def __str__(self):
fmtstr = '{name} {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class ProgressMeter:
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
if __name__ == "__main__":
main() Error message (click to expand):(napari-empanada-dev) genevieb@Admins-MacBook-Pro scripts % python train_mps.py ../projects/mitonet/configs/mps_config.yml
Model with 15842660 trainable parameters.
Found 1 image subdirectories with 16 images.
Found 1 image subdirectories with 16 images.
Steps per epoch adjusted from -1 to 1
/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py:115: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.
warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/amp/autocast_mode.py:198: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
Traceback (most recent call last):
File "/Users/genevieb/Documents/GitHub/empanada/empanada/scripts/train_mps.py", line 727, in <module>
main()
File "/Users/genevieb/Documents/GitHub/empanada/empanada/scripts/train_mps.py", line 102, in main
main_worker(config['TRAIN']['gpu'], ngpus_per_node, config)
File "/Users/genevieb/Documents/GitHub/empanada/empanada/scripts/train_mps.py", line 370, in main_worker
train(train_loader, model, criterion, optimizer,
File "/Users/genevieb/Documents/GitHub/empanada/empanada/scripts/train_mps.py", line 543, in train
loss, aux_loss = criterion(output, target) # output and target are both dicts
File "/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/losses.py", line 140, in forward
ce = self.ce_loss(output['sem_logits'], target['sem'])
File "/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/losses.py", line 44, in forward
pixel_losses, _ = torch.topk(pixel_losses, top_k_pixels)
RuntimeError: Currently topk on mps works only for k<=16
(napari-empanada-dev) genevieb@Admins-MacBook-Pro scripts % Ryan says we use top k of a very large number (like thousands), but the mps backend will only suppport k=>16. So we're probably not going to be able to run these models easily on the M1 mac hardware 😢 Maybe we could modify the loss function we are using, so it doesn't encounter this particular problem? |
So GPU M1 acceleration appears to be blocked by pytorch/pytorch#78915.
2D and 3D inferencing works fine by appending some code in the Engine2D and Engine3D classes to use the ARM quantization backend: model_url = model_config['model_quantized'] Add the following: model_url = model_config['model_quantized']
if torch.backends.mps.is_available():
# set quantized backend to ARM
torch.backends.quantized.engine = 'qnnpack' |
As described above, we ran into difficulties implementing support for M1 acceleration training models. However, accelerating inference might still be a useful thing to do. So if your experiments with inference only look more promising, that could be something worthwhile continuing. |
DO NOT MERGE
The pytorch nightly (I'm using version '1.13.0.dev20220616') now supports acceleration on M1 Mac hardware. This branch experiments with adding support for this in empanada-napari.
Cross-reference: volume-em/empanada-napari#17