diff --git a/android_env/components/config_classes.py b/android_env/components/config_classes.py index 8e94db8..6828771 100644 --- a/android_env/components/config_classes.py +++ b/android_env/components/config_classes.py @@ -26,7 +26,7 @@ class AdbControllerConfig: # NOTE: This must be a full path and must not contain environment variables # or user folder shorthands (e.g. `~/some/path/to/adb`) since they will not be # expanded internally by AndroidEnv. - adb_path: str = 'adb' + adb_path: str = '~/Android/Sdk/platform-tools/adb' # Port for adb server. adb_server_port: int = 5037 # Default timeout in seconds for internal commands. @@ -80,13 +80,13 @@ class EmulatorLauncherConfig: # exists already and EmulatorLauncher will be skipped. # Filesystem path to the `emulator` binary. - emulator_path: str = '' + emulator_path: str = '~/Android/Sdk/emulator/emulator' # Filesystem path to the Android SDK root. - android_sdk_root: str = '' + android_sdk_root: str = '~/Android/Sdk' # Name of the AVD. avd_name: str = '' # Local directory for AVDs. - android_avd_home: str = '' + android_avd_home: str = '~/.android/avd' # Name of the snapshot to load. snapshot_name: str = '' # Path to the KVM device. @@ -178,3 +178,18 @@ class FilesystemTaskConfig(TaskConfig): # Filesystem path to `.binarypb` or `.textproto` protobuf Task. path: str = '' + + +@dataclasses.dataclass +class AndroidEnvConfig: + """Config class for AndroidEnv.""" + + # Configs for main components. + task: TaskConfig = dataclasses.field(default_factory=TaskConfig) + task_manager: TaskManagerConfig = dataclasses.field( + default_factory=TaskManagerConfig + ) + coordinator: CoordinatorConfig = dataclasses.field( + default_factory=CoordinatorConfig + ) + simulator: SimulatorConfig = dataclasses.field(default_factory=EmulatorConfig) diff --git a/android_env/components/simulators/emulator/emulator_launcher_test.py b/android_env/components/simulators/emulator/emulator_launcher_test.py index 12b8d6f..3bfd44f 100644 --- a/android_env/components/simulators/emulator/emulator_launcher_test.py +++ b/android_env/components/simulators/emulator/emulator_launcher_test.py @@ -61,10 +61,12 @@ def setUp(self): base_lib_dir + 'gles_swiftshader/', base_lib_dir ]) + # Instantiate the config to extract default values. + config = config_classes.EmulatorLauncherConfig() self._expected_env_vars = { 'ANDROID_HOME': '', - 'ANDROID_SDK_ROOT': '', - 'ANDROID_AVD_HOME': '', + 'ANDROID_SDK_ROOT': config.android_sdk_root, + 'ANDROID_AVD_HOME': config.android_avd_home, 'ANDROID_EMULATOR_KVM_DEVICE': '/dev/kvm', 'ANDROID_ADB_SERVER_PORT': '1234', 'LD_LIBRARY_PATH': ld_library_path, diff --git a/android_env/loader.py b/android_env/loader.py index ee6d138..92666f6 100644 --- a/android_env/loader.py +++ b/android_env/loader.py @@ -23,6 +23,7 @@ from android_env.components import coordinator as coordinator_lib from android_env.components import task_manager as task_manager_lib from android_env.components.simulators.emulator import emulator_simulator +from android_env.components.simulators.fake import fake_simulator from android_env.proto import task_pb2 from google.protobuf import text_format @@ -42,64 +43,41 @@ def _load_task(task_config: config_classes.TaskConfig) -> task_pb2.Task: return task -def load( - task_path: str, - avd_name: str | None = None, - android_avd_home: str = '~/.android/avd', - android_sdk_root: str = '~/Android/Sdk', - emulator_path: str = '~/Android/Sdk/emulator/emulator', - adb_path: str = '~/Android/Sdk/platform-tools/adb', - run_headless: bool = False, - console_port: int | None = None, -) -> environment.AndroidEnv: - """Loads an AndroidEnv instance. - - Args: - task_path: Path to the task textproto file. - avd_name: Name of the AVD (Android Virtual Device). - android_avd_home: Path to the AVD (Android Virtual Device). - android_sdk_root: Root directory of the SDK. - emulator_path: Path to the emulator binary. - adb_path: Path to the ADB (Android Debug Bridge). - run_headless: If True, the emulator display is turned off. - console_port: The console port number; for connecting to an already running - device/emulator. - - Returns: - env: An AndroidEnv instance. - """ - connect_to_existing_device = console_port is not None - if not connect_to_existing_device and avd_name is None: - raise ValueError('An avd name must be provided if launching an emulator.') - - if connect_to_existing_device: - launcher_config = config_classes.EmulatorLauncherConfig( - emulator_console_port=console_port, - adb_port=console_port + 1, - grpc_port=8554, - ) - else: - launcher_config = config_classes.EmulatorLauncherConfig( - avd_name=avd_name, - android_avd_home=os.path.expanduser(android_avd_home), - android_sdk_root=os.path.expanduser(android_sdk_root), - emulator_path=os.path.expanduser(emulator_path), - run_headless=run_headless, - gpu_mode='swiftshader_indirect', - ) - - # Create simulator. - simulator = emulator_simulator.EmulatorSimulator( - config=config_classes.EmulatorConfig( - emulator_launcher=launcher_config, - adb_controller=config_classes.AdbControllerConfig( - adb_path=os.path.expanduser(adb_path), - adb_server_port=5037, - ), - ) - ) +def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv: + """Loads an AndroidEnv instance.""" - task = _load_task(config_classes.FilesystemTaskConfig(path=task_path)) + task = _load_task(config.task) task_manager = task_manager_lib.TaskManager(task) + + match config.simulator: + case config_classes.EmulatorConfig(): + _process_emulator_launcher_config(config.simulator) + simulator = emulator_simulator.EmulatorSimulator(config=config.simulator) + case config_classes.FakeSimulatorConfig(): + simulator = fake_simulator.FakeSimulator(config=config.simulator) + case _: + raise ValueError('Unsupported simulator config: {config.simulator}') + coordinator = coordinator_lib.Coordinator(simulator, task_manager) return environment.AndroidEnv(coordinator=coordinator) + + +def _process_emulator_launcher_config( + emulator_config: config_classes.EmulatorConfig, +) -> None: + """Adjusts the configuration of the emulator depending on some conditions.""" + + # Expand the user directory if specified. + launcher_config = emulator_config.emulator_launcher + launcher_config.android_avd_home = os.path.expanduser( + launcher_config.android_avd_home + ) + launcher_config.android_sdk_root = os.path.expanduser( + launcher_config.android_sdk_root + ) + launcher_config.emulator_path = os.path.expanduser( + launcher_config.emulator_path + ) + emulator_config.adb_controller.adb_path = os.path.expanduser( + emulator_config.adb_controller.adb_path + ) diff --git a/android_env/loader_test.py b/android_env/loader_test.py index 4435e91..9ca2a85 100644 --- a/android_env/loader_test.py +++ b/android_env/loader_test.py @@ -20,12 +20,13 @@ from unittest import mock from absl.testing import absltest -from android_env import environment +from android_env import env_interface from android_env import loader from android_env.components import config_classes from android_env.components import coordinator as coordinator_lib from android_env.components import task_manager as task_manager_lib from android_env.components.simulators.emulator import emulator_simulator +from android_env.components.simulators.fake import fake_simulator from android_env.proto import task_pb2 @@ -35,24 +36,34 @@ class LoaderTest(absltest.TestCase): @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) @mock.patch.object(builtins, 'open', autospec=True) - def test_load( + def test_load_emulator( self, mock_open, mock_coordinator, mock_simulator_class, mock_task_manager ): + # Arrange. mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = '' - - env = loader.load( - task_path='some/path/', - avd_name='my_avd', - android_avd_home='~/.android/avd', - android_sdk_root='~/Android/Sdk', - emulator_path='~/Android/Sdk/emulator/emulator', - adb_path='~/Android/Sdk/platform-tools/adb', - run_headless=False, + config = config_classes.AndroidEnvConfig( + task=config_classes.FilesystemTaskConfig(path='some/path/'), + simulator=config_classes.EmulatorConfig( + emulator_launcher=config_classes.EmulatorLauncherConfig( + avd_name='my_avd', + android_avd_home='~/.android/avd', + android_sdk_root='~/Android/Sdk', + emulator_path='~/Android/Sdk/emulator/emulator', + run_headless=False, + ), + adb_controller=config_classes.AdbControllerConfig( + adb_path='~/Android/Sdk/platform-tools/adb', + ), + ), ) - self.assertIsInstance(env, environment.AndroidEnv) + # Act. + env = loader.load(config) + + # Assert. + self.assertIsInstance(env, env_interface.AndroidEnvInterface) mock_simulator_class.assert_called_with( config=config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( @@ -63,7 +74,7 @@ def test_load( '~/Android/Sdk/emulator/emulator' ), run_headless=False, - gpu_mode='swiftshader_indirect', + gpu_mode='swangle_indirect', ), adb_controller=config_classes.AdbControllerConfig( adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'), @@ -77,31 +88,31 @@ def test_load( ) @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) - @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) + @mock.patch.object(fake_simulator, 'FakeSimulator', autospec=True) @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) @mock.patch.object(builtins, 'open', autospec=True) - def test_load_existing_device( + def test_load_fake_simulator( self, mock_open, mock_coordinator, mock_simulator_class, mock_task_manager ): + + # Arrange. mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = '' - - env = loader.load( - task_path='some/path/', - console_port=5554, - adb_path='~/Android/Sdk/platform-tools/adb', + config = config_classes.AndroidEnvConfig( + task=config_classes.FilesystemTaskConfig(path='some/path/'), + simulator=config_classes.FakeSimulatorConfig( + screen_dimensions=(1234, 5678) + ), ) - self.assertIsInstance(env, environment.AndroidEnv) + # Act. + env = loader.load(config) + + # Assert. + self.assertIsInstance(env, env_interface.AndroidEnvInterface) mock_simulator_class.assert_called_with( - config=config_classes.EmulatorConfig( - emulator_launcher=config_classes.EmulatorLauncherConfig( - emulator_console_port=5554, adb_port=5555, grpc_port=8554 - ), - adb_controller=config_classes.AdbControllerConfig( - adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'), - adb_server_port=5037, - ), + config=config_classes.FakeSimulatorConfig( + screen_dimensions=(1234, 5678) ) ) mock_coordinator.assert_called_with( @@ -116,6 +127,8 @@ def test_load_existing_device( def test_task( self, mock_open, mock_coordinator, mock_simulator, mock_task_manager ): + + # Arrange. del mock_coordinator, mock_simulator mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = r''' @@ -124,12 +137,22 @@ def test_task( description: "Task for testing loader." max_episode_sec: 0 ''' - - env = loader.load( - task_path='some/path/', - avd_name='my_avd', + config = config_classes.AndroidEnvConfig( + task=config_classes.FilesystemTaskConfig(path='some/path/'), + simulator=config_classes.EmulatorConfig( + emulator_launcher=config_classes.EmulatorLauncherConfig( + avd_name='my_avd' + ), + adb_controller=config_classes.AdbControllerConfig( + adb_path='~/Android/Sdk/platform-tools/adb', + ), + ), ) + # Act. + env = loader.load(config) + + # Assert. expected_task = task_pb2.Task() expected_task.id = 'fake_task' expected_task.name = 'Fake Task' @@ -137,7 +160,7 @@ def test_task( expected_task.max_episode_sec = 0 mock_task_manager.assert_called_with(expected_task) - assert isinstance(env, environment.AndroidEnv) + self.assertIsInstance(env, env_interface.AndroidEnvInterface) if __name__ == '__main__': diff --git a/examples/run_acme_agent.py b/examples/run_acme_agent.py index 3ee0961..754a043 100644 --- a/examples/run_acme_agent.py +++ b/examples/run_acme_agent.py @@ -24,6 +24,7 @@ from acme.agents.tf import dqn from acme.tf import networks from android_env import loader +from android_env.components import config_classes from android_env.wrappers import discrete_action_wrapper from android_env.wrappers import float_pixels_wrapper from android_env.wrappers import image_rescale_wrapper @@ -58,14 +59,22 @@ def apply_wrappers(env): def main(_): - with loader.load( - emulator_path=FLAGS.emulator_path, - android_sdk_root=FLAGS.android_sdk_root, - android_avd_home=FLAGS.android_avd_home, - avd_name=FLAGS.avd_name, - adb_path=FLAGS.adb_path, - task_path=FLAGS.task_path, - run_headless=False) as env: + config = config_classes.AndroidEnvConfig( + task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), + simulator=config_classes.EmulatorConfig( + emulator_launcher=config_classes.EmulatorLauncherConfig( + emulator_path=FLAGS.emulator_path, + android_sdk_root=FLAGS.android_sdk_root, + android_avd_home=FLAGS.android_avd_home, + avd_name=FLAGS.avd_name, + run_headless=FLAGS.run_headless, + ), + adb_controller=config_classes.AdbControllerConfig( + adb_path=FLAGS.adb_path + ), + ), + ) + with loader.load(config) as env: env = apply_wrappers(env) env_spec = specs.make_environment_spec(env) diff --git a/examples/run_human_agent.py b/examples/run_human_agent.py index dfc6fa7..66c944b 100644 --- a/examples/run_human_agent.py +++ b/examples/run_human_agent.py @@ -23,6 +23,7 @@ from absl import logging from android_env import loader from android_env.components import action_type +from android_env.components import config_classes from android_env.components import utils import dm_env import numpy as np @@ -129,14 +130,22 @@ def main(_): pygame.init() pygame.display.set_caption('android_human_agent') - with loader.load( - emulator_path=FLAGS.emulator_path, - android_sdk_root=FLAGS.android_sdk_root, - android_avd_home=FLAGS.android_avd_home, - avd_name=FLAGS.avd_name, - adb_path=FLAGS.adb_path, - task_path=FLAGS.task_path, - run_headless=FLAGS.run_headless) as env: + config = config_classes.AndroidEnvConfig( + task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), + simulator=config_classes.EmulatorConfig( + emulator_launcher=config_classes.EmulatorLauncherConfig( + emulator_path=FLAGS.emulator_path, + android_sdk_root=FLAGS.android_sdk_root, + android_avd_home=FLAGS.android_avd_home, + avd_name=FLAGS.avd_name, + run_headless=FLAGS.run_headless, + ), + adb_controller=config_classes.AdbControllerConfig( + adb_path=FLAGS.adb_path + ), + ), + ) + with loader.load(config) as env: # Reset environment. first_timestep = env.reset() diff --git a/examples/run_random_agent.py b/examples/run_random_agent.py index 04590e0..fc8017b 100644 --- a/examples/run_random_agent.py +++ b/examples/run_random_agent.py @@ -19,6 +19,7 @@ from absl import flags from absl import logging from android_env import loader +from android_env.components import config_classes from dm_env import specs import numpy as np @@ -44,14 +45,22 @@ def main(_): - with loader.load( - emulator_path=FLAGS.emulator_path, - android_sdk_root=FLAGS.android_sdk_root, - android_avd_home=FLAGS.android_avd_home, - avd_name=FLAGS.avd_name, - adb_path=FLAGS.adb_path, - task_path=FLAGS.task_path, - run_headless=FLAGS.run_headless) as env: + config = config_classes.AndroidEnvConfig( + task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), + simulator=config_classes.EmulatorConfig( + emulator_launcher=config_classes.EmulatorLauncherConfig( + emulator_path=FLAGS.emulator_path, + android_sdk_root=FLAGS.android_sdk_root, + android_avd_home=FLAGS.android_avd_home, + avd_name=FLAGS.avd_name, + run_headless=FLAGS.run_headless, + ), + adb_controller=config_classes.AdbControllerConfig( + adb_path=FLAGS.adb_path + ), + ), + ) + with loader.load(config) as env: action_spec = env.action_spec()