Skip to content

Commit

Permalink
Resolve unsoundness caught by pytype --strict-none-binding.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 546944132
  • Loading branch information
DeepMind authored and copybara-github committed Jul 14, 2023
1 parent 5220c8b commit cdf19ac
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 12 deletions.
2 changes: 1 addition & 1 deletion android_env/components/app_screen_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, data: str | None = None):
self._data = data

@property
def data(self) -> str:
def data(self) -> Optional[str]:
return self._data

@property
Expand Down
9 changes: 7 additions & 2 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,17 @@ def _gather_simulator_signals(self) -> dict[str, np.ndarray]:

# Get current timestamp and update the delta.
now = time.time()
timestamp_delta = (0 if self._latest_observation_time == 0 else
(now - self._latest_observation_time) * 1e6)
timestamp_delta = (
0
if self._latest_observation_time == 0
else (now - self._latest_observation_time) * 1e6
)
self._latest_observation_time = now

# Grab pixels.
if self._interaction_rate_sec > 0:
if self._interaction_thread is None:
raise ValueError('InteractionThread not initalized')
pixels = self._interaction_thread.screenshot() # Async mode.
else:
pixels = self._simulator.get_screenshot() # Sync mode.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def load_state(
will be `ERROR` and the `error_message` field will be filled.
"""
snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME)
if self._snapshot_stub is None:
raise ValueError('snapshot_stub is not initialized.')
snapshot_list = self._snapshot_stub.ListSnapshots(
snapshot_service_pb2.SnapshotFilter(
statusFilter=snapshot_service_pb2.SnapshotFilter.LoadStatus.All
Expand Down Expand Up @@ -318,6 +320,8 @@ def save_state(
will be `ERROR` and the `error_message` field will be filled.
"""
snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME)
if self._snapshot_stub is None:
raise ValueError('snapshot_stub is not initialized.')
snapshot_result = self._snapshot_stub.SaveSnapshot(
snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name)
)
Expand Down
32 changes: 23 additions & 9 deletions android_env/components/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,20 @@ def __init__(

logging.info('Task config: %s', self._task)

@property
def setup_step_interpreter(
self,
) -> setup_step_interpreter.SetupStepInterpreter:
if self._setup_step_interpreter is None:
raise ValueError('Setup step interpreter not initialized')
return self._setup_step_interpreter

@property
def logcat_thread(self) -> logcat_thread.LogcatThread:
if self._logcat_thread is None:
raise ValueError('Logcat thread not initialized')
return self._logcat_thread

def stats(self) -> dict[str, Any]:
"""Returns a dictionary of stats.
Expand All @@ -108,7 +122,7 @@ def stats(self) -> dict[str, Any]:

def setup_task(self) -> None:
"""Performs one-off task setup.."""
self._setup_step_interpreter.interpret(self._task.setup_steps)
self.setup_step_interpreter.interpret(self._task.setup_steps)

def stop(self) -> None:
"""Suspends task processing."""
Expand All @@ -121,16 +135,15 @@ def start(
"""Starts task processing."""

self._start_logcat_thread(log_stream=log_stream)
self._logcat_thread.resume()
self.logcat_thread.resume()
self._start_dumpsys_thread(adb_call_parser_factory())
self._start_setup_step_interpreter(adb_call_parser_factory())

def reset_task(self) -> None:
"""Resets a task for a new run."""

self._logcat_thread.pause()
self._setup_step_interpreter.interpret(self._task.reset_steps)
self._logcat_thread.resume()
self.logcat_thread.pause()
self.setup_step_interpreter.interpret(self._task.reset_steps)
self.logcat_thread.resume()

# Reset some other variables.
if not self._is_bad_episode:
Expand All @@ -150,8 +163,7 @@ def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:
"""Performs one RL step."""

self._stats['episode_steps'] = 0

self._logcat_thread.line_ready().wait()
self.logcat_thread.line_ready().wait()
with self._lock:
extras = self._get_current_extras()

Expand All @@ -168,7 +180,7 @@ def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:

self._stats['episode_steps'] += 1

self._logcat_thread.line_ready().wait()
self.logcat_thread.line_ready().wait()
with self._lock:
reward = self._get_current_reward()
extras = self._get_current_extras()
Expand Down Expand Up @@ -196,6 +208,8 @@ def _determine_transition_fn(self) -> Callable[..., dm_env.TimeStep]:
"""Determines the type of RL transition will be used."""

# Check if user existed the task
if self._dumpsys_thread is None:
raise ValueError('DumpsysThread not initialized.')
if self._dumpsys_thread.check_user_exited():
self._increment_bad_state()
self._stats['reset_count_user_exited'] += 1
Expand Down

0 comments on commit cdf19ac

Please sign in to comment.