Skip to content

Commit

Permalink
Introduce AndroidEnvConfig.
Browse files Browse the repository at this point in the history
This change introduces `AndroidEnvConfig`, the topmost `dataclass` to configure
AndroidEnv's instantiation. This aligns the interface used in the open-source
version with Google's internal version.

The main diffs are in `loader.py` and `loader_test.py`, with two significant
differences:

* No automagic param manipulation for connecting to existing emulators.
  * Instead of trying to guess what clients want, `loader.load()` now just
    interprets the given settings, including the Emulator's console port and
    adb port. This is less error-prone and much simpler to use.
* `loader.load()` now supports loading `AndroidEnv` with a `FakeSimulator`.
  * This is very convenient when setting up `AndroidEnv` in a more complex
    infrastructure since `FakeSimulator` boots up instantaneously.

PiperOrigin-RevId: 605303997
  • Loading branch information
kenjitoyama authored and copybara-github committed Feb 8, 2024
1 parent c08dd62 commit 2636392
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 121 deletions.
23 changes: 19 additions & 4 deletions android_env/components/config_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class AdbControllerConfig:
# NOTE: This must be a full path and must not contain environment variables
# or user folder shorthands (e.g. `~/some/path/to/adb`) since they will not be
# expanded internally by AndroidEnv.
adb_path: str = 'adb'
adb_path: str = '~/Android/Sdk/platform-tools/adb'
# Port for adb server.
adb_server_port: int = 5037
# Default timeout in seconds for internal commands.
Expand Down Expand Up @@ -80,13 +80,13 @@ class EmulatorLauncherConfig:
# exists already and EmulatorLauncher will be skipped.

# Filesystem path to the `emulator` binary.
emulator_path: str = ''
emulator_path: str = '~/Android/Sdk/emulator/emulator'
# Filesystem path to the Android SDK root.
android_sdk_root: str = ''
android_sdk_root: str = '~/Android/Sdk'
# Name of the AVD.
avd_name: str = ''
# Local directory for AVDs.
android_avd_home: str = ''
android_avd_home: str = '~/.android/avd'
# Name of the snapshot to load.
snapshot_name: str = ''
# Path to the KVM device.
Expand Down Expand Up @@ -178,3 +178,18 @@ class FilesystemTaskConfig(TaskConfig):

# Filesystem path to `.binarypb` or `.textproto` protobuf Task.
path: str = ''


@dataclasses.dataclass
class AndroidEnvConfig:
"""Config class for AndroidEnv."""

# Configs for main components.
task: TaskConfig = dataclasses.field(default_factory=TaskConfig)
task_manager: TaskManagerConfig = dataclasses.field(
default_factory=TaskManagerConfig
)
coordinator: CoordinatorConfig = dataclasses.field(
default_factory=CoordinatorConfig
)
simulator: SimulatorConfig = dataclasses.field(default_factory=EmulatorConfig)
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ def setUp(self):
base_lib_dir + 'gles_swiftshader/', base_lib_dir
])

