diff --git a/rsl_rl/runners/runner.py b/rsl_rl/runners/runner.py index b461a88..816f304 100644 --- a/rsl_rl/runners/runner.py +++ b/rsl_rl/runners/runner.py @@ -365,7 +365,7 @@ def _update(self) -> None: def load(self, path: str) -> Any: """Restores the agent and runner state from a file.""" - content = torch.load(path) + content = torch.load(path, map_location=self.device) assert "agent" in content assert "data" in content