diff --git a/android_env/components/config_classes.py b/android_env/components/config_classes.py index 72d243b..a3d2b41 100644 --- a/android_env/components/config_classes.py +++ b/android_env/components/config_classes.py @@ -41,10 +41,6 @@ class CoordinatorConfig: # Number of virtual "fingers" of the agent. num_fingers: int = 1 - # How often to (asynchronously) grab the screenshot from the simulator. - # If <= 0, stepping the environment blocks on fetching the screenshot (the - # environment is synchronous). - interaction_rate_sec: float = 0.0 # Whether to enable keyboard key events. enable_key_events: bool = False # Whether to show circles on the screen indicating touch position. @@ -67,6 +63,10 @@ class SimulatorConfig: # If true, the log stream of the simulator will be verbose. verbose_logs: bool = False + # How often to (asynchronously) grab the screenshot from the simulator. + # If <= 0, stepping the environment blocks on fetching the screenshot (the + # environment is synchronous). + interaction_rate_sec: float = 0.0 @dataclasses.dataclass diff --git a/android_env/components/coordinator.py b/android_env/components/coordinator.py index c6d8610..4279bb3 100644 --- a/android_env/components/coordinator.py +++ b/android_env/components/coordinator.py @@ -17,7 +17,6 @@ import copy import socket -import threading import time from typing import Any @@ -56,7 +55,6 @@ def __init__( self._config = config or config_classes.CoordinatorConfig() self._adb_call_parser: adb_call_parser.AdbCallParser = None self._orientation = np.zeros(4, dtype=np.uint8) - self._interaction_thread: InteractionThread | None = None # The size of the device screen in pixels (H x W). self._screen_size = np.array([0, 0], dtype=np.int32) @@ -171,11 +169,6 @@ def _launch_simulator(self, max_retries: int = 3): self._simulator_healthy = False - # Stop screenshot thread. - if self._interaction_thread is not None: - self._interaction_thread.stop() - self._interaction_thread.join() - # Attempt to restart the system a given number of times. num_tries = 1 latest_error = None @@ -221,11 +214,6 @@ def _launch_simulator(self, max_retries: int = 3): self._simulator_healthy = True self._stats['relaunch_count'] += 1 break - if self._config.interaction_rate_sec > 0: - self._interaction_thread = InteractionThread( - self._simulator, self._config.interaction_rate_sec - ) - self._interaction_thread.start() def _update_settings(self) -> None: """Updates some internal state and preferences given in the constructor.""" @@ -345,15 +333,8 @@ def _gather_simulator_signals(self) -> dict[str, np.ndarray]: ) self._latest_observation_time = now - # Grab pixels. - if self._config.interaction_rate_sec > 0: - assert self._interaction_thread is not None - pixels = self._interaction_thread.screenshot() # Async mode. - else: - pixels = self._simulator.get_screenshot() # Sync mode. - return { - 'pixels': pixels, + 'pixels': self._simulator.get_screenshot(), 'orientation': self._orientation, 'timedelta': np.array(timestamp_delta, dtype=np.int64), } @@ -455,44 +436,8 @@ def stats(self) -> dict[str, Any]: def close(self): """Cleans up the state of this Coordinator.""" - if self._interaction_thread is not None: - self._interaction_thread.stop() - self._interaction_thread.join() if hasattr(self, '_task_manager'): self._task_manager.stop() if hasattr(self, '_simulator'): self._simulator.close() - - -class InteractionThread(threading.Thread): - """A thread that interacts with a simulator.""" - - def __init__( - self, simulator: base_simulator.BaseSimulator, interaction_rate_sec: float - ): - super().__init__() - self._simulator = simulator - self._interaction_rate_sec = interaction_rate_sec - self._should_stop = threading.Event() - self._screenshot = self._simulator.get_screenshot() - - def run(self): - last_read = time.time() - while not self._should_stop.is_set(): - self._screenshot = self._simulator.get_screenshot() - - now = time.time() - elapsed = now - last_read - last_read = now - sleep_time = self._interaction_rate_sec - elapsed - if sleep_time > 0.0: - time.sleep(sleep_time) - logging.info('InteractionThread.run() finished.') - - def stop(self): - logging.info('Stopping InteractionThread.') - self._should_stop.set() - - def screenshot(self): - return self._screenshot diff --git a/android_env/components/coordinator_test.py b/android_env/components/coordinator_test.py index a9908b7..67f124e 100644 --- a/android_env/components/coordinator_test.py +++ b/android_env/components/coordinator_test.py @@ -35,15 +35,6 @@ import numpy as np -class MockScreenshotGetter: - def __init__(self): - self._screenshot_index = 0 - - def get_screenshot(self): - self._screenshot_index += 1 - return np.array(self._screenshot_index, ndmin=3) - - class CoordinatorTest(parameterized.TestCase): def setUp(self): @@ -146,145 +137,6 @@ def fake_rl_step(simulator_signals): self.assertEqual(timestep.reward, 0.0) self.assertTrue(timestep.last()) - @mock.patch.object(time, 'sleep', autospec=True) - def test_process_action_error_async(self, unused_mock_sleep): - mock_interaction_thread = mock.create_autospec( - coordinator_lib.InteractionThread) - with mock.patch.object( - coordinator_lib, - 'InteractionThread', - autospec=True, - return_value=mock_interaction_thread): - coordinator = coordinator_lib.Coordinator( - simulator=self._simulator, - task_manager=self._task_manager, - config=config_classes.CoordinatorConfig( - num_fingers=1, interaction_rate_sec=0.016 - ), - ) - - def fake_rl_step(agent_action, simulator_signals): - del agent_action - self.assertFalse(simulator_signals['simulator_healthy']) - return dm_env.truncation(reward=0.0, observation=None) - - self._task_manager.rl_step.side_effect = fake_rl_step - mock_interaction_thread.screenshot.side_effect = errors.ReadObservationError( - ) - timestep = coordinator.rl_step( - agent_action={ - 'action_type': np.array(action_type.ActionType.LIFT), - 'touch_position': np.array([0.5, 0.5]), - }) - self.assertIsNone(timestep.observation) - self.assertEqual(timestep.reward, 0.0) - self.assertTrue(timestep.last()) - coordinator.close() - - def test_async_step_faster_than_screenshot(self): - """Return same screenshot when step is faster than the interaction rate.""" - screenshot_getter = MockScreenshotGetter() - self._simulator.get_screenshot.side_effect = screenshot_getter.get_screenshot - def fake_rl_step(simulator_signals): - return dm_env.transition( - reward=10.0, - observation={ - 'pixels': simulator_signals['pixels'], - 'orientation': simulator_signals['orientation'], - 'timedelta': simulator_signals['timedelta'], - 'extras': { - 'extra': [0.0] - } - }) - self._task_manager.rl_step.side_effect = fake_rl_step - coordinator = coordinator_lib.Coordinator( - simulator=self._simulator, - task_manager=self._task_manager, - config=config_classes.CoordinatorConfig( - num_fingers=1, interaction_rate_sec=0.5 - ), - ) - ts1 = coordinator.rl_step( - agent_action={ - 'action_type': np.array(action_type.ActionType.LIFT), - 'touch_position': np.array([0.5, 0.5]), - }) - ts2 = coordinator.rl_step( - agent_action={ - 'action_type': np.array(action_type.ActionType.LIFT), - 'touch_position': np.array([0.5, 0.5]), - }) - np.testing.assert_almost_equal(ts2.observation['pixels'], - ts1.observation['pixels']) - coordinator.close() - - def test_async_step_slower_than_screenshot(self): - """Return different screenshots when step slower than the interaction rate.""" - screenshot_getter = MockScreenshotGetter() - self._simulator.get_screenshot.side_effect = screenshot_getter.get_screenshot - - def fake_rl_step(simulator_signals): - return dm_env.transition( - reward=10.0, - observation={ - 'pixels': simulator_signals['pixels'], - 'orientation': simulator_signals['orientation'], - 'timedelta': simulator_signals['timedelta'], - 'extras': { - 'extra': [0.0] - } - }) - - self._task_manager.rl_step.side_effect = fake_rl_step - coordinator = coordinator_lib.Coordinator( - simulator=self._simulator, - task_manager=self._task_manager, - config=config_classes.CoordinatorConfig( - num_fingers=1, interaction_rate_sec=0.01 - ), - ) - ts1 = coordinator.rl_step( - agent_action={ - 'action_type': np.array(action_type.ActionType.LIFT), - 'touch_position': np.array([0.5, 0.5]), - }) - time.sleep(0.5) - ts2 = coordinator.rl_step( - agent_action={ - 'action_type': np.array(action_type.ActionType.LIFT), - 'touch_position': np.array([0.5, 0.5]), - }) - np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, - ts2.observation['pixels'], - ts1.observation['pixels']) - coordinator.close() - - def test_interaction_thread_closes_upon_relaunch(self): - """Async coordinator should kill the InteractionThread when relaunching.""" - mock_interaction_thread = mock.create_autospec( - coordinator_lib.InteractionThread) - with mock.patch.object( - coordinator_lib, - 'InteractionThread', - autospec=True, - return_value=mock_interaction_thread): - coordinator = coordinator_lib.Coordinator( - simulator=self._simulator, - task_manager=self._task_manager, - config=config_classes.CoordinatorConfig( - num_fingers=1, - periodic_restart_time_min=1e-6, - interaction_rate_sec=0.5, - ), - ) - mock_interaction_thread.stop.assert_not_called() - mock_interaction_thread.join.assert_not_called() - time.sleep(0.1) - coordinator.rl_reset() - mock_interaction_thread.stop.assert_called_once() - mock_interaction_thread.join.assert_called_once() - coordinator.close() - @mock.patch.object(time, 'sleep', autospec=True) def test_execute_action_touch(self, unused_mock_sleep): diff --git a/android_env/components/simulators/base_simulator.py b/android_env/components/simulators/base_simulator.py index 10be141..672e462 100644 --- a/android_env/components/simulators/base_simulator.py +++ b/android_env/components/simulators/base_simulator.py @@ -16,9 +16,13 @@ """A base class for talking to different types of Android simulators.""" import abc +from collections.abc import Callable +import threading +import time from absl import logging from android_env.components import adb_controller +from android_env.components import config_classes from android_env.components import errors from android_env.components import log_stream from android_env.proto import state_pb2 @@ -28,7 +32,7 @@ class BaseSimulator(metaclass=abc.ABCMeta): """An interface for communicating with an Android simulator.""" - def __init__(self, verbose_logs: bool = False): + def __init__(self, config: config_classes.SimulatorConfig): """Instantiates a BaseSimulator object. The simulator may be an emulator, virtual machine or even a physical device. @@ -36,10 +40,11 @@ def __init__(self, verbose_logs: bool = False): bookkeeping. Args: - verbose_logs: If true, the log stream of the simulator will be verbose. + config: Settings for this simulator. """ - self._verbose_logs = verbose_logs + self._config = config + self._interaction_thread: InteractionThread | None = None # An increasing number that tracks the attempt at launching the simulator. self._num_launch_attempts: int = 0 @@ -74,6 +79,13 @@ def launch(self) -> None: 'above for more details.' ) from error + # Start interaction thread. + if self._config.interaction_rate_sec > 0: + self._interaction_thread = InteractionThread( + self._get_screenshot_impl, self._config.interaction_rate_sec + ) + self._interaction_thread.start() + @abc.abstractmethod def _launch_impl(self) -> None: """Platform specific launch implementation.""" @@ -131,9 +143,18 @@ def save_state( """ raise NotImplementedError('This simulator does not support save_state()') - @abc.abstractmethod def get_screenshot(self) -> np.ndarray: - """Returns pixels representing the current screenshot of the simulator. + """Returns pixels representing the current screenshot of the simulator.""" + + if self._config.interaction_rate_sec > 0: + assert self._interaction_thread is not None + return self._interaction_thread.screenshot() # Async mode. + else: + return self._get_screenshot_impl() # Sync mode. + + @abc.abstractmethod + def _get_screenshot_impl(self) -> np.ndarray: + """Actual implementation of `get_screenshot()`. The output numpy array should have shape [height, width, num_channels] and can be loaded into PIL using Image.fromarray(img, mode='RGB') and be saved @@ -142,3 +163,41 @@ def get_screenshot(self) -> np.ndarray: def close(self): """Frees up resources allocated by this object.""" + + if self._interaction_thread is not None: + self._interaction_thread.stop() + self._interaction_thread.join() + + +class InteractionThread(threading.Thread): + """A thread that gets screenshot in the background.""" + + def __init__( + self, + get_screenshot_fn: Callable[[], np.ndarray], + interaction_rate_sec: float, + ): + super().__init__() + self._get_screenshot_fn = get_screenshot_fn + self._interaction_rate_sec = interaction_rate_sec + self._should_stop = threading.Event() + self._screenshot = self._get_screenshot_fn() + + def run(self): + last_read = time.time() + while not self._should_stop.is_set(): + self._screenshot = self._get_screenshot_fn() + now = time.time() + elapsed = now - last_read + last_read = now + sleep_time = self._interaction_rate_sec - elapsed + if sleep_time > 0.0: + time.sleep(sleep_time) + logging.info('InteractionThread.run() finished.') + + def stop(self): + logging.info('Stopping InteractionThread.') + self._should_stop.set() + + def screenshot(self) -> np.ndarray: + return self._screenshot diff --git a/android_env/components/simulators/base_simulator_test.py b/android_env/components/simulators/base_simulator_test.py index ad260ba..22785eb 100644 --- a/android_env/components/simulators/base_simulator_test.py +++ b/android_env/components/simulators/base_simulator_test.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for base_simulator.""" - +import itertools +import time from unittest import mock from absl.testing import absltest @@ -22,6 +22,7 @@ from android_env.components import errors # fake_simulator.FakeSimulator inherits from BaseSimulator, so there's no need # to import it here explicitly. +from android_env.components.simulators import base_simulator from android_env.components.simulators.fake import fake_simulator import numpy as np @@ -58,11 +59,97 @@ def test_print_logs_on_exception(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig() ) - with mock.patch.object(simulator, 'get_logs') as mock_get_logs, \ - mock.patch.object(simulator, '_launch_impl', autospec=True) as mock_launch: + with mock.patch.object( + simulator, 'get_logs' + ) as mock_get_logs, mock.patch.object( + simulator, '_launch_impl', autospec=True + ) as mock_launch: mock_launch.side_effect = ValueError('Oh no!') self.assertRaises(errors.SimulatorError, simulator.launch) mock_get_logs.assert_called_once() + def test_get_screenshot_error_async(self): + """An exception in the underlying interaction thread should bubble up.""" + + # Arrange. + mock_interaction_thread = mock.create_autospec( + base_simulator.InteractionThread + ) + mock_interaction_thread.screenshot.side_effect = ( + errors.ReadObservationError() + ) + simulator = fake_simulator.FakeSimulator( + config_classes.FakeSimulatorConfig(interaction_rate_sec=0.5) + ) + with mock.patch.object( + base_simulator, + 'InteractionThread', + autospec=True, + return_value=mock_interaction_thread, + ): + simulator.launch() + + # Act & Assert. + self.assertRaises(errors.ReadObservationError, simulator.get_screenshot) + + # Cleanup. + simulator.close() + + def test_get_screenshot_faster_than_screenshot_impl(self): + """Return same screenshot when step is faster than the interaction rate.""" + + # Arrange. + slow_rate = 0.5 + simulator = fake_simulator.FakeSimulator( + config_classes.FakeSimulatorConfig(interaction_rate_sec=slow_rate) + ) + + # Act. + with mock.patch.object( + simulator, '_get_screenshot_impl', autospec=True + ) as mock_get_screenshot_impl: + mock_get_screenshot_impl.side_effect = ( + np.array(i, ndmin=3) for i in itertools.count(0, 1) + ) + simulator.launch() + # Get two screenshots one after the other without pausing. + screenshot1 = simulator.get_screenshot() + screenshot2 = simulator.get_screenshot() + + # Assert. + self.assertAlmostEqual(screenshot1[0][0][0], screenshot2[0][0][0]) + + # Cleanup. + simulator.close() + + def test_get_screenshot_slower_than_screenshot_impl(self): + """Return different screenshots when step slower than the interaction rate.""" + + # Arrange. + fast_rate = 0.01 + simulator = fake_simulator.FakeSimulator( + config_classes.FakeSimulatorConfig(interaction_rate_sec=fast_rate) + ) + + # Act. + with mock.patch.object( + simulator, '_get_screenshot_impl', autospec=True + ) as mock_get_screenshot_impl: + mock_get_screenshot_impl.side_effect = ( + np.array(i, ndmin=3) for i in itertools.count(0, 1) + ) + simulator.launch() + # Sleep for 500ms between two screenshots. + screenshot1 = simulator.get_screenshot() + time.sleep(0.5) + screenshot2 = simulator.get_screenshot() + + # Assert. + self.assertNotEqual(screenshot1[0][0][0], screenshot2[0][0][0]) + + # Cleanup. + simulator.close() + + if __name__ == '__main__': absltest.main() diff --git a/android_env/components/simulators/emulator/emulator_simulator.py b/android_env/components/simulators/emulator/emulator_simulator.py index 6293733..d7f6187 100644 --- a/android_env/components/simulators/emulator/emulator_simulator.py +++ b/android_env/components/simulators/emulator/emulator_simulator.py @@ -100,7 +100,7 @@ class EmulatorSimulator(base_simulator.BaseSimulator): def __init__(self, config: config_classes.EmulatorConfig): """Instantiates an EmulatorSimulator.""" - super().__init__(verbose_logs=config.verbose_logs) + super().__init__(config) self._config = config # If adb_port, console_port and grpc_port are all already provided, @@ -196,7 +196,8 @@ def create_adb_controller(self): def create_log_stream(self) -> log_stream.LogStream: return adb_log_stream.AdbLogStream( adb_command_prefix=self._adb_controller.command_prefix(), - verbose=self._verbose_logs) + verbose=self._config.verbose_logs, + ) def _launch_impl(self) -> None: """Prepares an Android Emulator for RL interaction. @@ -442,7 +443,7 @@ def send_key(self, keycode: np.int32, event_type: str) -> None: ) @_reconnect_on_grpc_error - def get_screenshot(self) -> np.ndarray: + def _get_screenshot_impl(self) -> np.ndarray: """Fetches the latest screenshot from the emulator.""" assert ( @@ -472,6 +473,8 @@ def _shutdown_emulator(self): self._launcher.confirm_shutdown() def close(self): + super().close() + if self._launcher is not None: self._shutdown_emulator() logging.info('Closing emulator (%r)', self.adb_device_name()) diff --git a/android_env/components/simulators/fake/fake_simulator.py b/android_env/components/simulators/fake/fake_simulator.py index fb66964..79996c8 100644 --- a/android_env/components/simulators/fake/fake_simulator.py +++ b/android_env/components/simulators/fake/fake_simulator.py @@ -108,7 +108,7 @@ class FakeSimulator(base_simulator.BaseSimulator): def __init__(self, config: config_classes.FakeSimulatorConfig): """FakeSimulator class that can replace EmulatorSimulator in AndroidEnv.""" - super().__init__(verbose_logs=config.verbose_logs) + super().__init__(config) self._screen_dimensions = np.array(config.screen_dimensions) logging.info('Created FakeSimulator.') @@ -133,6 +133,6 @@ def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None: def send_key(self, keycode: np.int32, event_type: str) -> None: del keycode, event_type - def get_screenshot(self) -> np.ndarray: + def _get_screenshot_impl(self) -> np.ndarray: return np.random.randint( low=0, high=255, size=(*self._screen_dimensions, 3), dtype=np.uint8)