Skip to content

Commit c30f677

Browse files
committed
[Environment] Fix lib CI failures
ghstack-source-id: 26531b2b414910ff56eea5e1d08f9c6c627ff72f Pull-Request-resolved: #2923
1 parent e4733b8 commit c30f677

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

torchrl/data/datasets/atari_dqn.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ def _download_and_preproc(self):
508508
if not os.listdir(tempdir):
509509
os.makedirs(tempdir, exist_ok=True)
510510
# get the list of runs
511+
try:
512+
subprocess.run(
513+
["gsutil", "version"], check=True, capture_output=True
514+
)
515+
except subprocess.CalledProcessError:
516+
raise RuntimeError("gsutil is not installed or not found in PATH.")
511517
command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs"
512518
output = subprocess.run(
513519
command, shell=True, capture_output=True
@@ -520,9 +526,7 @@ def _download_and_preproc(self):
520526
self.remote_gz_files = self._list_runs(None, files)
521527
remote_gz_files = list(self.remote_gz_files)
522528
if not len(remote_gz_files):
523-
raise RuntimeError(
524-
"Could not load the file list. Did you install gsutil?"
525-
)
529+
raise RuntimeError("No files in file list.")
526530

527531
total_runs = remote_gz_files[-1]
528532
if self.num_procs == 0:

torchrl/envs/libs/smacv2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821
228228
dtype=torch.bool,
229229
device=self.device,
230230
)
231-
self.action_spec = self._make_action_spec()
231+
self.full_action_spec = self._make_action_spec()
232232
self.observation_spec = self._make_observation_spec()
233233

234234
def _init_env(self) -> None:
@@ -356,7 +356,7 @@ def _reset(
356356
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
357357
# perform actions
358358
action = tensordict.get(("agents", "action"))
359-
action_np = self.action_spec.to_numpy(action)
359+
action_np = self.full_action_spec[self.action_key].to_numpy(action)
360360

361361
# Actions are validated by the environment.
362362
try:

0 commit comments

Comments
 (0)