Skip to content

Replaced deprecated scipy and torch functions. #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 37 additions & 62 deletions jacobian-vs-perturbation.ipynb

Large diffs are not rendered by default.

58 changes: 39 additions & 19 deletions overfit_atari.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,42 @@
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License

from __future__ import print_function
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously ;)
import warnings

warnings.filterwarnings("ignore") # mute warnings, live dangerously ;)

import torch
from torch.autograd import Variable
import torch.nn.functional as F

import gym, sys
import numpy as np
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]
import cv2

sys.path.append('..')
sys.path.append("..")
from visualize_atari import *

prepro = lambda img: imresize(img[35:195].mean(2), (80,80)).astype(np.float32).reshape(1,80,80)/255.
prepro = (
lambda img: cv2.resize(src=img[35:195].mean(2), dsize=(80, 80))
.astype(np.float32)
.reshape(1, 80, 80)
/ 255.0
)


class OverfitAtari():
class OverfitAtari:
def __init__(self, env_name, expert_dir, seed=0):
self.atari = gym.make(env_name) ; self.atari.seed(seed)
self.atari = gym.make(env_name)
self.atari.seed(seed)
self.action_space = self.atari.action_space
self.expert = NNPolicy(channels=1, num_actions=self.action_space.n)
self.expert.try_load(expert_dir)
self.cx = Variable(torch.zeros(1, 256)) # lstm memory vector
self.hx = Variable(torch.zeros(1, 256)) # lstm activation vector
self.cx = Variable(torch.zeros(1, 256)) # lstm memory vector
self.hx = Variable(torch.zeros(1, 256)) # lstm activation vector

def seed(self, s):
self.atari.seed(s) ; torch.manual_seed(s)
self.atari.seed(s)
torch.manual_seed(s)

def reset(self):
self.cx = Variable(torch.zeros(1, 256))
Expand All @@ -35,15 +45,25 @@ def reset(self):

def step(self, action):
state, reward, done, info = self.atari.step(action)

expert_state = torch.Tensor(prepro(state)) # get expert policy and incorporate it into environment
_, logit, (hx, cx) = self.expert((Variable(expert_state.view(1,1,80,80)), (self.hx, self.cx)))

expert_state = torch.Tensor(
prepro(state)
) # get expert policy and incorporate it into environment
_, logit, (hx, cx) = self.expert(
(Variable(expert_state.view(1, 1, 80, 80)), (self.hx, self.cx))
)
self.hx, self.cx = Variable(hx.data), Variable(cx.data)

expert_action = int(F.softmax(logit).data.max(1)[1][0,0])
target = torch.zeros(logit.size()) ; target[0,expert_action] = 1
j = 72 ; k = 5
expert_action = expert_action if False else np.random.randint(self.atari.action_space.n)

expert_action = int(F.softmax(logit).data.max(1)[1][0, 0])
target = torch.zeros(logit.size())
target[0, expert_action] = 1
j = 72
k = 5
expert_action = (
expert_action if False else np.random.randint(self.atari.action_space.n)
)
for i in range(self.atari.action_space.n):
state[37:41, j + k*i: j+1+k*i,:] = 250 if expert_action == i else 50
state[37:41, j + k * i : j + 1 + k * i, :] = (
250 if expert_action == i else 50
)
return state, reward, done, target
26 changes: 17 additions & 9 deletions policy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License

from __future__ import print_function
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously ;)
import warnings

warnings.filterwarnings("ignore") # mute warnings, live dangerously ;)

import torch
from torch.autograd import Variable
Expand All @@ -10,17 +12,19 @@

import glob
import numpy as np
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]

class NNPolicy(torch.nn.Module): # an actor-critic neural network

