Skip to content

Commit

Permalink
feat: support policy gradient
Browse files Browse the repository at this point in the history
feat: support policy gradient
  • Loading branch information
Gaiejj authored Aug 21, 2023
2 parents b43ef9a + 5930bac commit c55d57a
Showing 1 changed file with 397 additions and 0 deletions.
397 changes: 397 additions & 0 deletions safepo/single_agent/pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,397 @@
# Copyright 2023 OmniSafeAI Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from __future__ import annotations

import os
import random
import sys
import time
from collections import deque

import numpy as np
try:
from isaacgym import gymutil
except ImportError:
pass
import torch
import torch.nn as nn
import torch.optim
from rich.progress import track
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader, TensorDataset

from safepo.common.buffer import VectorizedOnPolicyBuffer
from safepo.common.env import make_sa_mujoco_env, make_sa_isaac_env
from safepo.common.lagrange import Lagrange
from safepo.common.logger import EpochLogger
from safepo.common.model import ActorVCritic
from safepo.utils.config import single_agent_args, isaac_gym_map, parse_sim_params


def main(args, cfg_env=None):
# set the random seed, device and number of threads
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.set_num_threads(4)
device = torch.device(args.device)


if args.task not in isaac_gym_map.keys():
env, obs_space, act_space = make_sa_mujoco_env(
num_envs=args.num_envs, env_id=args.task, seed=args.seed
)
eval_env, _, _ = make_sa_mujoco_env(num_envs=1, env_id=args.task, seed=None)

else:
sim_params = parse_sim_params(args, cfg_env, None)
env = make_sa_isaac_env(args=args, cfg=cfg_env, sim_params=sim_params)
eval_env = env
obs_space = env.observation_space
act_space = env.action_space
args.num_envs = env.num_envs

# set training steps
local_steps_per_epoch = args.steps_per_epoch // args.num_envs
epochs = args.total_steps // args.steps_per_epoch
# create the actor-critic module
policy = ActorVCritic(
obs_dim=obs_space.shape[0],
act_dim=act_space.shape[0],
).to(device)
actor_optimizer = torch.optim.Adam(policy.actor.parameters(), lr=3e-4)
actor_scheduler = LinearLR(
actor_optimizer,
start_factor=1.0,
end_factor=0.0,
total_iters=epochs,
verbose=False,
)
reward_critic_optimizer = torch.optim.Adam(
policy.reward_critic.parameters(), lr=3e-4
)
cost_critic_optimizer = torch.optim.Adam(
policy.cost_critic.parameters(), lr=3e-4
)

# create the vectorized on-policy buffer
buffer = VectorizedOnPolicyBuffer(
obs_space=obs_space,
act_space=act_space,
size=local_steps_per_epoch,
device=device,
num_envs=args.num_envs,
)

# set up the logger
dict_args = vars(args)
logger = EpochLogger(
log_dir=args.log_dir,
seed=str(args.seed),
)
rew_deque = deque(maxlen=50)
cost_deque = deque(maxlen=50)
len_deque = deque(maxlen=50)
eval_rew_deque = deque(maxlen=50)
eval_cost_deque = deque(maxlen=50)
eval_len_deque = deque(maxlen=50)
logger.save_config(dict_args)
logger.setup_torch_saver(policy.actor)
logger.log("Start with training.")

# training loop
for epoch in range(epochs):
rollout_start_time = time.time()
obs, _ = env.reset()
obs = torch.as_tensor(obs, dtype=torch.float32, device=device)
ep_ret, ep_cost, ep_len = (
np.zeros(args.num_envs),
np.zeros(args.num_envs),
np.zeros(args.num_envs),
)
# collect samples until we have enough to update
for steps in range(local_steps_per_epoch):
with torch.no_grad():
act, log_prob, value_r, value_c = policy.step(obs, deterministic=False)
next_obs, reward, cost, terminated, truncated, info = env.step(
act.detach().squeeze().cpu().numpy()
)
ep_ret += reward.cpu().numpy() if args.task in isaac_gym_map.keys() else reward
ep_cost += cost.cpu().numpy() if args.task in isaac_gym_map.keys() else cost
ep_len += 1
next_obs, reward, cost, terminated, truncated = (
torch.as_tensor(x, dtype=torch.float32, device=device)
for x in (next_obs, reward, cost, terminated, truncated)
)
if "final_observation" in info:
info["final_observation"] = np.array(
[
array if array is not None else np.zeros(obs.shape[-1])
for array in info["final_observation"]
],
)
info["final_observation"] = torch.as_tensor(
info["final_observation"],
dtype=torch.float32,
device=device,
)
buffer.store(
obs=obs,
act=act,
reward=reward,
cost=cost,
value_r=value_r,
value_c=value_c,
log_prob=log_prob,
)

