Skip to content

Commit

Permalink
Move async interaction from Coordinator to BaseSimulator.
Browse files Browse the repository at this point in the history
This change adds asynchronous interaction to all simulators (i.e. derived
classes of `BaseSimulator`), making this type of interaction less specific
about our specific usage in the `Coordinator`. This is an ongoing effort to
separate the RL parts in AndroidEnv (mostly `AndroidEnv` and `TaskManager`)
from the more generic parts like `BaseSimulator`. The interface of *derived*
classes changed a little: instead of implementing `get_screenshot()` they now
have to implement `_get_screenshot_impl()`.

This change also contains a small fix to `BaseSimulator.__init__()` which
instead of taking individual params, it now takes a single `SimulatorConfig`,
making its initialization the same as other major components. This requires a
small change to the `super().__init__()` call in derived classes, but it
prepares them for future changes to `SimulatorConfig`.

PiperOrigin-RevId: 684273378
  • Loading branch information
kenjitoyama authored and copybara-github committed Oct 11, 2024
1 parent 7ff414d commit 0c5b733
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 222 deletions.
8 changes: 4 additions & 4 deletions android_env/components/config_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ class CoordinatorConfig:

# Number of virtual "fingers" of the agent.
num_fingers: int = 1
# How often to (asynchronously) grab the screenshot from the simulator.
# If <= 0, stepping the environment blocks on fetching the screenshot (the
# environment is synchronous).
interaction_rate_sec: float = 0.0
# Whether to enable keyboard key events.
enable_key_events: bool = False
# Whether to show circles on the screen indicating touch position.
Expand All @@ -67,6 +63,10 @@ class SimulatorConfig:

# If true, the log stream of the simulator will be verbose.
verbose_logs: bool = False
# How often to (asynchronously) grab the screenshot from the simulator.
# If <= 0, stepping the environment blocks on fetching the screenshot (the
# environment is synchronous).
interaction_rate_sec: float = 0.0


@dataclasses.dataclass
Expand Down
57 changes: 1 addition & 56 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import copy
import socket
import threading
import time
from typing import Any

Expand Down Expand Up @@ -56,7 +55,6 @@ def __init__(
self._config = config or config_classes.CoordinatorConfig()
self._adb_call_parser: adb_call_parser.AdbCallParser = None
self._orientation = np.zeros(4, dtype=np.uint8)
self._interaction_thread: InteractionThread | None = None

# The size of the device screen in pixels (H x W).
self._screen_size = np.array([0, 0], dtype=np.int32)
Expand Down Expand Up @@ -171,11 +169,6 @@ def _launch_simulator(self, max_retries: int = 3):

self._simulator_healthy = False

# Stop screenshot thread.
if self._interaction_thread is not None:
self._interaction_thread.stop()
self._interaction_thread.join()

# Attempt to restart the system a given number of times.
num_tries = 1
latest_error = None
Expand Down Expand Up @@ -221,11 +214,6 @@ def _launch_simulator(self, max_retries: int = 3):
self._simulator_healthy = True
self._stats['relaunch_count'] += 1
break
if self._config.interaction_rate_sec > 0:
self._interaction_thread = InteractionThread(
self._simulator, self._config.interaction_rate_sec
)
self._interaction_thread.start()

def _update_settings(self) -> None:
"""Updates some internal state and preferences given in the constructor."""
Expand Down Expand Up @@ -345,15 +333,8 @@ def _gather_simulator_signals(self) -> dict[str, np.ndarray]:
)
self._latest_observation_time = now

# Grab pixels.
if self._config.interaction_rate_sec > 0:
assert self._interaction_thread is not None
pixels = self._interaction_thread.screenshot() # Async mode.
else:
pixels = self._simulator.get_screenshot() # Sync mode.

return {
'pixels': pixels,
'pixels': self._simulator.get_screenshot(),
'orientation': self._orientation,
'timedelta': np.array(timestamp_delta, dtype=np.int64),
}
Expand Down Expand Up @@ -455,44 +436,8 @@ def stats(self) -> dict[str, Any]:

def close(self):
"""Cleans up the state of this Coordinator."""
if self._interaction_thread is not None:
self._interaction_thread.stop()
self._interaction_thread.join()

if hasattr(self, '_task_manager'):
self._task_manager.stop()
if hasattr(self, '_simulator'):
self._simulator.close()


class InteractionThread(threading.Thread):
"""A thread that interacts with a simulator."""

def __init__(
self, simulator: base_simulator.BaseSimulator, interaction_rate_sec: float
):
super().__init__()
self._simulator = simulator
self._interaction_rate_sec = interaction_rate_sec
self._should_stop = threading.Event()
self._screenshot = self._simulator.get_screenshot()