class NNPolicy(torch.nn.Module): # an actor-critic neural network
def __init__(self, channels, num_actions):
super(NNPolicy, self).__init__()
self.conv1 = nn.Conv2d(channels, 32, 3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.lstm = nn.LSTMCell(32 * 5 * 5, 256)
self.critic_linear, self.actor_linear = nn.Linear(256, 1), nn.Linear(256, num_actions)
self.critic_linear, self.actor_linear = nn.Linear(256, 1), nn.Linear(
256, num_actions
)

def forward(self, inputs):
inputs, (hx, cx) = inputs
Expand All @@ -32,11 +36,15 @@ def forward(self, inputs):
hx, cx = self.lstm(x, (hx, cx))
return self.critic_linear(hx), self.actor_linear(hx), (hx, cx)

def try_load(self, save_dir, checkpoint='*.tar'):
paths = glob.glob(save_dir + checkpoint) ; step = 0
def try_load(self, save_dir, checkpoint="*.tar"):
paths = glob.glob(save_dir + checkpoint)
step = 0
if len(paths) > 0:
ckpts = [int(s.split('.')[-2]) for s in paths]
ix = np.argmax(ckpts) ; step = ckpts[ix]
ckpts = [int(s.split(".")[-2]) for s in paths]
ix = np.argmax(ckpts)
step = ckpts[ix]
self.load_state_dict(torch.load(paths[ix]))
print("\tno saved models") if step is 0 else print("\tloaded model: {}".format(paths[ix]))
print("\tno saved models") if step is 0 else print(
"\tloaded model: {}".format(paths[ix])
)
return step
51 changes: 31 additions & 20 deletions rollout.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,54 @@
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License

from __future__ import print_function
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously ;)
import warnings

warnings.filterwarnings("ignore") # mute warnings, live dangerously ;)

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

import numpy as np
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]
import cv2

prepro = (
lambda img: cv2.resize(src=img[35:195].mean(2), dsize=(80, 80))
.astype(np.float32)
.reshape(1, 80, 80)
/ 255.0
)

prepro = lambda img: imresize(img[35:195].mean(2), (80,80)).astype(np.float32).reshape(1,80,80)/255.

def rollout(model, env, max_ep_len=3e3, render=False):
history = {'ins': [], 'logits': [], 'values': [], 'outs': [], 'hx': [], 'cx': []}
state = torch.Tensor(prepro(env.reset())) # get first state
episode_length, epr, eploss, done = 0, 0, 0, False # bookkeeping
hx, cx = Variable(torch.zeros(1, 256)), Variable(torch.zeros(1, 256))
history = {"ins": [], "logits": [], "values": [], "outs": [], "hx": [], "cx": []}

state = torch.Tensor(prepro(env.reset())) # get first state
episode_length, epr, eploss, done = 0, 0, 0, False # bookkeeping
hx, cx = torch.zeros(1, 256), torch.zeros(1, 256)

while not done and episode_length <= max_ep_len:
episode_length += 1
value, logit, (hx, cx) = model((Variable(state.view(1,1,80,80)), (hx, cx)))
hx, cx = Variable(hx.data), Variable(cx.data)
model_inp = (state.view(1, 1, 80, 80), (hx, cx))
value, logit, (hx, cx) = model(model_inp)
hx, cx = hx.data, cx.data
prob = F.softmax(logit)

action = prob.max(1)[1].data # prob.multinomial().data[0] #
action = prob.max(1)[1].data # prob.multinomial().data[0] #
obs, reward, done, expert_policy = env.step(action.numpy()[0])
if render: env.render()
state = torch.Tensor(prepro(obs)) ; epr += reward
if render:
env.render()
state = torch.Tensor(prepro(obs))
epr += reward

