Skip to content

Commit

Permalink
Move async interaction from Coordinator to BaseSimulator.
Browse files Browse the repository at this point in the history
This change adds asynchronous interaction to all simulators (i.e. derived
classes of `BaseSimulator`), making this type of interaction less specific
about our specific usage in the `Coordinator`. This is an ongoing effort to
separate the RL parts in AndroidEnv (mostly `AndroidEnv` and `TaskManager`)
from the more generic parts like `BaseSimulator`. The interface of *derived*
classes changed a little: instead of implementing `get_screenshot()` they now
have to implement `_get_screenshot_impl()`.

This change also contains a small fix to `BaseSimulator.__init__()` which
instead of taking individual params, it now takes a single `SimulatorConfig`,
making its initialization the same as other major components. This requires a
small change to the `super().__init__()` call in derived classes, but it
prepares them for future changes to `SimulatorConfig`.

PiperOrigin-RevId: 684273378
  • Loading branch information
kenjitoyama authored and copybara-github committed Oct 10, 2024
1 parent 7ff414d commit efc8e48
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 222 deletions.
8 changes: 4 additions & 4 deletions android_env/components/config_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ class CoordinatorConfig:

# Number of virtual "fingers" of the agent.
num_fingers: int = 1
# How often to (asynchronously) grab the screenshot from the simulator.
# If <= 0, stepping the environment blocks on fetching the screenshot (the
# environment is synchronous).
interaction_rate_sec: float = 0.0
# Whether to enable keyboard key events.
enable_key_events: bool = False
# Whether to show circles on the screen indicating touch position.
Expand All @@ -67,6 +63,10 @@ class SimulatorConfig:

# If true, the log stream of the simulator will be verbose.
verbose_logs: bool = False
# How often to (asynchronously) grab the screenshot from the simulator.
# If <= 0, stepping the environment blocks on fetching the screenshot (the
# environment is synchronous).
interaction_rate_sec: float = 0.0


@dataclasses.dataclass
Expand Down
57 changes: 1 addition & 56 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import copy
import socket
import threading
import time
from typing import Any

Expand Down Expand Up @@ -56,7 +55,6 @@ def __init__(
self._config = config or config_classes.CoordinatorConfig()
self._adb_call_parser: adb_call_parser.AdbCallParser = None
self._orientation = np.zeros(4, dtype=np.uint8)
self._interaction_thread: InteractionThread | None = None

# The size of the device screen in pixels (H x W).
self._screen_size = np.array([0, 0], dtype=np.int32)
Expand Down Expand Up @@ -171,11 +169,6 @@ def _launch_simulator(self, max_retries: int = 3):

self._simulator_healthy = False

# Stop screenshot thread.
if self._interaction_thread is not None:
self._interaction_thread.stop()
self._interaction_thread.join()

# Attempt to restart the system a given number of times.
num_tries = 1
latest_error = None
Expand Down Expand Up @@ -221,11 +214,6 @@ def _launch_simulator(self, max_retries: int = 3):
self._simulator_healthy = True
self._stats['relaunch_count'] += 1
break
if self._config.interaction_rate_sec > 0:
self._interaction_thread = InteractionThread(
self._simulator, self._config.interaction_rate_sec
)
self._interaction_thread.start()

def _update_settings(self) -> None:
"""Updates some internal state and preferences given in the constructor."""
Expand Down Expand Up @@ -345,15 +333,8 @@ def _gather_simulator_signals(self) -> dict[str, np.ndarray]:
)
self._latest_observation_time = now

# Grab pixels.
if self._config.interaction_rate_sec > 0:
assert self._interaction_thread is not None
pixels = self._interaction_thread.screenshot() # Async mode.
else:
pixels = self._simulator.get_screenshot() # Sync mode.

return {
'pixels': pixels,
'pixels': self._simulator.get_screenshot(),
'orientation': self._orientation,
'timedelta': np.array(timestamp_delta, dtype=np.int64),
}
Expand Down Expand Up @@ -455,44 +436,8 @@ def stats(self) -> dict[str, Any]:

def close(self):
"""Cleans up the state of this Coordinator."""
if self._interaction_thread is not None:
self._interaction_thread.stop()
self._interaction_thread.join()

if hasattr(self, '_task_manager'):
self._task_manager.stop()
if hasattr(self, '_simulator'):
self._simulator.close()


