From af07a6c848c38564ec908a33dff03f3bbca95c07 Mon Sep 17 00:00:00 2001 From: "K.R. Zentner" Date: Wed, 1 Jun 2022 02:27:43 -0700 Subject: [PATCH] Make LogEncoder more robust --- .pylintrc | 2 + src/garage/experiment/experiment.py | 220 ++++++++++++-------- tests/garage/experiment/test_log_encoder.py | 59 ++++++ 3 files changed, 192 insertions(+), 89 deletions(-) create mode 100644 tests/garage/experiment/test_log_encoder.py diff --git a/.pylintrc b/.pylintrc index b2424cdb4f..1fd0370126 100644 --- a/.pylintrc +++ b/.pylintrc @@ -34,6 +34,8 @@ disable = no-else-return, # Discourages small interfaces too-few-public-methods, + # Too much old code + consider-using-f-string, [REPORTS] msg-template = {path}:{line:3d},{column}: {msg} ({symbol}) diff --git a/src/garage/experiment/experiment.py b/src/garage/experiment/experiment.py index 101540a0cb..17f0c2c08f 100644 --- a/src/garage/experiment/experiment.py +++ b/src/garage/experiment/experiment.py @@ -455,18 +455,29 @@ def dump_json(filename, data): filename(str): Filename for the file. data(dict): Data to save to file. + Raises: + KeyboardInterrupt: If the user issued a KeyboardInterrupt. + """ pathlib.Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True) - with open(filename, 'w') as f: + with open(filename, 'w', encoding='utf-8') as f: # We do our own circular reference handling. # Sometimes sort_keys fails because the keys don't get made into # strings early enough. - json.dump(data, - f, - indent=2, - sort_keys=False, - cls=LogEncoder, - check_circular=False) + # This feature is useful, but causes way too many weird errors. + # For this reason we catch almost any exception. + # pylint: disable=broad-except + try: + json.dump(data, + f, + indent=2, + sort_keys=False, + cls=LogEncoder, + check_circular=False) + except KeyboardInterrupt as e: + raise e + except Exception: + pass def get_metadata(): @@ -579,44 +590,80 @@ def __init__(self, *args, **kwargs): 'itertools', } - def default(self, o): + def default(self, o, path=''): """Perform JSON encoding. Args: o (object): Object to encode. - - Raises: - TypeError: If `o` cannot be turned into JSON even using `repr(o)`. + path (str): "Path" to o for describing backreferences. Returns: dict or str or float or bool: Object encoded in JSON. """ - # Why is this method hidden? What does that mean? - # pylint: disable=method-hidden - # pylint: disable=too-many-branches - # pylint: disable=too-many-return-statements # This circular reference checking code was copied from the standard - # library json implementation, but it outputs a repr'd string instead - # of ValueError on a circular reference. - if isinstance(o, (int, bool, float, str)): + # library json implementation, but it detects all backreferences + # instead of just circular ones and outputs a message with the path to + # the prior reference instead of raising a ValueError + if isinstance(o, (int, bool, float, str, type(None))): return o else: markerid = id(o) if markerid in self._markers: - return 'circular ' + repr(o) + original_path = self._markers[markerid] + return f'reference to {original_path}' else: - self._markers[markerid] = o - try: - return self._default_inner(o) - finally: - del self._markers[markerid] + self._markers[markerid] = path + return self._default_general_cases(o, path) - def _default_inner(self, o): - """Perform JSON encoding. + def _default_general_cases(self, o, path): + """Handle JSON encoding among the "general" cases. + + First, tries to just use the default encoder, then various special + cases, then by turning o into a list, then by calling repr, and finally + by returning the string '$invalid' Args: o (object): Object to encode. + path (str): "Path" to o for describing backreferences. + + Returns: + dict or str or float or bool: Object encoded in JSON. + + """ + try: + return json.JSONEncoder.default(self, o) + except TypeError: + pass + try: + return self._default_special_cases(o, path) + except TypeError: + pass + try: + # This case handles many built-in datatypes like deques + return [ + self.default(v, f'{path}/{i}') for i, v in enumerate(list(o)) + ] + except TypeError: + pass + try: + # This case handles most other weird objects. + return repr(o) + except TypeError: + pass + except ValueError: + pass + return '$invalid' + + def _default_special_cases(self, o, path): + """Handle various special cases we frequently want to JSON encode. + + Note that these cases aren't _that_ special, and include dicts, enums, + np.numbers, etc. + + Args: + o (object): Object to encode. + path (str): "Path" to o for describing backreferences. Raises: TypeError: If `o` cannot be turned into JSON even using `repr(o)`. @@ -626,70 +673,65 @@ def _default_inner(self, o): dict or str or float or bool: Object encoded in JSON. """ - # Why is this method hidden? What does that mean? - # pylint: disable=method-hidden # pylint: disable=too-many-branches # pylint: disable=too-many-return-statements - # This circular reference checking code was copied from the standard - # library json implementation, but it outputs a repr'd string instead - # of ValueError on a circular reference. - try: - return json.JSONEncoder.default(self, o) - except TypeError as err: - if isinstance(o, dict): - data = {} - for (k, v) in o.items(): - if isinstance(k, str): - data[k] = self.default(v) - else: - data[repr(k)] = self.default(v) - return data - elif isinstance(o, weakref.ref): - return repr(o) - elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES: - return repr(o) - elif isinstance(o, type): - return {'$typename': o.__module__ + '.' + o.__name__} - elif isinstance(o, np.number): - # For some reason these aren't natively considered - # serializable. - # JSON doesn't actually have ints, so always use a float. - return float(o) - elif isinstance(o, np.bool8): - return bool(o) - elif isinstance(o, enum.Enum): - return { - '$enum': - o.__module__ + '.' + o.__class__.__name__ + '.' + o.name - } - elif isinstance(o, np.ndarray): - return repr(o) - elif hasattr(o, '__dict__') or hasattr(o, '__slots__'): - obj_dict = getattr(o, '__dict__', None) - if obj_dict is not None: - data = {k: self.default(v) for (k, v) in obj_dict.items()} - else: - data = { - s: self.default(getattr(o, s)) - for s in o.__slots__ - } - t = type(o) - data['$type'] = t.__module__ + '.' + t.__name__ - return data - elif callable(o) and hasattr(o, '__name__'): - if getattr(o, '__module__', None) is not None: - return {'$function': o.__module__ + '.' + o.__name__} + if isinstance(o, dict): + data = {} + for (k, v) in o.items(): + if isinstance(k, str): + data[k] = self.default(v, f'{path}/{k}') else: - return repr(o) + data[repr(k)] = self.default(v, f'{path}/{k!r}') + return data + elif isinstance(o, weakref.ref): + return repr(o) + elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES: + return repr(o) + elif isinstance(o, type): + return {'$typename': o.__module__ + '.' + o.__name__} + elif isinstance(o, np.bool8): + return bool(o) + elif isinstance(o, np.number): + # For some reason these aren't natively considered + # serializable. + # JSON doesn't actually have ints, so always use a float. + # Some strange numpy "number" types can actually be None, + # so this case can actually fail as well, which will then fall back + # to one of the general cases. + return float(o) + elif isinstance(o, enum.Enum): + return { + '$enum': + o.__module__ + '.' + o.__class__.__name__ + '.' + o.name + } + elif isinstance(o, np.ndarray): + return repr(o) + elif hasattr(o, '__dict__') or hasattr(o, '__slots__'): + obj_dict = getattr(o, '__dict__', None) + if obj_dict is not None: + # Some objects will change their fields while being + # iterated over, so make a copy of their dictionary. + obj_dict = obj_dict.copy() + data = { + k: self.default(v, f'{path}/{k}') + for (k, v) in obj_dict.items() + # There's a lot of spam from empty dict / list fields + # The output of this JSONEncoder is not intended to be + # loaded back into the original objects anyways. + if not isinstance(v, (list, dict, set, tuple)) or v + } else: - try: - # This case handles many built-in datatypes like deques - return [self.default(v) for v in list(o)] - except TypeError: - pass - try: - # This case handles most other weird objects. - return repr(o) - except TypeError: - pass - raise err + data = { + s: self.default(getattr(o, s), f'{path}/{s}') + for s in o.__slots__ + } + t = type(o) + data['$type'] = t.__module__ + '.' + t.__name__ + return data + elif callable(o) and hasattr(o, '__name__'): + if getattr(o, '__module__', None) is not None: + return {'$function': o.__module__ + '.' + o.__name__} + else: + return repr(o) + else: + raise TypeError('Could not JSON encode object') diff --git a/tests/garage/experiment/test_log_encoder.py b/tests/garage/experiment/test_log_encoder.py new file mode 100644 index 0000000000..f4998a5df5 --- /dev/null +++ b/tests/garage/experiment/test_log_encoder.py @@ -0,0 +1,59 @@ +import json + +import numpy as np +import torch + +from garage.envs import GymEnv, normalize +from garage.experiment import deterministic +from garage.experiment.experiment import LogEncoder +from garage.plotter import Plotter +from garage.sampler import LocalSampler +from garage.torch.algos import PPO +from garage.torch.policies import GaussianMLPPolicy +from garage.torch.value_functions import GaussianMLPValueFunction +from garage.trainer import Trainer + +from tests.fixtures import snapshot_config + + +def test_encode_none_timedelta64(): + value = {'test': np.timedelta64(None)} + encoded = json.dumps(value, + indent=2, + sort_keys=False, + cls=LogEncoder, + check_circular=False) + assert 'test' in encoded + + +def test_encode_trainer(): + env = normalize(GymEnv('InvertedDoublePendulum-v2')) + policy = GaussianMLPPolicy( + env_spec=env.spec, + hidden_sizes=(64, 64), + hidden_nonlinearity=torch.tanh, + output_nonlinearity=None, + ) + value_function = GaussianMLPValueFunction(env_spec=env.spec) + sampler = LocalSampler(agents=policy, + envs=env, + max_episode_length=env.spec.max_episode_length, + is_tf_worker=False) + + trainer = Trainer(snapshot_config) + algo = PPO(env_spec=env.spec, + policy=policy, + value_function=value_function, + sampler=sampler, + discount=0.99, + gae_lambda=0.97, + lr_clip_range=2e-1) + + trainer.setup(algo, env) + encoded = json.dumps(trainer, + indent=2, + sort_keys=False, + cls=LogEncoder, + check_circular=False) + print(encoded) + assert 'value_function' in encoded