# save info!
history['ins'].append(obs)
history['hx'].append(hx.squeeze(0).data.numpy())
history['cx'].append(cx.squeeze(0).data.numpy())
history['logits'].append(logit.data.numpy()[0])
history['values'].append(value.data.numpy()[0])
history['outs'].append(prob.data.numpy()[0])
print('\tstep # {}, reward {:.0f}'.format(episode_length, epr), end='\r')
history["ins"].append(obs)
history["hx"].append(hx.squeeze(0).data.numpy())
history["cx"].append(cx.squeeze(0).data.numpy())
history["logits"].append(logit.data.numpy()[0])
history["values"].append(value.data.numpy()[0])
history["outs"].append(prob.data.numpy()[0])
print("\tstep # {}, reward {:.0f}".format(episode_length, epr), end="\r")

return history
128 changes: 87 additions & 41 deletions saliency.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,122 @@
# Visualizing and Understanding Atari Agents | Sam Greydanus | 2017 | MIT License

from __future__ import print_function
import warnings ; warnings.filterwarnings('ignore') # mute warnings, live dangerously ;)
import warnings

warnings.filterwarnings("ignore") # mute warnings, live dangerously ;)

import torch
from torch.autograd import Variable
import torch.nn.functional as F

import numpy as np
from scipy.ndimage.filters import gaussian_filter
from scipy.misc import imresize # preserves single-pixel info _unlike_ img = img[::2,::2]
import cv2

# [210, 160, 3] -> [1, 80, 80]
prepro = (
lambda img: cv2.resize(src=img[35:195].mean(2), dsize=(80, 80))
.astype(np.float32)
.reshape(1, 80, 80)
/ 255.0
)
searchlight = lambda im, mask: im * mask + gaussian_filter(im, sigma=3) * (
1 - mask
) # choose an area NOT to blur
occlude = (
lambda im, mask: im * (1 - mask) + gaussian_filter(im, sigma=3) * mask
) # choose an area to blur

prepro = lambda img: imresize(img[35:195].mean(2), (80,80)).astype(np.float32).reshape(1,80,80)/255.
searchlight = lambda I, mask: I*mask + gaussian_filter(I, sigma=3)*(1-mask) # choose an area NOT to blur
occlude = lambda I, mask: I*(1-mask) + gaussian_filter(I, sigma=3)*mask # choose an area to blur

def get_mask(center, size, r):
y,x = np.ogrid[-center[0]:size[0]-center[0], -center[1]:size[1]-center[1]]
keep = x*x + y*y <= 1
mask = np.zeros(size) ; mask[keep] = 1 # select a circle of pixels
mask = gaussian_filter(mask, sigma=r) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1
return mask/mask.max()
y, x = np.ogrid[-center[0] : size[0] - center[0], -center[1] : size[1] - center[1]]
keep = x * x + y * y <= 1
mask = np.zeros(size)
mask[keep] = 1 # select a circle of pixels
mask = gaussian_filter(
mask, sigma=r
) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1
return mask / mask.max()

def run_through_model(model, history, ix, interp_func=None, mask=None, blur_memory=None, mode='actor'):

def run_through_model(
model,
history,
ix,
interp_func=None,
mask=None,
blur_memory=None,
mode="actor",
):
# [210, 160, 3] -> [1, 80, 80]
if mask is None:
im = prepro(history['ins'][ix])
im = prepro(history["ins"][ix])
else:
assert(interp_func is not None, "interp func cannot be none")
im = interp_func(prepro(history['ins'][ix]).squeeze(), mask).reshape(1,80,80) # perturb input I -> I'
assert interp_func is not None, "interp func cannot be none"
# [210, 160, 3] -> [1, 80, 80]
im = prepro(history["ins"][ix]).squeeze()
# -> [1, 80, 80]
im = interp_func(im, mask).reshape(1, 80, 80) # perturb input im -> im'
tens_state = torch.Tensor(im)
state = Variable(tens_state.unsqueeze(0), volatile=True)
hx = Variable(torch.Tensor(history['hx'][ix-1]).view(1,-1))
cx = Variable(torch.Tensor(history['cx'][ix-1]).view(1,-1))
if blur_memory is not None: cx.mul_(1-blur_memory) # perturb memory vector
return model((state, (hx, cx)))[0] if mode == 'critic' else model((state, (hx, cx)))[1]
state = tens_state.unsqueeze(0)
hx = torch.tensor(history["hx"][ix - 1]).view(1, -1)
cx = torch.tensor(history["cx"][ix - 1]).view(1, -1)
if blur_memory is not None:
cx.mul_(1 - blur_memory) # perturb memory vector
model_inp = (state, (hx, cx))
if mode == "critic":
return model(model_inp)[0]
else:
return model(model_inp)[1]

