Skip to content

Commit

Permalink
Ruffed the code
Browse files Browse the repository at this point in the history
Fixed the code including the unsafe fixes using ALL rules
  • Loading branch information
rabinadk1 committed Dec 28, 2023
1 parent 35b22a9 commit 591bca6
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 29 deletions.
10 changes: 5 additions & 5 deletions src/data/core_datasets/phrasecutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ def __init__(
image_dir: StrPath = "images",
transforms: Optional[Callable] = None,
return_tensors: Literal["tf", "pt", "np"] = "np",
):
) -> None:
super().__init__()

data_root = Path(data_root)

with (data_root / task_json_path).open() as f:
self.tasks: Tuple[Mapping[str, Union[str, PolygonType]], ...] = tuple(
json.load(f)
json.load(f),
)

self.image_path = data_root / image_dir
Expand All @@ -41,14 +41,14 @@ def __init__(

self.transforms = transforms

def __len__(self):
def __len__(self) -> int:
return len(self.tasks)

def __getitem__(self, idx: int):
task = self.tasks[idx]

image = self.load_image(
self.image_path / f"{task['image_id']}.jpg", cv2.IMREAD_COLOR
self.image_path / f"{task['image_id']}.jpg", cv2.IMREAD_COLOR,
)
# Convert BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
Expand Down Expand Up @@ -120,7 +120,7 @@ def plot_img_mask_cut(
phrase: str,
mask: np.ndarray,
figsize: Tuple[int, int] = (15, 5),
):
) -> None:
_, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)

ax1.imshow(PhraseCutDataset.img_normalize(img))
Expand Down
7 changes: 4 additions & 3 deletions src/models/core_model/trans_segmentor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ def get_with_pos_enc(x: torch.Tensor):

# Get positional encoding
posenc = TransformerSegmentor.get_posenc(
d_model=H, token_length=N, device=x.device, dtype=x.dtype
d_model=H, token_length=N, device=x.device, dtype=x.dtype,
)

# Add positional encoding to the input and return it
return x + posenc

@staticmethod
@lru_cache()
@lru_cache
@torch.no_grad()
def get_posenc(
d_model: int,
Expand All @@ -123,8 +123,9 @@ def get_posenc(
dtype: torch.dtype = None,
) -> torch.Tensor:
if d_model % 2 != 0:
msg = f"Cannot use sin/cos positional encoding with odd dim (got dim={d_model})"
raise ValueError(
f"Cannot use sin/cos positional encoding with odd dim (got dim={d_model})"
msg,
)

# Create a tensor of shape (token_length, d_model)
Expand Down
36 changes: 20 additions & 16 deletions src/models/image_text_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pyright: reportIncompatibleMethodOverride=false
from pathlib import Path
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, Mapping

import torch
import wandb
Expand All @@ -9,6 +10,9 @@
from torchmetrics import Dice, JaccardIndex
from transformers import AutoTokenizer, PreTrainedTokenizerBase

if TYPE_CHECKING:
from pathlib import Path

BatchType = Mapping[str, Any]


Expand All @@ -23,7 +27,7 @@ def __init__(
loss_fn: nn.Module,
optimizer: type[optim.optimizer.Optimizer],
scheduler: type[optim.lr_scheduler.LRScheduler],
tokenizer_name_or_path: Union[str, Path],
tokenizer_name_or_path: str | Path,
compile: bool,
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand All @@ -47,7 +51,7 @@ def __init__(
self.scheduler = scheduler

self.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_name_or_path
pretrained_model_name_or_path=tokenizer_name_or_path,
)

# Dice Loggers
Expand All @@ -68,8 +72,8 @@ def forward(self, *args, **kwargs):
return self.net(*args, **kwargs)

def model_step(
self, batch: BatchType
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self, batch: BatchType,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform a single model step on a batch of data.
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
Expand Down Expand Up @@ -122,7 +126,7 @@ def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor:

# and the average across the epoch, to the progress bar and logger
self.log(
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True,
)

# return loss or backpropagation will fail
Expand All @@ -131,7 +135,7 @@ def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor:
def on_train_epoch_end(self) -> None:
# log epoch metric
self.log_dict(
{"train_dice_epoch": self.train_dice, "train_iou_epoch": self.train_iou}
{"train_dice_epoch": self.train_dice, "train_iou_epoch": self.train_iou},
)

def validation_step(self, batch: BatchType, batch_idx: int) -> None:
Expand Down Expand Up @@ -166,15 +170,15 @@ def validation_step(self, batch: BatchType, batch_idx: int) -> None:
text = batch["text"]

plot_input_ids = self.decode_input_ids(
text.input_ids[: self.hparams.log_image_num]
text.input_ids[: self.hparams.log_image_num],
)

plot_label = map(wandb.Image, targets[: self.hparams.log_image_num])

data = list(zip(plot_images, plot_input_ids, plot_label))

self.logger.log_table(
"val_caption_label", columns=self.plot_columns, data=data
"val_caption_label", columns=self.plot_columns, data=data,
)

# Stop logging now
Expand All @@ -183,16 +187,16 @@ def validation_step(self, batch: BatchType, batch_idx: int) -> None:
def decode_input_ids(
self,
input_ids: torch.Tensor,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
tokenizer: PreTrainedTokenizerBase | None = None,
):
tokenizer = tokenizer or self.tokenizer
return tokenizer.batch_decode(
input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True,
)

@staticmethod
def normalize_img(img: torch.Tensor):
kwargs = dict(dim=(2, 3), keepdim=True)
kwargs = {"dim": (2, 3), "keepdim": True}

img_min = img.amin(**kwargs)
img_max = img.amax(**kwargs)
Expand All @@ -216,7 +220,7 @@ def test_step(self, batch: BatchType) -> None:
prog_bar=True,
)

def setup(self, stage: Optional[str]) -> None:
def setup(self, stage: str | None) -> None:
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
test, or predict.
Expand All @@ -228,7 +232,7 @@ def setup(self, stage: Optional[str]) -> None:
if self.hparams.compile and (stage is None or stage == "fit"):
self.net = torch.compile(self.net)

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> dict[str, Any]:
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Expand Down Expand Up @@ -281,7 +285,7 @@ def configure_optimizers(self) -> Dict[str, Any]:
raise ValueError(msg)

# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
param_dict = dict(self.named_parameters())

extra_params = param_dict.keys() - (decay | no_decay)
if len(extra_params) == 0:
Expand Down
4 changes: 2 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import hydra
import rootutils
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from omegaconf import DictConfig

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
Expand Down Expand Up @@ -74,7 +74,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger
cfg.trainer, callbacks=callbacks, logger=logger,
)

object_dict = {
Expand Down
2 changes: 1 addition & 1 deletion src/utils/instantiators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List

import hydra
from omegaconf import DictConfig
from pytorch_lightning import Callback
from pytorch_lightning.loggers import Logger
from omegaconf import DictConfig

from src.utils import pylogger

Expand Down
2 changes: 1 addition & 1 deletion src/utils/pylogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self.rank_zero_only = rank_zero_only

def log(
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs,
) -> None:
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
of the process it's being logged from. If `'rank'` is provided, then the log will only
Expand Down
4 changes: 3 additions & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
return wrap


def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
def get_metric_value(
metric_dict: Dict[str, Any], metric_name: Optional[str],
) -> Optional[float]:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
Expand Down

0 comments on commit 591bca6

Please sign in to comment.