Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refcator] Extract message types as MessageType Enum in ProcessEnvironment #938

Merged
merged 1 commit into from
Jul 21, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions alf/environments/process_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from absl import logging
import atexit
from enum import Enum
import multiprocessing
import numpy as np
import sys
Expand Down Expand Up @@ -49,16 +50,21 @@ def _tensor_to_array(obj):
return nest.map_structure(_tensor_to_array, data)


class ProcessEnvironment(object):
class _MessageType(Enum):
"""Message types for communication via the pipe.

The ProcessEnvironment uses pipe to perform IPC, where each of the message
has a message type. This Enum provides all the available message types.
"""
READY = 1
ACCESS = 2
CALL = 3
RESULT = 4
EXCEPTION = 5
CLOSE = 6

# Message types for communication via the pipe.
_READY = 1
_ACCESS = 2
_CALL = 3
_RESULT = 4
_EXCEPTION = 5
_CLOSE = 6

class ProcessEnvironment(object):
def __init__(self, env_constructor, env_id=None, flatten=False):
"""Step environment in a separate process for lock free paralellism.

Expand Down Expand Up @@ -109,7 +115,7 @@ def wait_start(self):
self._conn.close()
self._process.join(5)
raise result
assert result == self._READY, result
assert result == _MessageType.READY, result

def env_info_spec(self):
if not self._env_info_spec:
Expand Down Expand Up @@ -148,7 +154,7 @@ def __getattr__(self, name):
Returns:
Value of the attribute.
"""
self._conn.send((self._ACCESS, name))
self._conn.send((_MessageType.ACCESS, name))
return self._receive()

def call(self, name, *args, **kwargs):
Expand All @@ -164,13 +170,13 @@ def call(self, name, *args, **kwargs):
"""
payload = name, args, kwargs
payload = tensor_to_array(payload)
self._conn.send((self._CALL, payload))
self._conn.send((_MessageType.CALL, payload))
return self._receive

def close(self):
"""Send a close message to the external process and join it."""
try:
self._conn.send((self._CLOSE, None))
self._conn.send((_MessageType.CLOSE, None))
self._conn.close()
except IOError:
# The connection was already closed.
Expand Down Expand Up @@ -223,10 +229,10 @@ def _receive(self):
payload = array_to_tensor(payload)

# Re-raise exceptions in the main process.
if message == self._EXCEPTION:
if message == _MessageType.EXCEPTION:
stacktrace = payload
raise Exception(stacktrace)
if message == self._RESULT:
if message == _MessageType.RESULT:
return payload
self.close()
raise KeyError(
Expand All @@ -248,7 +254,7 @@ def _worker(self, conn, env_constructor, env_id=None, flatten=False):
alf.set_default_device("cpu")
env = env_constructor(env_id=env_id)
action_spec = env.action_spec()
conn.send(self._READY) # Ready.
conn.send(_MessageType.READY) # Ready.
while True:
try:
# Only block for short times to have keyboard exceptions be raised.
Expand All @@ -257,12 +263,12 @@ def _worker(self, conn, env_constructor, env_id=None, flatten=False):
message, payload = conn.recv()
except (EOFError, KeyboardInterrupt):
break
if message == self._ACCESS:
if message == _MessageType.ACCESS:
name = payload
result = getattr(env, name)
conn.send((self._RESULT, result))
conn.send((_MessageType.RESULT, result))
continue
if message == self._CALL:
if message == _MessageType.CALL:
name, args, kwargs = payload
if flatten and name == 'step':
args = [nest.pack_sequence_as(action_spec, args[0])]
Expand All @@ -272,9 +278,9 @@ def _worker(self, conn, env_constructor, env_id=None, flatten=False):
assert all([
not isinstance(x, torch.Tensor) for x in result
]), ("Tensor result is not allowed: %s" % name)
conn.send((self._RESULT, result))
conn.send((_MessageType.RESULT, result))
continue
if message == self._CLOSE:
if message == _MessageType.CLOSE:
assert payload is None
env.close()
break
Expand All @@ -285,7 +291,7 @@ def _worker(self, conn, env_constructor, env_id=None, flatten=False):
stacktrace = ''.join(traceback.format_exception(etype, evalue, tb))
message = 'Error in environment process: {}'.format(stacktrace)
logging.error(message)
conn.send((self._EXCEPTION, stacktrace))
conn.send((_MessageType.EXCEPTION, stacktrace))
finally:
conn.close()

Expand Down