obs = next_obs
epoch_end = steps >= local_steps_per_epoch - 1
for idx, (done, time_out) in enumerate(zip(terminated, truncated)):
if epoch_end or done or time_out:
last_value_r = torch.zeros(1, device=device)
last_value_c = torch.zeros(1, device=device)
if not done:
if epoch_end:
with torch.no_grad():
_, _, last_value_r, last_value_c = policy.step(
obs[idx], deterministic=False
)
if time_out:
with torch.no_grad():
_, _, last_value_r, last_value_c = policy.step(
info["final_observation"][idx], deterministic=False
)
last_value_r = last_value_r.unsqueeze(0)
last_value_c = last_value_c.unsqueeze(0)
if done or time_out:
rew_deque.append(ep_ret[idx])
cost_deque.append(ep_cost[idx])
len_deque.append(ep_len[idx])
logger.store(
**{
"Metrics/EpRet": np.mean(rew_deque),
"Metrics/EpCost": np.mean(cost_deque),
"Metrics/EpLen": np.mean(len_deque),
}
)
ep_ret[idx] = 0.0
ep_cost[idx] = 0.0
ep_len[idx] = 0.0
logger.logged = False

buffer.finish_path(
last_value_r=last_value_r, last_value_c=last_value_c, idx=idx
)
rollout_end_time = time.time()

eval_start_time = time.time()

eval_episodes = 1 if epoch < epochs - 1 else 10
if args.use_eval:
for _ in range(eval_episodes):
eval_done = False
eval_obs, _ = eval_env.reset()
eval_obs = torch.as_tensor(eval_obs, dtype=torch.float32, device=device)
eval_rew, eval_cost, eval_len = 0.0, 0.0, 0.0
while not eval_done:
with torch.no_grad():
act, log_prob, value_r, value_c = policy.step(
eval_obs, deterministic=True
)
next_obs, reward, cost, terminated, truncated, info = env.step(
act.detach().squeeze().cpu().numpy()
)
next_obs = torch.as_tensor(
next_obs, dtype=torch.float32, device=device
)
eval_rew += reward
eval_cost += cost
eval_len += 1
eval_done = terminated[0] or truncated[0]
eval_obs = next_obs
eval_rew_deque.append(eval_rew)
eval_cost_deque.append(eval_cost)
eval_len_deque.append(eval_len)
logger.store(
**{
"Metrics/EvalEpRet": np.mean(eval_rew),
"Metrics/EvalEpCost": np.mean(eval_cost),
"Metrics/EvalEpLen": np.mean(eval_len),
}
)

eval_end_time = time.time()

# update lagrange multiplier
ep_costs = logger.get_stats("Metrics/EpCost")

# update policy
data = buffer.get()
old_distribution = policy.actor(data["obs"])

# comnpute advantage
advantage = data["adv_r"]

dataloader = DataLoader(
dataset=TensorDataset(
data["obs"],
data["act"],
data["log_prob"],
data["target_value_r"],
data["target_value_c"],
advantage,
),
batch_size=64,
shuffle=True,
)
update_counts = 0
final_kl = torch.ones_like(old_distribution.loc)
for i in range(40):
for (
obs_b,
act_b,
log_prob_b,
target_value_r_b,
target_value_c_b,
adv_b,
) in dataloader:
reward_critic_optimizer.zero_grad()
loss_r = nn.functional.mse_loss(policy.reward_critic(obs_b), target_value_r_b)
for param in policy.reward_critic.parameters():
loss_r += param.pow(2).sum() * 0.001
loss_r.backward()
clip_grad_norm_(policy.reward_critic.parameters(), 40.0)
reward_critic_optimizer.step()

