Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions scripts/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,29 @@
torch.backends.cudnn.benchmark = False


class Tee:
"""Duplicates output to both terminal and a log file."""

def __init__(self, file_path, mode="w", original_stream=None):
self.file = open(file_path, mode)
self.original_stream = original_stream

def write(self, message):
if self.original_stream:
self.original_stream.write(message)
self.original_stream.flush()
self.file.write(message)
self.file.flush()

def flush(self):
if self.original_stream:
self.original_stream.flush()
self.file.flush()

def close(self):
self.file.close()


@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
"""Train with RSL-RL agent."""
Expand Down Expand Up @@ -154,6 +177,15 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
log_dir += f"_{agent_cfg.run_name}"
log_dir = os.path.join(log_root_path, log_dir)

# set up terminal output logging
os.makedirs(log_dir, exist_ok=True)
log_file_path = os.path.join(log_dir, "training.log")
stdout_tee = Tee(log_file_path, mode="w", original_stream=sys.stdout)
stderr_tee = Tee(log_file_path, mode="a", original_stream=sys.stderr)
sys.stdout = stdout_tee
sys.stderr = stderr_tee
print(f"[INFO] Terminal output is being logged to: {log_file_path}")

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

Expand Down Expand Up @@ -208,6 +240,12 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# close the simulator
env.close()

# restore original stdout/stderr and close log files
sys.stdout = stdout_tee.original_stream
sys.stderr = stderr_tee.original_stream
stdout_tee.close()
stderr_tee.close()


if __name__ == "__main__":
# run the main function
Expand Down