Skip to content

Commit a3ec996

Browse files
authored
Merge pull request #534 from gchq/fix/remaining-decorator-warnings
fix: decorator warnings raised by test_basedecorator.py
2 parents 5276c33 + fb20457 commit a3ec996

File tree

7 files changed

+98
-23
lines changed

7 files changed

+98
-23
lines changed

tests/units/decoratorutils/test_basedecorator.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
from typing import Any, TypeVar, Union
1919

2020
import pytest
21+
from typing_extensions import ContextManager
2122

2223
from tests.cases import assert_not_warns
2324
from vanguard.base import GPController
2425
from vanguard.decoratorutils import Decorator, TopMostDecorator, errors, wraps_class
2526
from vanguard.decoratorutils.errors import TopmostDecoratorError
27+
from vanguard.utils import multi_context
2628
from vanguard.vanilla import GaussianGPController
2729

2830
ControllerT = TypeVar("ControllerT", bound=GPController)
@@ -53,7 +55,7 @@ def __init__(self, **kwargs: Any) -> None:
5355
super().__init__(framework_class=SimpleNumber, required_decorators={}, **kwargs)
5456

5557
def _decorate_class(self, cls: type[SimpleNumber]) -> type[SimpleNumber]:
56-
@wraps_class(cls)
58+
@wraps_class(cls, decorator_source=self)
5759
class InnerClass(cls):
5860
"""A wrapper for normalising y inputs and variance."""
5961

@@ -319,29 +321,40 @@ def add_5(self) -> Union[int, float]:
319321

320322
# If we're not ignoring the errors, then specify what the expected error/warning is.
321323
if who_is_wrong == "subclass":
322-
blame_class = "NewNumber"
324+
blame_classes = ["NewNumber"]
325+
elif who_is_wrong == "superclass":
326+
blame_classes = ["MiddleNumber"]
323327
elif who_is_wrong == "superclass" or who_is_wrong == "both":
324-
blame_class = "MiddleNumber"
328+
blame_classes = ["MiddleNumber", "NewNumber"]
325329
else:
326330
raise ValueError(who_is_wrong)
327331

