Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/notebooks/XLA_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"outputs": [],
"source": [
"from jax import __version__\n",
"\n",
"print(__version__)"
]
},
Expand Down
1 change: 1 addition & 0 deletions examples/notebooks/XLA_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
],
"source": [
"from torch import __version__\n",
"\n",
"print(__version__)"
]
},
Expand Down
147 changes: 147 additions & 0 deletions examples/torch_ppo/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse
import os
import time

import gym
import torch

from animus import set_global_seed

from src.acmodel import ACModel
from src.agent import Agent
from src.settings import LOGDIR
from src.utils import ParallelEnv, synthesize

if __name__ == "__main__":
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--env", required=True, help="name of the environment (REQUIRED)"
)
# parser.add_argument(
# "--model", required=True, help="name of the trained model (REQUIRED)"
# )
parser.add_argument(
"--episodes",
type=int,
default=100,
help="number of episodes of evaluation (default: 100)",
)
parser.add_argument("--seed", type=int, default=0, help="random seed (default: 0)")
parser.add_argument(
"--procs", type=int, default=16, help="number of processes (default: 16)"
)
parser.add_argument(
"--argmax",
action="store_true",
default=False,
help="action with highest probability is selected",
)
parser.add_argument(
"--worst-episodes-to-show",
type=int,
default=10,
help="how many worst episodes to show",
)
parser.add_argument(
"--recurrent", action="store_true", default=False, help="add a LSTM to the model"
)
args = parser.parse_args()

# Set seed for all randomness sources
set_global_seed(args.seed)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

# Load environments

envs = []
for i in range(args.procs):
env = gym.make(args.env)
envs.append(env)
env = ParallelEnv(envs)
print("Environments loaded\n")

# Load agent
acmodel = ACModel(
observation_space=env.observation_space,
action_space=env.action_space,
recurrent=args.recurrent,
)
checkpoint = torch.load(
os.path.join(LOGDIR, "acmodel.best.pth"),
map_location=lambda storage, loc: storage,
)
acmodel.load_state_dict(checkpoint)
agent = Agent(
acmodel=acmodel,
device=device,
argmax=args.argmax,
)
print("Agent loaded\n")

# Initialize logs
logs = {"num_steps_per_episode": [], "return_per_episode": []}

# Run agent
start_time = time.time()
obss = env.reset()

log_done_counter = 0
log_episode_return = torch.zeros(args.procs, device=device)
log_episode_num_steps = torch.zeros(args.procs, device=device)

while log_done_counter < args.episodes:
actions = agent.get_actions(obss)
obss, rewards, dones, _ = env.step(actions)
agent.analyze_feedbacks(rewards, dones)

log_episode_return += torch.tensor(rewards, device=device, dtype=torch.float)
log_episode_num_steps += torch.ones(args.procs, device=device)

for i, done in enumerate(dones):
if done:
log_done_counter += 1
logs["return_per_episode"].append(log_episode_return[i].item())
logs["num_steps_per_episode"].append(log_episode_num_steps[i].item())

mask = 1 - torch.tensor(dones, device=device, dtype=torch.float)
log_episode_return *= mask
log_episode_num_steps *= mask

end_time = time.time()

# Print logs
num_steps = sum(logs["num_steps_per_episode"])
fps = num_steps / (end_time - start_time)
duration = int(end_time - start_time)
return_per_episode = synthesize(logs["return_per_episode"])
num_steps_per_episode = synthesize(logs["num_steps_per_episode"])

print(
"S {} | FPS {:.0f} | D {} | R:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {}".format(
num_steps,
fps,
duration,
*return_per_episode.values(),
*num_steps_per_episode.values(),
)
)

# Print worst episodes
n = args.worst_episodes_to_show
if n > 0:
print("\n{} worst episodes:".format(n))

indexes = sorted(
range(len(logs["return_per_episode"])),
key=lambda k: logs["return_per_episode"][k],
)
for i in indexes[:n]:
print(
"- episode {}: R={}, F={}".format(
i, logs["return_per_episode"][i], logs["num_steps_per_episode"][i]
)
)
94 changes: 94 additions & 0 deletions examples/torch_ppo/src/acmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from torch.distributions.categorical import Categorical
import torch.nn as nn
import torch.nn.functional as F


# Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
def init_params(m):
classname = m.__class__.__name__
if classname.find("Linear") != -1:
m.weight.data.normal_(0, 1)
m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
if m.bias is not None:
m.bias.data.fill_(0)


class ACModel(nn.Module):
def __init__(self, observation_space, action_space, recurrent=False):
super().__init__()

# Decide which components are enabled
self.recurrent = recurrent

# Define embedder
self.embedder = nn.Sequential(
nn.Linear(observation_space.shape[0], 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
)
self.embedding_size = 128
# TODO: add image support
# self.embedder = nn.Sequential(
# nn.Conv2d(3, 16, (2, 2)),
# nn.ReLU(),
# nn.MaxPool2d((2, 2)),
# nn.Conv2d(16, 32, (2, 2)),
# nn.ReLU(),
# nn.Conv2d(32, 64, (2, 2)),
# nn.ReLU(),
# )
# n = obs_space["image"][0]
# m = obs_space["image"][1]
# self.embedding_size = ((n - 1) // 2 - 2) * ((m - 1) // 2 - 2) * 64

# Define memory
if self.recurrent:
self.memory_rnn = nn.LSTMCell(self.embedding_size, self.semi_memory_size)

# Resize embedding
self.embedding_size = self.semi_memory_size

# Define actor's model
self.actor = nn.Sequential(
nn.Linear(self.embedding_size, 64), nn.Tanh(), nn.Linear(64, action_space.n)
)

# Define critic's model
self.critic = nn.Sequential(
nn.Linear(self.embedding_size, 64), nn.Tanh(), nn.Linear(64, 1)
)

# Initialize parameters correctly
self.apply(init_params)

@property
def memory_size(self):
return 2 * self.semi_memory_size

@property
def semi_memory_size(self):
return self.embedding_size

def forward(self, x, memory=None):
x = self.embedder(x)

if self.recurrent:
hidden = (
memory[:, : self.semi_memory_size],
memory[:, self.semi_memory_size :],
)
hidden = self.memory_rnn(x, hidden)
embedding = hidden[0]
memory = torch.cat(hidden, dim=1)
else:
embedding = x

x = self.actor(embedding)
dist = Categorical(logits=F.log_softmax(x, dim=1))

x = self.critic(embedding)
value = x.squeeze(1)

return dist, value, memory
51 changes: 51 additions & 0 deletions examples/torch_ppo/src/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import torch


class Agent:
"""An agent. Used for model inference.

It is able:
- to choose an action given an observation,
- to analyze the feedback (i.e. reward and done state) of its action."""

def __init__(self, acmodel, device, argmax=False, num_envs=1):
self.acmodel = acmodel
self.device = device
self.argmax = argmax
self.num_envs = num_envs

if self.acmodel.recurrent:
self.memories = torch.zeros(
self.num_envs, self.acmodel.memory_size, device=self.device
)
else:
self.memories = None

self.acmodel.to(self.device)
self.acmodel.eval()

def get_actions(self, obss):
with torch.no_grad():
obss = torch.tensor(np.array(obss), device=self.device, dtype=torch.float)
dist, _, self.memories = self.acmodel(obss, self.memories)

if self.argmax:
actions = dist.probs.max(1, keepdim=True)[1]
else:
actions = dist.sample()

return actions.cpu().numpy()

def get_action(self, obs):
return self.get_actions([obs])[0]

def analyze_feedbacks(self, rewards, dones):
if self.acmodel.recurrent:
masks = 1 - torch.tensor(
dones, dtype=torch.float, device=self.device
).unsqueeze(1)
self.memories *= masks

def analyze_feedback(self, reward, done):
return self.analyze_feedbacks([reward], [done])
1 change: 1 addition & 0 deletions examples/torch_ppo/src/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
LOGDIR = "./logs_ppo"
Loading