Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pytorch acceleration on M1 mac hardware #14

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

GenevieveBuckley
Copy link

DO NOT MERGE

  1. We're waiting for the next official release of pytorch including M1 Mac hardware accleration (right now I'm using the nightly build to test things)
  2. The MitoNet model is not supported on the M1 Mac mps backend, so things are broken now because of that as well

The pytorch nightly (I'm using version '1.13.0.dev20220616') now supports acceleration on M1 Mac hardware. This branch experiments with adding support for this in empanada-napari.

Cross-reference: volume-em/empanada-napari#17

@GenevieveBuckley
Copy link
Author

Ryan suggests (and I agree) that we turn this code block into a utility function, since that pattern is repeated so many times. I haven't done that already because we're not sure just yet where this utility function should live.

    # check whether GPU or M1 Mac hardware is available
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
    model.to(device)

@GenevieveBuckley
Copy link
Author

GenevieveBuckley commented Jun 17, 2022

Problem: The models empanada uses are not completely compatible with the new mps backend (makes sense I guess, it's a very new pytorch feature)

    226             coarse_sem_seg_logits,
    227             self.train_num_points,
    228             self.oversample_ratio,
    229             self.importance_sample_ratio,
    230         )
    232     # sample points at coarse and fine resolutions
    233     coarse_sem_seg_points = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)

NotImplementedError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py", line 87, in get_uncertain_point_coords_with_randomness
    num_sampled = int(num_points * oversample_ratio)
    point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
    point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
                   ~~~~~~~~~~~~ <--- HERE

    point_uncertainties = calculate_uncertainty(point_logits)
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py", line 55, in point_sample
        point_coords = point_coords.unsqueeze(2)

    output = F.grid_sample(features, 2.0 * point_coords - 1.0, mode=mode, align_corners=align_corners)
             ~~~~~~~~~~~~~ <--- HERE

    if add_dim:
  File "/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/functional.py", line 4221, in grid_sample
        align_corners = False

    return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
           ~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: The operator 'aten::grid_sampler_2d' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

...the code just hangs at this point until you press Control+C in the terminal to kill it.

