Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the LPIPS metrics #473

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include basicsr/ops/dcn/src/*.cu basicsr/ops/dcn/src/*.cpp
include basicsr/ops/fused_act/src/*.cu basicsr/ops/fused_act/src/*.cpp
include basicsr/ops/upfirdn2d/src/*.cu basicsr/ops/upfirdn2d/src/*.cpp
include basicsr/metrics/niqe_pris_params.npz
include VERSION
include requirements.txt
39 changes: 22 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,21 @@ Other recommended projects:<br>
We provide simple pipelines to train/test/inference models for a quick start.
These pipelines/commands cannot cover all the cases and more details are in the following sections.

| GAN | | | | | |
| :--- | :---: | :---: | :--- | :---: | :---: |
| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | |
| **Face Restoration** | | | | | |
| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | |
| **Super Resolution** | | | | | |
| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*|
| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*|
| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr)|
| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
| **Deblurring** | | | | | |
| DeblurGANv2 | - | *TODO* | | | |
| **Denoise** | | | | | |
| RIDNet | - | *TODO* | CBDNet | - | *TODO*|
| GAN | | | | | |
| :------------------- | :--------------------------------------------: | :----------------------------------------------------: | :------- | :--------------------------------------------: | :----------------------------------------------------: |
| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | |
| **Face Restoration** | | | | | |
| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | |
| **Super Resolution** | | | | | |
| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO* |
| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO* |
| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr) |
| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
| **Deblurring** | | | | | |
| DeblurGANv2 | - | *TODO* | | | |
| **Denoise** | | | | | |
| RIDNet | - | *TODO* | CBDNet | - | *TODO* |

## :wrench: Dependencies and Installation

Expand Down Expand Up @@ -114,7 +114,7 @@ Please see [project boards](https://github.com/xinntao/BasicSR/projects).

Please see [DesignConvention.md](docs/DesignConvention.md) for the designs and conventions of the BasicSR codebase.<br>
The figure below shows the overall framework. More descriptions for each component: <br>
**[Datasets.md](docs/Datasets.md)**&emsp;|&emsp;**[Models.md](docs/Models.md)**&emsp;|&emsp;**[Config.md](Config.md)**&emsp;|&emsp;**[Logging.md](docs/Logging.md)**
**[Datasets.md](docs/Datasets.md)**&emsp;|&emsp;**[Models.md](docs/Models.md)**&emsp;|&emsp;**[Config.md](docs/Config.md)**&emsp;|&emsp;**[Logging.md](docs/Logging.md)**

![overall_structure](./assets/overall_structure.png)

Expand Down Expand Up @@ -144,7 +144,12 @@ The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX p

If you have any questions, please email `[email protected]`.

<br>

- **QQ群**: 扫描左边二维码 或者 搜索QQ群号: 320960100   入群答案:互帮互助共同进步
- **微信群**: 因为微信群超过200人,需要邀请才可以进群;要进微信群的小伙伴可以先添加 Liangbin 的个人微信 (右边二维码),他会在空闲的时候拉大家入群~

<p align="center">
<img src="https://user-images.githubusercontent.com/17445847/134879983-6f2d663b-16e7-49f2-97e1-7c53c8a5f71a.jpg" height="300"> &emsp; &emsp;
<img src="https://user-images.githubusercontent.com/17445847/135756881-51b73150-40ff-4eaa-8a16-2e98ecbfa457.png" height="300">
<img src="https://user-images.githubusercontent.com/17445847/139572512-8e192aac-00fa-432b-ac8e-a33026b019df.png" height="300">
</p>
39 changes: 22 additions & 17 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,21 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源

我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分.

| GAN | | | | | |
| :--- | :---: | :---: | :--- | :---: | :---: |
| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | |
| **Face Restoration** | | | | | |
| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | |
| **Super Resolution** | | | | | |
| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*|
| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*|
| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr)|
| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
| **Deblurring** | | | | | |
| DeblurGANv2 | - | *TODO* | | | |
| **Denoise** | | | | | |
| RIDNet | - | *TODO* | CBDNet | - | *TODO*|
| GAN | | | | | |
| :------------------- | :------------------------------------------: | :------------------------------------------: | :------- | :--------------------------------------------: | :----------------------------------------------------: |
| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | |
| **Face Restoration** | | | | | |
| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | |
| **Super Resolution** | | | | | |
| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO* |
| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO* |
| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr) |
| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
| **Deblurring** | | | | | |
| DeblurGANv2 | - | *TODO* | | | |
| **Denoise** | | | | | |
| RIDNet | - | *TODO* | CBDNet | - | *TODO* |

## :wrench: 依赖和安装

Expand Down Expand Up @@ -112,7 +112,7 @@ For detailed instructions refer to [INSTALL.md](INSTALL.md).

参见 [DesignConvention_CN.md](docs/DesignConvention_CN.md).<br>
下图概括了整体的框架. 每个模块更多的描述参见: <br>
**[Datasets_CN.md](docs/Datasets_CN.md)**&emsp;|&emsp;**[Models_CN.md](docs/Models_CN.md)**&emsp;|&emsp;**[Config_CN.md](Config_CN.md)**&emsp;|&emsp;**[Logging_CN.md](docs/Logging_CN.md)**
**[Datasets_CN.md](docs/Datasets_CN.md)**&emsp;|&emsp;**[Models_CN.md](docs/Models_CN.md)**&emsp;|&emsp;**[Config_CN.md](docs/Config_CN.md)**&emsp;|&emsp;**[Logging_CN.md](docs/Logging_CN.md)**

![overall_structure](./assets/overall_structure.png)

Expand Down Expand Up @@ -142,7 +142,12 @@ For detailed instructions refer to [INSTALL.md](INSTALL.md).

若有任何问题, 请电邮 `[email protected]`.

<br>

- **QQ群**: 扫描左边二维码 或者 搜索QQ群号: 320960100   入群答案:互帮互助共同进步
- **微信群**: 因为微信群超过200人,需要邀请才可以进群;要进微信群的小伙伴可以先添加 Liangbin 的个人微信 (右边二维码),他会在空闲的时候拉大家入群~

<p align="center">
<img src="https://user-images.githubusercontent.com/17445847/134879983-6f2d663b-16e7-49f2-97e1-7c53c8a5f71a.jpg" height="300"> &emsp; &emsp;
<img src="https://user-images.githubusercontent.com/17445847/135756881-51b73150-40ff-4eaa-8a16-2e98ecbfa457.png" height="300">
<img src="https://user-images.githubusercontent.com/17445847/139572512-8e192aac-00fa-432b-ac8e-a33026b019df.png" height="300">
</p>
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.3.4.4
1.3.4.7
60 changes: 58 additions & 2 deletions basicsr/archs/discriminator_arch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn as nn

from basicsr.utils.registry import ARCH_REGISTRY

from torch.nn import functional as F
from torch.nn.utils import spectral_norm

@ARCH_REGISTRY.register()
class VGGStyleDiscriminator128(nn.Module):
Expand Down Expand Up @@ -147,3 +147,59 @@ def forward(self, x):
feat = self.lrelu(self.linear1(feat))
out = self.linear2(feat)
return out


@ARCH_REGISTRY.register()
class UNetDiscriminatorSN(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN)"""

def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminatorSN, self).__init__()
self.skip_connection = skip_connection
norm = spectral_norm

self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)

self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
# upsample
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))

# extra
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))

self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)

def forward(self, x):
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)

# upsample
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)

if self.skip_connection:
x4 = x4 + x2
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)

if self.skip_connection:
x5 = x5 + x1
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)

if self.skip_connection:
x6 = x6 + x0

# extra
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)

return out
8 changes: 7 additions & 1 deletion basicsr/archs/ecbsr_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ class ECBSR(nn.Module):

def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
super(ECBSR, self).__init__()
self.num_in_ch = num_in_ch
self.scale = scale

backbone = []
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
Expand All @@ -240,6 +242,10 @@ def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_
self.upsampler = nn.PixelShuffle(scale)

def forward(self, x):
y = self.backbone(x) + x # will repeat the input in the channel dimension (repeat scale * scale times)
if self.num_in_ch > 1:
shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
else:
shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
y = self.backbone(x) + shortcut
y = self.upsampler(y)
return y
Loading