From bdf26c49d5bf9babb42e38b48d865471875bb72d Mon Sep 17 00:00:00 2001 From: Shambhuraj Sawant Date: Wed, 21 Aug 2024 15:53:10 +0200 Subject: [PATCH] rmse fix for ppo and sac --- safe_control_gym/controllers/ppo/ppo.py | 2 +- safe_control_gym/controllers/sac/sac.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/safe_control_gym/controllers/ppo/ppo.py b/safe_control_gym/controllers/ppo/ppo.py index f6b438e2e..d702f5fa9 100644 --- a/safe_control_gym/controllers/ppo/ppo.py +++ b/safe_control_gym/controllers/ppo/ppo.py @@ -252,7 +252,7 @@ def run(self, if done: assert 'episode' in info ep_rmse_mean.append(np.array(mse).mean()**0.5) - ep_rmse_std.append(np.array(mse).std()**0.5) + ep_rmse_std.append(np.array(mse).std()) mse = [] ep_returns.append(info['episode']['r']) ep_lengths.append(info['episode']['l']) diff --git a/safe_control_gym/controllers/sac/sac.py b/safe_control_gym/controllers/sac/sac.py index 2b2db7a8f..4ff5f28ec 100644 --- a/safe_control_gym/controllers/sac/sac.py +++ b/safe_control_gym/controllers/sac/sac.py @@ -251,7 +251,7 @@ def run(self, env=None, render=False, n_episodes=10, verbose=False, **kwargs): if done: assert 'episode' in info ep_rmse_mean.append(np.array(mse).mean()**0.5) - ep_rmse_std.append(np.array(mse).std()**0.5) + ep_rmse_std.append(np.array(mse).std()) mse = [] ep_returns.append(info['episode']['r']) ep_lengths.append(info['episode']['l'])