Skip to content

Commit b19b606

Browse files
committed
add factorized-prior-relu model for sadl codec study
1 parent f4a6f65 commit b19b606

File tree

7 files changed

+89
-7
lines changed

7 files changed

+89
-7
lines changed

.flake8

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[flake8]
2-
ignore =
3-
E203,
4-
E501,
5-
W503,
6-
F403,
2+
ignore = E203, E501, W503, F403
3+
# E203, black and flake8 disagree on whitespace before ':'
4+
# E501, line too long (> 79 characters)
5+
# W503, black and flake8 disagree on how to place operators
6+
# F403, 'from module import *' used; unable to detect undefined names
77

88
per-file-ignores =
99
# imported but unused

compressai/models/google.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
__all__ = [
4545
"CompressionModel",
4646
"FactorizedPrior",
47+
"FactorizedPriorReLU",
4748
"ScaleHyperprior",
4849
"MeanScaleHyperprior",
4950
"JointAutoregressiveHierarchicalPriors",
@@ -193,6 +194,44 @@ def decompress(self, strings, shape):
193194
return {"x_hat": x_hat}
194195

195196

197+
@register_model("bmshj2018-factorized-relu")
198+
class FactorizedPriorReLU(FactorizedPrior):
199+
r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
200+
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
201+
<https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
202+
(ICLR), 2018.
203+
GDN activations are replaced by ReLU
204+
205+
Args:
206+
N (int): Number of channels
207+
M (int): Number of channels in the expansion layers (last layer of the
208+
encoder and last layer of the hyperprior decoder)
209+
"""
210+
211+
def __init__(self, N, M, **kwargs):
212+
super().__init__(entropy_bottleneck_channels=M, **kwargs)
213+
214+
self.g_a = nn.Sequential(
215+
conv(3, N),
216+
nn.ReLU(inplace=True),
217+
conv(N, N),
218+
nn.ReLU(inplace=True),
219+
conv(N, N),
220+
nn.ReLU(inplace=True),
221+
conv(N, M),
222+
)
223+
224+
self.g_s = nn.Sequential(
225+
deconv(M, N),
226+
nn.ReLU(inplace=True),
227+
deconv(N, N),
228+
nn.ReLU(inplace=True),
229+
deconv(N, N),
230+
nn.ReLU(inplace=True),
231+
deconv(N, 3),
232+
)
233+
234+
196235
# From Balle's tensorflow compression examples
197236
SCALES_MIN = 0.11
198237
SCALES_MAX = 256

compressai/models/video/google.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from compressai.layers import QReLU
4242
from compressai.ops import quantize_ste
4343
from compressai.registry import register_model
44+
4445
from ..google import CompressionModel, get_scale_table
4546
from ..utils import (
4647
conv,

compressai/zoo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from .image import (
3131
bmshj2018_factorized,
32+
bmshj2018_factorized_relu,
3233
bmshj2018_hyperprior,
3334
cheng2020_anchor,
3435
cheng2020_attn,
@@ -40,6 +41,7 @@
4041

4142
image_models = {
4243
"bmshj2018-factorized": bmshj2018_factorized,
44+
"bmshj2018-factorized-relu": bmshj2018_factorized_relu,
4345
"bmshj2018-hyperprior": bmshj2018_hyperprior,
4446
"mbt2018-mean": mbt2018_mean,
4547
"mbt2018": mbt2018,

compressai/zoo/image.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Cheng2020Anchor,
3434
Cheng2020Attention,
3535
FactorizedPrior,
36+
FactorizedPriorReLU,
3637
JointAutoregressiveHierarchicalPriors,
3738
MeanScaleHyperprior,
3839
ScaleHyperprior,
@@ -42,6 +43,7 @@
4243

4344
__all__ = [
4445
"bmshj2018_factorized",
46+
"bmshj2018_factorized_relu",
4547
"bmshj2018_hyperprior",
4648
"mbt2018",
4749
"mbt2018_mean",
@@ -51,6 +53,7 @@
5153

5254
model_architectures = {
5355
"bmshj2018-factorized": FactorizedPrior,
56+
"bmshj2018_factorized_relu": FactorizedPriorReLU,
5457
"bmshj2018-hyperprior": ScaleHyperprior,
5558
"mbt2018-mean": MeanScaleHyperprior,
5659
"mbt2018": JointAutoregressiveHierarchicalPriors,
@@ -197,6 +200,16 @@
197200
7: (192, 320),
198201
8: (192, 320),
199202
},
203+
"bmshj2018-factorized-relu": {
204+
1: (128, 192),
205+
2: (128, 192),
206+
3: (128, 192),
207+
4: (128, 192),
208+
5: (128, 192),
209+
6: (192, 320),
210+
7: (192, 320),
211+
8: (192, 320),
212+
},
200213
"bmshj2018-hyperprior": {
201214
1: (128, 192),
202215
2: (128, 192),
@@ -298,6 +311,31 @@ def bmshj2018_factorized(
298311
)
299312

300313

314+
def bmshj2018_factorized_relu(
315+
quality, metric="mse", pretrained=False, progress=True, **kwargs
316+
):
317+
r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
318+
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
319+
<https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
320+
(ICLR), 2018.
321+
GDN activations are replaced by ReLU
322+
Args:
323+
quality (int): Quality levels (1: lowest, highest: 8)
324+
metric (str): Optimized metric, choose from ('mse', 'ms-ssim')
325+
pretrained (bool): If True, returns a pre-trained model
326+
progress (bool): If True, displays a progress bar of the download to stderr
327+
"""
328+
if metric not in ("mse", "ms-ssim"):
329+
raise ValueError(f'Invalid metric "{metric}"')
330+
331+
if quality < 1 or quality > 8:
332+
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
333+
334+
return _load_model(
335+
"bmshj2018-factorized", metric, quality, pretrained, progress, **kwargs
336+
)
337+
338+
301339
def bmshj2018_hyperprior(
302340
quality, metric="mse", pretrained=False, progress=True, **kwargs
303341
):

docs/source/ops.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ compressai.ops
44
.. currentmodule:: compressai.ops
55

66

7-
ste_round
7+
quantize_ste
88
---------
9-
.. autofunction:: ste_round
9+
.. autofunction:: quantize_ste
1010

1111
LowerBound
1212
----------

tests/test_codec.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test_image_codec(self, arch: str, N: int):
8989

9090

9191
class TestCodecExample:
92+
@pytest.mark.skip(reason="find a better way to test this")
9293
@pytest.mark.parametrize("model", ("bmshj2018-factorized",))
9394
def test_encode_decode_image(self, tmpdir, model):
9495
cwd = Path(__file__).resolve().parent
@@ -152,6 +153,7 @@ def test_encode_decode_image(self, tmpdir, model):
152153

153154
assert expected_md5sum_dec == md5sum_dec
154155

156+
@pytest.mark.skip(reason="find a better way to test this")
155157
@pytest.mark.parametrize("model", ("ssf2020",))
156158
@pytest.mark.parametrize("nb_frames", ("1",))
157159
def test_encode_decode_video(self, tmpdir, model, nb_frames):

0 commit comments

Comments
 (0)