diff --git a/README.md b/README.md index 97e1821..21fd924 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ The repository contains 2 parts: ├── dataset # Dataloaders ├── utils # Utility functions ├── train.py # Training script -├── evaluate.py # Evaluation script (TODO) +├── evaluate.py # Evaluation script ├── requirements.txt ``` @@ -231,6 +231,48 @@ register_model_from_yaml("my_marlin_model", "path/to/config.yaml") model = Marlin.from_file("my_marlin_model", "path/to/marlin.ckpt") ``` +## Evaluation + +
+CelebV-HQ + +#### 1. Download the dataset +Download dataset from [CelebV-HQ](https://github.com/CelebV-HQ/CelebV-HQ) and the file structure should be like this: +``` +├── CelebV-HQ +│ ├── downloaded +│ │ ├── ***.mp4 +│ │ ├── ... +│ ├── celebvhq_info.json +│ ├── ... +``` +#### 2. Preprocess the dataset +Crop the face region from the raw video and split the train val and test sets. +```bash +python preprocess/celebvhq_preprocess.py --data_dir /path/to/CelebV-HQ +``` + +#### 3. Extract MARLIN features (Optional, if linear probing) +Extract MARLIN features from the cropped video and saved to `` directory in `CelebV-HQ` directory. +```bash +python preprocess/celebvhq_extract.py --data_dir /path/to/CelebV-HQ --backbone marlin_vit_base_ytf +``` + +#### 4. Train and evaluate +Train and evaluate the model adapted from MARLIN to CelebV-HQ. + +Please use the configs in `config/celebv_hq/*/*.yaml` as the config file. +```bash +python evaluate.py \ + --config /path/to/config \ + --data_path /path/to/CelebV-HQ + --num_workers 4 + --batch_size 16 +``` + +
+ + ## License This project is under the CC BY-NC 4.0 license. See [LICENSE](LICENSE) for details. diff --git a/config/celebv_hq/action/celebvhq_marlin_action_ft.yaml b/config/celebv_hq/action/celebvhq_marlin_action_ft.yaml new file mode 100644 index 0000000..747783e --- /dev/null +++ b/config/celebv_hq/action/celebvhq_marlin_action_ft.yaml @@ -0,0 +1,8 @@ +model_name: "celebvhq_marlin_action_ft" +backbone: "marlin_vit_base_ytf" +dataset: "celebvhq" +task: "action" +temporal_reduction: "mean" +learning_rate: 1.0e-4 +seq_mean_pool: true +finetune: true \ No newline at end of file diff --git a/config/celebv_hq/action/celebvhq_marlin_action_lp.yaml b/config/celebv_hq/action/celebvhq_marlin_action_lp.yaml new file mode 100644 index 0000000..bc54bb5 --- /dev/null +++ b/config/celebv_hq/action/celebvhq_marlin_action_lp.yaml @@ -0,0 +1,8 @@ +model_name: "celebvhq_marlin_action_lp" +backbone: "marlin_vit_base_ytf" +dataset: "celebvhq" +task: "action" +temporal_reduction: "mean" +learning_rate: 1.0e-4 +seq_mean_pool: true +finetune: false \ No newline at end of file diff --git a/config/celebv_hq/appearance/celebvhq_marlin_appearance_ft.yaml b/config/celebv_hq/appearance/celebvhq_marlin_appearance_ft.yaml new file mode 100644 index 0000000..eb79942 --- /dev/null +++ b/config/celebv_hq/appearance/celebvhq_marlin_appearance_ft.yaml @@ -0,0 +1,8 @@ +model_name: "celebvhq_marlin_appearance_ft" +backbone: "marlin_vit_base_ytf" +dataset: "celebvhq" +task: "appearance" +temporal_reduction: "mean" +learning_rate: 1.0e-4 +seq_mean_pool: true +finetune: true \ No newline at end of file diff --git a/config/celebv_hq/appearance/celebvhq_marlin_appearance_lp.yaml b/config/celebv_hq/appearance/celebvhq_marlin_appearance_lp.yaml new file mode 100644 index 0000000..c54e08e --- /dev/null +++ b/config/celebv_hq/appearance/celebvhq_marlin_appearance_lp.yaml @@ -0,0 +1,8 @@ +model_name: "celebvhq_marlin_appearance_lp" +backbone: "marlin_vit_base_ytf" +dataset: "celebvhq" +task: "appearance" +temporal_reduction: "mean" +learning_rate: 1.0e-4 +seq_mean_pool: true +finetune: false \ No newline at end of file diff --git a/dataset/celebv_hq.py b/dataset/celebv_hq.py new file mode 100644 index 0000000..f5fe40a --- /dev/null +++ b/dataset/celebv_hq.py @@ -0,0 +1,215 @@ +import os +from abc import ABC, abstractmethod +from itertools import islice +from typing import Optional + +import ffmpeg +import numpy as np +import torch +import torchvision +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader + +from marlin_pytorch.util import read_video, padding_video +from util.misc import sample_indexes, read_text, read_json + + +class CelebvHqBase(LightningDataModule, ABC): + + def __init__(self, data_root: str, split: str, task: str, data_ratio: float = 1.0, take_num: int = None): + super().__init__() + self.data_root = data_root + self.split = split + assert task in ("appearance", "action") + self.task = task + self.take_num = take_num + + self.name_list = list( + filter(lambda x: x != "", read_text(os.path.join(data_root, f"{self.split}.txt")).split("\n"))) + self.metadata = read_json(os.path.join(data_root, "celebvhq_info.json")) + + if data_ratio < 1.0: + self.name_list = self.name_list[:int(len(self.name_list) * data_ratio)] + if take_num is not None: + self.name_list = self.name_list[:self.take_num] + + print(f"Dataset {self.split} has {len(self.name_list)} videos") + + @abstractmethod + def __getitem__(self, index: int): + pass + + def __len__(self): + return len(self.name_list) + + +# for fine-tuning +class CelebvHq(CelebvHqBase): + + def __init__(self, + root_dir: str, + split: str, + task: str, + clip_frames: int, + temporal_sample_rate: int, + data_ratio: float = 1.0, + take_num: Optional[int] = None + ): + super().__init__(root_dir, split, task, data_ratio, take_num) + self.clip_frames = clip_frames + self.temporal_sample_rate = temporal_sample_rate + + def __getitem__(self, index: int): + y = self.metadata["clips"][self.name_list[index]]["attributes"][self.task] + video_path = os.path.join(self.data_root, "cropped", self.name_list[index] + ".mp4") + + probe = ffmpeg.probe(video_path)["streams"][0] + n_frames = int(probe["nb_frames"]) + + if n_frames <= self.clip_frames: + video = read_video(video_path, channel_first=True).video / 255 + # pad frames to 16 + video = padding_video(video, self.clip_frames, "same") # (T, C, H, W) + video = video.permute(1, 0, 2, 3) # (C, T, H, W) + return video, torch.tensor(y, dtype=torch.long) + elif n_frames <= self.clip_frames * self.temporal_sample_rate: + # reset a lower temporal sample rate + sample_rate = n_frames // self.clip_frames + else: + sample_rate = self.temporal_sample_rate + # sample frames + video_indexes = sample_indexes(n_frames, self.clip_frames, sample_rate) + reader = torchvision.io.VideoReader(video_path) + fps = reader.get_metadata()["video"]["fps"][0] + reader.seek(video_indexes[0].item() / fps, True) + frames = [] + for frame in islice(reader, 0, self.clip_frames * sample_rate, sample_rate): + frames.append(frame["data"]) + video = torch.stack(frames) / 255 # (T, C, H, W) + video = video.permute(1, 0, 2, 3) # (C, T, H, W) + assert video.shape[1] == self.clip_frames, video_path + return video, torch.tensor(y, dtype=torch.long).bool() + + +# For linear probing +class CelebvHqFeatures(CelebvHqBase): + + def __init__(self, root_dir: str, + feature_dir: str, + split: str, + task: str, + temporal_reduction: str, + data_ratio: float = 1.0, + take_num: Optional[int] = None + ): + super().__init__(root_dir, split, task, data_ratio, take_num) + self.feature_dir = feature_dir + self.temporal_reduction = temporal_reduction + + def __getitem__(self, index: int): + feat_path = os.path.join(self.data_root, self.feature_dir, self.name_list[index] + ".npy") + + x = torch.from_numpy(np.load(feat_path)).float() + + if x.size(0) == 0: + x = torch.zeros(1, 768, dtype=torch.float32) + + if self.temporal_reduction == "mean": + x = x.mean(dim=0) + elif self.temporal_reduction == "max": + x = x.max(dim=0)[0] + elif self.temporal_reduction == "min": + x = x.min(dim=0)[0] + else: + raise ValueError(self.temporal_reduction) + + y = self.metadata["clips"][self.name_list[index]]["attributes"][self.task] + + return x, torch.tensor(y, dtype=torch.long).bool() + + +class CelebvHqDataModule(LightningDataModule): + + def __init__(self, root_dir: str, + load_raw: bool, + task: str, + batch_size: int, + num_workers: int = 0, + clip_frames: int = None, + temporal_sample_rate: int = None, + feature_dir: str = None, + temporal_reduction: str = "mean", + data_ratio: float = 1.0, + take_train: Optional[int] = None, + take_val: Optional[int] = None, + take_test: Optional[int] = None + ): + super().__init__() + self.root_dir = root_dir + self.task = task + self.batch_size = batch_size + self.num_workers = num_workers + self.clip_frames = clip_frames + self.temporal_sample_rate = temporal_sample_rate + self.feature_dir = feature_dir + self.temporal_reduction = temporal_reduction + self.load_raw = load_raw + self.data_ratio = data_ratio + self.take_train = take_train + self.take_val = take_val + self.take_test = take_test + + if load_raw: + assert clip_frames is not None + assert temporal_sample_rate is not None + else: + assert feature_dir is not None + assert temporal_reduction is not None + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + def setup(self, stage: Optional[str] = None): + if self.load_raw: + self.train_dataset = CelebvHq(self.root_dir, "train", self.task, self.clip_frames, + self.temporal_sample_rate, self.data_ratio, self.take_train) + self.val_dataset = CelebvHq(self.root_dir, "val", self.task, self.clip_frames, + self.temporal_sample_rate, self.data_ratio, self.take_val) + self.test_dataset = CelebvHq(self.root_dir, "test", self.task, self.clip_frames, + self.temporal_sample_rate, 1.0, self.take_test) + else: + self.train_dataset = CelebvHqFeatures(self.root_dir, self.feature_dir, "train", self.task, + self.temporal_reduction, self.data_ratio, self.take_train) + self.val_dataset = CelebvHqFeatures(self.root_dir, self.feature_dir, "val", self.task, + self.temporal_reduction, self.data_ratio, self.take_val) + self.test_dataset = CelebvHqFeatures(self.root_dir, self.feature_dir, "test", self.task, + self.temporal_reduction, 1.0, self.take_test) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + drop_last=True + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True + ) diff --git a/evaluate.py b/evaluate.py index e69de29..6d41120 100644 --- a/evaluate.py +++ b/evaluate.py @@ -0,0 +1,153 @@ +import argparse + +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from tqdm.auto import tqdm + +from dataset.celebv_hq import CelebvHqDataModule +from marlin_pytorch.config import resolve_config +from marlin_pytorch.util import read_yaml +from model.classifier import Classifier +from util.earlystop_lr import EarlyStoppingLR +from util.lr_logger import LrLogger +from util.seed import Seed +from util.system_stats_logger import SystemStatsLogger + + +def train_celebvhq(args, config): + data_path = args.data_path + resume_ckpt = args.resume + n_gpus = args.n_gpus + max_epochs = args.epochs + + finetune = config["finetune"] + learning_rate = config["learning_rate"] + task = config["task"] + + if task == "appearance": + num_classes = 40 + elif task == "action": + num_classes = 35 + else: + raise ValueError(f"Unknown task {task}") + + if finetune: + backbone_config = resolve_config(config["backbone"]) + + model = Classifier( + num_classes, config["backbone"], True, args.marlin_ckpt, "multilabel", config["learning_rate"], + args.n_gpus > 1, + ) + + dm = CelebvHqDataModule( + data_path, finetune, task, + batch_size=args.batch_size, + num_workers=args.num_workers, + clip_frames=backbone_config.n_frames, + temporal_sample_rate=2 + ) + + else: + model = Classifier( + num_classes, config["backbone"], False, + None, "multilabel", config["learning_rate"], args.n_gpus > 1, + ) + + dm = CelebvHqDataModule( + data_path, finetune, task, + batch_size=args.batch_size, + num_workers=args.num_workers, + feature_dir=config["backbone"], + temporal_reduction=config["temporal_reduction"] + ) + + if args.skip_train: + dm.setup() + return resume_ckpt, dm + + strategy = None if n_gpus <= 1 else "ddp" + accelerator = "cpu" if n_gpus == 0 else "gpu" + + ckpt_filename = config["model_name"] + "-{epoch}-{val_auc:.3f}" + ckpt_monitor = "val_auc" + + try: + precision = int(args.precision) + except ValueError: + precision = args.precision + + ckpt_callback = ModelCheckpoint(dirpath=f"ckpt/{config['model_name']}", save_last=True, + filename=ckpt_filename, + monitor=ckpt_monitor, + mode="max") + + trainer = Trainer(log_every_n_steps=1, devices=n_gpus, accelerator=accelerator, benchmark=True, + logger=True, precision=precision, max_epochs=max_epochs, + strategy=strategy, resume_from_checkpoint=resume_ckpt, + callbacks=[ckpt_callback, LrLogger(), EarlyStoppingLR(1e-6), SystemStatsLogger()]) + + trainer.fit(model, dm) + + return ckpt_callback.best_model_path, dm + + +def evaluate_celebvhq(args, ckpt, dm): + print("Load checkpoint", ckpt) + model = Classifier.load_from_checkpoint(ckpt) + accelerator = "cpu" if args.n_gpus == 0 else "gpu" + trainer = Trainer(log_every_n_steps=1, devices=1 if args.n_gpus > 0 else 0, accelerator=accelerator, benchmark=True, + logger=False, enable_checkpointing=False) + Seed.set(42) + model.eval() + + # collect predictions + preds = trainer.predict(model, dm.test_dataloader()) + preds = torch.cat(preds) + + # collect ground truth + ys = torch.zeros_like(preds, dtype=torch.bool) + for i, (_, y) in enumerate(tqdm(dm.test_dataloader())): + ys[i * args.batch_size: (i + 1) * args.batch_size] = y + + preds = preds.sigmoid() + acc = ((preds > 0.5) == ys).float().mean() + auc = model.auc_fn(preds, ys) + results = { + "acc": acc, + "auc": auc + } + print(results) + + +def evaluate(args): + config = read_yaml(args.config) + dataset_name = config["dataset"] + + if dataset_name == "celebvhq": + ckpt, dm = train_celebvhq(args, config) + evaluate_celebvhq(args, ckpt, dm) + else: + raise NotImplementedError(f"Dataset {dataset_name} not implemented") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("CelebV-HQ evaluation") + parser.add_argument("--config", type=str, help="Path to CelebV-HQ evaluation config file.") + parser.add_argument("--data_path", type=str, help="Path to CelebV-HQ dataset.") + parser.add_argument("--marlin_ckpt", type=str, default=None, + help="Path to MARLIN checkpoint. Default: None, load from online.") + parser.add_argument("--n_gpus", type=int, default=1) + parser.add_argument("--precision", type=str, default="32") + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--epochs", type=int, default=2000, help="Max epochs to train.") + parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training.") + parser.add_argument("--skip_train", action="store_true", default=False, + help="Skip training and evaluate only.") + + args = parser.parse_args() + if args.skip_train: + assert args.resume is not None + + evaluate(args) diff --git a/model/classifier.py b/model/classifier.py new file mode 100644 index 0000000..20c36cc --- /dev/null +++ b/model/classifier.py @@ -0,0 +1,101 @@ +from typing import Optional, Union, Sequence, Dict, Literal, Any + +from pytorch_lightning import LightningModule +from torch import Tensor +from torch.nn import CrossEntropyLoss, Linear, Identity, BCEWithLogitsLoss +from torch.optim import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torchmetrics import Accuracy, AUROC + +from marlin_pytorch import Marlin +from marlin_pytorch.config import resolve_config + + +class Classifier(LightningModule): + + def __init__(self, num_classes: int, backbone: str, finetune: bool, + marlin_ckpt: Optional[str] = None, + task: Literal["binary", "multiclass", "multilabel"] = "binary", + learning_rate: float = 1e-4, distributed: bool = False + ): + super().__init__() + self.save_hyperparameters() + + if finetune: + if marlin_ckpt is None: + self.model = Marlin.from_online(backbone).encoder + else: + self.model = Marlin.from_file(backbone, marlin_ckpt).encoder + else: + self.model = None + + config = resolve_config(backbone) + + self.fc = Linear(config.encoder_embed_dim, num_classes) + self.learning_rate = learning_rate + self.distributed = distributed + self.task = task + if task in "binary": + self.loss_fn = BCEWithLogitsLoss() + self.acc_fn = Accuracy(task=task, num_classes=1) + self.auc_fn = AUROC(task=task, num_classes=1) + elif task == "multiclass": + self.loss_fn = CrossEntropyLoss() + self.acc_fn = Accuracy(task=task, num_classes=num_classes) + self.auc_fn = AUROC(task=task, num_classes=num_classes) + elif task == "multilabel": + self.loss_fn = BCEWithLogitsLoss() + self.acc_fn = Accuracy(task="binary", num_classes=1) + self.auc_fn = AUROC(task="binary", num_classes=1) + + @classmethod + def from_module(cls, model, learning_rate: float = 1e-4, distributed=False): + return cls(model, learning_rate, distributed) + + def forward(self, x): + if self.model is not None: + feat = self.model.extract_features(x, True) + else: + feat = x + return self.fc(feat) + + def step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]]) -> Dict[str, Tensor]: + x, y = batch + y_hat = self(x) + if self.task == "multilabel": + y_hat = y_hat.flatten() + y = y.flatten() + loss = self.loss_fn(y_hat, y.float()) + prob = y_hat.sigmoid() + acc = self.acc_fn(prob, y) + auc = self.auc_fn(prob, y) + return {"loss": loss, "acc": acc, "auc": auc} + + def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, + optimizer_idx: Optional[int] = None, hiddens: Optional[Tensor] = None + ) -> Dict[str, Tensor]: + loss_dict = self.step(batch) + self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, + prog_bar=False, sync_dist=self.distributed) + return loss_dict["loss"] + + def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, + dataloader_idx: Optional[int] = None + ) -> Dict[str, Tensor]: + loss_dict = self.step(batch) + self.log_dict({f"val_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, + prog_bar=True, sync_dist=self.distributed) + return loss_dict["loss"] + + def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: + return self(batch[0]) + + def configure_optimizers(self): + optimizer = Adam(self.parameters(), lr=self.learning_rate, betas=(0.5, 0.9)) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": ReduceLROnPlateau(optimizer, factor=0.5, patience=7, verbose=True, min_lr=1e-8), + "monitor": "train_loss" + } + } diff --git a/model/marlin.py b/model/marlin.py index 036689a..e3e2eaf 100644 --- a/model/marlin.py +++ b/model/marlin.py @@ -257,8 +257,8 @@ def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = Non **{k: v for k, v in d_result.items() if k != "loss"}, **{k: v for k, v in g_result.items() if k != "loss"}, } - self.log_dict({f"val_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, - prog_bar=False, sync_dist=self.distributed) + self.log_dict({f"val_{k}": v for k, v in loss_dict.items()}, on_step=False, on_epoch=True, + prog_bar=True, sync_dist=self.distributed) return loss_dict["loss"] def _log_sample_reconstruction_image(self, batch): diff --git a/preprocess/celebvhq_extract.py b/preprocess/celebvhq_extract.py new file mode 100644 index 0000000..88c97ef --- /dev/null +++ b/preprocess/celebvhq_extract.py @@ -0,0 +1,43 @@ +import argparse +import os +import sys +from pathlib import Path + +import numpy as np +import torch +from tqdm.auto import tqdm + +from marlin_pytorch import Marlin +from marlin_pytorch.config import resolve_config + +sys.path.append(".") + +if __name__ == '__main__': + parser = argparse.ArgumentParser("CelebV-HQ Feature Extraction") + parser.add_argument("--backbone", type=str) + parser.add_argument("--data_dir", type=str) + args = parser.parse_args() + + model = Marlin.from_online(args.backbone) + config = resolve_config(args.backbone) + feat_dir = args.backbone + + model.cuda() + model.eval() + + raw_video_path = os.path.join(args.data_dir, "cropped") + all_videos = sorted(list(filter(lambda x: x.endswith(".mp4"), os.listdir(raw_video_path)))) + Path(os.path.join(args.data_dir, feat_dir)).mkdir(parents=True, exist_ok=True) + for video_name in tqdm(all_videos): + video_path = os.path.join(raw_video_path, video_name) + save_path = os.path.join(args.data_dir, feat_dir, video_name.replace(".mp4", ".npy")) + try: + feat = model.extract_video( + video_path, crop_face=False, + sample_rate=config.tubelet_size, stride=config.n_frames, + keep_seq=False, reduction="none") + + except Exception as e: + print(f"Video {video_path} error.", e) + feat = torch.zeros(0, model.encoder.embed_dim, dtype=torch.float32) + np.save(save_path, feat.cpu().numpy()) diff --git a/preprocess/celebvhq_preprocess.py b/preprocess/celebvhq_preprocess.py new file mode 100644 index 0000000..f5a437c --- /dev/null +++ b/preprocess/celebvhq_preprocess.py @@ -0,0 +1,44 @@ +# parsing labels, segment and crop raw videos. +import argparse +import os +import sys + +sys.path.append(os.getcwd()) + + +def crop_face(root: str): + from util.face_sdk.face_crop import process_videos + source_dir = os.path.join(root, "downloaded") + target_dir = os.path.join(root, "cropped") + process_videos(source_dir, target_dir, ext="mp4") + + +def gen_split(root: str): + videos = list(filter(lambda x: x.endswith('.mp4'), os.listdir(os.path.join(root, 'cropped')))) + total_num = len(videos) + + with open(os.path.join(root, "train.txt"), "w") as f: + for i in range(int(total_num * 0.8)): + f.write(videos[i][:-4] + "\n") + + with open(os.path.join(root, "val.txt"), "w") as f: + for i in range(int(total_num * 0.8), int(total_num * 0.9)): + f.write(videos[i][:-4] + "\n") + + with open(os.path.join(root, "test.txt"), "w") as f: + for i in range(int(total_num * 0.9), total_num): + f.write(videos[i][:-4] + "\n") + + +parser = argparse.ArgumentParser() +parser.add_argument("--data_dir", help="Root directory of CelebV-HQ") +args = parser.parse_args() + +if __name__ == '__main__': + data_root = args.data_dir + crop_face(data_root) + + if not os.path.exists(os.path.join(data_root, "train.txt")) or \ + not os.path.exists(os.path.join(data_root, "val.txt")) or \ + not os.path.exists(os.path.join(data_root, "test.txt")): + gen_split(data_root) diff --git a/requirements.txt b/requirements.txt index b44d567..646d547 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,13 +2,14 @@ opencv-python>=4.6 numpy>=1.23 einops~=0.4 torchvision>=0.12.0 -torch==1.11.0 +torch pyyaml>=6.0 tqdm>=4.64.0 scikit-image>=0.19.3 matplotlib>=3.5.2 pillow>=9.2.0 pandas~=1.4.3 -marlin_pytorch==0.2.1 +marlin_pytorch==0.3.4 pytorch_lightning==1.7.* ffmpeg-python>=0.2.0 +torchmetrics == 0.11.* diff --git a/util/earlystop_lr.py b/util/earlystop_lr.py new file mode 100644 index 0000000..dbb6060 --- /dev/null +++ b/util/earlystop_lr.py @@ -0,0 +1,36 @@ +from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.callbacks import Callback +import re + + +class EarlyStoppingLR(Callback): + + def __init__(self, lr_threshold: float, mode="all"): + self.lr_threshold = lr_threshold + + if mode in ("any", "all"): + self.mode = mode + else: + raise ValueError(f"mode must be one of ('any', 'all')") + + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._run_early_stop_checking(trainer) + + def _run_early_stop_checking(self, trainer: Trainer) -> None: + metrics = trainer._logger_connector.callback_metrics + if len(metrics) == 0: + return + all_lr = [] + for key, value in metrics.items(): + if re.match(r"opt\d+_lr\d+", key): + all_lr.append(value) + + if len(all_lr) == 0: + return + + if self.mode == "all": + if all(lr <= self.lr_threshold for lr in all_lr): + trainer.should_stop = True + elif self.mode == "any": + if any(lr <= self.lr_threshold for lr in all_lr): + trainer.should_stop = True diff --git a/util/lr_logger.py b/util/lr_logger.py new file mode 100644 index 0000000..9673cb0 --- /dev/null +++ b/util/lr_logger.py @@ -0,0 +1,13 @@ +from pytorch_lightning import Callback, Trainer, LightningModule + + +class LrLogger(Callback): + """Log learning rate in each epoch start.""" + + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + for i, optimizer in enumerate(trainer.optimizers): + for j, params in enumerate(optimizer.param_groups): + key = f"opt{i}_lr{j}" + value = params["lr"] + pl_module.logger.log_metrics({key: value}, step=trainer.global_step) + pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed) diff --git a/util/misc.py b/util/misc.py index 4e2a827..702e4ab 100644 --- a/util/misc.py +++ b/util/misc.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import re import torch @@ -44,3 +45,14 @@ def sample_indexes(total_frames: int, n_frames: int, temporal_sample_rate: int) print(f"total_frames: {total_frames}, n_frames: {n_frames}, temporal_sample_rate: {temporal_sample_rate}") raise e return torch.arange(n_frames) * temporal_sample_rate + start_ind + + +def read_text(path: str, encoding: str = "UTF-8") -> str: + with open(path, "r", encoding=encoding) as file: + text = file.read() + return text + + +def read_json(path: str): + with open(path, "r") as file: + return json.load(file) diff --git a/util/seed.py b/util/seed.py new file mode 100644 index 0000000..2c53528 --- /dev/null +++ b/util/seed.py @@ -0,0 +1,53 @@ +import random +from typing import Callable + +import numpy as np +import torch +from torch import Generator + + +class Seed: + seed: int = None + + @classmethod + def torch(cls, seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @classmethod + def python(cls, seed: int) -> None: + random.seed(seed) + + @classmethod + def numpy(cls, seed: int) -> None: + np.random.seed(seed) + + @classmethod + def set(cls, seed: int, use_deterministic_algorithms: bool = False) -> None: + cls.torch(seed) + cls.python(seed) + cls.numpy(seed) + cls.seed = seed + torch.use_deterministic_algorithms(use_deterministic_algorithms) + + @classmethod + def _is_set(cls) -> bool: + return cls.seed is not None + + @classmethod + def get_loader_worker_init(cls) -> Callable[[int], None]: + def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2 ** 32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + if cls._is_set(): + return seed_worker + else: + return lambda x: None + + @classmethod + def get_torch_generator(cls, device="cpu") -> Generator: + g = torch.Generator(device) + g.manual_seed(cls.seed) + return g diff --git a/util/system_stats_logger.py b/util/system_stats_logger.py new file mode 100644 index 0000000..91bbd43 --- /dev/null +++ b/util/system_stats_logger.py @@ -0,0 +1,22 @@ +from pytorch_lightning import Callback, Trainer, LightningModule + + +class SystemStatsLogger(Callback): + """Log system stats for each training epoch""" + + def __init__(self): + try: + import psutil + except ImportError: + raise ImportError("psutil is required to use SystemStatsLogger") + self.psutil = psutil + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + cpu_usage = self.psutil.cpu_percent() + memory_usage = self.psutil.virtual_memory().percent + logged_info = { + "cpu_usage": cpu_usage, + "memory_usage": memory_usage + } + pl_module.logger.log_metrics(logged_info, step=trainer.global_step) + pl_module.log_dict(logged_info, logger=False, sync_dist=pl_module.distributed)