Skip to content
This repository has been archived by the owner on May 3, 2023. It is now read-only.

Commit

Permalink
Final solution
Browse files Browse the repository at this point in the history
  • Loading branch information
commanderxa committed Apr 16, 2023
0 parents commit d38bf39
Show file tree
Hide file tree
Showing 12 changed files with 643 additions and 0 deletions.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/__pycache__
/models
/images
/.venv

/data/train
/data/test

*png
*jpg
*.csv
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# UNet for mechanical Claws segmentation

Deep Neural Network (UNet) that segments the claws on the image.

by `stable-confusion` team

## Data

**[Kaggle Competition]**(https://www.kaggle.com/competitions/gdsc-nu-ml-hackathon-bts-case-competition/overview)

Kaggle Competition from NU GDSC and BTS Kazakhstan.

## Results

**4th place** with `~87%` accuracy

## Libraries & Frameworks

- PyTorch
- polars
- tqdm

## Techniques

- Augmentations
- L2 Regularization as weight decay
- MixedPrecision

## Project Setup

Run the following commands in project root directory.

- `source .venv/bin/activate`
- `./setup.sh`

## Use

To use the project either:

- download the `unet.pth`
- place it inside the `models` directory

or train the model by yourself using `train.py`. Before training the model you need to download the dataset inside the project root directory and leave the filenamme unchenged, then run `./setup.sh`.

To get predicitons run `python main.py`, but note, you have to add at least one image into the `inference/imgs` directory
68 changes: 68 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import polars as pl
from PIL import Image
import random

import torch
import torchvision
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset
from torchvision.transforms.v2 import functional as TF


class ClawDataset(Dataset):

def __init__(self, annotations: str):
self.file = pl.read_csv(annotations)
self.is_train = True if annotations.split(".")[0].split("/")[1] == "train" else False
self.transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Resize((384, 512), antialias=True),
])
self.test_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
# torchvision.transforms.Resize((384, 512), antialias=True),
])

def __len__(self):
return len(self.file)

def __getitem__(self, index: int):
item = self.file[index]
x = Image.open(item["imgs"].item()).convert("RGB")

# if test (no y)
if not self.is_train:
x = self.test_transform(x)
return x, item["imgs"].item().split("/")[-1]

y = Image.open(item["masks"].item()).convert("1")
x = self.transform(x)
y = self.transform(y)

# crop
i, j, h, w = transforms.RandomCrop.get_params(
torch.randn(600, 600), output_size=(576, 576))
x = TF.crop(x, i, j, h, w)
y = TF.crop(y, i, j, h, w)

# rotation
angle = random.randrange(-20, 20)
x = TF.rotate(x, angle)
y = TF.rotate(y, angle)

# Random horizontal flipping
if random.random() > 0.5:
x = TF.hflip(x)
y = TF.hflip(y)

# image stuff
colorjitter = transforms.ColorJitter((0.5, 1.5), (0.5, 1.5), (0.5, 1.5), None)
x = colorjitter(x)

# blur
if random.random() > 0.5:
blur = torchvision.transforms.GaussianBlur((9,9), (0.1, 2))
x = blur(x)

return x, y
20 changes: 20 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from model import UNet
from operations import inference
from utils import load_sequence

device = "cuda" if torch.cuda.is_available() else "cpu"
model_file = "models/unet.pth"

model = UNet()
model = model.to(device)
model = torch.compile(model)

checkpoint = torch.load(model_file)
model.load_state_dict(checkpoint["model"])
model.eval()

loader = load_sequence()

inference(model, loader, device)
104 changes: 104 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict


class UNet(nn.Module):

def __init__(self, in_channels=3, out_channels=1, init_features=32):
super(UNet, self).__init__()

features = init_features
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

self.bottleneck = UNet._block(
features * 8, features * 16, name="bottleneck")

self.upconv4 = nn.ConvTranspose2d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = UNet._block(
(features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = UNet._block(
(features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = UNet._block(
(features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = UNet._block(features * 2, features, name="dec1")

self.conv = nn.Conv2d(
in_channels=features, out_channels=out_channels, kernel_size=1
)

def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))

bottleneck = self.bottleneck(self.pool4(enc4))

dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return self.conv(dec1)

@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv2d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
# (name + "dropout1", nn.Dropout(0.4)),
(name + "norm1", nn.BatchNorm2d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
# (name + "dropout2", nn.Dropout(0.2)),
(name + "norm2", nn.BatchNorm2d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
130 changes: 130 additions & 0 deletions operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os

import torch
import torchvision
from torchvision.utils import make_grid
from torchvision.utils import save_image
from torchvision.utils import draw_segmentation_masks

import numpy as np
from tqdm import tqdm
import pandas as pd

from model import UNet

losses = []
metrics = []


def train(model: UNet, loader, criterion, scaler, optim, dice, model_file, epochs: int, device, grad_scaler):
model.train()

for epoch in range(1, epochs+1):
with tqdm(iter(loader)) as tepoch:
tepoch.set_description(f"Epoch: {epoch}")
# mixed precision training
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=grad_scaler):
for x, y in tepoch:
x, y = x.to(device), y.to(device)
prediction = model(x)
loss: torch.Tensor = criterion(prediction, y)

# backprop and optimize
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
optim.zero_grad(set_to_none=True)
losses.append(loss.item())

# metric = iou(torch.floor(torch.sigmoid(prediction) + .5), torch.floor(y + .5))
metric = dice(prediction, y.int())
metrics.append(metric.item())

# save model
checkpoint = {
"model": model.state_dict(),
"optimizer": optim.state_dict(),
"scaler": scaler.state_dict()
}
# Write checkpoint as desired, e.g.,
torch.save(checkpoint, model_file)

print(f"Loss: {np.mean(losses)}, Accuracy: {np.mean(metrics)}")
losses.clear()
metrics.clear()


def validation(model: UNet, loader, device):
model.eval()
with torch.no_grad():
results = []
images = []

with tqdm(iter(loader)) as tepoch:
for x, name in tepoch:
x = x.to(device)
y = model(x)
y = torch.floor(torch.sigmoid(y) + .5)

claws_with_masks = draw_segmentation_masks(image=(
x[0]*255).type(torch.uint8).cpu(), masks=(y[0] > 0).cpu(), alpha=0.7, colors="#1FFF78")
save_image((claws_with_masks / 255), os.path.join(
"data/test/masks", name[0].split('.')[0] + ".png"))

results.append(claws_with_masks)
images.append(y)

grid = make_grid(results)
img = torchvision.transforms.ToPILImage()(grid)
img.save("images/test.png")
results.clear()

responses = []
for i, image in enumerate(images):
# img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
coords = list(np.where(image.view(1, -1).cpu().squeeze(0) > 0)[0])
short_coords = [str(coords[0])]
j = 1
length = 1
while j != len(coords):
if coords[j]-1 in coords:
length += 1
else:
short_coords.append(str(length))
short_coords.append(str(coords[j]))
length = 1
j += 1
short_coords.append(str(length))

responses.append([i] + [" ".join(short_coords)])

sample = pd.DataFrame(responses, columns=["ImageID", "Expected"])
sample.to_csv("sample16_new2.csv", index=None)


def inference(model, loader, device):
assert len(loader) > 0, "Add image(s) into the `inference/imgs` directory"
model.eval()
with torch.no_grad():
results = []
images = []

with tqdm(iter(loader)) as tepoch:
for x, name in tepoch:
x = x.to(device)
y = model(x)
y = torch.floor(torch.sigmoid(y) + .5)

claws_with_masks = draw_segmentation_masks(image=(
x[0]*255).type(torch.uint8).cpu(), masks=(y[0] > 0).cpu(), alpha=0.7, colors="#1FFF78")

save_image((claws_with_masks / 255), os.path.join(
"inference/masks", name[0].split('.')[0] + ".png"))

results.append(claws_with_masks)
images.append(y)

grid = make_grid(results)
img = torchvision.transforms.ToPILImage()(grid)
img.save("inference/test.jpg")
results.clear()
Loading

0 comments on commit d38bf39

Please sign in to comment.