Skip to content

Commit 98192bb

Browse files
committed
Update testing scripts and readme.
1 parent 56753b0 commit 98192bb

File tree

4 files changed

+106
-179
lines changed

4 files changed

+106
-179
lines changed

README.md

Lines changed: 19 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
1-
## Replacing Mobile Camera ISP with a Single Deep Learning Model
2-
3-
<br/>
1+
## PyNet-V2 Mobile: Efficient On-Device Photo Processing With Neural Networks
42

53
<img src="http://people.ee.ethz.ch/~ihnatova/assets/img/pynet/pynet_teaser.jpg"/>
64

7-
<br/>
85

9-
#### 1. Overview [[Paper]](https://arxiv.org/pdf/2002.05509.pdf) [[PyTorch Implementation]](https://github.com/aiff22/PyNET-PyTorch) [[Project Webpage]](http://people.ee.ethz.ch/~ihnatova/pynet.html)
6+
#### 1. Overview [[Paper (in progress)]]() [[Project Webpage (in progress)]]()
107

11-
This repository provides the implementation of the RAW-to-RGB mapping approach and PyNET CNN presented in [this paper](https://arxiv.org/). The model is trained to convert **RAW Bayer data** obtained directly from mobile camera sensor into photos captured with a professional Canon 5D DSLR camera, thus replacing the entire hand-crafted ISP camera pipeline. The provided pre-trained PyNET model can be used to generate full-resolution **12MP photos** from RAW (DNG) image files captured using the Sony Exmor IMX380 camera sensor. More visual results of this approach for the Huawei P20 and BlackBerry KeyOne smartphones can be found [here](http://people.ee.ethz.ch/~ihnatova/pynet.html#demo).
8+
This repository provides the implementation of further improvement of the PyNet model originally presented in [this paper](https://arxiv.org/abs/2002.05509).
129

13-
<br/>
1410

1511
#### 2. Prerequisites
1612

1713
- Python: scipy, numpy, imageio and pillow packages
1814
- [TensorFlow 1.X](https://www.tensorflow.org/install/) + [CUDA cuDNN](https://developer.nvidia.com/cudnn)
1915
- Nvidia GPU
2016

21-
<br/>
2217

2318
#### 3. First steps
2419

@@ -29,91 +24,36 @@ This repository provides the implementation of the RAW-to-RGB mapping approach a
2924

3025
<sub>*Please note that Google Drive has a quota limiting the number of downloads per day. To avoid it, you can login to your Google account and press "Add to My Drive" button instead of a direct download. Please check [this issue](https://github.com/aiff22/PyNET/issues/4) for more information.* </sub>
3126

32-
<br/>
33-
34-
35-
#### 4. PyNET CNN
36-
37-
<br/>
3827

39-
<img src="http://people.ee.ethz.ch/~ihnatova/assets/img/pynet/pynet_architecture_github.png" alt="drawing" width="1000"/>
4028

41-
<br/>
29+
#### 4. PyNet-V2 Mobile CNN
4230

43-
PyNET architecture has an inverted pyramidal shape and is processing the images at **five different scales** (levels). The model is trained sequentially, starting from the lowest 5th layer, which allows to achieve good reconstruction results at smaller image resolutions. After the bottom layer is pre-trained, the same procedure is applied to the next level till the training is done on the original resolution. Since each higher level is getting **upscaled high-quality features** from the lower part of the model, it mainly learns to reconstruct the missing low-level details and refines the results. In this work, we are additionally using one transposed convolutional layer (Level 0) on top of the model that upsamples the image to its target size.
31+
[WIP]
4432

45-
<br/>
4633

4734
#### 5. Training the model
4835

49-
The model is trained level by level, starting from the lowest (5th) one:
50-
51-
```bash
52-
python train_model.py level=<level>
53-
```
54-
55-
Obligatory parameters:
56-
57-
>```level```: **```5, 4, 3, 2, 1, 0```**
58-
59-
Optional parameters and their default values:
60-
61-
>```batch_size```: **```50```** &nbsp; - &nbsp; batch size [small values can lead to unstable training] <br/>
62-
>```train_size```: **```30000```** &nbsp; - &nbsp; the number of training patches randomly loaded each 1000 iterations <br/>
63-
>```eval_step```: **```1000```** &nbsp; - &nbsp; each ```eval_step``` iterations the accuracy is computed and the model is saved <br/>
64-
>```learning_rate```: **```5e-5```** &nbsp; - &nbsp; learning rate <br/>
65-
>```restore_iter```: **```None```** &nbsp; - &nbsp; iteration to restore (when not specified, the last saved model for PyNET's ```level+1``` is loaded)<br/>
66-
>```num_train_iters```: **```5K, 5K, 20K, 20K, 35K, 100K (for levels 5 - 0)```** &nbsp; - &nbsp; the number of training iterations <br/>
67-
>```vgg_dir```: **```vgg_pretrained/imagenet-vgg-verydeep-19.mat```** &nbsp; - &nbsp; path to the pre-trained VGG-19 network <br/>
68-
>```dataset_dir```: **```raw_images/```** &nbsp; - &nbsp; path to the folder with **Zurich RAW to RGB dataset** <br/>
69-
70-
</br>
71-
72-
Below we provide the commands used for training the model on the Nvidia Tesla V100 GPU with 16GB of RAM. When using GPUs with smaller amount of memory, the batch size and the number of training iterations should be adjusted accordingly:
36+
The model is trained level by level, starting from the lowest. The script below incorporates all training steps:
7337

7438
```bash
75-
python train_model.py level=5 batch_size=50 num_train_iters=5000
76-
python train_model.py level=4 batch_size=50 num_train_iters=5000
77-
python train_model.py level=3 batch_size=48 num_train_iters=20000
78-
python train_model.py level=2 batch_size=18 num_train_iters=20000
79-
python train_model.py level=1 batch_size=12 num_train_iters=35000
80-
python train_model.py level=0 batch_size=10 num_train_iters=100000
39+
./train.sh
8140
```
8241

83-
<br/>
8442

8543
#### 6. Test the provided pre-trained models on full-resolution RAW image files
8644

8745
```bash
88-
python test_model.py level=0 orig=true
46+
python test_model_keras.py
8947
```
9048

9149
Optional parameters:
9250

93-
>```use_gpu```: **```true```**,**```false```** &nbsp; - &nbsp; run the model on GPU or CPU <br/>
94-
>```dataset_dir```: **```raw_images/```** &nbsp; - &nbsp; path to the folder with **Zurich RAW to RGB dataset** <br/>
51+
>```--model```: - &nbsp; path to the Keras model checkpoint <br/>
52+
>```--inp_path```: **```raw_images/test/```** &nbsp; - &nbsp; path to the folder with **Zurich RAW to RGB dataset** <br/>
53+
>```--out_path```: **```.```** &nbsp; - &nbsp; path to the output images <br/>
9554
96-
<br/>
9755

98-
#### 7. Test the obtained model on full-resolution RAW image files
99-
100-
```bash
101-
python test_model.py level=<level>
102-
```
103-
104-
Obligatory parameters:
105-
106-
>```level```: **```5, 4, 3, 2, 1, 0```**
107-
108-
Optional parameters:
109-
110-
>```restore_iter```: **```None```** &nbsp; - &nbsp; iteration to restore (when not specified, the last saved model for level=```<level>``` is loaded)<br/>
111-
>```use_gpu```: **```true```**,**```false```** &nbsp; - &nbsp; run the model on GPU or CPU <br/>
112-
>```dataset_dir```: **```raw_images/```** &nbsp; - &nbsp; path to the folder with **Zurich RAW to RGB dataset** <br/>
113-
114-
<br/>
115-
116-
#### 8. Folder structure
56+
#### 7. Folder structure
11757

11858
>```models/``` &nbsp; - &nbsp; logs and models that are saved during the training process <br/>
11959
>```models/original/``` &nbsp; - &nbsp; the folder with the provided pre-trained PyNET model <br/>
@@ -123,44 +63,34 @@ Optional parameters:
12363
>```vgg-pretrained/``` &nbsp; - &nbsp; the folder with the pre-trained VGG-19 network <br/>
12464
12565
>```load_dataset.py``` &nbsp; - &nbsp; python script that loads training data <br/>
126-
>```model.py``` &nbsp; - &nbsp; PyNET implementation (TensorFlow) <br/>
127-
>```train_model.py``` &nbsp; - &nbsp; implementation of the training procedure <br/>
128-
>```test_model.py``` &nbsp; - &nbsp; applying the pre-trained model to full-resolution test images <br/>
66+
>```model.py``` &nbsp; - &nbsp; PyNET implementation (Keras) <br/>
67+
>```train_model_keras.py``` &nbsp; - &nbsp; implementation of the training procedure <br/>
68+
>```test_model_keras.py``` &nbsp; - &nbsp; applying the pre-trained model to full-resolution test images <br/>
12969
>```utils.py``` &nbsp; - &nbsp; auxiliary functions <br/>
13070
>```vgg.py``` &nbsp; - &nbsp; loading the pre-trained vgg-19 network <br/>
13171
132-
<br/>
13372

13473
#### 9. Bonus files
13574

13675
These files can be useful for further experiments with the model / dataset:
13776

13877
>```dng_to_png.py``` &nbsp; - &nbsp; convert raw DNG camera files to PyNET's input format <br/>
139-
>```evaluate_accuracy.py``` &nbsp; - &nbsp; compute PSNR and MS-SSIM scores on Zurich RAW-to-RGB dataset for your own model <br/>
78+
>```ckpt2pb_keras.py``` &nbsp; - &nbsp; converts Keras checkpoint to TFLite format <br/>
79+
>```evaluate_accuracy_tflite.py``` &nbsp; - &nbsp; compute PSNR and MS-SSIM scores on Zurich RAW-to-RGB dataset for TFLite model <br/>
14080
141-
<br/>
14281

14382
#### 10. License
14483

145-
Copyright (C) 2020 Andrey Ignatov. All rights reserved.
84+
Copyright (C) 2022 Andrey Ignatov. All rights reserved.
14685

14786
Licensed under the [CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
14887

14988
The code is released for academic research use only.
15089

151-
<br/>
15290

15391
#### 11. Citation
92+
[WIP]
15493

155-
```
156-
@article{ignatov2020replacing,
157-
title={Replacing Mobile Camera ISP with a Single Deep Learning Model},
158-
author={Ignatov, Andrey and Van Gool, Luc and Timofte, Radu},
159-
journal={arXiv preprint arXiv:2002.05509},
160-
year={2020}
161-
}
162-
```
163-
<br/>
16494

16595
#### 12. Any further questions?
16696

model.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,6 @@ def PyNET(input, instance_norm=False, instance_norm_level_1=False):
5959
conv_l1_out = _conv_layer(conv_l1_out, 3, 3, 1, relu=False, instance_norm=False)
6060
output_l1 = tf.nn.tanh(conv_l1_out) * 0.58 + 0.5
6161

62-
with tf.name_scope("generator_0_"):
63-
conv_l0 = _upsample_layer(conv_l1_d14, 32, 3, 2)
64-
conv_l0_out = _conv_layer(conv_l0, 3 * k * k, 3, 1, relu=False, instance_norm=False)
65-
66-
# -> Output: Level 0
67-
output_l0 = tf.nn.tanh(conv_l0_out) * 0.58 + 0.5
68-
6962
return None, output_l1, output_l2, output_l3
7063

7164

test_model.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

test_model_keras.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2022 by Andrey Ignatov. All Rights Reserved.
2+
3+
import numpy as np
4+
import tensorflow.compat.v1 as tf
5+
import tensorflow_addons as tfa
6+
import imageio
7+
import sys
8+
import os
9+
import importlib
10+
import rawpy
11+
import cv2
12+
from tensorflow.keras.models import load_model
13+
import argparse
14+
15+
from load_dataset import extract_bayer_channels
16+
17+
IMAGE_HEIGHT, IMAGE_WIDTH = 1472, 1984
18+
DSLR_SCALE = 2
19+
20+
21+
dataset_dir = 'raw_images/'
22+
dslr_dir = 'fujifilm/'
23+
phone_dir = 'mediatek_raw/'
24+
25+
26+
def main():
27+
"""Test model"""
28+
parser = argparse.ArgumentParser(
29+
description='Test model',
30+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
31+
)
32+
parser.add_argument(
33+
'--model', help='Path to model checkpoint.', type=str, default='model.h5', required=True)
34+
parser.add_argument(
35+
'--inp_path', help='Path to the input data.', type=str, default='raw_images/test', required=True)
36+
parser.add_argument(
37+
'--out_path', help='Path to the output images.', type=str, default='.', required=True)
38+
args = parser.parse_args()
39+
40+
41+
spec = importlib.util.spec_from_file_location('pynet.model', 'model.py')
42+
module = importlib.util.module_from_spec(spec)
43+
spec.loader.exec_module(module)
44+
PyNET = module.PyNET
45+
46+
phone_ = tf.keras.Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 4))
47+
# Loading pre-trained model
48+
_, enhanced, _, _ = \
49+
PyNET(phone_, instance_norm=True, instance_norm_level_1=False)
50+
51+
52+
print("Initializing variables")
53+
54+
model = tf.keras.Model(inputs=phone_, outputs=enhanced)
55+
prev_model = load_model(args.model, compile=False)
56+
for i, layer in enumerate(prev_model.layers):
57+
for k in model.layers:
58+
if k.name == layer.name:
59+
k.set_weights(layer.get_weights())
60+
61+
62+
# Processing full-resolution RAW images
63+
test_dir = args.inp_path
64+
test_photos = [f for f in os.listdir(test_dir) if os.path.isfile(test_dir + f)]
65+
66+
for photo in test_photos:
67+
with rawpy.imread(test_dir + photo) as raw:
68+
I = extract_bayer_channels(raw.raw_image)
69+
print("Processing image " + photo)
70+
71+
I = I[0:IMAGE_HEIGHT, 0:IMAGE_WIDTH, :]
72+
I = np.reshape(I, [1, I.shape[0], I.shape[1], 4])
73+
74+
# Run inference
75+
76+
enhanced_tensor = model.predict([I])
77+
enhanced_image = np.reshape(enhanced_tensor, [int(I.shape[1] * DSLR_SCALE), int(I.shape[2] * DSLR_SCALE), 3])
78+
79+
# Save the results as .png images
80+
photo_name = photo.rsplit(".", 1)[0]
81+
enhanced_image = cv2.cvtColor(enhanced_image, cv2.COLOR_RGB2BGR)
82+
enhanced_image = np.uint8(np.clip(enhanced_image, 0.0, 1.0) * 255.0)
83+
cv2.imwrite(os.path.join(args.out_path, photo_name + ".png"), enhanced_image)
84+
85+
86+
if __name__ == '__main__':
87+
main()

0 commit comments

Comments
 (0)