class InteractionThread(threading.Thread):
"""A thread that interacts with a simulator."""

def __init__(
self, simulator: base_simulator.BaseSimulator, interaction_rate_sec: float
):
super().__init__()
self._simulator = simulator
self._interaction_rate_sec = interaction_rate_sec
self._should_stop = threading.Event()
self._screenshot = self._simulator.get_screenshot()

def run(self):
last_read = time.time()
while not self._should_stop.is_set():
self._screenshot = self._simulator.get_screenshot()

now = time.time()
elapsed = now - last_read
last_read = now
sleep_time = self._interaction_rate_sec - elapsed
if sleep_time > 0.0:
time.sleep(sleep_time)
logging.info('InteractionThread.run() finished.')

def stop(self):
logging.info('Stopping InteractionThread.')
self._should_stop.set()

def screenshot(self):
return self._screenshot
148 changes: 0 additions & 148 deletions android_env/components/coordinator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@
import numpy as np


class MockScreenshotGetter:
def __init__(self):
self._screenshot_index = 0

def get_screenshot(self):
self._screenshot_index += 1
return np.array(self._screenshot_index, ndmin=3)


class CoordinatorTest(parameterized.TestCase):

def setUp(self):
Expand Down Expand Up @@ -146,145 +137,6 @@ def fake_rl_step(simulator_signals):
self.assertEqual(timestep.reward, 0.0)
self.assertTrue(timestep.last())

@mock.patch.object(time, 'sleep', autospec=True)
def test_process_action_error_async(self, unused_mock_sleep):
mock_interaction_thread = mock.create_autospec(
coordinator_lib.InteractionThread)
with mock.patch.object(
coordinator_lib,
'InteractionThread',
autospec=True,
return_value=mock_interaction_thread):
coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
num_fingers=1, interaction_rate_sec=0.016
),
)

def fake_rl_step(agent_action, simulator_signals):
del agent_action
self.assertFalse(simulator_signals['simulator_healthy'])
return dm_env.truncation(reward=0.0, observation=None)

self._task_manager.rl_step.side_effect = fake_rl_step
mock_interaction_thread.screenshot.side_effect = errors.ReadObservationError(
)
timestep = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
self.assertIsNone(timestep.observation)
self.assertEqual(timestep.reward, 0.0)
self.assertTrue(timestep.last())
coordinator.close()

def test_async_step_faster_than_screenshot(self):
"""Return same screenshot when step is faster than the interaction rate."""
screenshot_getter = MockScreenshotGetter()
self._simulator.get_screenshot.side_effect = screenshot_getter.get_screenshot
def fake_rl_step(simulator_signals):
return dm_env.transition(
reward=10.0,
observation={
'pixels': simulator_signals['pixels'],
'orientation': simulator_signals['orientation'],
'timedelta': simulator_signals['timedelta'],
'extras': {
'extra': [0.0]
}
})
self._task_manager.rl_step.side_effect = fake_rl_step
coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
num_fingers=1, interaction_rate_sec=0.5
),
)
ts1 = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
ts2 = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
np.testing.assert_almost_equal(ts2.observation['pixels'],
ts1.observation['pixels'])
coordinator.close()

def test_async_step_slower_than_screenshot(self):
"""Return different screenshots when step slower than the interaction rate."""
screenshot_getter = MockScreenshotGetter()
self._simulator.get_screenshot.side_effect = screenshot_getter.get_screenshot

def fake_rl_step(simulator_signals):
return dm_env.transition(
reward=10.0,
observation={
'pixels': simulator_signals['pixels'],
'orientation': simulator_signals['orientation'],
'timedelta': simulator_signals['timedelta'],
'extras': {
'extra': [0.0]
}
})

self._task_manager.rl_step.side_effect = fake_rl_step
coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
num_fingers=1, interaction_rate_sec=0.01
),
)
ts1 = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
time.sleep(0.5)
ts2 = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal,
ts2.observation['pixels'],
ts1.observation['pixels'])
coordinator.close()

