Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release trained model on SegPath data #26

Open
abs51295 opened this issue Apr 24, 2024 · 1 comment
Open

Release trained model on SegPath data #26

abs51295 opened this issue Apr 24, 2024 · 1 comment
Assignees

Comments

@abs51295
Copy link

Hello,

Are you planning to release the trained model on SegPath data so that we can directly run inference on our samples?

@pidemal
Copy link

pidemal commented Jul 3, 2024

Not sure if this helps, but I made the following wrapper: @abs51295 , would appreciate if @Richarizardd confirmed what type of adapters they used. My understanding is that they did an end-to-end finetune.

## Wrap UNI to make it compatible for Mask2Former backbone
class UNIWrapper(nn.Module):
    def __init__(self, out_channels_list=[96, 192, 384, 768], num_features=4):
        super().__init__()
        self.uni_model = timm.create_model(
            "hf-hub:MahmoodLab/uni", 
            pretrained=True, 
            init_values=1e-5, 
            dynamic_img_size=True
        )

        # for param in self.uni_model.parameters():
        #     param.requires_grad = False

        self.adapters = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1024, out_channels, kernel_size=1),
                nn.Upsample(scale_factor=2**(num_features-1-i))
            ) for i, out_channels in enumerate(out_channels_list)
        ])

    def forward(self, x):
        features = self.uni_model(x)
        batch_size, num_features = features.shape
        features = features.view(batch_size, 1024, 1, 1)
        
        multi_scale_features = []
        for adapter in self.adapters:
            multi_scale_features.append(adapter(features))

        return type('FeatureMaps', (), {'feature_maps': multi_scale_features})()


def load_model(num_classes=1):
    image_processor = Mask2FormerImageProcessor(
        reduce_labels=True,
        do_resize=True,
        ignore_index=255,
        size={"height": 224, "width": 224}
    )
    
    model_config = Mask2FormerConfig.from_pretrained("facebook/mask2former-swin-tiny-cityscapes-semantic")
    login(token=HUGGINGFACE_TOKEN)
    # Instantiate Mask2Former model with UNI backbone 
    model = Mask2FormerForUniversalSegmentation(model_config)
    model.model.pixel_level_module.encoder = UNIWrapper()

    total_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"total trainable parameters: {total_parameters:,}")
    
    return model, image_processor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants