Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zbranch #14

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,70 @@
__pycache__
*.pyc
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so
*.c

# logs
runs/
checkpoints/

#other
.DS_Store

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log

# Sphinx documentation
docs/_build/

# PyBuilder
target/

#Ipython Notebook
.ipynb_checkpoints
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ Install most recent nightly build (version '0.1.10+2fd4d08' or later) of PyTorch
pip install git+https://github.com/pytorch/pytorch
`

## Dependencies
* pytorch
* torchvision
* universe (for now)
* [tensorboard logger](https://github.com/TeamHG-Memex/tensorboard_logger)

## Results

With 16 processes it converges for PongDeterministic-v3 in 15 minutes.
Expand Down
54 changes: 42 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
import argparse
import os
import sys
import math
import time

import torch
import torch.optim as optim
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import tensorboard_logger as tb

import my_optim
from envs import create_atari_env
from model import ActorCritic
from train import train
from test import test
import my_optim
from utils import logger
from utils.shared_memory import SharedCounter


logger = logger.getLogger('main')

# Based on
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
Expand All @@ -37,11 +46,25 @@
help='environment to train on (default: PongDeterministic-v3)')
parser.add_argument('--no-shared', default=False, metavar='O',
help='use an optimizer without shared momentum.')
parser.add_argument('--max-episode-count', type=int, default=math.inf,
help='maximum number of episodes to run per process.')
parser.add_argument('--debug', action='store_true', default=False,
help='run in a way its easier to debug')
parser.add_argument('--short-description', default='no_descr',
help='Short description of the run params, (used in tensorboard)')

def setup_loggings(args):
logger.debug('CONFIGURATION: {}'.format(args))

cur_path = os.path.dirname(os.path.realpath(__file__))
args.summ_base_dir = (cur_path+'/runs/{}/{}({})').format(args.env_name,
time.strftime('%d.%m-%H.%M'), args.short_description)
logger.info('logging run logs to {}'.format(args.summ_base_dir))
tb.configure(args.summ_base_dir)

if __name__ == '__main__':
args = parser.parse_args()

setup_loggings(args)
torch.manual_seed(args.seed)

env = create_atari_env(args.env_name)
Expand All @@ -54,16 +77,23 @@
else:
optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
optimizer.share_memory()

gl_step_cnt = SharedCounter()

if not args.debug:
processes = []

processes = []

p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
p.start()
processes.append(p)

for rank in range(0, args.num_processes):
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
p = mp.Process(target=test, args=(args.num_processes, args,
shared_model, gl_step_cnt))
p.start()
processes.append(p)
for p in processes:
p.join()
for rank in range(0, args.num_processes):
p = mp.Process(target=train, args=(rank, args, shared_model,
gl_step_cnt, optimizer))
p.start()
processes.append(p)
for p in processes:
p.join()
else: ## debug is enabled
# run only one process in a main, easier to debug
train(0, args, shared_model, gl_step_cnt, optimizer)
2 changes: 0 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(self, num_inputs, action_space):
num_outputs = action_space.n
self.critic_linear = nn.Linear(256, 1)
self.actor_linear = nn.Linear(256, num_outputs)

self.apply(weights_init)
self.actor_linear.weight.data = normalized_columns_initializer(
self.actor_linear.weight.data, 0.01)
Expand All @@ -66,7 +65,6 @@ def forward(self, inputs):
x = F.elu(self.conv2(x))
x = F.elu(self.conv3(x))
x = F.elu(self.conv4(x))

x = x.view(-1, 32 * 3 * 3)
hx, cx = self.lstm(x, (hx, cx))
x = hx
Expand Down
43 changes: 42 additions & 1 deletion my_optim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import torch
import torch.optim as optim

class SharedAdam(optim.Adam):
Expand All @@ -12,13 +13,53 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['step'] = torch.zeros(1)
state['exp_avg'] = p.data.new().resize_as_(p.data).zero_()
state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_()

def share_memory(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'].share_memory_()
state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step'][0]
bias_correction2 = 1 - beta2 ** state['step'][0]
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

p.data.addcdiv_(-step_size, exp_avg, denom)

return loss
21 changes: 17 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import math
import os
import sys
import time

import torch
import torch.nn.functional as F
import torch.optim as optim
import tensorboard_logger as tb

from envs import create_atari_env
from model import ActorCritic
from torch.autograd import Variable
from torchvision import datasets, transforms
import time
from collections import deque
from utils import logger

logger = logger.getLogger('test')

def test(rank, args, shared_model):
def test(rank, args, shared_model, gl_step_cnt):
torch.manual_seed(args.seed + rank)

env = create_atari_env(args.env_name)
Expand All @@ -30,6 +34,8 @@ def test(rank, args, shared_model):

start_time = time.time()

local_episode_num = 0

# a quick hack to prevent the agent from stucking
actions = deque(maxlen=100)
episode_length = 0
Expand Down Expand Up @@ -59,10 +65,17 @@ def test(rank, args, shared_model):
done = True

if done:
print("Time {}, episode reward {}, episode length {}".format(
passed_time = time.time() - start_time
local_episode_num += 1
global_step_count = gl_step_cnt.get_value()

logger.info("Time {}, episode reward {}, episode length {}".format(
time.strftime("%Hh %Mm %Ss",
time.gmtime(time.time() - start_time)),
time.gmtime(passed_time)),
reward_sum, episode_length))
tb.log_value('steps_second', global_step_count / passed_time, global_step_count)
tb.log_value('reward', reward_sum, global_step_count)

reward_sum = 0
episode_length = 0
actions.clear()
Expand Down
31 changes: 22 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@
import torch
import torch.nn.functional as F
import torch.optim as optim

from envs import create_atari_env
from model import ActorCritic
from torch.autograd import Variable
from torchvision import datasets, transforms
from utils import logger

logger = logger.getLogger('main')

def ensure_shared_grads(model, shared_model):
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
if shared_param.grad is not None:
return
shared_param._grad = param.grad


def train(rank, args, shared_model, optimizer=None):
def train(rank, args, shared_model, gl_step_count, optimizer=None):
torch.manual_seed(args.seed + rank)

env = create_atari_env(args.env_name)
Expand All @@ -31,16 +33,27 @@ def train(rank, args, shared_model, optimizer=None):

model.train()

values = []
log_probs = []

state = env.reset()
state = torch.from_numpy(state)
done = True

episode_length = 0
episode_count = 0

while True:

values = []
log_probs = []
rewards = []
entropies = []

if episode_count == args.max_episode_count:
logger.info('Maxiumum episode count {} reached..'.format(args.max_episode_count))
# TODO make sure if no train process is running test.py closes as well
break

episode_length += 1

# Sync with the shared model
model.load_state_dict(shared_model.state_dict())
if done:
Expand All @@ -50,10 +63,6 @@ def train(rank, args, shared_model, optimizer=None):
cx = Variable(cx.data)
hx = Variable(hx.data)

values = []
log_probs = []
rewards = []
entropies = []

for step in range(args.num_steps):
value, logit, (hx, cx) = model(
Expand All @@ -72,6 +81,7 @@ def train(rank, args, shared_model, optimizer=None):

if done:
episode_length = 0
episode_count += 1
state = env.reset()

state = torch.from_numpy(state)
Expand All @@ -82,6 +92,9 @@ def train(rank, args, shared_model, optimizer=None):
if done:
break

# increment global step count
gl_step_count.increment_by(step)

R = torch.zeros(1, 1)
if not done:
value, _, _ = model((Variable(state.unsqueeze(0)), (hx, cx)))
Expand Down
Loading