Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add V-Net segmentation model #72

Open
wants to merge 2 commits into
base: stg-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 129 additions & 38 deletions UniTrain/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,20 @@ def __init__(self, num_classes):
self.softmax = F.softmax
self.num_classes = num_classes


def conv_block(self, xb, inp_filter_size, hidden_filter_size, out_filter_size, pool = False):
layers = nn.Sequential(nn.Conv2d(inp_filter_size, hidden_filter_size, padding=0, kernel_size=1), nn.BatchNorm2d(hidden_filter_size), nn.ReLU(inplace=True),
nn.Conv2d(hidden_filter_size, hidden_filter_size, padding=1, kernel_size=3), nn.BatchNorm2d(hidden_filter_size), nn.ReLU(inplace=True),
nn.Conv2d(hidden_filter_size, out_filter_size, padding=0, kernel_size=1), nn.BatchNorm2d(out_filter_size), nn.ReLU(inplace=True))
def conv_block(
self, xb, inp_filter_size, hidden_filter_size, out_filter_size, pool=False
):
layers = nn.Sequential(
nn.Conv2d(inp_filter_size, hidden_filter_size, padding=0, kernel_size=1),
nn.BatchNorm2d(hidden_filter_size),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_filter_size, hidden_filter_size, padding=1, kernel_size=3),
nn.BatchNorm2d(hidden_filter_size),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_filter_size, out_filter_size, padding=0, kernel_size=1),
nn.BatchNorm2d(out_filter_size),
nn.ReLU(inplace=True),
)
layers.to(xb.device)
return layers(xb)

Expand All @@ -38,7 +47,7 @@ def forward(self, xb):
y = self.conv_block(y, 512, 256, 1024)
for i in range(0, 22):
y = self.conv_block(y, 1024, 256, 1024) + y
i+=1
i += 1

y = self.conv_block(y, 1024, 512, 2048)
y = self.conv_block(y, 2048, 512, 2048) + y
Expand All @@ -52,7 +61,7 @@ def forward(self, xb):
y = linear_layer(y)

return y


# Define the ResNet-9 model in a single class
class ResNet9(nn.Module):
Expand Down Expand Up @@ -120,28 +129,32 @@ def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc(x)

#ResNet50 functionality addition

# ResNet50 functionality addition
class ResNet9_50(nn.Module):
def __init__(self, num_classes):
super(ResNet9_50, self).__init__()

self.resnet9 = ResNet9(num_classes)
self.resnet50 = models.resnet50(pretrained=True)

def forward(self, x):
x = self.resnet50(x)
#GoogLeNet functionality addition
x = self.resnet50(x)


# GoogLeNet functionality addition
import torch
import torch.nn as nn
import torchvision.models as models


class GoogleNetModel(nn.Module):
def __init__(self, num_classes):
super(GoogleNetModel, self).__init()

# Load the pre-trained GoogleNet model
self.googlenet = models.inception_v3(pretrained=True)

# Modify the classification head to match the number of classes in your dataset
num_ftrs = self.googlenet.fc.in_features
self.googlenet.fc = nn.Linear(num_ftrs, num_classes)
Expand All @@ -156,10 +169,15 @@ def forward(self, x):


# Making a custom transfer learning model
def create_transfer_learning_model(num_classes, model = torchvision.models.resnet18, feature_extract=True, use_pretrained=True):
def create_transfer_learning_model(
num_classes,
model=torchvision.models.resnet18,
feature_extract=True,
use_pretrained=True,
):
"""
Create a transfer learning model with a custom output layer.

Args:
num_classes (int): Number of classes in the custom output layer.
model(torchvision.models.<ModelName>): Pre-trained model you want to use.
Expand All @@ -171,18 +189,19 @@ def create_transfer_learning_model(num_classes, model = torchvision.models.resne
"""
# Load a pre-trained model, for example, ResNet-18
model = model(pretrained=use_pretrained)