# Instantiate the config to extract default values.
config = config_classes.EmulatorLauncherConfig()
self._expected_env_vars = {
'ANDROID_HOME': '',
'ANDROID_SDK_ROOT': '',
'ANDROID_AVD_HOME': '',
'ANDROID_SDK_ROOT': config.android_sdk_root,
'ANDROID_AVD_HOME': config.android_avd_home,
'ANDROID_EMULATOR_KVM_DEVICE': '/dev/kvm',
'ANDROID_ADB_SERVER_PORT': '1234',
'LD_LIBRARY_PATH': ld_library_path,
Expand Down
92 changes: 35 additions & 57 deletions android_env/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from android_env.components import coordinator as coordinator_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators.emulator import emulator_simulator
from android_env.components.simulators.fake import fake_simulator
from android_env.proto import task_pb2

from google.protobuf import text_format
Expand All @@ -42,64 +43,41 @@ def _load_task(task_config: config_classes.TaskConfig) -> task_pb2.Task:
return task


def load(
task_path: str,
avd_name: str | None = None,
android_avd_home: str = '~/.android/avd',
android_sdk_root: str = '~/Android/Sdk',
emulator_path: str = '~/Android/Sdk/emulator/emulator',
adb_path: str = '~/Android/Sdk/platform-tools/adb',
run_headless: bool = False,
console_port: int | None = None,
) -> environment.AndroidEnv:
"""Loads an AndroidEnv instance.
Args:
task_path: Path to the task textproto file.
avd_name: Name of the AVD (Android Virtual Device).
android_avd_home: Path to the AVD (Android Virtual Device).
android_sdk_root: Root directory of the SDK.
emulator_path: Path to the emulator binary.
adb_path: Path to the ADB (Android Debug Bridge).
run_headless: If True, the emulator display is turned off.
console_port: The console port number; for connecting to an already running
device/emulator.
Returns:
env: An AndroidEnv instance.
"""
connect_to_existing_device = console_port is not None
if not connect_to_existing_device and avd_name is None:
raise ValueError('An avd name must be provided if launching an emulator.')

if connect_to_existing_device:
launcher_config = config_classes.EmulatorLauncherConfig(
emulator_console_port=console_port,
adb_port=console_port + 1,
grpc_port=8554,
)
else:
launcher_config = config_classes.EmulatorLauncherConfig(
avd_name=avd_name,
android_avd_home=os.path.expanduser(android_avd_home),
android_sdk_root=os.path.expanduser(android_sdk_root),
emulator_path=os.path.expanduser(emulator_path),
run_headless=run_headless,
gpu_mode='swiftshader_indirect',
)

# Create simulator.
simulator = emulator_simulator.EmulatorSimulator(
config=config_classes.EmulatorConfig(
emulator_launcher=launcher_config,
adb_controller=config_classes.AdbControllerConfig(
adb_path=os.path.expanduser(adb_path),
adb_server_port=5037,
),
)
)
def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv:
"""Loads an AndroidEnv instance."""

task = _load_task(config_classes.FilesystemTaskConfig(path=task_path))
task = _load_task(config.task)
task_manager = task_manager_lib.TaskManager(task)

match config.simulator:
case config_classes.EmulatorConfig():
_process_emulator_launcher_config(config.simulator)
simulator = emulator_simulator.EmulatorSimulator(config=config.simulator)
case config_classes.FakeSimulatorConfig():
simulator = fake_simulator.FakeSimulator(config=config.simulator)
case _:
raise ValueError('Unsupported simulator config: {config.simulator}')

coordinator = coordinator_lib.Coordinator(simulator, task_manager)
return environment.AndroidEnv(coordinator=coordinator)


def _process_emulator_launcher_config(
emulator_config: config_classes.EmulatorConfig,
) -> None:
"""Adjusts the configuration of the emulator depending on some conditions."""

# Expand the user directory if specified.
launcher_config = emulator_config.emulator_launcher
launcher_config.android_avd_home = os.path.expanduser(
launcher_config.android_avd_home
)
launcher_config.android_sdk_root = os.path.expanduser(
launcher_config.android_sdk_root
)
launcher_config.emulator_path = os.path.expanduser(
launcher_config.emulator_path
)
emulator_config.adb_controller.adb_path = os.path.expanduser(
emulator_config.adb_controller.adb_path
)
91 changes: 57 additions & 34 deletions android_env/loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
from unittest import mock

from absl.testing import absltest
from android_env import environment
from android_env import env_interface
from android_env import loader
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators.emulator import emulator_simulator
from android_env.components.simulators.fake import fake_simulator
from android_env.proto import task_pb2


Expand All @@ -35,24 +36,34 @@ class LoaderTest(absltest.TestCase):
@mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True)
@mock.patch.object(coordinator_lib, 'Coordinator', autospec=True)
@mock.patch.object(builtins, 'open', autospec=True)
def test_load(
def test_load_emulator(
self, mock_open, mock_coordinator, mock_simulator_class, mock_task_manager
):

# Arrange.
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.read.return_value = ''

env = loader.load(
task_path='some/path/',
avd_name='my_avd',
android_avd_home='~/.android/avd',
android_sdk_root='~/Android/Sdk',
emulator_path='~/Android/Sdk/emulator/emulator',
adb_path='~/Android/Sdk/platform-tools/adb',
run_headless=False,
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path='some/path/'),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
avd_name='my_avd',
android_avd_home='~/.android/avd',
android_sdk_root='~/Android/Sdk',
emulator_path='~/Android/Sdk/emulator/emulator',
run_headless=False,
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='~/Android/Sdk/platform-tools/adb',
),
),
)

