From c2238f0485e8be7ac4c65e1c614b297b30c542e8 Mon Sep 17 00:00:00 2001 From: Lionel Miller Date: Sun, 2 Apr 2023 19:55:15 +0100 Subject: [PATCH] [Week 08] Update & clean up the notebook --- week08_pomdp/practice_pytorch.ipynb | 86 +++++++++++++++++------------ 1 file changed, 51 insertions(+), 35 deletions(-) diff --git a/week08_pomdp/practice_pytorch.ipynb b/week08_pomdp/practice_pytorch.ipynb index b37cdc06..8ce4e85a 100644 --- a/week08_pomdp/practice_pytorch.ipynb +++ b/week08_pomdp/practice_pytorch.ipynb @@ -6,19 +6,25 @@ "metadata": {}, "outputs": [], "source": [ - "import sys\n", - "if 'google.colab' in sys.modules:\n", - " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/atari_util.py\n", - " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/env_pool.py\n", - "\n", - " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n", - " !touch .setup_complete\n", - "# If you are running on a server, launch xvfb to record game videos\n", - "# Please make sure you have xvfb installed\n", - "import os\n", - "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", - " !bash ../xvfb start\n", - " os.environ['DISPLAY'] = ':1'" + "import sys, os\n", + "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n", + " # Install xvfb and our launcher script for it\n", + " !apt-get install -y xvfb\n", + " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/xvfb -O ../xvfb\n", + "\n", + " !pip install gym[atari,accept-rom-license]\n", + "\n", + " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/atari_util.py\n", + " !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/env_pool.py\n", + "\n", + " !touch .setup_complete\n", + "\n", + "# This code creates a virtual display to draw game images on.\n", + "# It will have no effect if your machine has a monitor.\n", + "import os\n", + "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n", + " !bash ../xvfb start\n", + " os.environ['DISPLAY'] = ':1'" ] }, { @@ -53,7 +59,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "Observation shape: (1, 42, 42)\n", "Num actions: 14\n", "Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']\n" @@ -70,6 +75,7 @@ " env = PreprocessAtari(env, height=42, width=42,\n", " crop=lambda img: img[60:-30, 15:],\n", " color=False, n_frames=1)\n", + " env.metadata['render_fps'] = 30\n", " return env\n", "\n", "\n", @@ -204,13 +210,15 @@ " return new_state, (logits, state_value)\n", "\n", " def get_initial_state(self, batch_size):\n", - " \"\"\"Return a list of agent memory states at game start. Each state is a np array of shape [batch_size, ...]\"\"\"\n", - " return torch.zeros((batch_size, 128)), torch.zeros((batch_size, 128))\n", + " \"\"\"Return the agent memory state at the beginning of the game. Each state is a np array of shape [batch_size, ...]\"\"\"\n", + " h0 = torch.zeros((batch_size, 128))\n", + " c0 = torch.zeros((batch_size, 128))\n", + " return h0, c0\n", "\n", " def sample_actions(self, agent_outputs):\n", " \"\"\"pick actions given numeric agent outputs (np arrays)\"\"\"\n", " logits, state_values = agent_outputs\n", - " probs = F.softmax(logits)\n", + " probs = F.softmax(logits, dim=-1)\n", " return torch.multinomial(probs, 1)[:, 0].data.numpy()\n", "\n", " def step(self, prev_state, obs_t):\n", @@ -258,11 +266,13 @@ "metadata": {}, "outputs": [], "source": [ + "import tqdm\n", + "\n", "def evaluate(agent, env, n_games=1):\n", " \"\"\"Plays an entire game start to end, returns session rewards.\"\"\"\n", "\n", " game_rewards = []\n", - " for _ in range(n_games):\n", + " for _ in tqdm.notebook.trange(n_games):\n", " # initial observation and memory\n", " observation = env.reset()\n", " prev_memories = agent.get_initial_state(1)\n", @@ -292,7 +302,7 @@ "source": [ "import gym.wrappers\n", "\n", - "with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n", + "with gym.wrappers.RecordVideo(make_env(), video_folder=\"videos\") as env_monitor:\n", " rewards = evaluate(agent, env_monitor, n_games=3)\n", "\n", "print(rewards)" @@ -446,7 +456,7 @@ "source": [ "def to_one_hot(y, n_dims=None):\n", " \"\"\" Take an integer tensor and convert it to 1-hot matrix. \"\"\"\n", - " y_tensor = y.to(dtype=torch.int64).view(-1, 1)\n", + " y_tensor = y.to(dtype=torch.int64).reshape(-1, 1)\n", " n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1\n", " y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)\n", " return y_one_hot" @@ -472,7 +482,7 @@ " states = torch.tensor(np.asarray(states), dtype=torch.float32)\n", " actions = torch.tensor(np.array(actions), dtype=torch.int64) # shape: [batch_size, time]\n", " rewards = torch.tensor(np.array(rewards), dtype=torch.float32) # shape: [batch_size, time]\n", - " is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.float32) # shape: [batch_size, time]\n", + " is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.bool) # shape: [batch_size, time]\n", " rollout_length = rewards.shape[1] - 1\n", "\n", " # predict logits, probas and log-probas using an agent.\n", @@ -483,7 +493,7 @@ " for t in range(rewards.shape[1]):\n", " obs_t = states[:, t]\n", "\n", - " # use agent to comute logits_t and state values_t.\n", + " # use agent to compute logits_t and state values_t.\n", " # append them to logits and state_values array\n", "\n", " memory, (logits_t, values_t) = \n", @@ -521,9 +531,10 @@ " V_next = state_values[:, t + 1].detach() # next state values\n", " # log-probability of a_t in s_t\n", " logpi_a_s_t = logprobas_for_actions[:, t]\n", + " is_not_done_t = is_not_done[:, t]\n", "\n", " # update G_t = r_t + gamma * G_{t+1} as we did in week6 reinforce\n", - " cumulative_returns = G_t = r_t + gamma * cumulative_returns\n", + " cumulative_returns = G_t = r_t + torch.where(is_not_done_t, gamma * cumulative_returns, 0)\n", "\n", " # Compute temporal difference error (MSE for V(s))\n", " value_loss += \n", @@ -579,7 +590,6 @@ "outputs": [], "source": [ "from IPython.display import clear_output\n", - "from tqdm import trange\n", "from pandas import DataFrame\n", "moving_average = lambda x, **kw: DataFrame(\n", " {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n", @@ -593,21 +603,27 @@ "metadata": {}, "outputs": [], "source": [ - "for i in trange(15000):\n", + "log_every = 100\n", + "\n", + "for i in tqdm.trange(15000):\n", + " # tqdm.notebook.tqdm is not trivial to use here because clear_output(True)\n", + " # also removes the tqdm widget\n", "\n", " memory = list(pool.prev_memory_states)\n", - " rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(\n", - " 10)\n", - " train_on_rollout(rollout_obs, rollout_actions,\n", - " rollout_rewards, rollout_mask, memory)\n", + " rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)\n", + " train_on_rollout(rollout_obs, rollout_actions, rollout_rewards, rollout_mask, memory)\n", "\n", - " if i % 100 == 0:\n", + " if i % log_every == 0:\n", " rewards_history.append(np.mean(evaluate(agent, env, n_games=1)))\n", " clear_output(True)\n", - " plt.plot(rewards_history, label='rewards')\n", - " plt.plot(moving_average(np.array(rewards_history),\n", - " span=10), label='rewards ewma@10')\n", + " plt.plot(\n", + " np.arange(len(rewards_history)) * log_every,\n", + " rewards_history, label='rewards')\n", + " plt.plot(\n", + " np.arange(len(rewards_history)) * log_every,\n", + " moving_average(np.array(rewards_history), span=10), label='rewards ewma@10')\n", " plt.legend()\n", + " plt.grid()\n", " plt.show()\n", " if rewards_history[-1] >= 10000:\n", " print(\"Your agent has just passed the minimum homework threshold\")\n", @@ -628,7 +644,7 @@ "Since we use a policy-based method, we also keep track of __policy entropy__ - the same one you used as a regularizer. The only important thing about it is that your entropy shouldn't drop too low (`< 0.1`) before your agent gets the yellow belt. Or at least it can drop there, but _it shouldn't stay there for long_.\n", "\n", "If it does, the culprit is likely:\n", - "* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot log p(a_i) $\n", + "* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot \\log p(a_i) $\n", "* Your agent architecture converges too fast. Increase entropy coefficient in actor loss. \n", "* Gradient explosion - just [clip gradients](https://stackoverflow.com/a/56069467) and maybe use a smaller network\n", "* Us. Or PyTorch developers. Or aliens. Or lizardfolk. Contact us on forums before it's too late!\n", @@ -651,7 +667,7 @@ "source": [ "import gym.wrappers\n", "\n", - "with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n", + "with gym.wrappers.RecordVideo(make_env(), video_folder=\"videos\") as env_monitor:\n", " final_rewards = evaluate(agent, env_monitor, n_games=20)\n", "\n", "print(\"Final mean reward\", np.mean(final_rewards))"