Skip to content

Commit

Permalink
Make LogEncoder more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Jun 1, 2022
1 parent 3492f44 commit 58874e8
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 47 deletions.
2 changes: 2 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ disable =
no-else-return,
# Discourages small interfaces
too-few-public-methods,
# Too much old code
consider-using-f-string,

[REPORTS]
msg-template = {path}:{line:3d},{column}: {msg} ({symbol})
Expand Down
101 changes: 54 additions & 47 deletions src/garage/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def dump_json(filename, data):
"""
pathlib.Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True)
with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf-8') as f:
# We do our own circular reference handling.
# Sometimes sort_keys fails because the keys don't get made into
# strings early enough.
Expand Down Expand Up @@ -635,53 +635,60 @@ def _default_inner(self, o):
# of ValueError on a circular reference.
try:
return json.JSONEncoder.default(self, o)
except TypeError as err:
if isinstance(o, dict):
data = {}
for (k, v) in o.items():
if isinstance(k, str):
data[k] = self.default(v)
else:
data[repr(k)] = self.default(v)
return data
elif isinstance(o, weakref.ref):
return repr(o)
elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES:
return repr(o)
elif isinstance(o, type):
return {'$typename': o.__module__ + '.' + o.__name__}
elif isinstance(o, np.number):
# For some reason these aren't natively considered
# serializable.
# JSON doesn't actually have ints, so always use a float.
return float(o)
elif isinstance(o, np.bool8):
return bool(o)
elif isinstance(o, enum.Enum):
return {
'$enum':
o.__module__ + '.' + o.__class__.__name__ + '.' + o.name
}
elif isinstance(o, np.ndarray):
return repr(o)
elif hasattr(o, '__dict__') or hasattr(o, '__slots__'):
obj_dict = getattr(o, '__dict__', None)
if obj_dict is not None:
data = {k: self.default(v) for (k, v) in obj_dict.items()}
else:
data = {
s: self.default(getattr(o, s))
for s in o.__slots__
except TypeError:
try:
if isinstance(o, dict):
data = {}
for (k, v) in o.items():
if isinstance(k, str):
data[k] = self.default(v)
else:
data[repr(k)] = self.default(v)
return data
elif isinstance(o, weakref.ref):
return repr(o)
elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES:
return repr(o)
elif isinstance(o, type):
return {'$typename': o.__module__ + '.' + o.__name__}
elif isinstance(o, np.number):
# For some reason these aren't natively considered
# serializable.
# JSON doesn't actually have ints, so always use a float.
# Some strange numpy "number" types can actually be None,
# so this case can actually fail as well.
return float(o)
elif isinstance(o, np.bool8):
return bool(o)
elif isinstance(o, enum.Enum):
return {
'$enum':
o.__module__ + '.' + o.__class__.__name__ + '.' +
o.name
}
t = type(o)
data['$type'] = t.__module__ + '.' + t.__name__
return data
elif callable(o) and hasattr(o, '__name__'):
if getattr(o, '__module__', None) is not None:
return {'$function': o.__module__ + '.' + o.__name__}
else:
elif isinstance(o, np.ndarray):
return repr(o)
else:
elif hasattr(o, '__dict__') or hasattr(o, '__slots__'):
obj_dict = getattr(o, '__dict__', None)
if obj_dict is not None:
data = {
k: self.default(v)
for (k, v) in obj_dict.items()
}
else:
data = {
s: self.default(getattr(o, s))
for s in o.__slots__
}
t = type(o)
data['$type'] = t.__module__ + '.' + t.__name__
return data
elif callable(o) and hasattr(o, '__name__'):
if getattr(o, '__module__', None) is not None:
return {'$function': o.__module__ + '.' + o.__name__}
else:
return repr(o)
except TypeError:
try:
# This case handles many built-in datatypes like deques
return [self.default(v) for v in list(o)]
Expand All @@ -692,4 +699,4 @@ def _default_inner(self, o):
return repr(o)
except TypeError:
pass
raise err
return '$invalid'
15 changes: 15 additions & 0 deletions tests/garage/experiment/test_log_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import json

import numpy as np

from garage.experiment.experiment import LogEncoder


def test_encode_none_timedelta64():
value = {'test': np.timedelta64(None)}
encoded = json.dumps(value,
indent=2,
sort_keys=False,
cls=LogEncoder,
check_circular=False)
assert 'test' in encoded

0 comments on commit 58874e8

Please sign in to comment.