def run(self):
last_read = time.time()
while not self._should_stop.is_set():
self._screenshot = self._simulator.get_screenshot()

now = time.time()
elapsed = now - last_read
last_read = now
sleep_time = self._interaction_rate_sec - elapsed
if sleep_time > 0.0:
time.sleep(sleep_time)
logging.info('InteractionThread.run() finished.')

def stop(self):
logging.info('Stopping InteractionThread.')
self._should_stop.set()

def screenshot(self):
return self._screenshot
148 changes: 0 additions & 148 deletions android_env/components/coordinator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@
import numpy as np


class MockScreenshotGetter:
def __init__(self):
self._screenshot_index = 0

def get_screenshot(self):
self._screenshot_index += 1
return np.array(self._screenshot_index, ndmin=3)


class CoordinatorTest(parameterized.TestCase):

def setUp(self):
Expand Down Expand Up @@ -146,145 +137,6 @@ def fake_rl_step(simulator_signals):
self.assertEqual(timestep.reward, 0.0)
self.assertTrue(timestep.last())

@mock.patch.object(time, 'sleep', autospec=True)
def test_process_action_error_async(self, unused_mock_sleep):
mock_interaction_thread = mock.create_autospec(
coordinator_lib.InteractionThread)
with mock.patch.object(
coordinator_lib,
'InteractionThread',
autospec=True,
return_value=mock_interaction_thread):
coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
num_fingers=1, interaction_rate_sec=0.016
),
)

def fake_rl_step(agent_action, simulator_signals):
del agent_action
self.assertFalse(simulator_signals['simulator_healthy'])
return dm_env.truncation(reward=0.0, observation=None)

self._task_manager.rl_step.side_effect = fake_rl_step
mock_interaction_thread.screenshot.side_effect = errors.ReadObservationError(
)
timestep = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
self.assertIsNone(timestep.observation)
self.assertEqual(timestep.reward, 0.0)
self.assertTrue(timestep.last())
coordinator.close()

def test_async_step_faster_than_screenshot(self):
"""Return same screenshot when step is faster than the interaction rate."""
screenshot_getter = MockScreenshotGetter()
self._simulator.get_screenshot.side_effect = screenshot_getter.get_screenshot
def fake_rl_step(simulator_signals):
return dm_env.transition(
reward=10.0,
observation={
'pixels': simulator_signals['pixels'],
'orientation': simulator_signals['orientation'],
'timedelta': simulator_signals['timedelta'],
'extras': {
'extra': [0.0]
}
})
self._task_manager.rl_step.side_effect = fake_rl_step
coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
num_fingers=1, interaction_rate_sec=0.5
),
)
ts1 = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
ts2 = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
np.testing.assert_almost_equal(ts2.observation['pixels'],
ts1.observation['pixels'])
coordinator.close()

def test_async_step_slower_than_screenshot(self):
"""Return different screenshots when step slower than the interaction rate."""
screenshot_getter = MockScreenshotGetter()
self._simulator.get_screenshot.side_effect = screenshot_getter.get_screenshot

def fake_rl_step(simulator_signals):
return dm_env.transition(
reward=10.0,
observation={
'pixels': simulator_signals['pixels'],
'orientation': simulator_signals['orientation'],
'timedelta': simulator_signals['timedelta'],
'extras': {
'extra': [0.0]
}
})

self._task_manager.rl_step.side_effect = fake_rl_step
coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
num_fingers=1, interaction_rate_sec=0.01
),
)
ts1 = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
time.sleep(0.5)
ts2 = coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal,
ts2.observation['pixels'],
ts1.observation['pixels'])
coordinator.close()

def test_interaction_thread_closes_upon_relaunch(self):
"""Async coordinator should kill the InteractionThread when relaunching."""
mock_interaction_thread = mock.create_autospec(
coordinator_lib.InteractionThread)
with mock.patch.object(
coordinator_lib,
'InteractionThread',
autospec=True,
return_value=mock_interaction_thread):
coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
num_fingers=1,
periodic_restart_time_min=1e-6,
interaction_rate_sec=0.5,
),
)
mock_interaction_thread.stop.assert_not_called()
mock_interaction_thread.join.assert_not_called()
time.sleep(0.1)
coordinator.rl_reset()
mock_interaction_thread.stop.assert_called_once()
mock_interaction_thread.join.assert_called_once()
coordinator.close()

@mock.patch.object(time, 'sleep', autospec=True)
def test_execute_action_touch(self, unused_mock_sleep):

Expand Down
Loading

0 comments on commit 0c5b733

Please sign in to comment.