Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Week 08] Update & clean up the notebook #522

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 51 additions & 35 deletions week08_pomdp/practice_pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"if 'google.colab' in sys.modules:\n",
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/atari_util.py\n",
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/env_pool.py\n",
"\n",
" !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
" !touch .setup_complete\n",
"# If you are running on a server, launch xvfb to record game videos\n",
"# Please make sure you have xvfb installed\n",
"import os\n",
"if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
" !bash ../xvfb start\n",
" os.environ['DISPLAY'] = ':1'"
"import sys, os\n",
"if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
" # Install xvfb and our launcher script for it\n",
" !apt-get install -y xvfb\n",
" !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/xvfb -O ../xvfb\n",
"\n",
" !pip install gym[atari,accept-rom-license]\n",
"\n",
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/atari_util.py\n",
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/env_pool.py\n",
"\n",
" !touch .setup_complete\n",
"\n",
"# This code creates a virtual display to draw game images on.\n",
"# It will have no effect if your machine has a monitor.\n",
"import os\n",
"if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
" !bash ../xvfb start\n",
" os.environ['DISPLAY'] = ':1'"
]
},
{
Expand Down Expand Up @@ -53,7 +59,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
"Observation shape: (1, 42, 42)\n",
"Num actions: 14\n",
"Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']\n"
Expand All @@ -70,6 +75,7 @@
" env = PreprocessAtari(env, height=42, width=42,\n",
" crop=lambda img: img[60:-30, 15:],\n",
" color=False, n_frames=1)\n",
" env.metadata['render_fps'] = 30\n",
" return env\n",
"\n",
"\n",
Expand Down Expand Up @@ -204,13 +210,15 @@
" return new_state, (logits, state_value)\n",
"\n",
" def get_initial_state(self, batch_size):\n",
" \"\"\"Return a list of agent memory states at game start. Each state is a np array of shape [batch_size, ...]\"\"\"\n",
" return torch.zeros((batch_size, 128)), torch.zeros((batch_size, 128))\n",
" \"\"\"Return the agent memory state at the beginning of the game. Each state is a np array of shape [batch_size, ...]\"\"\"\n",
" h0 = torch.zeros((batch_size, 128))\n",
" c0 = torch.zeros((batch_size, 128))\n",
" return h0, c0\n",
"\n",
" def sample_actions(self, agent_outputs):\n",
" \"\"\"pick actions given numeric agent outputs (np arrays)\"\"\"\n",
" logits, state_values = agent_outputs\n",
" probs = F.softmax(logits)\n",
" probs = F.softmax(logits, dim=-1)\n",
" return torch.multinomial(probs, 1)[:, 0].data.numpy()\n",
"\n",
" def step(self, prev_state, obs_t):\n",
Expand Down Expand Up @@ -258,11 +266,13 @@
"metadata": {},
"outputs": [],
"source": [
"import tqdm\n",
"\n",
"def evaluate(agent, env, n_games=1):\n",
" \"\"\"Plays an entire game start to end, returns session rewards.\"\"\"\n",
"\n",
" game_rewards = []\n",
" for _ in range(n_games):\n",
" for _ in tqdm.notebook.trange(n_games):\n",
" # initial observation and memory\n",
" observation = env.reset()\n",
" prev_memories = agent.get_initial_state(1)\n",
Expand Down Expand Up @@ -292,7 +302,7 @@
"source": [
"import gym.wrappers\n",
"\n",
"with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n",
"with gym.wrappers.RecordVideo(make_env(), video_folder=\"videos\") as env_monitor:\n",
" rewards = evaluate(agent, env_monitor, n_games=3)\n",
"\n",
"print(rewards)"
Expand Down Expand Up @@ -446,7 +456,7 @@
"source": [
"def to_one_hot(y, n_dims=None):\n",
" \"\"\" Take an integer tensor and convert it to 1-hot matrix. \"\"\"\n",
" y_tensor = y.to(dtype=torch.int64).view(-1, 1)\n",
" y_tensor = y.to(dtype=torch.int64).reshape(-1, 1)\n",
" n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1\n",
" y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)\n",
" return y_one_hot"
Expand All @@ -472,7 +482,7 @@
" states = torch.tensor(np.asarray(states), dtype=torch.float32)\n",
" actions = torch.tensor(np.array(actions), dtype=torch.int64) # shape: [batch_size, time]\n",
" rewards = torch.tensor(np.array(rewards), dtype=torch.float32) # shape: [batch_size, time]\n",
" is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.float32) # shape: [batch_size, time]\n",
" is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.bool) # shape: [batch_size, time]\n",
" rollout_length = rewards.shape[1] - 1\n",
"\n",
" # predict logits, probas and log-probas using an agent.\n",
Expand All @@ -483,7 +493,7 @@
" for t in range(rewards.shape[1]):\n",
" obs_t = states[:, t]\n",
"\n",
" # use agent to comute logits_t and state values_t.\n",
" # use agent to compute logits_t and state values_t.\n",
" # append them to logits and state_values array\n",
"\n",
" memory, (logits_t, values_t) = <YOUR CODE>\n",
Expand Down Expand Up @@ -521,9 +531,10 @@
" V_next = state_values[:, t + 1].detach() # next state values\n",
" # log-probability of a_t in s_t\n",
" logpi_a_s_t = logprobas_for_actions[:, t]\n",
" is_not_done_t = is_not_done[:, t]\n",
"\n",
" # update G_t = r_t + gamma * G_{t+1} as we did in week6 reinforce\n",
" cumulative_returns = G_t = r_t + gamma * cumulative_returns\n",
" cumulative_returns = G_t = r_t + torch.where(is_not_done_t, gamma * cumulative_returns, 0)\n",
"\n",
" # Compute temporal difference error (MSE for V(s))\n",
" value_loss += <YOUR CODE>\n",
Expand Down Expand Up @@ -579,7 +590,6 @@
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"from tqdm import trange\n",
"from pandas import DataFrame\n",
"moving_average = lambda x, **kw: DataFrame(\n",
" {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n",
Expand All @@ -593,21 +603,27 @@
"metadata": {},
"outputs": [],
"source": [
"for i in trange(15000):\n",
"log_every = 100\n",
"\n",
"for i in tqdm.trange(15000):\n",
" # tqdm.notebook.tqdm is not trivial to use here because clear_output(True)\n",
" # also removes the tqdm widget\n",
"\n",
" memory = list(pool.prev_memory_states)\n",
" rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(\n",
" 10)\n",
" train_on_rollout(rollout_obs, rollout_actions,\n",
" rollout_rewards, rollout_mask, memory)\n",
" rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)\n",
" train_on_rollout(rollout_obs, rollout_actions, rollout_rewards, rollout_mask, memory)\n",
"\n",
" if i % 100 == 0:\n",
" if i % log_every == 0:\n",
" rewards_history.append(np.mean(evaluate(agent, env, n_games=1)))\n",
" clear_output(True)\n",
" plt.plot(rewards_history, label='rewards')\n",
" plt.plot(moving_average(np.array(rewards_history),\n",
" span=10), label='rewards ewma@10')\n",
" plt.plot(\n",
" np.arange(len(rewards_history)) * log_every,\n",
" rewards_history, label='rewards')\n",
" plt.plot(\n",
" np.arange(len(rewards_history)) * log_every,\n",
" moving_average(np.array(rewards_history), span=10), label='rewards ewma@10')\n",
" plt.legend()\n",
" plt.grid()\n",
" plt.show()\n",
" if rewards_history[-1] >= 10000:\n",
" print(\"Your agent has just passed the minimum homework threshold\")\n",
Expand All @@ -628,7 +644,7 @@
"Since we use a policy-based method, we also keep track of __policy entropy__ - the same one you used as a regularizer. The only important thing about it is that your entropy shouldn't drop too low (`< 0.1`) before your agent gets the yellow belt. Or at least it can drop there, but _it shouldn't stay there for long_.\n",
"\n",
"If it does, the culprit is likely:\n",
"* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot log p(a_i) $\n",
"* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot \\log p(a_i) $\n",
"* Your agent architecture converges too fast. Increase entropy coefficient in actor loss. \n",
"* Gradient explosion - just [clip gradients](https://stackoverflow.com/a/56069467) and maybe use a smaller network\n",
"* Us. Or PyTorch developers. Or aliens. Or lizardfolk. Contact us on forums before it's too late!\n",
Expand All @@ -651,7 +667,7 @@
"source": [
"import gym.wrappers\n",
"\n",
"with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n",
"with gym.wrappers.RecordVideo(make_env(), video_folder=\"videos\") as env_monitor:\n",
" final_rewards = evaluate(agent, env_monitor, n_games=20)\n",
"\n",
"print(\"Final mean reward\", np.mean(final_rewards))"
Expand Down