Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensors on Different Devices During TorchScript Export #325

Open
honest8ear opened this issue Sep 24, 2024 · 0 comments
Open

Tensors on Different Devices During TorchScript Export #325

honest8ear opened this issue Sep 24, 2024 · 0 comments

Comments

@honest8ear
Copy link

Hello, I am currently modifying the architecture from "ONNX-SAM2-Segment-Anything" to export it to TorchScript. However, I have encountered a device mismatch issue during the export process. Below is my code:

from typing import Optional, Tuple, Any
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_


from sam2.modeling.sam2_base import SAM2Base

class SAM2ImageEncoder(nn.Module):
    def __init__(self, sam_model: SAM2Base) -> None:
        super().__init__()
        self.model = sam_model
        self.image_encoder = sam_model.image_encoder
        self.no_mem_embed = sam_model.no_mem_embed

    def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
        backbone_out = self.image_encoder(x)
        backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(
            backbone_out["backbone_fpn"][0]
        )
        backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(
            backbone_out["backbone_fpn"][1]
        )

        feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels:]
        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels:]

        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]

        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
        vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]

        vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
                 for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]

        return feats[0], feats[1], feats[2]


class SAM2ImageDecoder(nn.Module):
    def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None:
        super().__init__()
        self.mask_decoder = sam_model.sam_mask_decoder
        self.prompt_encoder = sam_model.sam_prompt_encoder
        self.model = sam_model
        self.multimask_output = multimask_output

    @torch.no_grad()
    def forward(self, image_embed: torch.Tensor, high_res_feats_0: torch.Tensor, high_res_feats_1: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, mask_input: torch.Tensor, has_mask_input: torch.Tensor, img_size: torch.Tensor):
        sparse_embedding = self._embed_points(point_coords, point_labels)
        dense_embedding = self._embed_masks(mask_input, has_mask_input)

        high_res_feats = [high_res_feats_0, high_res_feats_1]

        masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
            image_embeddings=image_embed,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embedding,
            dense_prompt_embeddings=dense_embedding,
            repeat_image=False,
            high_res_features=high_res_feats,
        )

        if self.multimask_output:
            masks = masks[:, 1:, :, :]
            iou_predictions = iou_predictions[:, 1:]
        else:
            masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(masks, iou_predictions)

        masks = torch.clamp(masks, -32.0, 32.0)
        print(masks.shape, iou_predictions.shape)

        masks = F.interpolate(masks, (img_size[0], img_size[1]), mode="bilinear", align_corners=False)

        return masks, iou_predictions

    def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
        point_coords = point_coords + 0.5
        padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
        padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
        point_coords = torch.cat([point_coords, padding_point], dim=1)
        point_labels = torch.cat([point_labels, padding_label], dim=1)

        point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
        point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size

        point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
        point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)

        point_embedding = point_embedding * (point_labels != -1)
        point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (
                point_labels == -1
        )

        for i in range(self.prompt_encoder.num_point_embeddings):
            point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)

        return point_embedding

    def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
        mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(input_mask)
        mask_embedding = mask_embedding + (
                1 - has_mask_input
        ) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
        return mask_embedding
import traceback
from sam2.build_sam import build_sam2
def main():
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(device)
        model_type = 'sam2_hiera_large'
        input_size = 1024
        multimask_output = False

        if model_type == "sam2_hiera_tiny":
            model_cfg = "sam2_hiera_t.yaml"
        elif model_type == "sam2_hiera_small":
            model_cfg = "sam2_hiera_s.yaml"
        elif model_type == "sam2_hiera_base_plus":
            model_cfg = "sam2_hiera_b+.yaml"
        else:
            model_cfg = "sam2_hiera_l.yaml"

        sam2_checkpoint = f"checkpoints/{model_type}.pt"
        sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

        img = torch.randn(1, 3, input_size, input_size).to(device)

        sam2_encoder = SAM2ImageEncoder(sam2_model).to(device)
        high_res_feats_0, high_res_feats_1, image_embed = sam2_encoder(img)
        #print(high_res_feats_0.shape)
        #print(high_res_feats_1.shape)
        #print(image_embed.shape)

        sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=multimask_output).to(device)

        backbone_stride = sam2_model.backbone_stride
        embed_size = (input_size // backbone_stride, input_size // backbone_stride)
        mask_input_size = [4 * x for x in embed_size]

        point_coords = torch.randint(low=0, high=input_size, size=(1, 5, 2), dtype=torch.float).to(device)
        point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float).to(device)
        mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float).to(device)  # 正確形狀的 mask_input
        has_mask_input = torch.tensor([1], dtype=torch.float).to(device)
        orig_im_size = torch.tensor([input_size, input_size], dtype=torch.int32).to(device)

        traced_decoder = torch.jit.trace(sam2_decoder, (image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size))
        traced_decoder.save("decoder.pt")

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()
if __name__ == "__main__":
    main()

While trying to export the model using torch.jit.trace in a Colab environment with T4 GPU, I encountered the following error message:

cuda
An error occurred: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
/content/segment-anything-2/sam2/modeling/sam/mask_decoder.py:203: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert image_embeddings.shape[0] == tokens.shape[0]
/content/segment-anything-2/sam2/modeling/sam/mask_decoder.py:207: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  image_pe.size(0) == 1
Traceback (most recent call last):
  File "<ipython-input-3-b6754915e605>", line 43, in main
    traced_decoder = torch.jit.trace(sam2_decoder, (image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size))
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1000, in trace
    traced_func = _trace_impl(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 695, in _trace_impl
    return trace_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1275, in trace_module
    module._c._create_method_from_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "<ipython-input-2-0239f2314634>", line 58, in forward
    masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
  File "/content/segment-anything-2/sam2/modeling/sam/mask_decoder.py", line 209, in predict_masks
    pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

It appears that some tensors are on cpu while others are on cuda:0, even though I have ensured that all input tensors are moved to cuda. I suspect that some internal layers or tensors inside the mask_decoder are not being moved to cuda correctly. Despite moving all model layers to cuda, this issue persists.

I would greatly appreciate any suggestions or advice on how to resolve this device mismatch issue.

Best regards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant