Skip to content

Commit

Permalink
Log all start info to experiment.json (#2024)
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner authored Nov 6, 2020
1 parent cdda5fc commit e1c576f
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 13 deletions.
144 changes: 131 additions & 13 deletions src/garage/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import pathlib
import subprocess
import warnings
import weakref

import dateutil.tz
import dowel
from dowel import logger
import numpy as np

import __main__ as main

Expand Down Expand Up @@ -456,7 +458,15 @@ def dump_json(filename, data):
"""
pathlib.Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True)
with open(filename, 'w') as f:
json.dump(data, f, indent=2, sort_keys=True, cls=LogEncoder)
# We do our own circular reference handling.
# Sometimes sort_keys fails because the keys don't get made into
# strings early enough.
json.dump(data,
f,
indent=2,
sort_keys=False,
cls=LogEncoder,
check_circular=False)


def get_metadata():
Expand Down Expand Up @@ -550,28 +560,136 @@ def make_launcher_archive(*, git_root_path, log_dir):


class LogEncoder(json.JSONEncoder):
"""Encoder to be used as cls in json.dump."""
"""Encoder to be used as cls in json.dump.
Args:
args (object): Passed to super class.
kwargs (dict): Passed to super class.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._markers = {}

# Modules whose contents cannot be meaningfully or safelly jsonified.
BLOCKED_MODULES = {
'tensorflow',
'ray',
'itertools',
}

def default(self, o):
"""Perform JSON encoding.
Args:
o (object): Object to encode.
Raises:
TypeError: If `o` cannot be turned into JSON even using `repr(o)`.
Returns:
str: Object encoded in JSON.
dict or str or float or bool: Object encoded in JSON.
"""
# Why is this method hidden? What does that mean?
# pylint: disable=method-hidden
# pylint: disable=too-many-branches
# pylint: disable=too-many-return-statements
# This circular reference checking code was copied from the standard
# library json implementation, but it outputs a repr'd string instead
# of ValueError on a circular reference.
if isinstance(o, (int, bool, float, str)):
return o
else:
markerid = id(o)
if markerid in self._markers:
return 'circular ' + repr(o)
else:
self._markers[markerid] = o
try:
return self._default_inner(o)
finally:
del self._markers[markerid]

def _default_inner(self, o):
"""Perform JSON encoding.
Args:
o (object): Object to encode.
Raises:
TypeError: If `o` cannot be turned into JSON even using `repr(o)`.
ValueError: If raised by calling repr on an object.
if isinstance(o, type):
return {'$class': o.__module__ + '.' + o.__name__}
elif isinstance(o, enum.Enum):
return {
'$enum':
o.__module__ + '.' + o.__class__.__name__ + '.' + o.name
}
elif callable(o):
return {'$function': o.__module__ + '.' + o.__name__}
return json.JSONEncoder.default(self, o)
Returns:
dict or str or float or bool: Object encoded in JSON.
"""
# Why is this method hidden? What does that mean?
# pylint: disable=method-hidden
# pylint: disable=too-many-branches
# pylint: disable=too-many-return-statements
# This circular reference checking code was copied from the standard
# library json implementation, but it outputs a repr'd string instead
# of ValueError on a circular reference.
try:
return json.JSONEncoder.default(self, o)
except TypeError as err:
if isinstance(o, dict):
data = {}
for (k, v) in o.items():
if isinstance(k, str):
data[k] = self.default(v)
else:
data[repr(k)] = self.default(v)
return data
elif isinstance(o, weakref.ref):
return repr(o)
elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES:
return repr(o)
elif isinstance(o, type):
return {'$typename': o.__module__ + '.' + o.__name__}
elif isinstance(o, np.number):
# For some reason these aren't natively considered
# serializable.
# JSON doesn't actually have ints, so always use a float.
return float(o)
elif isinstance(o, np.bool8):
return bool(o)
elif isinstance(o, enum.Enum):
return {
'$enum':
o.__module__ + '.' + o.__class__.__name__ + '.' + o.name
}
elif isinstance(o, np.ndarray):
return repr(o)
elif hasattr(o, '__dict__') or hasattr(o, '__slots__'):
obj_dict = getattr(o, '__dict__', None)
if obj_dict is not None:
data = {k: self.default(v) for (k, v) in obj_dict.items()}
else:
data = {
s: self.default(getattr(o, s))
for s in o.__slots__
}
t = type(o)
data['$type'] = t.__module__ + '.' + t.__name__
return data
elif callable(o) and hasattr(o, '__name__'):
if getattr(o, '__module__', None) is not None:
return {'$function': o.__module__ + '.' + o.__name__}
else:
return repr(o)
else:
try:
# This case handles many built-in datatypes like deques
return [self.default(v) for v in list(o)]
except TypeError:
pass
try:
# This case handles most other weird objects.
return repr(o)
except TypeError:
pass
raise err
5 changes: 5 additions & 0 deletions src/garage/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# This is avoiding a circular import
from garage.experiment.deterministic import get_seed, set_seed
from garage.experiment.experiment import dump_json
from garage.experiment.snapshotter import Snapshotter
from garage.sampler.default_worker import DefaultWorker
from garage.sampler.worker_factory import WorkerFactory
Expand Down Expand Up @@ -518,6 +519,10 @@ def train(self,
self._plot = plot
self._start_worker()

log_dir = self._snapshotter.snapshot_dir
summary_file = os.path.join(log_dir, 'experiment.json')
dump_json(summary_file, self)

average_return = self._algo.train(self)
self._shutdown_worker()

Expand Down

0 comments on commit e1c576f

Please sign in to comment.