Skip to content

Commit

Permalink
add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Aug 29, 2021
1 parent dbfc3b3 commit bf4172b
Show file tree
Hide file tree
Showing 15 changed files with 476 additions and 6 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# ProjectTemplate-Python
# BasicSR Examples

[English](README.md) **|** [简体中文](README_CN.md)

In this repository, we give simple examples to illustrate how to use [`BasicSR`](https://github.com/xinntao/BasicSR) in your own project.



[English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/ProjectTemplate-Python) **|** [Gitee码云](https://gitee.com/xinntao/ProjectTemplate-Python)

## File Modification

Expand Down
65 changes: 63 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,67 @@
# ProjectTemplate-Python
# BasicSR Examples

[English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/ProjectTemplate-Python) **|** [Gitee码云](https://gitee.com/xinntao/ProjectTemplate-Python)
[English](README.md) **|** [简体中文](README_CN.md)

[`BasicSR`](https://github.com/xinntao/BasicSR) **|** [`simple example`](https://github.com/xinntao/BasicSR-examples/tree/master) **|** [installation example`](https://github.com/xinntao/BasicSR-examples/tree/installation)

使用 BasicSR 的项目:
:white_check_mark: [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN): Practical Algorithms for General Image Restoration
:white_check_mark: [GFPGAN](https://github.com/TencentARC/GFPGAN): Practical Algorithms for Real-world Face Restoration

如果你的开源项目中使用 BasicSR, 欢迎联系我,将你的开源添加到上面的列表中

---

在这个仓库中,我们提供简单的例子来说明,如何在你自己的项目中使用 [`BasicSR`](https://github.com/xinntao/BasicSR)。<br>
`BasicSR` 有两种使用方式:
:one: Git clone 整个 BasicSR 的代码。这样可以看到 BasicSR 完整的代码,然后根据你自己的需求进行修改。
:two: BasicSR作为一个 [python package](https://pypi.org/project/basicsr/)(即可以通过pip安装),提供了训练的框架,流程和一些基本功能。你可以基于 BasicSR 方便地搭建你自己的项目。

我们的样例主要针对第二种使用方式,即如何基于 basicsr 这个package来方便简洁地搭建你自己的项目。<br>

使用 basicsr 的python package可以有两种方式,我们分别提供在两个 branch 中:<br>
:one: [简单模式](https://github.com/xinntao/BasicSR-examples/tree/master): 项目的仓库不需要安装,就可以运行使用。但它有局限:不方便 import 复杂的层级关系;在其他位置也不容易访问本项目中的函数
:two: [安装模式](https://github.com/xinntao/BasicSR-examples/tree/installation): 项目的仓库需要安装`python setup.py develop`,安装之后 import和使用都更加方便

作为简单的入门和讲解, 我们使用简单模式的样例,但在实际使用中我们推荐安装模式。

## 预备

大部分的深度学习项目,都可以分为以下几个部分:

1. **data**: 定义了训练数据,来喂给模型的训练过程。
2. **arch** (architecture): 定义了网络结构 和 forward 的步骤。
3. **model**: 定义了在训练中必要的组件(比如 loss) 和 一次完整的训练过程(包括前向传播,反向传播,梯度优化等),还有其他功能,比如 validation等。
4. training pipeline: 定义了训练的流程,即把数据 dataloader,模型,validation,保存checkpoints 等等串联起来。

当我们开发一个新的方法时,我们往往在改进: **data**, **arch**, **model**;而很多流程、基础的功能其实是共用的。那么,我们希望可以专注于主要功能的开发,而不要重复造轮子。<br>
因此便有了 BasicSR,它把很多相似的功能都独立出来,我们只要关心 **data**, **arch**, **model**的开发即可。<br>
为了进一步方便大家使用,我们提供了 basicsr package,大家可以通过 `pip install basicsr` 方便地安装,然后就可以使用 BasicSR 的训练流程以及已经在BasicSR里面开发好的功能啦。

## 简单的例子

下面我们就通过一个简单的例子,来说明如何使用 BasicSR 来搭建你自己的项目。

### 目的

我们来假设一个超分辨率的任务,输入一个低分辨率的图片,输出一个有锐化效果的高分辨率的图片。<br>
在这个任务中,我们要做的是: 构建自己的data loader, architecture 和 model。下面我们分别来说明一下。

### data

### arch

### model

###

debug 模式





In this repository, we give simple examples to illustrate how to use [`BasicSR`](https://github.com/xinntao/BasicSR) in your own project.

## 文件修改

Expand Down
11 changes: 11 additions & 0 deletions archs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import importlib
from os import path as osp

from basicsr.utils import scandir

# automatically scan and import arch modules for registry
# scan all the files that end with '_arch.py' under the archs folder
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
52 changes: 52 additions & 0 deletions archs/example_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from torch import nn as nn
from torch.nn import functional as F

from basicsr.archs.arch_util import default_init_weights
from basicsr.utils.registry import ARCH_REGISTRY


@ARCH_REGISTRY.register()
class ExampleArch(nn.Module):
"""Example architecture.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
upscale (int): Upsampling factor. Default: 4.
"""

def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, upscale=4):
super(ExampleArch, self).__init__()
self.upscale = upscale

self.conv1 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)

self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)

self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

# initialization
default_init_weights(
[self.conv1, self.conv2, self.conv3, self.upconv1, self.upconv2, self.conv_hr, self.conv_last], 0.1)

def forward(self, x):
feat = self.lrelu(self.conv1(x))
feat = self.lrelu(self.conv2(feat))
feat = self.lrelu(self.conv3(feat))

out = self.lrelu(self.pixel_shuffle(self.upconv1(feat)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))

out = self.conv_last(self.lrelu(self.conv_hr(out)))
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
out += base
return out
11 changes: 11 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import importlib
from os import path as osp

from basicsr.utils import scandir

# automatically scan and import dataset modules for registry
# scan all the files that end with '_dataset.py' under the data folder
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
88 changes: 88 additions & 0 deletions data/example_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import cv2
import os
import torch
from torch.utils import data as data
from torchvision.transforms.functional import normalize

from basicsr.data.degradations import add_jpg_compression
from basicsr.data.transforms import augment, mod_crop, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
from basicsr.utils.registry import DATASET_REGISTRY


@DATASET_REGISTRY.register()
class ExampleDataset(data.Dataset):
"""Example dataset.
1. Read GT image
2. Generate LQ (Low Quality) image with cv2 bicubic downsampling and JPEG compression
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
io_backend (dict): IO backend type and other kwarg.
gt_size (int): Cropped patched size for gt patches.
use_flip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h
and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""

def __init__(self, opt):
super(ExampleDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None

self.gt_folder = opt['dataroot_gt']
# it now only supports folder mode, for other modes such as lmdb and meta_info file, please see:
# https://github.com/xinntao/BasicSR/blob/master/basicsr/data/
self.paths = [os.path.join(self.gt_folder, v) for v in list(scandir(self.gt_folder))]

def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

scale = self.opt['scale']

# Load gt images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
img_gt = mod_crop(img_gt, scale)

# generate lq image
# downsample
h, w = img_gt.shape[0:2]
img_lq = cv2.resize(img_gt, (w // scale, h // scale), interpolation=cv2.INTER_CUBIC)
# add JPEG compression
img_lq = add_jpg_compression(img_lq, quality=70)

# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)

img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.

# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)

return {'lq': img_lq, 'gt': img_gt, 'lq_path': gt_path, 'gt_path': gt_path}

def __len__(self):
return len(self.paths)
3 changes: 3 additions & 0 deletions datasets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Soft link your dataset here

`ln -s xxx ./`
Empty file added experiments/README.md
Empty file.
11 changes: 11 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import importlib
from os import path as osp

from basicsr.utils import scandir

# automatically scan and import model modules for registry
# scan all the files that end with '_model.py' under the model folder
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
86 changes: 86 additions & 0 deletions models/example_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections import OrderedDict

from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.models.sr_model import SRModel
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY


@MODEL_REGISTRY.register() # This line is necessary to register the model
class ExampleModel(SRModel):
"""Example model based on the SRModel class.
In this example model, we want to implement a new model that trains with both L1 and L2 loss.
New defined functions:
init_training_settings(self)
feed_data(self, data)
optimize_parameters(self, current_iter)
Inherited functions:
__init__(self, opt)
setup_optimizers(self)
test(self)
dist_validation(self, dataloader, current_iter, tb_logger, save_img)
nondist_validation(self, dataloader, current_iter, tb_logger, save_img)
_log_validation_metric_values(self, current_iter, dataset_name, tb_logger)
get_current_visuals(self)
save(self, epoch, current_iter)
"""

def init_training_settings(self):
self.net_g.train()
train_opt = self.opt['train']

self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()

# define losses
self.l1_pix = build_loss(train_opt['l1_opt']).to(self.device)
self.l2_pix = build_loss(train_opt['l2_opt']).to(self.device)

# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()

def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)

def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)

l_total = 0
loss_dict = OrderedDict()
# l1 loss
l_l1 = self.l1_pix(self.output, self.gt)
l_total += l_l1
loss_dict['l_l1'] = l_l1
# l2 loss
l_l2 = self.l2_pix(self.output, self.gt)
l_total += l_l2
loss_dict['l_l2'] = l_l2

l_total.backward()
self.optimizer_g.step()

self.log_dict = self.reduce_loss_dict(loss_dict)

if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
Loading

0 comments on commit bf4172b

Please sign in to comment.