From c23b3d83321e13e488d72bc8a68a1d1a50b2d50c Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Tue, 31 Oct 2023 16:30:54 +0900 Subject: [PATCH] add CallFunction extension --- .../training/extensions/call_function.py | 62 +++++++++ .../extensions_tests/test_call_function.py | 119 ++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 pytorch_pfn_extras/training/extensions/call_function.py create mode 100644 tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_call_function.py diff --git a/pytorch_pfn_extras/training/extensions/call_function.py b/pytorch_pfn_extras/training/extensions/call_function.py new file mode 100644 index 00000000..9e3e9a82 --- /dev/null +++ b/pytorch_pfn_extras/training/extensions/call_function.py @@ -0,0 +1,62 @@ +import types +from typing import Any, Callable, Dict, Mapping, Optional, Sequence + +from pytorch_pfn_extras.reporting import Value, report +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) +from pytorch_pfn_extras.training.extension import PRIORITY_WRITER, Extension + + +class CallFunction(Extension): + def __init__( + self, + fn: Callable[..., Dict[str, Value]], + args: Optional[Sequence[Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, + report_keys: Optional[Sequence[str]] = None, + report_prefix: Optional[str] = None, + run_on_error: bool = False, + priority: int = PRIORITY_WRITER, + ) -> None: + """wrapper extension to call functions during the training loop + + Args: + fn (Callable[..., Dict[str, Value]]): Function to be called via extension. + args (Optional[Sequence[Any]], optional): Arguments to be passed to the function. Defaults to None. + kwargs (Optional[Mapping[str, Any]], optional): Keyword arguments you want to pass to the function. Defaults to None. + report_keys (Optional[Sequence[str]], optional): The key of the value to be reported among the values contained in the function's return dict. Defaults to None. + report_prefix (Optional[str], optional): If necessary, the prefix to attach to the function's return value when reporting it. Defaults to None. + run_on_error (bool, optional): Whether or not want to run when an error occurs during the training loop. Defaults to False. + priority (int, optional): When this Extension will be executed. Defaults to PRIORITY_WRITER. + """ + self._fn = fn + self._args = args or [] + self._kwargs = kwargs or {} + self._report_keys = set(report_keys) if report_keys else None + self._report_prefix = report_prefix + self._run_on_error = run_on_error + self.priority = priority + + def _call(self) -> None: + out = self._fn(*self._args, **self._kwargs) + if self._report_keys: + out = {k: v for k, v in out.items() if k in self._report_keys} + if self._report_prefix: + out = { + "/".join([self._report_prefix, k]): v for k, v in out.items() + } + + report(out) + + def __call__(self, manager: ExtensionsManagerProtocol) -> Any: + self._call() + + def on_error( + self, + manager: ExtensionsManagerProtocol, + exc: Exception, + tb: types.TracebackType, + ) -> None: + if self._run_on_error: + self._call() diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_call_function.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_call_function.py new file mode 100644 index 00000000..e77203b5 --- /dev/null +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_call_function.py @@ -0,0 +1,119 @@ +from typing import Dict +from unittest.mock import MagicMock + +import pytest +from pytorch_pfn_extras.training import ExtensionsManager +from pytorch_pfn_extras.training.extensions.call_function import CallFunction + + +def test_function_is_called() -> None: + fn = MagicMock() + args = [MagicMock()] + kwargs = {"a": MagicMock()} + + extension = CallFunction(fn=fn, args=args, kwargs=kwargs) + fn.assert_not_called() + extension(MagicMock()) + fn.assert_called_once_with(*args, **kwargs) + + +def test_report() -> None: + def add_fn(a: Dict[str, int], b: Dict[str, int]) -> Dict[str, int]: + (a_val,) = list(a.values()) + (b_val,) = list(b.values()) + return {"ret": a_val + b_val} + + a = {"0": 0} + b = {"0": 0} + + extension = CallFunction(fn=add_fn, args=[a, b]) + + epoch = 1 + iteration = 10 + manager = ExtensionsManager({}, {}, epoch, iters_per_epoch=iteration) + manager.extend(extension, trigger=(1, "iteration")) + a_val = 0 + b_val = 0 + while not manager.stop_trigger: + a["0"] = a_val + b["0"] = b_val + with manager.run_iteration(): + pass + assert manager.observation["ret"] == a_val + b_val + a_val += 1 + b_val += 1 + + +@pytest.mark.parametrize("report_keys", [["ret"], ["other"], ["ret", "other"]]) +def test_report_with_key(report_keys) -> None: + def add_fn(a: Dict[str, int], b: Dict[str, int]) -> Dict[str, int]: + (a_val,) = list(a.values()) + (b_val,) = list(b.values()) + return { + "ret": a_val + b_val, + "other": 0, + } + + a = {"0": 0} + b = {"0": 0} + + extension = CallFunction(fn=add_fn, args=[a, b], report_keys=report_keys) + + epoch = 1 + iteration = 10 + manager = ExtensionsManager({}, {}, epoch, iters_per_epoch=iteration) + manager.extend(extension, trigger=(1, "iteration")) + while not manager.stop_trigger: + with manager.run_iteration(): + pass + assert set(manager.observation.keys()) == set(report_keys) + + +def test_report_with_prefix() -> None: + def add_fn(a: Dict[str, int], b: Dict[str, int]) -> Dict[str, int]: + (a_val,) = list(a.values()) + (b_val,) = list(b.values()) + return {"ret": a_val + b_val} + + a = {"0": 0} + b = {"0": 0} + + extension = CallFunction(fn=add_fn, args=[a, b], report_prefix="prefix") + expected_keys = set(["prefix/ret"]) + + epoch = 1 + iteration = 10 + manager = ExtensionsManager({}, {}, epoch, iters_per_epoch=iteration) + manager.extend(extension, trigger=(1, "iteration")) + while not manager.stop_trigger: + with manager.run_iteration(): + pass + assert set(manager.observation.keys()) == expected_keys + + +@pytest.mark.parametrize("run_on_error", [True, False]) +def test_on_error(run_on_error) -> None: + fn = MagicMock() + args = [MagicMock()] + kwargs = {"a": MagicMock()} + + extension = CallFunction( + fn=fn, args=args, kwargs=kwargs, run_on_error=run_on_error + ) + epoch = 1 + iteration = 10 + manager = ExtensionsManager({}, {}, epoch, iters_per_epoch=iteration) + manager.extend(extension, trigger=(1000, "iteration")) + fn.assert_not_called() + with manager.run_iteration(): + pass + fn.assert_not_called() + + try: + with manager.run_iteration(): + raise RuntimeError + except RuntimeError: + if run_on_error: + fn.assert_called_once_with(*args, **kwargs) + else: + fn.assert_not_called()