Skip to content

Commit

Permalink
add real-esrgan-general-x4v3 for new real-esrgan model real-esrgan-ge…
Browse files Browse the repository at this point in the history
…neral-x4v3
  • Loading branch information
xuejiejie committed Aug 19, 2024
1 parent a6ef9e3 commit b6468d4
Show file tree
Hide file tree
Showing 12 changed files with 1,491 additions and 0 deletions.
42 changes: 42 additions & 0 deletions real-esrgan-general-x4v3/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.16)
project(real-esrgan)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/")

add_definitions(-std=c++17)
add_definitions(-DAPI_EXPORTS)
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
#set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Debug)

#find_package(CUDA REQUIRED)

INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/src/include)

# cuda
FIND_PACKAGE(CUDA REQUIRED)
#INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)

# <------------------------TensorRT Related------------------------->
include_directories(YOUR_TENSORRT_INCLUDE_DIR) # TensorRT-8.6.1.6/include
link_directories(YOUR_TENSORRT_LIB_DIR) # TensorRT-8.6.1.6/lib

# <------------------------OpenCV Related------------------------->
# opencv
FIND_PACKAGE(OpenCV REQUIRED)
INCLUDE_DIRECTORIES(${OpenCV_INCLUDE_DIRS})

set(CMAKE_CXX_STANDARD 17)

add_executable(${PROJECT_NAME} main.cpp)

cuda_add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/src/pixel_shuffle/pixel_shuffle.cu)
target_link_libraries(myplugins nvinfer cudart)


TARGET_LINK_LIBRARIES(${PROJECT_NAME} nvinfer)
TARGET_LINK_LIBRARIES(${PROJECT_NAME} cudart)
TARGET_LINK_LIBRARIES(${PROJECT_NAME} ${OpenCV_LIBS})
TARGET_LINK_LIBRARIES(${PROJECT_NAME} myplugins)
41 changes: 41 additions & 0 deletions real-esrgan-general-x4v3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Real-ESRGAN realesr-general-x4v3 model

## How to Run
0. Replace YOUR_TENSORRT_INCLUDE_DIR and YOUR_TENSORRT_LIB_DIR in CMakeLists.txt with your TensorRT include and lib directories.
1. generate .wts from pytorch with .pt
```
git clone https://github.com/xinntao/Real-ESRGAN.git
cd Real-ESRGAN
# Install basicsr - https://github.com/xinntao/BasicSR
# We use BasicSR for both training and inference
pip install basicsr
# facexlib and gfpgan are for face enhancement
pip install facexlib
pip install gfpgan
pip install -r requirements.txt
python setup.py develop
```
download realesr-general-x4v3.pth (and realesr-general-wdn-x4v3.pth if needed) from
https://github.com/xinntao/Real-ESRGAN/releases

```
cp {tensorrtx}/real-esrgan-general-x4v3/gen_wts.py {xinntao}/Real-ESRGAN
cd {xinntao}/Real-ESRGAN
python gen_wts.py
// a file 'real-esrgan.wts' will be generated.
```

**Be aware that if you need both realesr-general-x4v3.pth and realesr-general-wdn-x4v3.pth, please write a Python script to average all weights of realesr-general-x4v3.pth and realesr-general-wdn-x4v3.pth (from {xinntao}/Real-ESRGAN), then save it as a .pth file, and use this new file to generate a .wts file.**

2. build tensorrtx/real-esrgan-general-x4v3 and run

```
cd {tensorrtx}/real-esrgan-general-x4v3/
mkdir build
cd build
cp {xinntao}/Real-ESRGAN/real-esrgan.wts {tensorrtx}/real-esrgan/weights/
cmake ..
make
./real-esrgan your_images_dir
```
87 changes: 87 additions & 0 deletions real-esrgan-general-x4v3/cmake/FindTensorRT.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# source:
# https://github.com/NVIDIA/tensorrt-laboratory/blob/master/cmake/FindTensorRT.cmake

# This module defines the following variables:
#
# ::
#
# TensorRT_INCLUDE_DIRS
# TensorRT_LIBRARIES
# TensorRT_FOUND
#
# ::
#
# TensorRT_VERSION_STRING - version (x.y.z)
# TensorRT_VERSION_MAJOR - major version (x)
# TensorRT_VERSION_MINOR - minor version (y)
# TensorRT_VERSION_PATCH - patch version (z)
#
# Hints
# ^^^^^
# A user may set ``TensorRT_DIR`` to an installation root to tell this module where to look.
#
set(_TensorRT_SEARCHES)

if(TensorRT_DIR)
set(_TensorRT_SEARCH_ROOT PATHS ${TensorRT_DIR} NO_DEFAULT_PATH)
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_ROOT)
endif()

# appends some common paths
set(_TensorRT_SEARCH_NORMAL
PATHS "/usr"
)
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_NORMAL)