def test_interaction_thread_closes_upon_relaunch(self):
"""Async coordinator should kill the InteractionThread when relaunching."""
mock_interaction_thread = mock.create_autospec(
coordinator_lib.InteractionThread)
with mock.patch.object(
coordinator_lib,
'InteractionThread',
autospec=True,
return_value=mock_interaction_thread):
coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
num_fingers=1,
periodic_restart_time_min=1e-6,
interaction_rate_sec=0.5,
),
)
mock_interaction_thread.stop.assert_not_called()
mock_interaction_thread.join.assert_not_called()
time.sleep(0.1)
coordinator.rl_reset()
mock_interaction_thread.stop.assert_called_once()
mock_interaction_thread.join.assert_called_once()
coordinator.close()

@mock.patch.object(time, 'sleep', autospec=True)
def test_execute_action_touch(self, unused_mock_sleep):

Expand Down
69 changes: 64 additions & 5 deletions android_env/components/simulators/base_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
"""A base class for talking to different types of Android simulators."""

import abc
from collections.abc import Callable
import threading
import time

from absl import logging
from android_env.components import adb_controller
from android_env.components import config_classes
from android_env.components import errors
from android_env.components import log_stream
from android_env.proto import state_pb2
Expand All @@ -28,18 +32,19 @@
class BaseSimulator(metaclass=abc.ABCMeta):
"""An interface for communicating with an Android simulator."""

def __init__(self, verbose_logs: bool = False):
def __init__(self, config: config_classes.SimulatorConfig):
"""Instantiates a BaseSimulator object.
The simulator may be an emulator, virtual machine or even a physical device.
Each simulator has its own AdbController that is used for internal
bookkeeping.
Args:
verbose_logs: If true, the log stream of the simulator will be verbose.
config: Settings for this simulator.
"""

self._verbose_logs = verbose_logs
self._config = config
self._interaction_thread: InteractionThread | None = None

# An increasing number that tracks the attempt at launching the simulator.
self._num_launch_attempts: int = 0
Expand Down Expand Up @@ -74,6 +79,13 @@ def launch(self) -> None:
'above for more details.'
) from error

# Start interaction thread.
if self._config.interaction_rate_sec > 0:
self._interaction_thread = InteractionThread(
self._get_screenshot_impl, self._config.interaction_rate_sec
)
self._interaction_thread.start()

@abc.abstractmethod
def _launch_impl(self) -> None:
"""Platform specific launch implementation."""
Expand Down Expand Up @@ -131,9 +143,18 @@ def save_state(
"""
raise NotImplementedError('This simulator does not support save_state()')

@abc.abstractmethod
def get_screenshot(self) -> np.ndarray:
"""Returns pixels representing the current screenshot of the simulator.
"""Returns pixels representing the current screenshot of the simulator."""

if self._config.interaction_rate_sec > 0:
assert self._interaction_thread is not None
return self._interaction_thread.screenshot() # Async mode.
else:
return self._get_screenshot_impl() # Sync mode.

@abc.abstractmethod
def _get_screenshot_impl(self) -> np.ndarray:
"""Actual implementation of `get_screenshot()`.
The output numpy array should have shape [height, width, num_channels] and
can be loaded into PIL using Image.fromarray(img, mode='RGB') and be saved
Expand All @@ -142,3 +163,41 @@ def get_screenshot(self) -> np.ndarray:

def close(self):
"""Frees up resources allocated by this object."""

if self._interaction_thread is not None:
self._interaction_thread.stop()
self._interaction_thread.join()


class InteractionThread(threading.Thread):
"""A thread that gets screenshot in the background."""

def __init__(
self,
get_screenshot_fn: Callable[[], np.ndarray],
interaction_rate_sec: float,
):
super().__init__()
self._get_screenshot_fn = get_screenshot_fn
self._interaction_rate_sec = interaction_rate_sec
self._should_stop = threading.Event()
self._screenshot = self._get_screenshot_fn()

def run(self):
last_read = time.time()
while not self._should_stop.is_set():
self._screenshot = self._get_screenshot_fn()
now = time.time()
elapsed = now - last_read
last_read = now
sleep_time = self._interaction_rate_sec - elapsed
if sleep_time > 0.0:
time.sleep(sleep_time)
logging.info('InteractionThread.run() finished.')

def stop(self):
logging.info('Stopping InteractionThread.')
self._should_stop.set()

def screenshot(self) -> np.ndarray:
return self._screenshot
Loading

0 comments on commit efc8e48

Please sign in to comment.