Full traceback (click to expand):
              (conv): Identity()
            )
            (3): Resample2d(
              (conv): Identity()
            )
          )
          (resize_down): Resize2d(
            (resample): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          )
          (after_combines): ModuleList(
            (0): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (2): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (3): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
        )
      )
      (2): BiFPNLayer(
        (top_down_fpn): TopDownFPN(
          (resamplings): ModuleList(
            (0): Resample2d(
              (conv): Identity()
            )
            (1): Resample2d(
              (conv): Identity()
            )
            (2): Resample2d(
              (conv): Identity()
            )
            (3): Resample2d(
              (conv): Identity()
            )
          )
          (resize_up): Resize2d(
            (resample): Interpolate2d()
          )
          (after_combines): ModuleList(
            (0): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (2): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (3): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
        )
        (bottom_up_fpn): BottomUpFPN(
          (resamplings): ModuleList(
            (0): Resample2d(
              (conv): Identity()
            )
            (1): Resample2d(
              (conv): Identity()
            )
            (2): Resample2d(
              (conv): Identity()
            )
            (3): Resample2d(
              (conv): Identity()
            )
          )
          (resize_down): Resize2d(
            (resample): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          )
          (after_combines): ModuleList(
            (0): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (2): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (3): Sequential(
              (0): SeparableConv2d(
                (sepconv): Sequential(
                  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
                  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                )
              )
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
        )
      )
    )
  )
  (semantic_decoder): BiFPNDecoder(
    (upsamplings): ModuleList(
      (0): Sequential(
        (0): ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (2): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (3): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (4): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
    )
    (fusion): Sequential(
      (0): SeparableConv2d(
        (sepconv): Sequential(
          (0): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256, bias=False)
          (1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
      )
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (semantic_head): PanopticDeepLabHead(
    (head): Sequential(
      (0): Sequential(
        (0): SeparableConv2d(
          (sepconv): Sequential(
            (0): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=128, bias=False)
            (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
        )
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (ins_center): PanopticDeepLabHead(
    (head): Sequential(
      (0): Sequential(
        (0): SeparableConv2d(
          (sepconv): Sequential(
            (0): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=128, bias=False)
            (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
        )
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (ins_xy): PanopticDeepLabHead(
    (head): Sequential(
      (0): Sequential(
        (0): SeparableConv2d(
          (sepconv): Sequential(
            (0): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=128, bias=False)
            (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
        )
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (interpolate): Interpolate2d()
  (semantic_pr): PointRendSemSegHead(
    (point_head): StandardPointHead(
      (fc_layers): ModuleList(
        (0): Sequential(
          (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
          (1): ReLU(inplace=True)
        )
        (1): Sequential(
          (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
          (1): ReLU(inplace=True)
        )
        (2): Sequential(
          (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
          (1): ReLU(inplace=True)
        )
      )
      (predictor): Conv1d(129, 1, kernel_size=(1,), stride=(1,))
    )
    (interpolate): Interpolate2d()
  )
)
    159 if self.training:
    160     # interpolate to original resolution (4x)
    161     heads_out['sem_logits'] = self.interpolate(pr_out['sem_seg_logits'])

File ~/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/modules/module.py:1131, in Module._call_impl(self=PointRendSemSegHead(
  (point_head): StandardPoi...ride=(1,))
  )
  (interpolate): Interpolate2d()
), *input=(tensor([[[[0.0088, 0.0088, 0.0088,  ..., 0.0089,...  device='mps:0', grad_fn=<ConvolutionBackward0>), tensor([[[[0.0000e+00, 7.4089e-02, 8.8945e-01,  ...+00]]]], device='mps:0', grad_fn=<ReluBackward0>)), **kwargs={})
   1127 # If we don't have any hooks, we want to skip the rest of the logic in
   1128 # this function, and just call forward.
   1129 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1130         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1131     return forward_call(*input, **kwargs)
        forward_call = <bound method PointRendSemSegHead.forward of PointRendSemSegHead(
  (point_head): StandardPointHead(
    (fc_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
        (1): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
        (1): ReLU(inplace=True)
      )
      (2): Sequential(
        (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
        (1): ReLU(inplace=True)
      )
    )
    (predictor): Conv1d(129, 1, kernel_size=(1,), stride=(1,))
  )
  (interpolate): Interpolate2d()
)>
        input = (tensor([[[[0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0087, 0.0088,  ..., 0.0090, 0.0088, 0.0089],
          [0.0088, 0.0088, 0.0089,  ..., 0.0089, 0.0088, 0.0089],
          ...,
          [0.0088, 0.0088, 0.0089,  ..., 0.0088, 0.0088, 0.0089],
          [0.0088, 0.0089, 0.0088,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0088,  ..., 0.0088, 0.0089, 0.0089]]],


        [[[0.0088, 0.0089, 0.0088,  ..., 0.0090, 0.0089, 0.0089],
          [0.0087, 0.0088, 0.0089,  ..., 0.0090, 0.0089, 0.0089],
          [0.0088, 0.0089, 0.0089,  ..., 0.0089, 0.0088, 0.0090],
          ...,
          [0.0088, 0.0088, 0.0088,  ..., 0.0090, 0.0088, 0.0089],
          [0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0089, 0.0089],
          [0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0089, 0.0089]]],


        [[[0.0088, 0.0089, 0.0088,  ..., 0.0088, 0.0088, 0.0089],
          [0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0089, 0.0088, 0.0088],
          ...,
          [0.0089, 0.0088, 0.0089,  ..., 0.0088, 0.0087, 0.0089],
          [0.0088, 0.0089, 0.0088,  ..., 0.0089, 0.0088, 0.0089],
          [0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0088, 0.0089]]],


        ...,


        [[[0.0089, 0.0088, 0.0088,  ..., 0.0088, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0089, 0.0088, 0.0088],
          [0.0090, 0.0089, 0.0087,  ..., 0.0090, 0.0089, 0.0089],
          ...,
          [0.0088, 0.0088, 0.0088,  ..., 0.0088, 0.0088, 0.0089],
          [0.0089, 0.0089, 0.0088,  ..., 0.0089, 0.0088, 0.0089],
          [0.0090, 0.0088, 0.0088,  ..., 0.0088, 0.0089, 0.0089]]],


        [[[0.0088, 0.0088, 0.0088,  ..., 0.0088, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0089, 0.0088,  ..., 0.0089, 0.0089, 0.0089],
          ...,
          [0.0088, 0.0087, 0.0088,  ..., 0.0088, 0.0089, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0088, 0.0088, 0.0088],
          [0.0089, 0.0088, 0.0088,  ..., 0.0089, 0.0089, 0.0088]]],


        [[[0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0089, 0.0089,  ..., 0.0088, 0.0088, 0.0088],
          [0.0089, 0.0089, 0.0089,  ..., 0.0090, 0.0089, 0.0089],
          ...,
          [0.0088, 0.0088, 0.0088,  ..., 0.0088, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0088, 0.0088, 0.0088],
          [0.0089, 0.0088, 0.0088,  ..., 0.0089, 0.0089, 0.0089]]]],
       device='mps:0', grad_fn=<ConvolutionBackward0>), tensor([[[[0.0000e+00, 7.4089e-02, 8.8945e-01,  ..., 2.1623e+00,
           1.0930e+00, 5.5897e-01],
          [0.0000e+00, 0.0000e+00, 5.0784e-01,  ..., 1.9066e+00,
           1.6825e-02, 0.0000e+00],
          [0.0000e+00, 9.8485e-01, 0.0000e+00,  ..., 1.9920e+00,
           1.9017e+00, 5.9719e-01],
          ...,
          [0.0000e+00, 0.0000e+00, 2.8657e-01,  ..., 0.0000e+00,
           0.0000e+00, 3.7702e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           3.5290e-03, 8.8577e-01],
          [0.0000e+00, 2.2756e-01, 7.9611e-01,  ..., 0.0000e+00,
           0.0000e+00, 5.4088e-01]],

         [[7.4642e-01, 9.0818e-01, 4.7434e-01,  ..., 1.3986e+00,
           0.0000e+00, 3.6472e-01],
          [1.2746e+00, 4.2758e-01, 5.6983e-01,  ..., 5.9831e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 5.1684e-01, 6.1667e-02,  ..., 4.2643e-01,
           0.0000e+00, 0.0000e+00],
          ...,
          [1.0358e+00, 9.9598e-01, 6.1302e-01,  ..., 2.4554e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 2.4817e-01, 1.2009e-01,  ..., 6.6150e-01,
           2.0307e-01, 0.0000e+00],
          [5.2141e-02, 1.2608e-01, 0.0000e+00,  ..., 9.7955e-01,
           0.0000e+00, 0.0000e+00]],

         [[8.2630e-01, 4.0873e-01, 0.0000e+00,  ..., 0.0000e+00,
           3.8406e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 3.0580e-03,  ..., 1.2760e+00,
           1.1749e+00, 2.5873e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 6.7121e-01,
           7.1096e-01, 2.5795e-02],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [6.0656e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 7.5030e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           1.6559e-01, 0.0000e+00],
          [1.8934e-01, 0.0000e+00, 0.0000e+00,  ..., 1.0735e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 6.3372e-01, 1.5395e+00,  ..., 9.7852e-01,
           7.1311e-01, 2.0924e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.5243e-01,
           0.0000e+00, 0.0000e+00],
          [4.3880e-01, 7.3612e-01, 0.0000e+00,  ..., 7.6816e-01,
           0.0000e+00, 2.0065e-01]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [9.8340e-01, 0.0000e+00, 1.1624e-01,  ..., 0.0000e+00,
           1.3949e+00, 0.0000e+00],
          [4.4687e-01, 5.6389e-01, 8.3710e-01,  ..., 1.3881e+00,
           5.1339e-01, 0.0000e+00],
          ...,
          [2.6702e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 6.0383e-01,  ..., 0.0000e+00,
           3.9091e-01, 0.0000e+00],
          [0.0000e+00, 7.8590e-03, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 1.0503e+00, 9.1893e-01,  ..., 7.4280e-01,
           0.0000e+00, 1.1529e+00],
          [0.0000e+00, 5.0458e-01, 0.0000e+00,  ..., 2.9885e-03,
           2.9599e-02, 5.8381e-01],
          [0.0000e+00, 0.0000e+00, 4.0211e-01,  ..., 1.3079e-01,
           4.2846e-01, 4.0586e-01],
          ...,
          [1.4783e+00, 7.7416e-01, 4.6123e-01,  ..., 0.0000e+00,
           0.0000e+00, 1.8699e-02],
          [1.5407e-01, 7.8448e-01, 0.0000e+00,  ..., 0.0000e+00,
           4.0555e-01, 4.0005e-01],
          [9.8552e-02, 0.0000e+00, 9.7077e-01,  ..., 1.1479e+00,
           1.3398e+00, 3.9441e-01]]],


        [[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 2.8061e-01, 4.1951e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [3.1790e-01, 0.0000e+00, 1.0249e+00,  ..., 1.1682e+00,
           7.6773e-02, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.8090e-01,  ..., 8.6559e-01,
           0.0000e+00, 1.4326e+00],
          [0.0000e+00, 6.2360e-01, 5.4497e-01,  ..., 3.9382e-01,
           0.0000e+00, 0.0000e+00]],

         [[1.1216e+00, 4.3993e-01, 4.4988e-01,  ..., 6.7284e-02,
           0.0000e+00, 0.0000e+00],
          [4.0034e-01, 1.2157e+00, 1.2398e+00,  ..., 0.0000e+00,
           2.3730e-01, 0.0000e+00],
          [4.4041e-01, 0.0000e+00, 5.0124e-01,  ..., 0.0000e+00,
           5.1770e-02, 2.0369e-01],
          ...,
          [1.2808e+00, 1.1491e+00, 5.6730e-01,  ..., 0.0000e+00,
           1.8608e+00, 0.0000e+00],
          [7.6671e-01, 6.2083e-01, 0.0000e+00,  ..., 0.0000e+00,
           6.7525e-01, 4.5533e-01],
          [3.3561e-01, 5.5205e-01, 0.0000e+00,  ..., 0.0000e+00,
           7.2797e-01, 0.0000e+00]],

         [[1.3734e-01, 0.0000e+00, 8.5182e-01,  ..., 2.6845e+00,
           2.5839e+00, 1.7602e+00],
          [2.8231e-01, 0.0000e+00, 9.1393e-01,  ..., 2.4226e+00,
           1.9898e+00, 1.5943e+00],
          [0.0000e+00, 0.0000e+00, 5.9704e-01,  ..., 3.1030e+00,
           2.2420e+00, 1.2189e+00],
          ...,
          [0.0000e+00, 2.6186e-01, 0.0000e+00,  ..., 0.0000e+00,
           7.6884e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.9080e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.4086e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 5.3720e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 4.0644e-01, 1.6045e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [3.2313e-01, 6.0408e-01, 5.6011e-01,  ..., 2.3999e+00,
           2.1060e+00, 3.4051e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 9.4139e-01,
           8.4063e-01, 8.6017e-01],
          [0.0000e+00, 2.0202e-02, 1.0373e+00,  ..., 2.1997e-01,
           1.9469e+00, 6.8659e-01]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           9.1034e-02, 1.4219e+00],
          [2.2594e-01, 6.9682e-01, 0.0000e+00,  ..., 0.0000e+00,
           1.0161e+00, 3.6442e+00],
          ...,
          [0.0000e+00, 6.0895e-01, 0.0000e+00,  ..., 0.0000e+00,
           1.0281e+00, 2.0246e+00],
          [0.0000e+00, 3.7160e-02, 0.0000e+00,  ..., 0.0000e+00,
           1.3466e-01, 1.0907e+00],
          [0.0000e+00, 0.0000e+00, 1.9310e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[1.5424e+00, 1.2874e+00, 0.0000e+00,  ..., 0.0000e+00,
           1.2070e-01, 5.3125e-02],
          [2.9930e-01, 1.4011e+00, 0.0000e+00,  ..., 1.0405e+00,
           1.0802e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.4425e+00,  ..., 0.0000e+00,
           1.8737e+00, 0.0000e+00],
          ...,
          [4.0814e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           1.3249e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 2.3509e-01,  ..., 0.0000e+00,
           1.3229e+00, 7.5914e-01],
          [0.0000e+00, 1.1340e+00, 6.6128e-01,  ..., 5.5020e-01,
           1.2950e+00, 9.2822e-01]]],


        [[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           4.6020e-01, 5.4386e-01],
          [0.0000e+00, 1.2850e+00, 1.0804e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [4.8737e-01, 0.0000e+00, 0.0000e+00,  ..., 7.4040e-01,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.8643e-01,
           0.0000e+00, 4.1738e-01],
          [1.8504e-01, 8.0441e-01, 1.1717e+00,  ..., 6.5370e-01,
           0.0000e+00, 6.4115e-01],
          [0.0000e+00, 1.8474e-01, 0.0000e+00,  ..., 5.2801e-01,
           0.0000e+00, 0.0000e+00]],

         [[2.5379e-01, 6.9640e-01, 1.3286e+00,  ..., 7.4304e-01,
           4.0592e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.3103e+00,  ..., 7.5919e-01,
           1.0384e+00, 0.0000e+00],
          [1.2091e+00, 1.0910e+00, 1.0791e+00,  ..., 1.0534e+00,
           2.6434e-01, 0.0000e+00],
          ...,
          [5.7020e-01, 1.5868e+00, 1.1818e+00,  ..., 1.4721e+00,
           1.4554e+00, 0.0000e+00],
          [1.9463e-01, 9.0874e-01, 0.0000e+00,  ..., 8.8098e-01,
           3.5288e-01, 0.0000e+00],
          [4.9881e-01, 0.0000e+00, 0.0000e+00,  ..., 1.2071e+00,
           1.1014e+00, 0.0000e+00]],

         [[4.8452e-01, 5.1103e-01, 2.0605e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 9.0667e-01,  ..., 2.8976e-01,
           0.0000e+00, 3.9946e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 6.7167e-01,
           0.0000e+00, 1.5029e-01],
          ...,
          [0.0000e+00, 0.0000e+00, 2.7325e-01,  ..., 5.0072e-01,
           3.7166e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.5733e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 6.8057e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.2125e-01, 1.0418e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 6.3101e-01, 1.5323e+00,  ..., 3.9036e-01,
           2.3309e-01, 0.0000e+00],
          [0.0000e+00, 2.1632e-01, 4.5374e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.4529e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[2.1779e-01, 5.8201e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 8.4552e-01, 5.3630e-01,  ..., 0.0000e+00,
           0.0000e+00, 2.6251e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.4750e+00,
           2.9286e-01, 0.0000e+00],
          ...,
          [9.4575e-02, 0.0000e+00, 1.6777e+00,  ..., 1.0457e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.1746e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.4591e+00,
           5.0983e-02, 0.0000e+00]],

         [[0.0000e+00, 5.1716e-01, 6.2236e-02,  ..., 1.0030e+00,
           0.0000e+00, 2.0162e-01],
          [1.3101e+00, 1.2060e+00, 2.6098e+00,  ..., 4.0982e-01,
           4.1046e-01, 3.8965e-01],
          [0.0000e+00, 5.1605e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 1.3625e+00, 0.0000e+00,  ..., 0.0000e+00,
           8.9353e-01, 3.1565e-01],
          [0.0000e+00, 0.0000e+00, 2.3708e-01,  ..., 0.0000e+00,
           5.3425e-01, 4.1147e-01],
          [0.0000e+00, 0.0000e+00, 7.1393e-01,  ..., 9.8011e-01,
           1.9881e-01, 5.7963e-01]]],


        ...,


        [[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           4.3204e-01, 8.7892e-01],
          [0.0000e+00, 1.7989e+00, 2.9548e+00,  ..., 1.6490e+00,
           1.9274e-01, 1.7261e+00],
          [8.3221e-01, 0.0000e+00, 7.9636e-01,  ..., 1.7252e+00,
           1.9930e-01, 1.1918e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 3.6827e+00,  ..., 5.3516e-01,
           6.7847e-01, 2.1678e+00],
          [0.0000e+00, 0.0000e+00, 2.7483e+00,  ..., 1.3287e+00,
           5.4538e-02, 1.0777e+00],
          [0.0000e+00, 0.0000e+00, 1.9590e+00,  ..., 5.5577e-01,
           0.0000e+00, 6.2696e-01]],

         [[1.8083e-01, 1.8776e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 8.3434e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 2.2160e-01, 5.0457e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 7.1693e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.8506e+00, 0.0000e+00,  ..., 0.0000e+00,
           5.1740e-01, 0.0000e+00],
          [0.0000e+00, 2.4024e+00, 1.3335e-01,  ..., 4.0303e-01,
           7.4837e-01, 4.9973e-01]],

         [[0.0000e+00, 5.3818e-01, 1.1546e+00,  ..., 0.0000e+00,
           0.0000e+00, 6.6571e-02],
          [3.7659e-01, 1.5354e+00, 1.8003e+00,  ..., 3.0831e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 6.5659e-01,  ..., 2.7220e-01,
           3.7993e-01, 0.0000e+00],
          ...,
          [0.0000e+00, 1.8084e+00, 0.0000e+00,  ..., 1.5273e+00,
           0.0000e+00, 6.1346e-01],
          [0.0000e+00, 1.4131e-01, 0.0000e+00,  ..., 0.0000e+00,
           1.1623e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.4635e-01,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 4.6800e-01],
          [0.0000e+00, 5.5588e-01, 0.0000e+00,  ..., 1.2340e-01,
           5.4171e-02, 2.2038e-01],
          [0.0000e+00, 0.0000e+00, 5.7616e-01,  ..., 0.0000e+00,
           0.0000e+00, 9.3509e-01],
          ...,
          [0.0000e+00, 5.0771e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 2.3833e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 2.3509e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.4870e-01,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 1.5589e-02],
          [2.0271e-01, 6.5152e-01, 5.0315e-01,  ..., 1.2720e+00,
           1.0532e-02, 2.0511e-02],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.2694e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 8.6492e-01,
           6.2874e-01, 1.1512e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 7.2908e-01,
           1.4395e+00, 4.7596e-02],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 7.8735e-01, 0.0000e+00,  ..., 1.1731e+00,
           1.5951e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.2886e+00,
           2.5239e+00, 0.0000e+00],
          [1.8061e+00, 4.2019e+00, 1.7260e+00,  ..., 3.2810e+00,
           2.4222e+00, 0.0000e+00],
          ...,
          [1.8344e+00, 2.0851e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.3024e+00, 2.2217e+00, 0.0000e+00,  ..., 0.0000e+00,
           3.0012e-01, 0.0000e+00],
          [6.3968e-01, 1.8986e+00, 0.0000e+00,  ..., 4.2057e-01,
           0.0000e+00, 0.0000e+00]]],


        [[[0.0000e+00, 0.0000e+00, 2.4025e-01,  ..., 0.0000e+00,
           1.3822e-01, 5.1537e-01],
          [0.0000e+00, 0.0000e+00, 3.4819e-01,  ..., 1.4473e-01,
           0.0000e+00, 8.3819e-01],
          [1.8107e-01, 4.5902e-02, 1.8605e+00,  ..., 9.5985e-01,
           1.5111e-01, 6.0978e-01],
          ...,
          [4.1359e-01, 1.0578e-01, 1.6421e+00,  ..., 2.8347e-03,
           0.0000e+00, 2.4505e-01],
          [3.3733e-01, 4.0608e-01, 1.8206e+00,  ..., 1.8545e-01,
           0.0000e+00, 0.0000e+00],
          [5.0725e-01, 0.0000e+00, 9.6854e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[1.2885e-01, 1.4179e+00, 0.0000e+00,  ..., 3.9084e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 7.5183e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.5614e-01, 0.0000e+00,  ..., 0.0000e+00,
           4.5392e-01, 0.0000e+00],
          ...,
          [0.0000e+00, 7.1066e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.7085e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [4.4407e-01, 8.6787e-01, 0.0000e+00,  ..., 2.6774e-01,
           1.1174e+00, 5.8951e-01]],

         [[0.0000e+00, 8.2225e-02, 2.7425e-01,  ..., 0.0000e+00,
           1.7316e-01, 3.5523e-01],
          [0.0000e+00, 3.0014e-01, 3.7909e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           2.4758e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           3.3341e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 3.6318e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 2.8388e-01],
          [0.0000e+00, 3.5971e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 8.2956e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 6.8531e-02],
          ...,
          [0.0000e+00, 4.9203e-01, 0.0000e+00,  ..., 2.7257e-01,
           0.0000e+00, 5.1685e-01],
          [0.0000e+00, 6.0935e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 2.0687e-01],
          [0.0000e+00, 5.5951e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.9856e-01,
           0.0000e+00, 1.9472e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [7.9620e-01, 1.2160e-01, 2.1168e-01,  ..., 1.3841e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [2.7774e-01, 1.0537e+00, 1.5652e+00,  ..., 1.2279e+00,
           6.4342e-02, 0.0000e+00],
          [0.0000e+00, 5.9066e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 3.4963e-01, 6.6080e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[1.3104e-01, 8.6589e-01, 0.0000e+00,  ..., 1.1465e+00,
           1.5749e-01, 1.8180e-01],
          [3.0964e-01, 6.1976e-01, 2.7942e-01,  ..., 1.5026e+00,
           1.4885e+00, 6.7010e-03],
          [5.3211e-01, 4.9338e-01, 0.0000e+00,  ..., 4.0494e-01,
           1.3100e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 5.1027e-01, 0.0000e+00,  ..., 0.0000e+00,
           1.0122e+00, 0.0000e+00],
          [5.1633e-03, 1.4668e+00, 6.6400e-01,  ..., 8.0338e-01,
           1.3246e+00, 1.9260e-01],
          [7.0187e-02, 1.1304e+00, 7.0275e-01,  ..., 9.5207e-01,
           1.1219e+00, 0.0000e+00]]],


        [[[0.0000e+00, 0.0000e+00, 4.6005e-01,  ..., 0.0000e+00,
           6.9761e-02, 5.6667e-01],
          [0.0000e+00, 0.0000e+00, 3.2581e-01,  ..., 3.6165e-01,
           0.0000e+00, 9.5707e-01],
          [0.0000e+00, 0.0000e+00, 2.1238e+00,  ..., 5.9505e-01,
           1.5425e-01, 4.0622e-01],
          ...,
          [2.4131e-01, 0.0000e+00, 2.5841e+00,  ..., 4.9629e-02,
           1.4180e-01, 6.3241e-02],
          [0.0000e+00, 1.5117e-01, 2.1144e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.9965e-01, 0.0000e+00, 8.6311e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 2.1975e+00, 0.0000e+00,  ..., 4.4226e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.1389e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 9.7757e-01, 0.0000e+00,  ..., 6.1028e-02,
           9.4302e-01, 0.0000e+00],
          ...,
          [0.0000e+00, 6.7376e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 5.8245e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [6.1490e-01, 9.7687e-01, 0.0000e+00,  ..., 5.5859e-01,
           1.4821e+00, 5.2390e-01]],

         [[0.0000e+00, 3.7543e-02, 2.7192e-01,  ..., 0.0000e+00,
           1.7679e-01, 2.0123e-01],
          [0.0000e+00, 7.9174e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 4.1048e-02],
          [0.0000e+00, 6.7443e-01, 1.6642e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.0269e-01, 0.0000e+00,  ..., 0.0000e+00,
           1.1899e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           5.5099e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 7.4942e-01, 4.6713e-01,  ..., 2.2149e-01,
           0.0000e+00, 3.2799e-01],
          [0.0000e+00, 8.9299e-01, 1.3694e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.2217e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 3.1208e-01],
          ...,
          [0.0000e+00, 1.3503e+00, 0.0000e+00,  ..., 8.0093e-02,
           0.0000e+00, 7.0693e-01],
          [0.0000e+00, 1.2291e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 2.7166e-01],
          [0.0000e+00, 7.2925e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.6526e-01,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [2.9020e-01, 0.0000e+00, 3.7589e-01,  ..., 1.2548e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 1.0710e+00, 1.2292e+00,  ..., 7.4159e-01,
           1.7804e-01, 0.0000e+00],
          [0.0000e+00, 7.8056e-01, 4.1537e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 4.1615e-01, 6.2823e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 1.3807e+00, 0.0000e+00,  ..., 7.0208e-01,
           0.0000e+00, 2.5674e-01],
          [6.8659e-01, 7.8201e-01, 2.9957e-01,  ..., 1.4700e+00,
           1.8378e+00, 0.0000e+00],
          [2.6428e-01, 1.9601e+00, 0.0000e+00,  ..., 6.3278e-01,
           1.5842e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 1.3229e+00, 0.0000e+00,  ..., 0.0000e+00,
           1.2955e+00, 0.0000e+00],
          [0.0000e+00, 1.7383e+00, 0.0000e+00,  ..., 6.2228e-01,
           1.0436e+00, 0.0000e+00],
          [0.0000e+00, 1.6015e+00, 4.6638e-01,  ..., 9.1203e-01,
           1.2354e+00, 0.0000e+00]]]], device='mps:0', grad_fn=<ReluBackward0>))
        kwargs = {}
   1132 # Do not call functions when jit is used
   1133 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py:225, in PointRendSemSegHead.forward(self=PointRendSemSegHead(
  (point_head): StandardPoi...ride=(1,))
  )
  (interpolate): Interpolate2d()
), coarse_sem_seg_logits=tensor([[[[0.0088, 0.0088, 0.0088,  ..., 0.0089,...  device='mps:0', grad_fn=<ConvolutionBackward0>), features=tensor([[[[0.0000e+00, 7.4089e-02, 8.8945e-01,  ...+00]]]], device='mps:0', grad_fn=<ReluBackward0>))
    222 if self.training:
    223     # pick the points to apply point rend
    224     with torch.no_grad():
--> 225         point_coords = get_uncertain_point_coords_with_randomness(
        get_uncertain_point_coords_with_randomness = <torch.jit.ScriptFunction object at 0x28b911c20>
        coarse_sem_seg_logits = tensor([[[[0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0087, 0.0088,  ..., 0.0090, 0.0088, 0.0089],
          [0.0088, 0.0088, 0.0089,  ..., 0.0089, 0.0088, 0.0089],
          ...,
          [0.0088, 0.0088, 0.0089,  ..., 0.0088, 0.0088, 0.0089],
          [0.0088, 0.0089, 0.0088,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0088,  ..., 0.0088, 0.0089, 0.0089]]],


        [[[0.0088, 0.0089, 0.0088,  ..., 0.0090, 0.0089, 0.0089],
          [0.0087, 0.0088, 0.0089,  ..., 0.0090, 0.0089, 0.0089],
          [0.0088, 0.0089, 0.0089,  ..., 0.0089, 0.0088, 0.0090],
          ...,
          [0.0088, 0.0088, 0.0088,  ..., 0.0090, 0.0088, 0.0089],
          [0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0089, 0.0089],
          [0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0089, 0.0089]]],


        [[[0.0088, 0.0089, 0.0088,  ..., 0.0088, 0.0088, 0.0089],
          [0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0089, 0.0088, 0.0088],
          ...,
          [0.0089, 0.0088, 0.0089,  ..., 0.0088, 0.0087, 0.0089],
          [0.0088, 0.0089, 0.0088,  ..., 0.0089, 0.0088, 0.0089],
          [0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0088, 0.0089]]],


        ...,


        [[[0.0089, 0.0088, 0.0088,  ..., 0.0088, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0089, 0.0088, 0.0088],
          [0.0090, 0.0089, 0.0087,  ..., 0.0090, 0.0089, 0.0089],
          ...,
          [0.0088, 0.0088, 0.0088,  ..., 0.0088, 0.0088, 0.0089],
          [0.0089, 0.0089, 0.0088,  ..., 0.0089, 0.0088, 0.0089],
          [0.0090, 0.0088, 0.0088,  ..., 0.0088, 0.0089, 0.0089]]],


        [[[0.0088, 0.0088, 0.0088,  ..., 0.0088, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0089, 0.0088,  ..., 0.0089, 0.0089, 0.0089],
          ...,
          [0.0088, 0.0087, 0.0088,  ..., 0.0088, 0.0089, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0088, 0.0088, 0.0088],
          [0.0089, 0.0088, 0.0088,  ..., 0.0089, 0.0089, 0.0088]]],


        [[[0.0088, 0.0088, 0.0088,  ..., 0.0089, 0.0088, 0.0088],
          [0.0088, 0.0089, 0.0089,  ..., 0.0088, 0.0088, 0.0088],
          [0.0089, 0.0089, 0.0089,  ..., 0.0090, 0.0089, 0.0089],
          ...,
          [0.0088, 0.0088, 0.0088,  ..., 0.0088, 0.0088, 0.0088],
          [0.0088, 0.0088, 0.0089,  ..., 0.0088, 0.0088, 0.0088],
          [0.0089, 0.0088, 0.0088,  ..., 0.0089, 0.0089, 0.0089]]]],
       device='mps:0', grad_fn=<ConvolutionBackward0>)
        self = PointRendSemSegHead(
  (point_head): StandardPointHead(
    (fc_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
        (1): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
        (1): ReLU(inplace=True)
      )
      (2): Sequential(
        (0): Conv1d(129, 128, kernel_size=(1,), stride=(1,))
        (1): ReLU(inplace=True)
      )
    )
    (predictor): Conv1d(129, 1, kernel_size=(1,), stride=(1,))
  )
  (interpolate): Interpolate2d()
)
        self.train_num_points = 1024
        self.oversample_ratio = 3
        self.importance_sample_ratio = 0.75
    226             coarse_sem_seg_logits,
    227             self.train_num_points,
    228             self.oversample_ratio,
    229             self.importance_sample_ratio,
    230         )
    232     # sample points at coarse and fine resolutions
    233     coarse_sem_seg_points = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)

NotImplementedError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py", line 87, in get_uncertain_point_coords_with_randomness
    num_sampled = int(num_points * oversample_ratio)
    point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
    point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
                   ~~~~~~~~~~~~ <--- HERE

    point_uncertainties = calculate_uncertainty(point_logits)
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/models/point_rend.py", line 55, in point_sample
        point_coords = point_coords.unsqueeze(2)

    output = F.grid_sample(features, 2.0 * point_coords - 1.0, mode=mode, align_corners=align_corners)
             ~~~~~~~~~~~~~ <--- HERE

    if add_dim:
  File "/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/functional.py", line 4221, in grid_sample
        align_corners = False

    return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
           ~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: The operator 'aten::grid_sampler_2d' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

@GenevieveBuckley
Copy link
Author

GenevieveBuckley commented Jun 17, 2022

Ryan says the model incompatibility is only in the last layer of the model, so we could potentially remove it and replace it with a simpler upsampling. He already has code for that in empanada, but empanada-napari will need to be edited so it doesn't cause other errors there (mostly in the inference part of the empanada-napari plugin).

@GenevieveBuckley
Copy link
Author

GenevieveBuckley commented Jun 17, 2022

Results of experiment using script to run on M1 hardware

tl:dr - not great

conda activate napari-empanda-dev
cd GitHub/empanada/scripts
python train_mps.py ../projects/mitonet/configs/mps_config.yml
Config file `mps_config.yml`, modelled on `GitHub/empanada/projects/mitonet/configs/mmm_panoptic_deeplab_pointrend.yaml` (click to expand):
DATASET:
  dataset_name: "Baselines"
  class_names: { 1: "membrane" }
  labels: [ 1 ]
  thing_list: [ ]
  norms: { mean: 0.508979, std: 0.148561 }

MODEL:
  arch: "PanopticDeepLab"
  encoder: "resnet50"
  num_classes: 1
  stage4_stride: 16
  decoder_channels: 256
  low_level_stages: [ 1 ]
  low_level_channels_project: [ 32 ]
  atrous_rates: [ 2, 4, 6 ]
  aspp_channels: null
  aspp_dropout: 0.5
  ins_decoder: False
  ins_ratio: 0.5

TRAIN:
  run_name: "Panoptic DeepLab Baseline"
  # image and model directories
  train_dir: "/Users/genevieb/Documents/Projects/empanada/training-nuclei-membrane-patches/kidney-nuclei-membrane-patches/"
  additional_train_dirs: null

  model_dir: "/Users/genevieb/Documents/Projects/empanada/nuclei-membrane-model/"
  save_freq: 1

  # path to .pth file for resuming training
  resume: null

  # pretraining parameters
  encoder_pretraining: "/Users/genevieb/.empanada/checkpoints/cem1.5m_swav_resnet50_200ep_balanced.pth.tar"
  whole_pretraining: null
  finetune_layer: "none"

  # set the lr schedule
  lr_schedule: "OneCycleLR"
  schedule_params:
    max_lr: 0.003
    epochs: 30
    steps_per_epoch: -1
    pct_start: 0.3

  # setup the optimizer
  amp: True  # automatic mixed precision
  optimizer: "AdamW"
  optimizer_params:
    weight_decay: 0.1

  # criterion parameters
  criterion: "PanopticLoss"
  criterion_params:
    ce_weight: 1
    mse_weight: 200
    l1_weight: 0.01
    top_k_percent: 0.2
    pr_weight: 1

  # performance metrics
  print_freq: 50
  metrics:
      - { metric: "IoU", name: "semantic_iou", labels: [ 1 ], output_key: "sem_logits",  target_key: "sem"}

  # dataset parameters
  batch_size: 16
  workers: 0
  
  dataset_class: "SingleClassInstanceDataset"
  dataset_params:
      weight_gamma: 0.7
  
  augmentations:
    - { aug: "RandomScale", scale_limit: [ -0.9, 1 ]}
    - { aug: "PadIfNeeded", min_height: 256, min_width: 256, border_mode: 0 }
    - { aug: "RandomCrop", height: 256, width: 256}
    - { aug: "Rotate", limit: 180, border_mode: 0 }
    - { aug: "RandomBrightnessContrast", brightness_limit: 0.3, contrast_limit: 0.3 }
    - { aug: "HorizontalFlip" }
    - { aug: "VerticalFlip" }

  # distributed training parameters
  multiprocessing_distributed: False
  gpu: null
  mps: 'mps'
  world_size: 1
  rank: 0
  dist_url: "tcp://localhost:10001"
  dist_backend: "nccl"

EVAL:
  eval_dir: "/Users/genevieb/Documents/Projects/empanada/training-nuclei-membrane-patches/kidney-nuclei-membrane-patches/"
  eval_track_indices: null # from test_eval
  eval_track_freq: 1
  epochs_per_eval: 1

  # parameters needed for eval_metrics
  metrics:
      - { metric: "IoU", name: "semantic_iou", labels: [ 1 ], output_key: "sem_logits",  target_key: "sem"}

  # parameters needed for inference
  engine: "PanopticDeepLabEngine"
  engine_params:
    thing_list: [ ]
    label_divisor: 1000
    stuff_area: 64
    void_label: 0
    nms_threshold: 0.1
    nms_kernel: 7
    confidence_thr: 0.5
Contents of `mps_train.py script, modelled on `GitHub/empanda/scripts/train.py` script (click to expand):
import os
import time
import yaml
import argparse
import mlflow
import random

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler

import albumentations as A
from albumentations.pytorch import ToTensorV2
from skimage import io
from skimage import measure
from matplotlib import pyplot as plt

from empanada import losses
from empanada import data
from empanada import metrics
from empanada import models
from empanada.inference import engines
from empanada.config_loaders import load_config
from empanada.data.utils.sampler import DistributedWeightedSampler
from empanada.data.utils.transforms import FactorPad

archs = sorted(name for name in models.__dict__
    if callable(models.__dict__[name])
)

schedules = sorted(name for name in lr_scheduler.__dict__
    if callable(lr_scheduler.__dict__[name]) and not name.startswith('__')
    and name[0].isupper()
)

optimizers = sorted(name for name in optim.__dict__
    if callable(optim.__dict__[name]) and not name.startswith('__')
    and name[0].isupper()
)

augmentations = sorted(name for name in A.__dict__
    if callable(A.__dict__[name]) and not name.startswith('__')
    and name[0].isupper()
)

datasets = sorted(name for name in data.__dict__
    if callable(data.__dict__[name])
)

engine_names = sorted(name for name in engines.__dict__
    if callable(engines.__dict__[name])
)

loss_names = sorted(name for name in losses.__dict__
    if callable(losses.__dict__[name])
)

def parse_args():
    parser = argparse.ArgumentParser(description='Runs panoptic deeplab training')
    parser.add_argument('config', type=str, metavar='config', help='Path to a config yaml file')
    return parser.parse_args()

def main():
    args = parse_args()

    # read the config file
    config = load_config(args.config)

    config['config_file'] = args.config
    config['config_name'] = os.path.basename(args.config).split('.yaml')[0]

    # create model directory if None
    if not os.path.isdir(config['TRAIN']['model_dir']):
        os.mkdir(config['TRAIN']['model_dir'])

    # validate parameters
    assert config['MODEL']['arch'] in archs
    assert config['TRAIN']['lr_schedule'] in schedules
    assert config['TRAIN']['optimizer'] in optimizers
    assert config['TRAIN']['criterion'] in loss_names
    assert config['EVAL']['engine'] in engine_names

    if config['TRAIN']['dist_url'] == "env://" and config['TRAIN']['world_size'] == -1:
        config['TRAIN']['world_size'] = int(os.environ["WORLD_SIZE"])

    ngpus_per_node = torch.cuda.device_count()
    if config['TRAIN']['multiprocessing_distributed']:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        config['TRAIN']['world_size'] = ngpus_per_node * config['TRAIN']['world_size']
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config))
    else:
        # Simply call main_worker function
        main_worker(config['TRAIN']['gpu'], ngpus_per_node, config)

def main_worker(gpu, ngpus_per_node, config):
    config['gpu'] = gpu

    if config['gpu'] is not None:
        print(f"Use GPU: {gpu} for training")

    if config['TRAIN']['multiprocessing_distributed']:
        if config['TRAIN']['dist_url'] == "env://" and config['TRAIN']['rank'] == -1:
            config['TRAIN']['rank'] = int(os.environ["RANK"])
        if config['TRAIN']['multiprocessing_distributed']:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            config['TRAIN']['rank'] = config['TRAIN']['rank'] * ngpus_per_node + config['gpu']

        dist.init_process_group(backend=config['TRAIN']['dist_backend'], init_method=config['TRAIN']['dist_url'],
                                world_size=config['TRAIN']['world_size'], rank=config['TRAIN']['rank'])

    # setup the model and pick dataset class
    model_arch = config['MODEL']['arch']
    model = models.__dict__[model_arch](**config['MODEL'])
    dataset_class_name = config['TRAIN']['dataset_class']
    data_cls = data.__dict__[dataset_class_name]

    # load pre-trained weights, if using
    if config['TRAIN']['whole_pretraining'] is not None:
        state = torch.load(config['TRAIN']['whole_pretraining'], map_location='cpu')
        state_dict = state['state_dict']

        # remove the prefix 'module' from all of the keys
        for k in list(state_dict.keys()):
            if k.startswith('module'):
                state_dict[k[len('module.'):]] = state_dict[k]

            # delete renamed or unused k
            del state_dict[k]

        msg = model.load_state_dict(state['state_dict'], strict=True)
        norms = state['norms']
    elif config['TRAIN']['encoder_pretraining'] is not None:
        state = torch.load(config['TRAIN']['encoder_pretraining'], map_location='cpu')
        state_dict = state['state_dict']

        # add the prefix 'encoder' to all of the keys
        for k in list(state_dict.keys()):
            if not k.startswith('fc'):
                state_dict['encoder.' + k] = state_dict[k]

            # delete renamed or unused k
            del state_dict[k]

        msg = model.load_state_dict(state['state_dict'], strict=False)
        norms = {}
        norms['mean'] = state['norms'][0]
        norms['std'] = state['norms'][1]
    else:
        norms = config['DATASET']['norms']

    finetune_layer = config['TRAIN']['finetune_layer']
    # start by freezing all encoder parameters
    for pname, param in model.named_parameters():
        if 'encoder' in pname:
            param.requires_grad = False

    # freeze encoder layers
    if finetune_layer == 'none':
        # leave all encoder layers frozen
        pass
    elif finetune_layer == 'all':
        # unfreeze all encoder parameters
        for pname, param in model.named_parameters():
            if 'encoder' in pname:
                param.requires_grad = True
    else:
        valid_layers = ['stage1', 'stage2', 'stage3', 'stage4']
        assert finetune_layer in valid_layers
        # unfreeze all layers from finetune_layer onward
        for layer_name in valid_layers[valid_layers.index(finetune_layer):]:
            # freeze all encoder parameters
            for pname, param in model.named_parameters():
                if f'encoder.{layer_name}' in pname:
                    param.requires_grad = True

    num_trainable = sum(p[1].numel() for p in model.named_parameters() if p[1].requires_grad)
    print(f'Model with {num_trainable} trainable parameters.')

    # CPU
    if not torch.cuda.is_available() and not torch.backends.mps.is_available():
        print('Using CPU, this will be slow')
    # M1 Mac
    if torch.backends.mps.is_available():
        model = model.to('mps')
    # GPU
    elif config['TRAIN']['multiprocessing_distributed']:
        # use Synced batchnorm
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if config['gpu'] is not None:
            torch.cuda.set_device(config['gpu'])
            model.cuda(config['gpu'])
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            config['TRAIN']['batch_size'] = int(config['TRAIN']['batch_size'] / ngpus_per_node)
            config['TRAIN']['workers'] = int((config['TRAIN']['workers'] + ngpus_per_node - 1) / ngpus_per_node)
            model = DDP(model, device_ids=[config['gpu']])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = DDP(model)
    elif config['gpu'] is not None:
        torch.cuda.set_device(config['gpu'])
        model = model.cuda(config['gpu'])
    else:
        # script the model
        #model = torch.jit.script(model, optimize=True)
        model = torch.nn.DataParallel(model).cuda()
        #raise Exception

    cudnn.benchmark = True

    # set the training image augmentations
    config['aug_string'] = []
    dataset_augs = []
    for aug_params in config['TRAIN']['augmentations']:
        aug_name = aug_params['aug']

        assert aug_name in augmentations or aug_name == 'CopyPaste', \
        f'{aug_name} is not a valid albumentations augmentation!'

        config['aug_string'].append(aug_params['aug'])
        del aug_params['aug']
        if aug_name == 'CopyPaste':
            dataset_augs.append(CopyPaste(**aug_params))
        else:
            dataset_augs.append(A.__dict__[aug_name](**aug_params))

    config['aug_string'] = ','.join(config['aug_string'])

    tfs = A.Compose([
        *dataset_augs,
        A.Normalize(**norms),
        ToTensorV2()
    ])

    # create training dataset and loader
    train_dataset = data_cls(config['TRAIN']['train_dir'], transforms=tfs, **config['TRAIN']['dataset_params'])
    if config['TRAIN']['additional_train_dirs'] is not None:
        for train_dir in config['TRAIN']['additional_train_dirs']:
            add_dataset = data_cls(train_dir, transforms=tfs, **config['TRAIN']['dataset_params'])
            train_dataset = train_dataset + add_dataset

    if config['TRAIN']['multiprocessing_distributed']:
        if config['TRAIN']['dataset_params']['weight_gamma'] is None:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        else:
            train_sampler = DistributedWeightedSampler(train_dataset, train_dataset.weights)
    elif config['TRAIN']['dataset_params']['weight_gamma'] is not None:
        train_sampler = WeightedRandomSampler(train_dataset.weights, len(train_dataset))
    else:
        train_sampler = None

    # num workers always less than number of batches in train dataset
    num_workers = min(config['TRAIN']['workers'], len(train_dataset) // config['TRAIN']['batch_size'])

    train_loader = DataLoader(
        train_dataset, batch_size=config['TRAIN']['batch_size'], shuffle=(train_sampler is None),
        num_workers=config['TRAIN']['workers'], pin_memory=torch.cuda.is_available(), sampler=train_sampler,
        drop_last=True
    )

    if config['EVAL']['eval_dir'] is not None:
        eval_tfs = A.Compose([
            FactorPad(128), # pad image to be divisible by 128
            A.Normalize(**norms),
            ToTensorV2()
        ])
        eval_dataset = data_cls(config['EVAL']['eval_dir'], transforms=eval_tfs, **config['TRAIN']['dataset_params'])
        # evaluation runs on a single gpu
        eval_loader = DataLoader(eval_dataset, batch_size=1, shuffle=False,
                                 pin_memory=torch.cuda.is_available(),
                                 num_workers=config['TRAIN']['workers'])

        # pick images to track during validation
        if config['EVAL']['eval_track_indices'] is None:
            # randomly pick 8 examples from eval set to track
            config['EVAL']['eval_track_indices'] = [random.randint(0, len(eval_dataset)) for _ in range(8)]

    else:
        eval_loader = None

    # set criterion
    criterion_name = config['TRAIN']['criterion']
    if config['gpu'] is not None:
        criterion = losses.__dict__[criterion_name](**config['TRAIN']['criterion_params']).cuda(config['gpu'])
    else:
        criterion = losses.__dict__[criterion_name](**config['TRAIN']['criterion_params']).cuda()

    # set optimizer and lr scheduler
    opt_name = config['TRAIN']['optimizer']
    opt_params = config['TRAIN']['optimizer_params']
    optimizer = configure_optimizer(model, opt_name, **opt_params)

    schedule_name = config['TRAIN']['lr_schedule']
    schedule_params = config['TRAIN']['schedule_params']

    if 'steps_per_epoch' in schedule_params:
        n_steps = schedule_params['steps_per_epoch']
        if n_steps != len(train_loader):
            schedule_params['steps_per_epoch'] = len(train_loader)
            print(f'Steps per epoch adjusted from {n_steps} to {len(train_loader)}')

    scheduler = lr_scheduler.__dict__[schedule_name](optimizer, **schedule_params)

    scaler = GradScaler() if config['TRAIN']['amp'] else None

    # optionally resume from a checkpoint
    config['run_id'] = None
    config['start_epoch'] = 0
    if config['TRAIN']['resume'] is not None:
        if os.path.isfile(config['TRAIN']['resume']):
            print("=> loading checkpoint")
            if config['gpu'] is None:
                checkpoint = torch.load(config['TRAIN']['resume'])
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(config['gpu'])
                checkpoint = torch.load(config['TRAIN']['resume'], map_location=loc)

            config['start_epoch'] = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            if scaler is not None:
                scaler.load_state_dict(checkpoint['scaler'])

            # use the saved norms
            norms = checkpoint['norms']
            config['run_id'] = checkpoint['run_id']

            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(config['TRAIN']['resume'], checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(config['TRAIN']['resume']))

    # training and evaluation loop
    if 'epochs' in config['TRAIN']['schedule_params']:
        epochs = config['TRAIN']['schedule_params']['epochs']
    elif 'epochs' in config['TRAIN']:
        epochs = config['TRAIN']['epochs']
    else:
        raise Exception('Number of training epochs not defined!')

    config['TRAIN']['epochs'] = epochs

    # log important parameters and start/resume mlflow run
    prepare_logging(config)

    for epoch in range(config['start_epoch'], epochs):
        if config['TRAIN']['multiprocessing_distributed']:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer,
              scheduler, scaler, epoch, config)

        is_distributed = config['TRAIN']['multiprocessing_distributed']
        gpu_rank = config['TRAIN']['rank'] % ngpus_per_node

        # evaluate on validation set, does not support multiGPU
        if eval_loader is not None and (epoch + 1) % config['EVAL']['epochs_per_eval'] == 0:
            if not is_distributed or (is_distributed and gpu_rank == 0):
                validate(eval_loader, model, criterion, epoch, config)

        save_now = (epoch + 1) % config['TRAIN']['save_freq'] == 0
        if save_now and not is_distributed or (is_distributed and gpu_rank == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': config['MODEL']['arch'],
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'scaler': scaler.state_dict(),
                'run_id': mlflow.active_run().info.run_id,
                'norms': norms
            }, os.path.join(config['TRAIN']['model_dir'], f"{config['config_name']}_checkpoint.pth.tar"))

def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)

def prepare_logging(config):
    # log parameters for run, or resume existing run
    if config['run_id'] is None and config['TRAIN']['rank'] == 0:
        # log parameters in mlflow
        mlflow.end_run()
        mlflow.set_experiment(config['DATASET']['dataset_name'])

        # log the full config file after inheritance
        artifact_path = 'mlruns/' + mlflow.get_artifact_uri().split('/mlruns/')[-1]
        config_fp = os.path.join(artifact_path, os.path.basename(config['config_file']))
        with open(config_fp, mode='w') as f:
            yaml.dump(config, f)

        #we don't want to add everything in the config
        #to mlflow parameters, we'll just add the most
        #likely to change parameters
        mlflow.log_param('run_name', config['TRAIN']['run_name'])
        mlflow.log_param('architecture', config['MODEL']['arch'])
        mlflow.log_param('encoder_pretraining', config['TRAIN']['encoder_pretraining'])
        mlflow.log_param('whole_pretraining', config['TRAIN']['whole_pretraining'])
        mlflow.log_param('epochs', config['TRAIN']['epochs'])
        mlflow.log_param('batch_size', config['TRAIN']['batch_size'])
        mlflow.log_param('lr_schedule', config['TRAIN']['lr_schedule'])
        mlflow.log_param('optimizer', config['TRAIN']['optimizer'])

        aug_names = config['aug_string']
        mlflow.log_param('augmentations', aug_names)
    else:
        # resume existing run
        mlflow.start_run(run_id=config['run_id'])

def log_metrics(progress, meters, epoch, dataset='Train'):
    # log all the losses from progress
    for meter in progress.meters:
        mlflow.log_metric(f'{dataset}_{meter.name}', meter.avg, step=epoch)

    for metric_name, values in meters.history.items():
        mlflow.log_metric(f'{dataset}_{metric_name}', values[-1], step=epoch)

def configure_optimizer(model, opt_name, **opt_params):
    """
    Takes an optimizer and separates parameters into two groups
    that either use weight decay or are exempt.

    Only BatchNorm parameters and biases are excluded.
    """

    # easy if there's no weight_decay
    if 'weight_decay' not in opt_params:
        return optim.__dict__[opt_name](model.parameters(), **opt_params)
    elif opt_params['weight_decay'] == 0:
        return optim.__dict__[opt_name](model.parameters(), **opt_params)

    # otherwise separate parameters into two groups
    decay = set()
    no_decay = set()

    blacklist = (torch.nn.BatchNorm2d,)
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters(recurse=False):
            full_name = '%s.%s' % (mn, pn) if mn else pn

            if full_name.endswith('bias'):
                no_decay.add(full_name)
            elif full_name.endswith('weight') and isinstance(m, blacklist):
                no_decay.add(full_name)
            else:
                decay.add(full_name)

    param_dict = {pn: p for pn, p in model.named_parameters()}
    inter_params = decay & no_decay
    union_params = decay | no_decay
    assert(len(inter_params) == 0), "Overlapping decay and no decay"
    assert(len(param_dict.keys() - union_params) == 0), "Missing decay parameters"

    decay_params = [param_dict[pn] for pn in sorted(list(decay))]
    no_decay_params = [param_dict[pn] for pn in sorted(list(no_decay))]

    param_groups = [
        {"params": decay_params, **opt_params},
        {"params": no_decay_params, **opt_params}
    ]
    param_groups[1]['weight_decay'] = 0 # overwrite default to 0 for no_decay group

    return optim.__dict__[opt_name](param_groups, **opt_params)

def train(
    train_loader,
    model,
    criterion,
    optimizer,
    scheduler,
    scaler,
    epoch,
    config
):
    # generic progress
    batch_time = ProgressAverageMeter('Time', ':6.3f')
    data_time = ProgressAverageMeter('Data', ':6.3f')
    loss_meters = None

    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time],
        prefix="Epoch: [{}]".format(epoch)
    )

    # end of epoch metrics
    class_names = config['DATASET']['class_names']
    metric_dict = {}
    for metric_params in config['TRAIN']['metrics']:
        reg_name = metric_params['name']
        metric_name = metric_params['metric']
        metric_params = {k: v for k,v in metric_params.items() if k not in ['name', 'metric']}
        metric_dict[reg_name] = metrics.__dict__[metric_name](metrics.EMAMeter, **metric_params)

    meters = metrics.ComposeMetrics(metric_dict, class_names)

    # switch to train mode
    model.train()

    end = time.time()
    for i, batch in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        images = batch['image']
        target = {k: v for k,v in batch.items() if k not in ['image', 'fname']}

        if config.get('gpu') is not None and torch.cuda.is_available():
            images = images.to(config['gpu'], non_blocking=True)
            target = {k: tensor.to(config['gpu'], non_blocking=True)
                      for k,tensor in target.items()}

        if config['TRAIN'].get('mps') is not None and torch.backends.mps.is_available():
            images = images.to('mps', non_blocking=True)
            target = {k: tensor.to('mps', non_blocking=True)
                      for k,tensor in target.items()}

        # zero grad before running
        optimizer.zero_grad()

        # compute output
        if scaler is not None:
            with autocast():
                output = model(images)
                loss, aux_loss = criterion(output, target)  # output and target are both dicts

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
        else:
            output = model(images)
            loss, aux_loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        # update the LR
        scheduler.step()

        # record losses
        if loss_meters is None:
            loss_meters = {}
            for k,v in aux_loss.items():
                loss_meters[k] = ProgressEMAMeter(k, ':.4e')
                loss_meters[k].update(v)
                # add to progress
                progress.meters.append(loss_meters[k])
        else:
            for k,v in aux_loss.items():
                loss_meters[k].update(v)

        # calculate human-readable per epoch metrics
        with torch.no_grad():
            meters.evaluate(output, target)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config['TRAIN']['print_freq'] == 0:
            progress.display(i)

    # end of epoch print evaluation metrics
    print('\n')
    print(f'Epoch {epoch} training metrics:')
    meters.display()
    log_metrics(progress, meters, epoch, dataset='Train')

def validate(
    eval_loader,
    model,
    criterion,
    epoch,
    config
):
    # validation metrics to track
    class_names = config['DATASET']['class_names']
    metric_dict = {}
    for metric_params in config['EVAL']['metrics']:
        reg_name = metric_params['name']
        metric_name = metric_params['metric']
        metric_params = {k: v for k,v in metric_params.items() if k not in ['name', 'metric']}
        metric_dict[reg_name] = metrics.__dict__[metric_name](metrics.AverageMeter, **metric_params)

    meters = metrics.ComposeMetrics(metric_dict, class_names)

    # validation tracking
    batch_time = ProgressAverageMeter('Time', ':6.3f')
    loss_meters = None

    progress = ProgressMeter(
        len(eval_loader),
        [batch_time],
        prefix='Validation: '
    )

    # create the Inference Engine
    engine_name = config['EVAL']['engine']
    engine = engines.__dict__[engine_name](model, **config['EVAL']['engine_params'])

    for i, batch in enumerate(eval_loader):
        end = time.time()
        images = batch['image']
        target = {k: v for k,v in batch.items() if k not in ['image', 'fname']}

        if config['gpu'] is not None:
            images = images.cuda(config['gpu'], non_blocking=True)
        if torch.cuda.is_available():
            target = {k: tensor.cuda(config['gpu'], non_blocking=True)
                      for k,tensor in target.items()}

        # compute panoptic segmentations
        # from prediction and ground truth
        output = engine.infer(images)
        semantic = engine._harden_seg(output['sem'])
        output['pan_seg'] = engine.postprocess(
            semantic, output['ctr_hmp'], output['offsets']
        )
        target['pan_seg'] = engine.postprocess(
            target['sem'].unsqueeze(1), target['ctr_hmp'], target['offsets']
        )

        loss, aux_loss = criterion(output, target)

        # record losses
        if loss_meters is None:
            loss_meters = {}
            for k,v in aux_loss.items():
                loss_meters[k] = ProgressAverageMeter(k, ':.4e')
                loss_meters[k].update(v)
                # add to progress
                progress.meters.append(loss_meters[k])
        else:
            for k,v in aux_loss.items():
                loss_meters[k].update(v)

        # compute metrics
        with torch.no_grad():
            meters.evaluate(output, target)

        batch_time.update(time.time() - end)

        if i % config['TRAIN']['print_freq'] == 0:
            progress.display(i)

        if i in config['EVAL']['eval_track_indices'] and (epoch + 1) % config['EVAL']['eval_track_freq'] == 0:
            impath = batch['fname'][0]
            fname = '.'.join(os.path.basename(impath).split('.')[:-1])
            image = io.imread(impath)

            # gt and prediction
            h, w = image.shape
            gt = measure.label(target['pan_seg'].squeeze().cpu().numpy()[:h, :w])
            pred = measure.label(output['pan_seg'].squeeze().cpu().numpy()[:h, :w])

            artifact_path = 'mlruns/' + mlflow.get_artifact_uri().split('/mlruns/')[-1]

            f, ax = plt.subplots(1, 3, figsize=(12, 4))
            ax[0].imshow(image, cmap='gray')
            ax[1].imshow(gt, cmap='plasma')
            ax[2].imshow(pred, cmap='plasma')
            plt.savefig(os.path.join(artifact_path, f'{fname}_epoch{epoch}.png'))
            plt.clf()

    # end of epoch print evaluation metrics
    print('\n')
    print(f'Validation results:')
    meters.display()
    log_metrics(progress, meters, epoch, dataset='Eval')

class ProgressAverageMeter(metrics.AverageMeter):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        super().__init__()

    def __str__(self):
        fmtstr = '{name} {avg' + self.fmt + '}'
        return fmtstr.format(**self.__dict__)

class ProgressEMAMeter(metrics.EMAMeter):
    """Computes and stores the exponential moving average and current value"""
    def __init__(self, name, fmt=':f', momentum=0.98):
        self.name = name
        self.fmt = fmt
        super().__init__(momentum)

    def __str__(self):
        fmtstr = '{name} {avg' + self.fmt + '}'
        return fmtstr.format(**self.__dict__)

class ProgressMeter:
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

if __name__ == "__main__":
    main()
Error message (click to expand):
(napari-empanada-dev) genevieb@Admins-MacBook-Pro scripts % python train_mps.py ../projects/mitonet/configs/mps_config.yml
Model with 15842660 trainable parameters.
Found 1 image subdirectories with 16 images.
Found 1 image subdirectories with 16 images.
Steps per epoch adjusted from -1 to 1
/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py:115: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
  warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.")
/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/amp/autocast_mode.py:198: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
  warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
Traceback (most recent call last):
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/scripts/train_mps.py", line 727, in <module>
    main()
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/scripts/train_mps.py", line 102, in main
    main_worker(config['TRAIN']['gpu'], ngpus_per_node, config)
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/scripts/train_mps.py", line 370, in main_worker
    train(train_loader, model, criterion, optimizer,
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/scripts/train_mps.py", line 543, in train
    loss, aux_loss = criterion(output, target)  # output and target are both dicts
  File "/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/losses.py", line 140, in forward
    ce = self.ce_loss(output['sem_logits'], target['sem'])
  File "/Users/genevieb/mambaforge/envs/napari-empanada-dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/genevieb/Documents/GitHub/empanada/empanada/empanada/losses.py", line 44, in forward
    pixel_losses, _ = torch.topk(pixel_losses, top_k_pixels)
RuntimeError: Currently topk on mps works only for k<=16
(napari-empanada-dev) genevieb@Admins-MacBook-Pro scripts %

Ryan says we use top k of a very large number (like thousands), but the mps backend will only suppport k=>16. So we're probably not going to be able to run these models easily on the M1 mac hardware 😢

Maybe we could modify the loss function we are using, so it doesn't encounter this particular problem?

@berombau
Copy link

So GPU M1 acceleration appears to be blocked by pytorch/pytorch#78915.
Is CPU M1 acceleration supported by the quantized model currently?
I get this error when running the quantized model on CPU:

Traceback of TorchScript, original code (most recent call last):
  File "/data/IASEM/iasemconda/lib/python3.9/site-packages/torch/nn/quantized/modules/conv.py", line 176, in __setstate__
        self.groups = state[8]
        self.padding_mode = state[9]
        self.set_weight_bias(state[10], state[11])
        ~~~~~~~~~~~~~~~~~~~~ <--- HERE
        self.scale = state[12]
        self.zero_point = state[13]
  File "/data/IASEM/iasemconda/lib/python3.9/site-packages/torch/nn/quantized/modules/conv.py", line 401, in set_weight_bias
    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
        if self.padding_mode == 'zeros':
            self._packed_params = torch.ops.quantized.conv2d_prepack(
                                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                w, b, self.stride, self.padding, self.dilation, self.groups)
        else:
RuntimeError: Didn't find engine for operation quantized::conv2d_prepack NoQEngine

2D and 3D inferencing works fine by appending some code in the Engine2D and Engine3D classes to use the ARM quantization backend:
e.g. in the 2 lines of empanada-napari/empanada_napari/inference.py:

model_url = model_config['model_quantized']

Add the following:

model_url = model_config['model_quantized']
if torch.backends.mps.is_available():
    # set quantized backend to ARM
    torch.backends.quantized.engine = 'qnnpack'

@GenevieveBuckley
Copy link
Author

As described above, we ran into difficulties implementing support for M1 acceleration training models.

However, accelerating inference might still be a useful thing to do. So if your experiments with inference only look more promising, that could be something worthwhile continuing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants