-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add real-esrgan-general-x4v3 for new real-esrgan model real-esrgan-ge…
…neral-x4v3
- Loading branch information
xuejiejie
committed
Aug 19, 2024
1 parent
a6ef9e3
commit b6468d4
Showing
12 changed files
with
1,491 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
cmake_minimum_required(VERSION 3.16) | ||
project(real-esrgan) | ||
|
||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/") | ||
|
||
add_definitions(-std=c++17) | ||
add_definitions(-DAPI_EXPORTS) | ||
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) | ||
#set(CMAKE_CXX_STANDARD 11) | ||
set(CMAKE_BUILD_TYPE Debug) | ||
|
||
#find_package(CUDA REQUIRED) | ||
|
||
INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/src/include) | ||
|
||
# cuda | ||
FIND_PACKAGE(CUDA REQUIRED) | ||
#INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS}) | ||
include_directories(/usr/local/cuda/include) | ||
link_directories(/usr/local/cuda/lib64) | ||
|
||
# <------------------------TensorRT Related-------------------------> | ||
include_directories(YOUR_TENSORRT_INCLUDE_DIR) # TensorRT-8.6.1.6/include | ||
link_directories(YOUR_TENSORRT_LIB_DIR) # TensorRT-8.6.1.6/lib | ||
|
||
# <------------------------OpenCV Related-------------------------> | ||
# opencv | ||
FIND_PACKAGE(OpenCV REQUIRED) | ||
INCLUDE_DIRECTORIES(${OpenCV_INCLUDE_DIRS}) | ||
|
||
set(CMAKE_CXX_STANDARD 17) | ||
|
||
add_executable(${PROJECT_NAME} main.cpp) | ||
|
||
cuda_add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/src/pixel_shuffle/pixel_shuffle.cu) | ||
target_link_libraries(myplugins nvinfer cudart) | ||
|
||
|
||
TARGET_LINK_LIBRARIES(${PROJECT_NAME} nvinfer) | ||
TARGET_LINK_LIBRARIES(${PROJECT_NAME} cudart) | ||
TARGET_LINK_LIBRARIES(${PROJECT_NAME} ${OpenCV_LIBS}) | ||
TARGET_LINK_LIBRARIES(${PROJECT_NAME} myplugins) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Real-ESRGAN realesr-general-x4v3 model | ||
|
||
## How to Run | ||
0. Replace YOUR_TENSORRT_INCLUDE_DIR and YOUR_TENSORRT_LIB_DIR in CMakeLists.txt with your TensorRT include and lib directories. | ||
1. generate .wts from pytorch with .pt | ||
``` | ||
git clone https://github.com/xinntao/Real-ESRGAN.git | ||
cd Real-ESRGAN | ||
# Install basicsr - https://github.com/xinntao/BasicSR | ||
# We use BasicSR for both training and inference | ||
pip install basicsr | ||
# facexlib and gfpgan are for face enhancement | ||
pip install facexlib | ||
pip install gfpgan | ||
pip install -r requirements.txt | ||
python setup.py develop | ||
``` | ||
download realesr-general-x4v3.pth (and realesr-general-wdn-x4v3.pth if needed) from | ||
https://github.com/xinntao/Real-ESRGAN/releases | ||
|
||
``` | ||
cp {tensorrtx}/real-esrgan-general-x4v3/gen_wts.py {xinntao}/Real-ESRGAN | ||
cd {xinntao}/Real-ESRGAN | ||
python gen_wts.py | ||
// a file 'real-esrgan.wts' will be generated. | ||
``` | ||
|
||
**Be aware that if you need both realesr-general-x4v3.pth and realesr-general-wdn-x4v3.pth, please write a Python script to average all weights of realesr-general-x4v3.pth and realesr-general-wdn-x4v3.pth (from {xinntao}/Real-ESRGAN), then save it as a .pth file, and use this new file to generate a .wts file.** | ||
|
||
2. build tensorrtx/real-esrgan-general-x4v3 and run | ||
|
||
``` | ||
cd {tensorrtx}/real-esrgan-general-x4v3/ | ||
mkdir build | ||
cd build | ||
cp {xinntao}/Real-ESRGAN/real-esrgan.wts {tensorrtx}/real-esrgan/weights/ | ||
cmake .. | ||
make | ||
./real-esrgan your_images_dir | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# source: | ||
# https://github.com/NVIDIA/tensorrt-laboratory/blob/master/cmake/FindTensorRT.cmake | ||
|
||
# This module defines the following variables: | ||
# | ||
# :: | ||
# | ||
# TensorRT_INCLUDE_DIRS | ||
# TensorRT_LIBRARIES | ||
# TensorRT_FOUND | ||
# | ||
# :: | ||
# | ||
# TensorRT_VERSION_STRING - version (x.y.z) | ||
# TensorRT_VERSION_MAJOR - major version (x) | ||
# TensorRT_VERSION_MINOR - minor version (y) | ||
# TensorRT_VERSION_PATCH - patch version (z) | ||
# | ||
# Hints | ||
# ^^^^^ | ||
# A user may set ``TensorRT_DIR`` to an installation root to tell this module where to look. | ||
# | ||
set(_TensorRT_SEARCHES) | ||
|
||
if(TensorRT_DIR) | ||
set(_TensorRT_SEARCH_ROOT PATHS ${TensorRT_DIR} NO_DEFAULT_PATH) | ||
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_ROOT) | ||
endif() | ||
|
||
# appends some common paths | ||
set(_TensorRT_SEARCH_NORMAL | ||
PATHS "/usr" | ||
) | ||
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_NORMAL) | ||
|
||
# Include dir | ||
foreach(search ${_TensorRT_SEARCHES}) | ||
find_path(TensorRT_INCLUDE_DIR NAMES NvInfer.h ${${search}} PATH_SUFFIXES include) | ||
endforeach() | ||
|
||
if(NOT TensorRT_LIBRARY) | ||
foreach(search ${_TensorRT_SEARCHES}) | ||
find_library(TensorRT_LIBRARY NAMES nvinfer ${${search}} PATH_SUFFIXES lib) | ||
endforeach() | ||
endif() | ||
|
||
if(NOT TensorRT_PARSERS_LIBRARY) | ||
foreach(search ${_TensorRT_SEARCHES}) | ||
find_library(TensorRT_NVPARSERS_LIBRARY NAMES nvparsers ${${search}} PATH_SUFFIXES lib) | ||
endforeach() | ||
endif() | ||
|
||
if(NOT TensorRT_NVONNXPARSER_LIBRARY) | ||
foreach(search ${_TensorRT_SEARCHES}) | ||
find_library(TensorRT_NVONNXPARSER_LIBRARY NAMES nvonnxparser ${${search}} PATH_SUFFIXES lib) | ||
endforeach() | ||
endif() | ||
|
||
mark_as_advanced(TensorRT_INCLUDE_DIR) | ||
|
||
if(TensorRT_INCLUDE_DIR AND EXISTS "${TensorRT_INCLUDE_DIR}/NvInfer.h") | ||
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MAJOR REGEX "^#define NV_TENSORRT_MAJOR [0-9]+.*$") | ||
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MINOR REGEX "^#define NV_TENSORRT_MINOR [0-9]+.*$") | ||
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_PATCH REGEX "^#define NV_TENSORRT_PATCH [0-9]+.*$") | ||
|
||
string(REGEX REPLACE "^#define NV_TENSORRT_MAJOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MAJOR "${TensorRT_MAJOR}") | ||
string(REGEX REPLACE "^#define NV_TENSORRT_MINOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MINOR "${TensorRT_MINOR}") | ||
string(REGEX REPLACE "^#define NV_TENSORRT_PATCH ([0-9]+).*$" "\\1" TensorRT_VERSION_PATCH "${TensorRT_PATCH}") | ||
set(TensorRT_VERSION_STRING "${TensorRT_VERSION_MAJOR}.${TensorRT_VERSION_MINOR}.${TensorRT_VERSION_PATCH}") | ||
endif() | ||
|
||
include(FindPackageHandleStandardArgs) | ||
FIND_PACKAGE_HANDLE_STANDARD_ARGS(TensorRT REQUIRED_VARS TensorRT_LIBRARY TensorRT_INCLUDE_DIR VERSION_VAR TensorRT_VERSION_STRING) | ||
|
||
if(TensorRT_FOUND) | ||
set(TensorRT_INCLUDE_DIRS ${TensorRT_INCLUDE_DIR}) | ||
|
||
if(NOT TensorRT_LIBRARIES) | ||
set(TensorRT_LIBRARIES ${TensorRT_LIBRARY} ${TensorRT_NVONNXPARSER_LIBRARY} ${TensorRT_NVPARSERS_LIBRARY}) | ||
endif() | ||
|
||
if(NOT TARGET TensorRT::TensorRT) | ||
add_library(TensorRT::TensorRT UNKNOWN IMPORTED) | ||
set_target_properties(TensorRT::TensorRT PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}") | ||
set_property(TARGET TensorRT::TensorRT APPEND PROPERTY IMPORTED_LOCATION "${TensorRT_LIBRARY}") | ||
endif() | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import argparse | ||
import os | ||
import struct | ||
from realesrgan import RealESRGANer | ||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact | ||
|
||
from basicsr.archs.rrdbnet_arch import RRDBNet | ||
from basicsr.utils.download_util import load_file_from_url | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-i', '--input', type=str, help='Input image or folder') | ||
parser.add_argument( | ||
'-n', | ||
'--model_name', | ||
type=str, | ||
default='realesr-general-x4v3', | ||
help=('RealESRGAN_x2plus Model names: ' | ||
'realesr-animevideov3 | realesr-general-x4v3')) | ||
parser.add_argument('-o', '--output', type=str, help='Output folder') | ||
parser.add_argument( | ||
'-dn', | ||
'--denoise_strength', | ||
type=float, | ||
default=0.5, | ||
help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. ' | ||
'Only used for the realesr-general-x4v3 model')) | ||
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image') | ||
parser.add_argument( | ||
'--model_path', type=str, default=None, help='[Option] Model path. Usually, you do not need to specify it') | ||
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image') | ||
parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') | ||
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') | ||
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border') | ||
parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face') | ||
parser.add_argument( | ||
'--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).') | ||
parser.add_argument( | ||
'--alpha_upsampler', | ||
type=str, | ||
default='realesrgan', | ||
help='The upsampler for the alpha channels. Options: realesrgan | bicubic') | ||
parser.add_argument( | ||
'--ext', | ||
type=str, | ||
default='auto', | ||
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') | ||
parser.add_argument( | ||
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu') | ||
|
||
args = parser.parse_args() | ||
|
||
# determine models according to model names | ||
args.model_name = args.model_name.split('.')[0] | ||
if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model | ||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | ||
netscale = 4 | ||
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] | ||
elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model | ||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | ||
netscale = 4 | ||
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] | ||
elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks | ||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) | ||
netscale = 4 | ||
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] | ||
elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model | ||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | ||
netscale = 2 | ||
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] | ||
elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size) | ||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') | ||
netscale = 4 | ||
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'] | ||
elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size) | ||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | ||
netscale = 4 | ||
file_url = [ | ||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth', | ||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' | ||
] | ||
|
||
# determine model paths | ||
if args.model_path is not None: | ||
model_path = args.model_path | ||
else: | ||
model_path = os.path.join('weights', args.model_name + '.pth') | ||
if not os.path.isfile(model_path): | ||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
for url in file_url: | ||
# model_path will be updated | ||
model_path = load_file_from_url( | ||
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) | ||
|
||
# use dni to control the denoise strength | ||
dni_weight = None | ||
if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1: | ||
# wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') | ||
# model_path = [model_path, wdn_model_path] | ||
# dni_weight = [args.denoise_strength, 1 - args.denoise_strength] | ||
model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-x4v3-cat') | ||
dni_weight = None | ||
|
||
# restorer | ||
upsampler = RealESRGANer( | ||
scale=netscale, | ||
model_path=model_path, | ||
dni_weight=dni_weight, | ||
model=model, | ||
tile=args.tile, | ||
tile_pad=args.tile_pad, | ||
pre_pad=args.pre_pad, | ||
half=not args.fp32, | ||
gpu_id=args.gpu_id) | ||
|
||
if os.path.isfile('real-esrgan.wts'): | ||
print('Already, real-esrgan.wts file exists.') | ||
else: | ||
print('making real-esrgan.wts file ...') | ||
f = open("real-esrgan.wts", 'w') | ||
f.write("{}\n".format(len(upsampler.model.state_dict().keys()))) | ||
for k, v in upsampler.model.state_dict().items(): | ||
print('key: ', k) | ||
print('value: ', v.shape) | ||
vr = v.reshape(-1).cpu().numpy() | ||
f.write("{} {}".format(k, len(vr))) | ||
for vv in vr: | ||
f.write(" ") | ||
f.write(struct.pack(">f", float(vv)).hex()) | ||
f.write("\n") | ||
print('Completed real-esrgan.wts file!') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.