cost_critic_optimizer.zero_grad()
loss_c = nn.functional.mse_loss(policy.cost_critic(obs_b), target_value_c_b)
for param in policy.cost_critic.parameters():
loss_c += param.pow(2).sum() * 0.001
loss_c.backward()
clip_grad_norm_(policy.cost_critic.parameters(), 40.0)
cost_critic_optimizer.step()

distribution = policy.actor(obs_b)
log_prob = distribution.log_prob(act_b).sum(dim=-1)
ratio = torch.exp(log_prob - log_prob_b)
ratio_cliped = torch.clamp(ratio, 0.8, 1.2)
loss_pi = -torch.min(ratio * adv_b, ratio_cliped * adv_b).mean()
actor_optimizer.zero_grad()
loss_pi.backward()
clip_grad_norm_(policy.actor.parameters(), 40.0)
actor_optimizer.step()

logger.store(
**{
"Loss/Loss_reward_critic": loss_r.mean().item(),
"Loss/Loss_cost_critic": loss_c.mean().item(),
"Loss/Loss_actor": loss_pi.mean().item(),
}
)

new_distribution = policy.actor(data["obs"])
kl = (
torch.distributions.kl.kl_divergence(old_distribution, new_distribution)
.sum(-1, keepdim=True)
.mean()
.item()
)
final_kl = kl
update_counts += 1
if kl > 0.02:
break
update_end_time = time.time()
actor_scheduler.step()

if not logger.logged:
# log data
logger.log_tabular("Metrics/EpRet")
logger.log_tabular("Metrics/EpCost")
logger.log_tabular("Metrics/EpLen")
if args.use_eval:
logger.log_tabular("Metrics/EvalEpRet")
logger.log_tabular("Metrics/EvalEpCost")
logger.log_tabular("Metrics/EvalEpLen")
logger.log_tabular("Train/Epoch", epoch + 1)
logger.log_tabular("Train/TotalSteps", (epoch + 1) * args.steps_per_epoch)
logger.log_tabular("Train/StopIter", update_counts)
logger.log_tabular("Train/KL", final_kl)
logger.log_tabular("Train/LR", actor_scheduler.get_last_lr()[0])
logger.log_tabular("Loss/Loss_reward_critic")
logger.log_tabular("Loss/Loss_cost_critic")
logger.log_tabular("Loss/Loss_actor")
logger.log_tabular("Time/Rollout", rollout_end_time - rollout_start_time)
if args.use_eval:
logger.log_tabular("Time/Eval", eval_end_time - eval_start_time)
logger.log_tabular("Time/Update", update_end_time - eval_end_time)
logger.log_tabular("Time/Total", update_end_time - rollout_start_time)
logger.log_tabular("Value/RewardAdv", data["adv_r"].mean().item())
logger.log_tabular("Value/CostAdv", data["adv_c"].mean().item())

logger.dump_tabular()
if (epoch+1) % 100 == 0 or epoch == 0:
logger.torch_save(itr=epoch)
if args.task not in isaac_gym_map.keys():
logger.save_state(
state_dict={
"Normalizer": env.obs_rms,
},
itr = epoch
)
logger.close()


if __name__ == "__main__":
args, cfg_env = single_agent_args()
relpath = time.strftime("%Y-%m-%d-%H-%M-%S")
subfolder = "-".join(["seed", str(args.seed).zfill(3)])
relpath = "-".join([subfolder, relpath])
algo = os.path.basename(__file__).split(".")[0]
args.log_dir = os.path.join(args.log_dir, args.experiment, args.task, algo, relpath)
if not args.write_terminal:
terminal_log_name = "terminal.log"
error_log_name = "error.log"
terminal_log_name = f"seed{args.seed}_{terminal_log_name}"
error_log_name = f"seed{args.seed}_{error_log_name}"
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir, exist_ok=True)
with open(
os.path.join(
f"{args.log_dir}",
terminal_log_name,
),
"w",
encoding="utf-8",
) as f_out:
sys.stdout = f_out
with open(
os.path.join(
f"{args.log_dir}",
error_log_name,
),
"w",
encoding="utf-8",
) as f_error:
sys.stderr = f_error
main(args, cfg_env)
else:
main(args, cfg_env)

0 comments on commit c55d57a

Please sign in to comment.