def score_frame(model, history, ix, r, d, interp_func, mode='actor'):

def score_frame(model, history, ix, r, d, interp_func, mode="actor"):
# r: radius of blur
# d: density of scores (if d==1, then get a score for every pixel...
# if d==2 then every other, which is 25% of total pixels for a 2D image)
assert mode in ['actor', 'critic'], 'mode must be either "actor" or "critic"'
assert mode in ["actor", "critic"], 'mode must be either "actor" or "critic"'
L = run_through_model(model, history, ix, interp_func, mask=None, mode=mode)
scores = np.zeros((int(80/d)+1,int(80/d)+1)) # saliency scores S(t,i,j)
for i in range(0,80,d):
for j in range(0,80,d):
mask = get_mask(center=[i,j], size=[80,80], r=r)
scores = np.zeros((int(80 / d) + 1, int(80 / d) + 1)) # saliency scores S(t,i,j)
for i in range(0, 80, d):
for j in range(0, 80, d):
mask = get_mask(center=[i, j], size=[80, 80], r=r)
l = run_through_model(model, history, ix, interp_func, mask=mask, mode=mode)
scores[int(i/d),int(j/d)] = (L-l).pow(2).sum().mul_(.5).data[0]
scores[int(i / d), int(j / d)] = (L - l).pow(2).sum().mul_(0.5).item()
pmax = scores.max()
scores = imresize(scores, size=[80,80], interp='bilinear').astype(np.float32)
return pmax * scores / scores.max()
scores = cv2.resize(
src=scores, dsize=(80, 80), interpolation=cv2.INTER_LINEAR
).astype(np.float32)
return scores


def saliency_on_atari_frame(saliency, atari, fudge_factor, channel=2, sigma=0):
# sometimes saliency maps are a bit clearer if you blur them
# slightly...sigma adjusts the radius of that blur
pmax = saliency.max()
S = imresize(saliency, size=[160,160], interp='bilinear').astype(np.float32)
S = cv2.resize(
src=saliency, dsize=(160, 160), interpolation=cv2.INTER_LINEAR
).astype(np.float32)
S = S if sigma == 0 else gaussian_filter(S, sigma=sigma)
S -= S.min() ; S = fudge_factor*pmax * S / S.max()
I = atari.astype('uint16')
I[35:195,:,channel] += S.astype('uint16')
I = I.clip(1,255).astype('uint8')
return I
S -= S.min()
# S = fudge_factor * pmax * S / S.max()
S = fudge_factor * S
im = atari.astype("uint16")
im[35:195, :, channel] += S.astype("uint16")
im = im.clip(1, 255).astype("uint8")
return im


def get_env_meta(env_name):
meta = {}
if env_name=="Pong-v0":
meta['critic_ff'] = 600 ; meta['actor_ff'] = 500
elif env_name=="Breakout-v0":
meta['critic_ff'] = 600 ; meta['actor_ff'] = 300
elif env_name=="SpaceInvaders-v0":
meta['critic_ff'] = 400 ; meta['actor_ff'] = 400
if env_name == "Pong-v0":
meta["critic_ff"] = 600
meta["actor_ff"] = 500
elif env_name == "Breakout-v0":
meta["critic_ff"] = 600
meta["actor_ff"] = 300
elif env_name == "SpaceInvaders-v0":
meta["critic_ff"] = 400
meta["actor_ff"] = 400
else:
print('environment "{}" not supported'.format(env_name))
return meta
return meta