# Freeze the pre-trained weights if feature_extract is True
if feature_extract:
for param in model.parameters():
param.requires_grad = False

# Modify the output layer to match the number of classes
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)

return model


# Define the ResNet-18 model in a single class
class ResNet34(nn.Module):
def __init__(self, num_classes):
Expand Down Expand Up @@ -210,15 +229,31 @@ def make_layer(self, out_channels, num_blocks, stride):
layers.append(self.build_residual_block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
for _ in range(1, num_blocks):
layers.append(self.build_residual_block(self.in_channels, out_channels, stride=1))
layers.append(
self.build_residual_block(self.in_channels, out_channels, stride=1)
)
return nn.Sequential(*layers)

def build_residual_block(self, in_channels, out_channels, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
)

Expand All @@ -237,7 +272,6 @@ def forward(self, x):
return x



class ResNet50(nn.Module):
def __init__(self, num_classes):
super(ResNet50, self).__init__()
Expand All @@ -264,15 +298,31 @@ def make_layer(self, out_channels, num_blocks, stride):
layers.append(self.build_residual_block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
for _ in range(1, num_blocks):
layers.append(self.build_residual_block(self.in_channels, out_channels, stride=1))
layers.append(
self.build_residual_block(self.in_channels, out_channels, stride=1)
)
return nn.Sequential(*layers)

def build_residual_block(self, in_channels, out_channels, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
)

Expand Down Expand Up @@ -317,15 +367,31 @@ def make_layer(self, out_channels, num_blocks, stride):
layers.append(self.build_residual_block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
for _ in range(1, num_blocks):
layers.append(self.build_residual_block(self.in_channels, out_channels, stride=1))
layers.append(
self.build_residual_block(self.in_channels, out_channels, stride=1)
)
return nn.Sequential(*layers)

def build_residual_block(self, in_channels, out_channels, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
)

Expand All @@ -350,7 +416,9 @@ def __init__(self, num_classes, growth_rate=12, num_blocks=3, num_layers=4):
self.in_channels = 64

# Initial convolution layer
self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.conv1 = nn.Conv2d(
3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(self.in_channels)
self.relu = nn.ReLU(inplace=True)

Expand All @@ -367,11 +435,20 @@ def make_dense_block(self, growth_rate, num_layers):
layers = []
in_channels = self.in_channels
for _ in range(num_layers):
layers.extend([
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=False),
])
layers.extend(
[
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
in_channels,
growth_rate,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
]
)
in_channels += growth_rate
self.in_channels = in_channels
return nn.Sequential(*layers)
Expand All @@ -387,15 +464,28 @@ def forward(self, x):
x = self.fc(x)
return x


class LightVisionTransformer(nn.Module):
def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=256, depth=6, heads=4, mlp_dim=512, dropout=0.1):
def __init__(
self,
image_size=224,
patch_size=16,
num_classes=1000,
dim=256,
depth=6,
heads=4,
mlp_dim=512,
dropout=0.1,
):
super(LightVisionTransformer, self).__init()

num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size * patch_size # 3 for RGB channels

# Patch embedding layer
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
self.patch_embedding = nn.Conv2d(
3, dim, kernel_size=patch_size, stride=patch_size
)
self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = nn.Transformer(
Expand All @@ -411,10 +501,11 @@ def forward(self, x):
B, C, H, W = x.shape
x = self.patch_embedding(x)
x = x.permute(0, 2, 3, 1).view(B, -1, x.size(1)) # Flatten and transpose
x = torch.cat([self.cls_token.expand(B, -1, -1), x], dim=1) # Prepend the classification token
x = torch.cat(
[self.cls_token.expand(B, -1, -1), x], dim=1
) # Prepend the classification token
x = x + self.positional_embedding
x = self.transformer(x)
x = x.mean(dim=1) # Global average pooling
x = self.fc(x)
return x

Loading