1818from typing import Any , TypeVar , Union
1919
2020import pytest
21+ from typing_extensions import ContextManager
2122
2223from tests .cases import assert_not_warns
2324from vanguard .base import GPController
2425from vanguard .decoratorutils import Decorator , TopMostDecorator , errors , wraps_class
2526from vanguard .decoratorutils .errors import TopmostDecoratorError
27+ from vanguard .utils import multi_context
2628from vanguard .vanilla import GaussianGPController
2729
2830ControllerT = TypeVar ("ControllerT" , bound = GPController )
@@ -53,7 +55,7 @@ def __init__(self, **kwargs: Any) -> None:
5355 super ().__init__ (framework_class = SimpleNumber , required_decorators = {}, ** kwargs )
5456
5557 def _decorate_class (self , cls : type [SimpleNumber ]) -> type [SimpleNumber ]:
56- @wraps_class (cls )
58+ @wraps_class (cls , decorator_source = self )
5759 class InnerClass (cls ):
5860 """A wrapper for normalising y inputs and variance."""
5961
@@ -319,29 +321,40 @@ def add_5(self) -> Union[int, float]:
319321
320322 # If we're not ignoring the errors, then specify what the expected error/warning is.
321323 if who_is_wrong == "subclass" :
322- blame_class = "NewNumber"
324+ blame_classes = ["NewNumber" ]
325+ elif who_is_wrong == "superclass" :
326+ blame_classes = ["MiddleNumber" ]
323327 elif who_is_wrong == "superclass" or who_is_wrong == "both" :
324- blame_class = "MiddleNumber"
328+ blame_classes = [ "MiddleNumber" , "NewNumber" ]
325329 else :
326330 raise ValueError (who_is_wrong )
327331
332+ contexts : list [ContextManager ]
328333 if mode == "new_method" :
329- expected_message = (
334+ expected_messages = [
330335 f"{ SquareResult .__name__ !r} : The class '{ blame_class } ' has added the following unexpected methods"
336+ for blame_class in blame_classes
337+ ]
338+ contexts = (
339+ [pytest .raises (errors .UnexpectedMethodError , match = expected_messages [0 ])]
340+ if raise_instead
341+ else [pytest .warns (errors .UnexpectedMethodWarning , match = msg ) for msg in expected_messages ]
331342 )
332- expected_types = errors .UnexpectedMethodError , errors .UnexpectedMethodWarning
333343 elif mode == "override_method" :
334- expected_message = (
344+ expected_messages = [
335345 f"{ SquareResult .__name__ !r} : The class '{ blame_class } ' has overwritten the following methods"
346+ for blame_class in blame_classes
347+ ]
348+ contexts = (
349+ [pytest .raises (errors .OverwrittenMethodError , match = expected_messages [0 ])]
350+ if raise_instead
351+ else [pytest .warns (errors .OverwrittenMethodWarning , match = msg ) for msg in expected_messages ]
336352 )
337- expected_types = errors .OverwrittenMethodError , errors .OverwrittenMethodWarning
338353 else :
339354 raise ValueError (mode )
340355
341- if raise_instead :
342- context = pytest .raises (expected_types [0 ], match = expected_message )
343- else :
344- context = pytest .warns (expected_types [1 ], match = expected_message )
356+ context = multi_context (contexts )
357+
345358 elif ignore == "specific" :
346359 # Ignore by setting `ignore_methods`.
347360 if mode == "new_method" :
@@ -358,9 +371,11 @@ def add_5(self) -> Union[int, float]:
358371 else :
359372 raise ValueError (ignore )
360373
361- # Actually perform the decoration, and check for errors/warnings as appropriate
362- with context :
363- SquareResult (** kwargs )(NewNumber )
374+ # Actually perform the decoration, and check for errors/warnings as appropriate. No warnings other than those
375+ # we expect should be raised.
376+ with assert_not_warns (errors .OverwrittenMethodWarning , errors .UnexpectedMethodWarning ):
377+ with context :
378+ SquareResult (** kwargs )(NewNumber )
364379
365380
366381class TestInvalidDecoration :
@@ -394,7 +409,7 @@ class NewNumber(SimpleNumber): # pylint: disable=unused-variable
394409 def test_passed_requirements (self ) -> None :
395410 """Test that when all requirements for a decorator are satisfied, no error is thrown."""
396411
397- @RequiresSquareResult (ignore_methods = ( "__init__" , "add_5" ) )
412+ @RequiresSquareResult (ignore_all = True )
398413 @SquareResult ()
399414 class NewNumber (SimpleNumber ): # pylint: disable=unused-variable
400415 """Declaring this class should not throw any error, as all requirements are satisfied."""
0 commit comments