|
7 | 7 | from vskernels import Catrom, Lanczos
|
8 | 8 | from vsscale import autoselect_backend
|
9 | 9 | from vstools import (ColorRange, CustomValueError, DependencyNotFoundError,
|
10 |
| - FileWasNotFoundError, FunctionUtil, |
11 |
| - InvalidColorFamilyError, LengthMismatchError, Matrix, |
12 |
| - SPath, UnsupportedVideoFormatError, depth, inject_self, |
13 |
| - iterate, join, normalize_planes, split, vs) |
| 10 | + FileWasNotFoundError, FunctionUtil, Matrix, SPath, depth, |
| 11 | + inject_self, iterate, join, normalize_planes, split, vs, get_peak_value) |
14 | 12 |
|
15 | 13 | __all__: list[str] = []
|
16 | 14 |
|
@@ -198,57 +196,52 @@ def _set_model_path(self) -> None:
|
198 | 196 |
|
199 | 197 |
|
200 | 198 | class Base1xModelWithStrength(Base1xModel):
|
201 |
| - """Base class for 1x models to reconstruct high-frequency information.""" |
| 199 | + """Base class for 1x models to reconstruct high-frequency information with strength control.""" |
202 | 200 |
|
203 |
| - _strength: float | vs.VideoNode | None = None |
204 |
| - """Strength of the model.""" |
| 201 | + def _initialize_strength( |
| 202 | + self, clip: vs.VideoNode, strength: SupportsFloat | vs.VideoNode = 10.0 |
| 203 | + ) -> vs.VideoNode: |
| 204 | + self._strength_clip = strength |
205 | 205 |
|
206 |
| - def _initialize_strength(self, clip: vs.VideoNode, strength: SupportsFloat | vs.VideoNode | None = None) -> None: |
207 |
| - self._strength = strength |
| 206 | + if isinstance(strength, (float, int)): |
| 207 | + self._set_strength_clip(clip, strength) |
| 208 | + elif isinstance(strength, vs.VideoNode): |
| 209 | + self._norm_str_clip(clip) |
208 | 210 |
|
209 |
| - if isinstance(self._strength, SupportsFloat): |
210 |
| - self._set_strength_clip(clip) |
211 |
| - elif isinstance(self._strength, vs.VideoNode): |
212 |
| - self._norm_str_clip() |
213 |
| - else: |
214 |
| - raise UnsupportedVideoFormatError( |
215 |
| - '`strength` must be a float or a GRAYS clip', self._func.func, type(self._strength) |
216 |
| - ) |
| 211 | + return self._strength_clip |
| 212 | + |
| 213 | + def _set_strength_clip(self, clip: vs.VideoNode, strength: float) -> None: |
| 214 | + norm_strength = strength / 100.0 * get_peak_value(16 if self._fp16 else 32) |
217 | 215 |
|
218 |
| - def _set_strength_clip(self, clip: vs.VideoNode) -> None: |
219 |
| - self._strength = clip.std.BlankClip( |
220 |
| - format=vs.GRAYH if self._fp16 else vs.GRAYS, color=float(self._strength) / 255, keep=True |
| 216 | + self._strength_clip = expr_func( |
| 217 | + [clip.std.BlankClip(format=vs.GRAYH if self._fp16 else vs.GRAYS, keep=True)], |
| 218 | + f"x {norm_strength} +" |
221 | 219 | )
|
222 | 220 |
|
223 |
| - def _norm_str_clip(self) -> None: |
224 |
| - assert not isinstance(self._strength, float), 'The dev must run `_initialize_strength` in the apply method!' |
| 221 | + def _norm_str_clip(self, clip: vs.VideoNode) -> None: |
| 222 | + str_clip = self._strength_clip |
225 | 223 |
|
226 |
| - assert (fmt := self._strength.format) |
227 |
| - fmt_name = fmt.name.upper() |
| 224 | + if str_clip.format.color_family != vs.GRAY: |
| 225 | + raise ValueError("Strength clip must be GRAY") |
228 | 226 |
|
229 |
| - InvalidColorFamilyError.check( |
230 |
| - fmt, vs.GRAY, self._func.func, '"strength" must be of {correct} color family, not {wrong}!' |
231 |
| - ) |
| 227 | + if str_clip.format.id == vs.GRAY8: |
| 228 | + str_clip = expr_func(str_clip, 'x 255 /', vs.GRAYH if self._fp16 else vs.GRAYS) |
| 229 | + elif self._fp16 and str_clip.format.id != vs.GRAYH: |
| 230 | + str_clip = depth(str_clip, 16, vs.FLOAT) |
232 | 231 |
|
233 |
| - if fmt.id == vs.GRAY8: |
234 |
| - self._strength = expr_func(self._strength, 'x 255 /', vs.GRAYH if self._fp16 else vs.GRAYS) |
235 |
| - elif fmt.id not in {vs.GRAYH, vs.GRAYS}: |
236 |
| - raise UnsupportedVideoFormatError( |
237 |
| - f'`strength` must be GRAY8, GRAYH, or GRAYS, not {fmt_name}!', self._func.func |
238 |
| - ) |
239 |
| - elif self._fp16 and fmt.id != vs.GRAYH: |
240 |
| - self._strength = depth(self._strength, 16, vs.FLOAT) |
| 232 | + if str_clip.width != clip.width or str_clip.height != clip.height: |
| 233 | + str_clip = Catrom.scale(str_clip, clip.width, clip.height) |
241 | 234 |
|
242 |
| - if self._strength.width != self._func.work_clip.width or self._strength.height != self._func.work_clip.height: |
243 |
| - self._strength = Catrom.scale(self._strength, self._func.work_clip.width, self._func.work_clip.height) |
| 235 | + if str_clip.num_frames != clip.num_frames: |
| 236 | + raise ValueError("Strength clip must have the same number of frames as the input clip") |
244 | 237 |
|
245 |
| - if self._strength.num_frames != self._func.work_clip.num_frames: |
246 |
| - raise LengthMismatchError(self._func.func, '`strength` must be the same length as \'clip\'') |
| 238 | + self._strength_clip = str_clip |
247 | 239 |
|
248 | 240 | def _should_process(self, strength: SupportsFloat | vs.VideoNode | None | Literal[False] = False) -> bool:
|
249 |
| - strength = self._strength if strength is False else strength |
| 241 | + if hasattr(self, '_strength_clip'): |
| 242 | + return self._strength_clip is not False |
250 | 243 |
|
251 |
| - return not (isinstance(strength, int) and strength >= 0) |
| 244 | + return (strength is None) or not (isinstance(strength, (int, float)) and strength <= 0.0) |
252 | 245 |
|
253 | 246 |
|
254 | 247 | def get_models_path() -> SPath:
|
|
0 commit comments