Skip to content

Commit 59f3ac4

Browse files
LDempeg2: Update strength logic handling
1 parent 4a3abc5 commit 59f3ac4

File tree

2 files changed

+41
-48
lines changed

2 files changed

+41
-48
lines changed

lvsfunc/models/base.py

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from vskernels import Catrom, Lanczos
88
from vsscale import autoselect_backend
99
from vstools import (ColorRange, CustomValueError, DependencyNotFoundError,
10-
FileWasNotFoundError, FunctionUtil,
11-
InvalidColorFamilyError, LengthMismatchError, Matrix,
12-
SPath, UnsupportedVideoFormatError, depth, inject_self,
13-
iterate, join, normalize_planes, split, vs)
10+
FileWasNotFoundError, FunctionUtil, Matrix, SPath, depth,
11+
inject_self, iterate, join, normalize_planes, split, vs, get_peak_value)
1412

1513
__all__: list[str] = []
1614

@@ -198,57 +196,52 @@ def _set_model_path(self) -> None:
198196

199197

200198
class Base1xModelWithStrength(Base1xModel):
201-
"""Base class for 1x models to reconstruct high-frequency information."""
199+
"""Base class for 1x models to reconstruct high-frequency information with strength control."""
202200

203-
_strength: float | vs.VideoNode | None = None
204-
"""Strength of the model."""
201+
def _initialize_strength(
202+
self, clip: vs.VideoNode, strength: SupportsFloat | vs.VideoNode = 10.0
203+
) -> vs.VideoNode:
204+
self._strength_clip = strength
205205

206-
def _initialize_strength(self, clip: vs.VideoNode, strength: SupportsFloat | vs.VideoNode | None = None) -> None:
207-
self._strength = strength
206+
if isinstance(strength, (float, int)):
207+
self._set_strength_clip(clip, strength)
208+
elif isinstance(strength, vs.VideoNode):
209+
self._norm_str_clip(clip)
208210

209-
if isinstance(self._strength, SupportsFloat):
210-
self._set_strength_clip(clip)
211-
elif isinstance(self._strength, vs.VideoNode):
212-
self._norm_str_clip()
213-
else:
214-
raise UnsupportedVideoFormatError(
215-
'`strength` must be a float or a GRAYS clip', self._func.func, type(self._strength)
216-
)
211+
return self._strength_clip
212+
213+
def _set_strength_clip(self, clip: vs.VideoNode, strength: float) -> None:
214+
norm_strength = strength / 100.0 * get_peak_value(16 if self._fp16 else 32)
217215

218-
def _set_strength_clip(self, clip: vs.VideoNode) -> None:
219-
self._strength = clip.std.BlankClip(
220-
format=vs.GRAYH if self._fp16 else vs.GRAYS, color=float(self._strength) / 255, keep=True
216+
self._strength_clip = expr_func(
217+
[clip.std.BlankClip(format=vs.GRAYH if self._fp16 else vs.GRAYS, keep=True)],
218+
f"x {norm_strength} +"
221219
)
222220

223-
def _norm_str_clip(self) -> None:
224-
assert not isinstance(self._strength, float), 'The dev must run `_initialize_strength` in the apply method!'
221+
def _norm_str_clip(self, clip: vs.VideoNode) -> None:
222+
str_clip = self._strength_clip
225223

226-
assert (fmt := self._strength.format)
227-
fmt_name = fmt.name.upper()
224+
if str_clip.format.color_family != vs.GRAY:
225+
raise ValueError("Strength clip must be GRAY")
228226

229-
InvalidColorFamilyError.check(
230-
fmt, vs.GRAY, self._func.func, '"strength" must be of {correct} color family, not {wrong}!'
231-
)
227+
if str_clip.format.id == vs.GRAY8:
228+
str_clip = expr_func(str_clip, 'x 255 /', vs.GRAYH if self._fp16 else vs.GRAYS)
229+
elif self._fp16 and str_clip.format.id != vs.GRAYH:
230+
str_clip = depth(str_clip, 16, vs.FLOAT)
232231

233-
if fmt.id == vs.GRAY8:
234-
self._strength = expr_func(self._strength, 'x 255 /', vs.GRAYH if self._fp16 else vs.GRAYS)
235-
elif fmt.id not in {vs.GRAYH, vs.GRAYS}:
236-
raise UnsupportedVideoFormatError(
237-
f'`strength` must be GRAY8, GRAYH, or GRAYS, not {fmt_name}!', self._func.func
238-
)
239-
elif self._fp16 and fmt.id != vs.GRAYH:
240-
self._strength = depth(self._strength, 16, vs.FLOAT)
232+
if str_clip.width != clip.width or str_clip.height != clip.height:
233+
str_clip = Catrom.scale(str_clip, clip.width, clip.height)
241234

242-
if self._strength.width != self._func.work_clip.width or self._strength.height != self._func.work_clip.height:
243-
self._strength = Catrom.scale(self._strength, self._func.work_clip.width, self._func.work_clip.height)
235+
if str_clip.num_frames != clip.num_frames:
236+
raise ValueError("Strength clip must have the same number of frames as the input clip")
244237

245-
if self._strength.num_frames != self._func.work_clip.num_frames:
246-
raise LengthMismatchError(self._func.func, '`strength` must be the same length as \'clip\'')
238+
self._strength_clip = str_clip
247239

248240
def _should_process(self, strength: SupportsFloat | vs.VideoNode | None | Literal[False] = False) -> bool:
249-
strength = self._strength if strength is False else strength
241+
if hasattr(self, '_strength_clip'):
242+
return self._strength_clip is not False
250243

251-
return not (isinstance(strength, int) and strength >= 0)
244+
return (strength is None) or not (isinstance(strength, (int, float)) and strength <= 0.0)
252245

253246

254247
def get_models_path() -> SPath:

lvsfunc/models/dempeg2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def apply(
3232
:param clip: The clip to process.
3333
:param strength: The "strength" of the model. Works by merging the clip with a "strength" mask,
3434
just like DPIR. Higher values remove more noise, but also more detail.
35-
Sane values are between 75 and 125. A custom strength clip can be provided instead.
36-
If None, do not apply a strength mask.
35+
Sane values are between 50-100. A custom strength clip can be provided instead.
36+
If None, do not apply a strength mask. Values above 100 are not recommended.
3737
Default: None.
3838
:param show_mask: Whether to show the strength mask. Default: False.
3939
:param iterations: The number of iterations to apply the model.
@@ -45,7 +45,7 @@ def apply(
4545
:return: The processed clip.
4646
"""
4747

48-
if not self._should_process():
48+
if not self._should_process(strength):
4949
return clip
5050

5151
nplanes = normalize_planes(clip, planes)
@@ -57,14 +57,14 @@ def apply(
5757

5858
processed = super().apply(clip, **kwargs)
5959

60-
if self._strength is None:
60+
if strength is None:
6161
return depth(processed, clip)
6262

63-
self._initialize_strength(clip, strength)
63+
strength_mask = self._initialize_strength(clip, strength)
6464

6565
if show_mask:
66-
return self._strength
66+
return strength_mask
6767

68-
limited = depth(clip, processed).std.MaskedMerge(processed, self._strength, nplanes)
68+
limited = depth(clip, processed).std.MaskedMerge(processed, strength_mask, nplanes)
6969

7070
return depth(limited, clip)

0 commit comments

Comments
 (0)