# Include dir
foreach(search ${_TensorRT_SEARCHES})
find_path(TensorRT_INCLUDE_DIR NAMES NvInfer.h ${${search}} PATH_SUFFIXES include)
endforeach()

if(NOT TensorRT_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_LIBRARY NAMES nvinfer ${${search}} PATH_SUFFIXES lib)
endforeach()
endif()

if(NOT TensorRT_PARSERS_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_NVPARSERS_LIBRARY NAMES nvparsers ${${search}} PATH_SUFFIXES lib)
endforeach()
endif()

if(NOT TensorRT_NVONNXPARSER_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_NVONNXPARSER_LIBRARY NAMES nvonnxparser ${${search}} PATH_SUFFIXES lib)
endforeach()
endif()

mark_as_advanced(TensorRT_INCLUDE_DIR)

if(TensorRT_INCLUDE_DIR AND EXISTS "${TensorRT_INCLUDE_DIR}/NvInfer.h")
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MAJOR REGEX "^#define NV_TENSORRT_MAJOR [0-9]+.*$")
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MINOR REGEX "^#define NV_TENSORRT_MINOR [0-9]+.*$")
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_PATCH REGEX "^#define NV_TENSORRT_PATCH [0-9]+.*$")

string(REGEX REPLACE "^#define NV_TENSORRT_MAJOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MAJOR "${TensorRT_MAJOR}")
string(REGEX REPLACE "^#define NV_TENSORRT_MINOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MINOR "${TensorRT_MINOR}")
string(REGEX REPLACE "^#define NV_TENSORRT_PATCH ([0-9]+).*$" "\\1" TensorRT_VERSION_PATCH "${TensorRT_PATCH}")
set(TensorRT_VERSION_STRING "${TensorRT_VERSION_MAJOR}.${TensorRT_VERSION_MINOR}.${TensorRT_VERSION_PATCH}")
endif()

include(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(TensorRT REQUIRED_VARS TensorRT_LIBRARY TensorRT_INCLUDE_DIR VERSION_VAR TensorRT_VERSION_STRING)

if(TensorRT_FOUND)
set(TensorRT_INCLUDE_DIRS ${TensorRT_INCLUDE_DIR})

if(NOT TensorRT_LIBRARIES)
set(TensorRT_LIBRARIES ${TensorRT_LIBRARY} ${TensorRT_NVONNXPARSER_LIBRARY} ${TensorRT_NVPARSERS_LIBRARY})
endif()

if(NOT TARGET TensorRT::TensorRT)
add_library(TensorRT::TensorRT UNKNOWN IMPORTED)
set_target_properties(TensorRT::TensorRT PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}")
set_property(TARGET TensorRT::TensorRT APPEND PROPERTY IMPORTED_LOCATION "${TensorRT_LIBRARY}")
endif()
endif()
136 changes: 136 additions & 0 deletions real-esrgan-general-x4v3/gen_wts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import argparse
import os
import struct
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url


def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, help='Input image or folder')
parser.add_argument(
'-n',
'--model_name',
type=str,
default='realesr-general-x4v3',
help=('RealESRGAN_x2plus Model names: '
'realesr-animevideov3 | realesr-general-x4v3'))
parser.add_argument('-o', '--output', type=str, help='Output folder')
parser.add_argument(
'-dn',
'--denoise_strength',
type=float,
default=0.5,
help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
'Only used for the realesr-general-x4v3 model'))
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
parser.add_argument(
'--model_path', type=str, default=None, help='[Option] Model path. Usually, you do not need to specify it')
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
parser.add_argument(
'--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
parser.add_argument(
'--alpha_upsampler',
type=str,
default='realesrgan',
help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
parser.add_argument(
'--ext',
type=str,
default='auto',
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
parser.add_argument(
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')

args = parser.parse_args()

# determine models according to model names
args.model_name = args.model_name.split('.')[0]
if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
]

# determine model paths
if args.model_path is not None:
model_path = args.model_path
else:
model_path = os.path.join('weights', args.model_name + '.pth')
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
# model_path will be updated
model_path = load_file_from_url(
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)

# use dni to control the denoise strength
dni_weight = None
if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
# wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
# model_path = [model_path, wdn_model_path]
# dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-x4v3-cat')
dni_weight = None

# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=not args.fp32,
gpu_id=args.gpu_id)

if os.path.isfile('real-esrgan.wts'):
print('Already, real-esrgan.wts file exists.')
else:
print('making real-esrgan.wts file ...')
f = open("real-esrgan.wts", 'w')
f.write("{}\n".format(len(upsampler.model.state_dict().keys())))
for k, v in upsampler.model.state_dict().items():
print('key: ', k)
print('value: ', v.shape)
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k, len(vr)))
for vv in vr:
f.write(" ")
f.write(struct.pack(">f", float(vv)).hex())
f.write("\n")
print('Completed real-esrgan.wts file!')


if __name__ == '__main__':
main()
Loading

0 comments on commit b6468d4

Please sign in to comment.