-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest.py
31 lines (25 loc) · 1.02 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Conv2D
from tensorflow.keras.layers import add
from tensorflow.keras.callbacks import TensorBoard
import numpy as np
from scipy.misc import imread, imsave, imresize
from pathlib import Path
from model import SRRAM
from dataset import Dataset
import utils
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, required=True)
parser.add_argument('--filename', type=str, required=True)
flags = parser.parse_args()
model = load_model(str(Path(flags.model_dir) / 'model.h5'))
filename = flags.filename
scale_factor = int(Path(flags.model_dir).stem.split('_')[2].replace('srf', ''))
img = imread(filename)
img = imresize(img, (img.shape[0] // scale_factor, img.shape[1] // scale_factor), interp='bicubic')
out = np.squeeze(model.predict(img[None, :, :, :]), axis=0)
imsave(Path(filename).stem + '.bmp', out)