Skip to content

Commit

Permalink
Make LogEncoder more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Jun 2, 2022
1 parent 3492f44 commit af07a6c
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 89 deletions.
2 changes: 2 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ disable =
no-else-return,
# Discourages small interfaces
too-few-public-methods,
# Too much old code
consider-using-f-string,

[REPORTS]
msg-template = {path}:{line:3d},{column}: {msg} ({symbol})
Expand Down
220 changes: 131 additions & 89 deletions src/garage/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,29 @@ def dump_json(filename, data):
filename(str): Filename for the file.
data(dict): Data to save to file.
Raises:
KeyboardInterrupt: If the user issued a KeyboardInterrupt.
"""
pathlib.Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True)
with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf-8') as f:
# We do our own circular reference handling.
# Sometimes sort_keys fails because the keys don't get made into
# strings early enough.
json.dump(data,
f,
indent=2,
sort_keys=False,
cls=LogEncoder,
check_circular=False)
# This feature is useful, but causes way too many weird errors.
# For this reason we catch almost any exception.
# pylint: disable=broad-except
try:
json.dump(data,
f,
indent=2,
sort_keys=False,
cls=LogEncoder,
check_circular=False)
except KeyboardInterrupt as e:
raise e
except Exception:
pass


def get_metadata():
Expand Down Expand Up @@ -579,44 +590,80 @@ def __init__(self, *args, **kwargs):
'itertools',
}

def default(self, o):
def default(self, o, path=''):
"""Perform JSON encoding.
Args:
o (object): Object to encode.
Raises:
TypeError: If `o` cannot be turned into JSON even using `repr(o)`.
path (str): "Path" to o for describing backreferences.
Returns:
dict or str or float or bool: Object encoded in JSON.
"""
# Why is this method hidden? What does that mean?
# pylint: disable=method-hidden
# pylint: disable=too-many-branches
# pylint: disable=too-many-return-statements
# This circular reference checking code was copied from the standard
# library json implementation, but it outputs a repr'd string instead
# of ValueError on a circular reference.
if isinstance(o, (int, bool, float, str)):
# library json implementation, but it detects all backreferences
# instead of just circular ones and outputs a message with the path to
# the prior reference instead of raising a ValueError
if isinstance(o, (int, bool, float, str, type(None))):
return o
else:
markerid = id(o)
if markerid in self._markers:
return 'circular ' + repr(o)
original_path = self._markers[markerid]
return f'reference to {original_path}'
else:
self._markers[markerid] = o
try:
return self._default_inner(o)
finally:
del self._markers[markerid]
self._markers[markerid] = path
return self._default_general_cases(o, path)

def _default_inner(self, o):
"""Perform JSON encoding.
def _default_general_cases(self, o, path):
"""Handle JSON encoding among the "general" cases.
First, tries to just use the default encoder, then various special
cases, then by turning o into a list, then by calling repr, and finally
by returning the string '$invalid'
Args:
o (object): Object to encode.
path (str): "Path" to o for describing backreferences.
Returns:
dict or str or float or bool: Object encoded in JSON.
"""
try:
return json.JSONEncoder.default(self, o)
except TypeError:
pass
try:
return self._default_special_cases(o, path)
except TypeError:
pass
try:
# This case handles many built-in datatypes like deques
return [
self.default(v, f'{path}/{i}') for i, v in enumerate(list(o))
]
except TypeError:
pass
try:
# This case handles most other weird objects.
return repr(o)
except TypeError:
pass
except ValueError:
pass
return '$invalid'

def _default_special_cases(self, o, path):
"""Handle various special cases we frequently want to JSON encode.
Note that these cases aren't _that_ special, and include dicts, enums,
np.numbers, etc.
Args:
o (object): Object to encode.
path (str): "Path" to o for describing backreferences.
Raises:
TypeError: If `o` cannot be turned into JSON even using `repr(o)`.
Expand All @@ -626,70 +673,65 @@ def _default_inner(self, o):
dict or str or float or bool: Object encoded in JSON.
"""
# Why is this method hidden? What does that mean?
# pylint: disable=method-hidden
# pylint: disable=too-many-branches
# pylint: disable=too-many-return-statements
# This circular reference checking code was copied from the standard
# library json implementation, but it outputs a repr'd string instead
# of ValueError on a circular reference.
try:
return json.JSONEncoder.default(self, o)
except TypeError as err:
if isinstance(o, dict):
data = {}
for (k, v) in o.items():
if isinstance(k, str):
data[k] = self.default(v)
else:
data[repr(k)] = self.default(v)
return data
elif isinstance(o, weakref.ref):
return repr(o)
elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES:
return repr(o)
elif isinstance(o, type):
return {'$typename': o.__module__ + '.' + o.__name__}
elif isinstance(o, np.number):
# For some reason these aren't natively considered
# serializable.
# JSON doesn't actually have ints, so always use a float.
return float(o)
elif isinstance(o, np.bool8):
return bool(o)
elif isinstance(o, enum.Enum):
return {
'$enum':
o.__module__ + '.' + o.__class__.__name__ + '.' + o.name
}
elif isinstance(o, np.ndarray):
return repr(o)
elif hasattr(o, '__dict__') or hasattr(o, '__slots__'):
obj_dict = getattr(o, '__dict__', None)
if obj_dict is not None:
data = {k: self.default(v) for (k, v) in obj_dict.items()}
else:
data = {
s: self.default(getattr(o, s))
for s in o.__slots__
}
t = type(o)
data['$type'] = t.__module__ + '.' + t.__name__
return data
elif callable(o) and hasattr(o, '__name__'):
if getattr(o, '__module__', None) is not None:
return {'$function': o.__module__ + '.' + o.__name__}
if isinstance(o, dict):
data = {}
for (k, v) in o.items():
if isinstance(k, str):
data[k] = self.default(v, f'{path}/{k}')
else:
return repr(o)
data[repr(k)] = self.default(v, f'{path}/{k!r}')
return data
elif isinstance(o, weakref.ref):
return repr(o)
elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES:
return repr(o)
elif isinstance(o, type):
return {'$typename': o.__module__ + '.' + o.__name__}
elif isinstance(o, np.bool8):
return bool(o)
elif isinstance(o, np.number):
# For some reason these aren't natively considered
# serializable.
# JSON doesn't actually have ints, so always use a float.
# Some strange numpy "number" types can actually be None,
# so this case can actually fail as well, which will then fall back
# to one of the general cases.
return float(o)
elif isinstance(o, enum.Enum):
return {
'$enum':
o.__module__ + '.' + o.__class__.__name__ + '.' + o.name
}
elif isinstance(o, np.ndarray):
return repr(o)
elif hasattr(o, '__dict__') or hasattr(o, '__slots__'):
obj_dict = getattr(o, '__dict__', None)
if obj_dict is not None:
# Some objects will change their fields while being
# iterated over, so make a copy of their dictionary.
obj_dict = obj_dict.copy()
data = {
k: self.default(v, f'{path}/{k}')
for (k, v) in obj_dict.items()
# There's a lot of spam from empty dict / list fields
# The output of this JSONEncoder is not intended to be
# loaded back into the original objects anyways.
if not isinstance(v, (list, dict, set, tuple)) or v
}
else:
try:
# This case handles many built-in datatypes like deques
return [self.default(v) for v in list(o)]
except TypeError:
pass
try:
# This case handles most other weird objects.
return repr(o)
except TypeError:
pass
raise err
data = {
s: self.default(getattr(o, s), f'{path}/{s}')
for s in o.__slots__
}
t = type(o)
data['$type'] = t.__module__ + '.' + t.__name__
return data
elif callable(o) and hasattr(o, '__name__'):
if getattr(o, '__module__', None) is not None:
return {'$function': o.__module__ + '.' + o.__name__}
else:
return repr(o)
else:
raise TypeError('Could not JSON encode object')
59 changes: 59 additions & 0 deletions tests/garage/experiment/test_log_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json

import numpy as np
import torch

from garage.envs import GymEnv, normalize
from garage.experiment import deterministic
from garage.experiment.experiment import LogEncoder
from garage.plotter import Plotter
from garage.sampler import LocalSampler
from garage.torch.algos import PPO
from garage.torch.policies import GaussianMLPPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer

from tests.fixtures import snapshot_config


def test_encode_none_timedelta64():
value = {'test': np.timedelta64(None)}
encoded = json.dumps(value,
indent=2,
sort_keys=False,
cls=LogEncoder,
check_circular=False)
assert 'test' in encoded


def test_encode_trainer():
env = normalize(GymEnv('InvertedDoublePendulum-v2'))
policy = GaussianMLPPolicy(
env_spec=env.spec,
hidden_sizes=(64, 64),
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None,
)
value_function = GaussianMLPValueFunction(env_spec=env.spec)
sampler = LocalSampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=False)

trainer = Trainer(snapshot_config)
algo = PPO(env_spec=env.spec,
policy=policy,
value_function=value_function,
sampler=sampler,
discount=0.99,
gae_lambda=0.97,
lr_clip_range=2e-1)

trainer.setup(algo, env)
encoded = json.dumps(trainer,
indent=2,
sort_keys=False,
cls=LogEncoder,
check_circular=False)
print(encoded)
assert 'value_function' in encoded

0 comments on commit af07a6c

Please sign in to comment.