332+
contexts: list[ContextManager]
328333
if mode == "new_method":
329-
expected_message = (
334+
expected_messages = [
330335
f"{SquareResult.__name__!r}: The class '{blame_class}' has added the following unexpected methods"
336+
for blame_class in blame_classes
337+
]
338+
contexts = (
339+
[pytest.raises(errors.UnexpectedMethodError, match=expected_messages[0])]
340+
if raise_instead
341+
else [pytest.warns(errors.UnexpectedMethodWarning, match=msg) for msg in expected_messages]
331342
)
332-
expected_types = errors.UnexpectedMethodError, errors.UnexpectedMethodWarning
333343
elif mode == "override_method":
334-
expected_message = (
344+
expected_messages = [
335345
f"{SquareResult.__name__!r}: The class '{blame_class}' has overwritten the following methods"
346+
for blame_class in blame_classes
347+
]
348+
contexts = (
349+
[pytest.raises(errors.OverwrittenMethodError, match=expected_messages[0])]
350+
if raise_instead
351+
else [pytest.warns(errors.OverwrittenMethodWarning, match=msg) for msg in expected_messages]
336352
)
337-
expected_types = errors.OverwrittenMethodError, errors.OverwrittenMethodWarning
338353
else:
339354
raise ValueError(mode)
340355

341-
if raise_instead:
342-
context = pytest.raises(expected_types[0], match=expected_message)
343-
else:
344-
context = pytest.warns(expected_types[1], match=expected_message)
356+
context = multi_context(contexts)
357+
345358
elif ignore == "specific":
346359
# Ignore by setting `ignore_methods`.
347360
if mode == "new_method":
@@ -358,9 +371,11 @@ def add_5(self) -> Union[int, float]:
358371
else:
359372
raise ValueError(ignore)
360373

361-
# Actually perform the decoration, and check for errors/warnings as appropriate
362-
with context:
363-
SquareResult(**kwargs)(NewNumber)
374+
# Actually perform the decoration, and check for errors/warnings as appropriate. No warnings other than those
375+
# we expect should be raised.
376+
with assert_not_warns(errors.OverwrittenMethodWarning, errors.UnexpectedMethodWarning):
377+
with context:
378+
SquareResult(**kwargs)(NewNumber)
364379

365380

366381
class TestInvalidDecoration:
@@ -394,7 +409,7 @@ class NewNumber(SimpleNumber): # pylint: disable=unused-variable
394409
def test_passed_requirements(self) -> None:
395410
"""Test that when all requirements for a decorator are satisfied, no error is thrown."""
396411

397-
@RequiresSquareResult(ignore_methods=("__init__", "add_5"))
412+
@RequiresSquareResult(ignore_all=True)
398413
@SquareResult()
399414
class NewNumber(SimpleNumber): # pylint: disable=unused-variable
400415
"""Declaring this class should not throw any error, as all requirements are satisfied."""

tests/units/test_decorator_combinations.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,18 @@ def _decorate_class(self, cls: T) -> T:
379379
# (upper, lower) -> (warning type, message regex)
380380
# Warnings we expect to be raised on decorator application.
381381
EXPECTED_COMBINATION_APPLY_WARNINGS: dict[tuple[type[Decorator], type[Decorator]], tuple[type[Warning], str]] = {
382-
(NormaliseY, DirichletMulticlassClassification): (
383-
BadCombinationWarning,
384-
"NormaliseY should not be used above classification decorators - this may lead to unexpected behaviour.",
385-
),
382+
**{
383+
(NormaliseY, lower): (
384+
BadCombinationWarning,
385+
"NormaliseY should not be used above classification decorators - this may lead to unexpected behaviour.",
386+
)
387+
for lower in [
388+
BinaryClassification,
389+
CategoricalClassification,
390+
DirichletMulticlassClassification,
391+
DirichletKernelMulticlassClassification,
392+
]
393+
},
386394
**{
387395
(VariationalInference, lower): (
388396
BadCombinationWarning,

tests/units/test_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
"""
1818

1919
import unittest
20+
from unittest.mock import MagicMock
2021

2122
import numpy as np
2223
import numpy.typing
2324
import pytest
2425
import torch
26+
from typing_extensions import ContextManager
2527

2628
from tests.cases import get_default_rng
2729
from vanguard.utils import (
@@ -30,6 +32,7 @@
3032
generator_append_constant,
3133
infinite_tensor_generator,
3234
instantiate_with_subset_of_kwargs,
35+
multi_context,
3336
optional_random_generator,
3437
)
3538

@@ -286,3 +289,37 @@ def test_generator_zero_dimensional(self):
286289

287290
with pytest.raises(ValueError, match="0-dimensional tensors are incompatible"):
288291
next(generator)
292+
293+
294+
@pytest.mark.parametrize("num_contexts", [1, 5])
295+
def test_multi_context(num_contexts: int):
296+
"""
297+
Test the `multi_context` context manager.
298+
299+
Given that the context manager syntax is ultimately just syntactic sugar for calling `ctx.__enter__` and
300+
`ctx.__exit__` (with some extra semantics around exception handling), all we really need to do is check that
301+
these methods are called on each context manager passed to `multi_context` at the appropriate times. We don't
302+
test the specifics of exception propagation, as that is (a) too complex for a single unit test, and (b) more the
303+
responsibility of the `contextlib.ExitStack` that `multi_context` uses internally.
304+
305+
Test with both a single context and multiple contexts being passed to `multi_context`.
306+
"""
307+
dummy_contexts = [MagicMock(spec=ContextManager)() for _ in range(num_contexts)]
308+
context = multi_context(dummy_contexts)
309+
310+
for ctx in dummy_contexts:
311+
# check that we haven't entered or left any of the contexts yet
312+
ctx.__enter__.assert_not_called()
313+
ctx.__exit__.assert_not_called()
314+
315+
# enter the multi context
316+
with context:
317+
for ctx in dummy_contexts:
318+
# check that we have entered but not left each of the contexts yet
319+
ctx.__enter__.assert_called_once()
320+
ctx.__exit__.assert_not_called()
321+
322+
for ctx in dummy_contexts:
323+
# check that we have left each of the contexts
324+
ctx.__enter__.assert_called_once()
325+
ctx.__exit__.assert_called_once()

vanguard/base/gpcontroller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def fit(
141141
warnings.warn(
142142
f"You are trying to set gradient_every (in this case to {gradient_every}) in batch mode."
143143
"This does not make mathematical sense and your value of gradient every will be ignored "
144-
" and replaced by 1."
144+
" and replaced by 1.",
145+
stacklevel=2,
145146
)
146147
gradient_every = 1
147148

vanguard/decoratorutils/basedecorator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ def _verify_class_has_no_newly_added_methods(self, cls: type[T], super_methods:
185185
cls_methods = {
186186
key
187187
for key, value in getmembers(cls, isfunction)
188+
# only functions that are actually from this class (as opposed to a superclass)
189+
if key in cls.__dict__
190+
# ignore functions defined in safe_updates
188191
if key not in self.safe_updates.get(self._get_method_implementation(cls, key), set())
189192
# beartype does weird things with __sizeof__; however, it's of no concern to us, and we never make use of
190193
# this dunder attribute. See https://github.com/beartype/beartype/blob/v0.19.0/beartype/_decor/_decortype.py
@@ -216,7 +219,7 @@ def _verify_class_has_no_newly_added_methods(self, cls: type[T], super_methods:
216219
else:
217220
warnings.warn(message, errors.UnexpectedMethodWarning, stacklevel=4)
218221

219-
overwritten_methods = {method for method in cls_methods if method in cls.__dict__} - ignore_methods
222+
overwritten_methods = cls_methods - ignore_methods - extra_methods
220223
if overwritten_methods:
221224
if __debug__:
222225
overwritten_method_messages = "\n".join(

vanguard/distribute/decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
275275
**self._expert_init_kwargs,
276276
)
277277

278-
def fit(self, n_sgd_iters: int = 10, gradient_every: int = 10) -> torch.Tensor:
278+
def fit(self, n_sgd_iters: int = 10, gradient_every: Optional[int] = None) -> torch.Tensor:
279279
"""
280280
Create the expert controllers.
281281

vanguard/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
Contain some small utilities of use in some cases.
1717
"""
1818

19+
import contextlib
1920
import functools
2021
import os
2122
import warnings
22-
from collections.abc import Generator
23+
from collections.abc import Generator, Iterable
2324
from typing import Any, Callable, Optional, TypeVar
2425

2526
import numpy as np
2627
import numpy.typing
2728
import torch
29+
from typing_extensions import ContextManager
2830

2931
from vanguard.warnings import _RE_INCORRECT_LIKELIHOOD_PARAMETER
3032

@@ -253,3 +255,12 @@ def compose(functions: list[Callable[[T], T]]) -> Callable[[T], T]:
253255
:return: A single function of type (T -> T) that applies each of the passed functions in series.
254256
"""
255257
return lambda x: functools.reduce(lambda acc, f: f(acc), reversed(functions), x)
258+
259+
260+
@contextlib.contextmanager
261+
def multi_context(contexts: Iterable[ContextManager]):
262+
"""Combine multiple context managers into one."""
263+
with contextlib.ExitStack() as stack:
264+
for context in contexts:
265+
stack.enter_context(context)
266+
yield

0 commit comments

Comments
 (0)