generated from xinntao/ProjectTemplate-Python
-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
476 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import importlib | ||
from os import path as osp | ||
|
||
from basicsr.utils import scandir | ||
|
||
# automatically scan and import arch modules for registry | ||
# scan all the files that end with '_arch.py' under the archs folder | ||
arch_folder = osp.dirname(osp.abspath(__file__)) | ||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] | ||
# import all the arch modules | ||
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
from basicsr.archs.arch_util import default_init_weights | ||
from basicsr.utils.registry import ARCH_REGISTRY | ||
|
||
|
||
@ARCH_REGISTRY.register() | ||
class ExampleArch(nn.Module): | ||
"""Example architecture. | ||
Args: | ||
num_in_ch (int): Channel number of inputs. Default: 3. | ||
num_out_ch (int): Channel number of outputs. Default: 3. | ||
num_feat (int): Channel number of intermediate features. Default: 64. | ||
upscale (int): Upsampling factor. Default: 4. | ||
""" | ||
|
||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, upscale=4): | ||
super(ExampleArch, self).__init__() | ||
self.upscale = upscale | ||
|
||
self.conv1 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) | ||
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | ||
self.conv3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | ||
|
||
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) | ||
self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) | ||
self.pixel_shuffle = nn.PixelShuffle(2) | ||
|
||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | ||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) | ||
|
||
# activation function | ||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) | ||
|
||
# initialization | ||
default_init_weights( | ||
[self.conv1, self.conv2, self.conv3, self.upconv1, self.upconv2, self.conv_hr, self.conv_last], 0.1) | ||
|
||
def forward(self, x): | ||
feat = self.lrelu(self.conv1(x)) | ||
feat = self.lrelu(self.conv2(feat)) | ||
feat = self.lrelu(self.conv3(feat)) | ||
|
||
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat))) | ||
out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) | ||
|
||
out = self.conv_last(self.lrelu(self.conv_hr(out))) | ||
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) | ||
out += base | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import importlib | ||
from os import path as osp | ||
|
||
from basicsr.utils import scandir | ||
|
||
# automatically scan and import dataset modules for registry | ||
# scan all the files that end with '_dataset.py' under the data folder | ||
data_folder = osp.dirname(osp.abspath(__file__)) | ||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] | ||
# import all the dataset modules | ||
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import cv2 | ||
import os | ||
import torch | ||
from torch.utils import data as data | ||
from torchvision.transforms.functional import normalize | ||
|
||
from basicsr.data.degradations import add_jpg_compression | ||
from basicsr.data.transforms import augment, mod_crop, paired_random_crop | ||
from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir | ||
from basicsr.utils.registry import DATASET_REGISTRY | ||
|
||
|
||
@DATASET_REGISTRY.register() | ||
class ExampleDataset(data.Dataset): | ||
"""Example dataset. | ||
1. Read GT image | ||
2. Generate LQ (Low Quality) image with cv2 bicubic downsampling and JPEG compression | ||
Args: | ||
opt (dict): Config for train datasets. It contains the following keys: | ||
dataroot_gt (str): Data root path for gt. | ||
io_backend (dict): IO backend type and other kwarg. | ||
gt_size (int): Cropped patched size for gt patches. | ||
use_flip (bool): Use horizontal flips. | ||
use_rot (bool): Use rotation (use vertical flip and transposing h | ||
and w for implementation). | ||
scale (bool): Scale, which will be added automatically. | ||
phase (str): 'train' or 'val'. | ||
""" | ||
|
||
def __init__(self, opt): | ||
super(ExampleDataset, self).__init__() | ||
self.opt = opt | ||
# file client (io backend) | ||
self.file_client = None | ||
self.io_backend_opt = opt['io_backend'] | ||
self.mean = opt['mean'] if 'mean' in opt else None | ||
self.std = opt['std'] if 'std' in opt else None | ||
|
||
self.gt_folder = opt['dataroot_gt'] | ||
# it now only supports folder mode, for other modes such as lmdb and meta_info file, please see: | ||
# https://github.com/xinntao/BasicSR/blob/master/basicsr/data/ | ||
self.paths = [os.path.join(self.gt_folder, v) for v in list(scandir(self.gt_folder))] | ||
|
||
def __getitem__(self, index): | ||
if self.file_client is None: | ||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) | ||
|
||
scale = self.opt['scale'] | ||
|
||
# Load gt images. Dimension order: HWC; channel order: BGR; | ||
# image range: [0, 1], float32. | ||
gt_path = self.paths[index] | ||
img_bytes = self.file_client.get(gt_path, 'gt') | ||
img_gt = imfrombytes(img_bytes, float32=True) | ||
img_gt = mod_crop(img_gt, scale) | ||
|
||
# generate lq image | ||
# downsample | ||
h, w = img_gt.shape[0:2] | ||
img_lq = cv2.resize(img_gt, (w // scale, h // scale), interpolation=cv2.INTER_CUBIC) | ||
# add JPEG compression | ||
img_lq = add_jpg_compression(img_lq, quality=70) | ||
|
||
# augmentation for training | ||
if self.opt['phase'] == 'train': | ||
gt_size = self.opt['gt_size'] | ||
# random crop | ||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) | ||
# flip, rotation | ||
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) | ||
|
||
# BGR to RGB, HWC to CHW, numpy to tensor | ||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) | ||
|
||
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. | ||
|
||
# normalize | ||
if self.mean is not None or self.std is not None: | ||
normalize(img_lq, self.mean, self.std, inplace=True) | ||
normalize(img_gt, self.mean, self.std, inplace=True) | ||
|
||
return {'lq': img_lq, 'gt': img_gt, 'lq_path': gt_path, 'gt_path': gt_path} | ||
|
||
def __len__(self): | ||
return len(self.paths) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Soft link your dataset here | ||
|
||
`ln -s xxx ./` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import importlib | ||
from os import path as osp | ||
|
||
from basicsr.utils import scandir | ||
|
||
# automatically scan and import model modules for registry | ||
# scan all the files that end with '_model.py' under the model folder | ||
model_folder = osp.dirname(osp.abspath(__file__)) | ||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] | ||
# import all the model modules | ||
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from collections import OrderedDict | ||
|
||
from basicsr.archs import build_network | ||
from basicsr.losses import build_loss | ||
from basicsr.models.sr_model import SRModel | ||
from basicsr.utils import get_root_logger | ||
from basicsr.utils.registry import MODEL_REGISTRY | ||
|
||
|
||
@MODEL_REGISTRY.register() # This line is necessary to register the model | ||
class ExampleModel(SRModel): | ||
"""Example model based on the SRModel class. | ||
In this example model, we want to implement a new model that trains with both L1 and L2 loss. | ||
New defined functions: | ||
init_training_settings(self) | ||
feed_data(self, data) | ||
optimize_parameters(self, current_iter) | ||
Inherited functions: | ||
__init__(self, opt) | ||
setup_optimizers(self) | ||
test(self) | ||
dist_validation(self, dataloader, current_iter, tb_logger, save_img) | ||
nondist_validation(self, dataloader, current_iter, tb_logger, save_img) | ||
_log_validation_metric_values(self, current_iter, dataset_name, tb_logger) | ||
get_current_visuals(self) | ||
save(self, epoch, current_iter) | ||
""" | ||
|
||
def init_training_settings(self): | ||
self.net_g.train() | ||
train_opt = self.opt['train'] | ||
|
||
self.ema_decay = train_opt.get('ema_decay', 0) | ||
if self.ema_decay > 0: | ||
logger = get_root_logger() | ||
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') | ||
# define network net_g with Exponential Moving Average (EMA) | ||
# net_g_ema is used only for testing on one GPU and saving | ||
# There is no need to wrap with DistributedDataParallel | ||
self.net_g_ema = build_network(self.opt['network_g']).to(self.device) | ||
# load pretrained model | ||
load_path = self.opt['path'].get('pretrain_network_g', None) | ||
if load_path is not None: | ||
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') | ||
else: | ||
self.model_ema(0) # copy net_g weight | ||
self.net_g_ema.eval() | ||
|
||
# define losses | ||
self.l1_pix = build_loss(train_opt['l1_opt']).to(self.device) | ||
self.l2_pix = build_loss(train_opt['l2_opt']).to(self.device) | ||
|
||
# set up optimizers and schedulers | ||
self.setup_optimizers() | ||
self.setup_schedulers() | ||
|
||
def feed_data(self, data): | ||
self.lq = data['lq'].to(self.device) | ||
if 'gt' in data: | ||
self.gt = data['gt'].to(self.device) | ||
|
||
def optimize_parameters(self, current_iter): | ||
self.optimizer_g.zero_grad() | ||
self.output = self.net_g(self.lq) | ||
|
||
l_total = 0 | ||
loss_dict = OrderedDict() | ||
# l1 loss | ||
l_l1 = self.l1_pix(self.output, self.gt) | ||
l_total += l_l1 | ||
loss_dict['l_l1'] = l_l1 | ||
# l2 loss | ||
l_l2 = self.l2_pix(self.output, self.gt) | ||
l_total += l_l2 | ||
loss_dict['l_l2'] = l_l2 | ||
|
||
l_total.backward() | ||
self.optimizer_g.step() | ||
|
||
self.log_dict = self.reduce_loss_dict(loss_dict) | ||
|
||
if self.ema_decay > 0: | ||
self.model_ema(decay=self.ema_decay) |
Oops, something went wrong.