self.assertIsInstance(env, environment.AndroidEnv)
# Act.
env = loader.load(config)

# Assert.
self.assertIsInstance(env, env_interface.AndroidEnvInterface)
mock_simulator_class.assert_called_with(
config=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
Expand All @@ -63,7 +74,7 @@ def test_load(
'~/Android/Sdk/emulator/emulator'
),
run_headless=False,
gpu_mode='swiftshader_indirect',
gpu_mode='swangle_indirect',
),
adb_controller=config_classes.AdbControllerConfig(
adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'),
Expand All @@ -77,31 +88,31 @@ def test_load(
)

@mock.patch.object(task_manager_lib, 'TaskManager', autospec=True)
@mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True)
@mock.patch.object(fake_simulator, 'FakeSimulator', autospec=True)
@mock.patch.object(coordinator_lib, 'Coordinator', autospec=True)
@mock.patch.object(builtins, 'open', autospec=True)
def test_load_existing_device(
def test_load_fake_simulator(
self, mock_open, mock_coordinator, mock_simulator_class, mock_task_manager
):

# Arrange.
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.read.return_value = ''

env = loader.load(
task_path='some/path/',
console_port=5554,
adb_path='~/Android/Sdk/platform-tools/adb',
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path='some/path/'),
simulator=config_classes.FakeSimulatorConfig(
screen_dimensions=(1234, 5678)
),
)

self.assertIsInstance(env, environment.AndroidEnv)
# Act.
env = loader.load(config)

# Assert.
self.assertIsInstance(env, env_interface.AndroidEnvInterface)
mock_simulator_class.assert_called_with(
config=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
emulator_console_port=5554, adb_port=5555, grpc_port=8554
),
adb_controller=config_classes.AdbControllerConfig(
adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'),
adb_server_port=5037,
),
config=config_classes.FakeSimulatorConfig(
screen_dimensions=(1234, 5678)
)
)
mock_coordinator.assert_called_with(
Expand All @@ -116,6 +127,8 @@ def test_load_existing_device(
def test_task(
self, mock_open, mock_coordinator, mock_simulator, mock_task_manager
):

# Arrange.
del mock_coordinator, mock_simulator
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.read.return_value = r'''
Expand All @@ -124,20 +137,30 @@ def test_task(
description: "Task for testing loader."
max_episode_sec: 0
'''

env = loader.load(
task_path='some/path/',
avd_name='my_avd',
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path='some/path/'),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
avd_name='my_avd'
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='~/Android/Sdk/platform-tools/adb',
),
),
)

# Act.
env = loader.load(config)

# Assert.
expected_task = task_pb2.Task()
expected_task.id = 'fake_task'
expected_task.name = 'Fake Task'
expected_task.description = 'Task for testing loader.'
expected_task.max_episode_sec = 0

mock_task_manager.assert_called_with(expected_task)
assert isinstance(env, environment.AndroidEnv)
self.assertIsInstance(env, env_interface.AndroidEnvInterface)


if __name__ == '__main__':
Expand Down
25 changes: 17 additions & 8 deletions examples/run_acme_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from acme.agents.tf import dqn
from acme.tf import networks
from android_env import loader
from android_env.components import config_classes
from android_env.wrappers import discrete_action_wrapper
from android_env.wrappers import float_pixels_wrapper
from android_env.wrappers import image_rescale_wrapper
Expand Down Expand Up @@ -58,14 +59,22 @@ def apply_wrappers(env):

def main(_):

with loader.load(
emulator_path=FLAGS.emulator_path,
android_sdk_root=FLAGS.android_sdk_root,
android_avd_home=FLAGS.android_avd_home,
avd_name=FLAGS.avd_name,
adb_path=FLAGS.adb_path,
task_path=FLAGS.task_path,
run_headless=False) as env:
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
emulator_path=FLAGS.emulator_path,
android_sdk_root=FLAGS.android_sdk_root,
android_avd_home=FLAGS.android_avd_home,
avd_name=FLAGS.avd_name,
run_headless=FLAGS.run_headless,
),
adb_controller=config_classes.AdbControllerConfig(
adb_path=FLAGS.adb_path
),
),
)
with loader.load(config) as env:

env = apply_wrappers(env)
env_spec = specs.make_environment_spec(env)
Expand Down
Loading

0 comments on commit 2636392

Please sign in to comment.