|
20 | 20 | "import sys, os\n",
|
21 | 21 | "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
|
22 | 22 | " !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
|
23 |
| - " !pip install -q gym[classic_control]==0.18.0\n", |
| 23 | + " !pip install -q gym[classic_control]==0.26.0\n", |
| 24 | + " !pip install moviepy\n", |
| 25 | + " !apt install ffmpeg\n", |
| 26 | + " !pip install imageio-ffmpeg\n", |
24 | 27 | " !touch .setup_complete\n",
|
25 | 28 | "\n",
|
26 | 29 | "# This code creates a virtual display to draw game images on.\n",
|
|
37 | 40 | "outputs": [],
|
38 | 41 | "source": [
|
39 | 42 | "import gym\n",
|
| 43 | + "print(f'gym version: {gym.__version__}')\n", |
| 44 | + "assert tuple(map(int, gym.__version__.split('.')))[:2] >= (0, 26)\n", |
40 | 45 | "import numpy as np\n",
|
41 | 46 | "import matplotlib.pyplot as plt\n",
|
42 | 47 | "%matplotlib inline"
|
43 | 48 | ]
|
44 | 49 | },
|
| 50 | + { |
| 51 | + "cell_type": "code", |
| 52 | + "execution_count": null, |
| 53 | + "metadata": {}, |
| 54 | + "outputs": [], |
| 55 | + "source": [ |
| 56 | + "# also you need to install ffmpeg if not installed\n", |
| 57 | + "# for MacOS: ! brew install ffmpeg" |
| 58 | + ] |
| 59 | + }, |
45 | 60 | {
|
46 | 61 | "cell_type": "markdown",
|
47 | 62 | "metadata": {},
|
|
55 | 70 | "metadata": {},
|
56 | 71 | "outputs": [],
|
57 | 72 | "source": [
|
58 |
| - "env = gym.make(\"CartPole-v0\")\n", |
| 73 | + "env = gym.make(\"CartPole-v0\", render_mode = \"rgb_array\")\n", |
59 | 74 | "\n",
|
60 | 75 | "# gym compatibility: unwrap TimeLimit\n",
|
61 | 76 | "if hasattr(env, '_max_episode_steps'):\n",
|
|
65 | 80 | "n_actions = env.action_space.n\n",
|
66 | 81 | "state_dim = env.observation_space.shape\n",
|
67 | 82 | "\n",
|
68 |
| - "plt.imshow(env.render(\"rgb_array\"))" |
| 83 | + "plt.imshow(env.render())" |
69 | 84 | ]
|
70 | 85 | },
|
71 | 86 | {
|
|
156 | 171 | "metadata": {},
|
157 | 172 | "outputs": [],
|
158 | 173 | "source": [
|
159 |
| - "test_states = np.array([env.reset() for _ in range(5)])\n", |
| 174 | + "test_states = np.array([env.reset()[0] for _ in range(5)])\n", |
160 | 175 | "test_probas = predict_probs(test_states)\n",
|
161 | 176 | "assert isinstance(test_probas, np.ndarray), \\\n",
|
162 | 177 | " \"you must return np array and not %s\" % type(test_probas)\n",
|
|
180 | 195 | "metadata": {},
|
181 | 196 | "outputs": [],
|
182 | 197 | "source": [
|
183 |
| - "def generate_session(env, t_max=1000):\n", |
| 198 | + "def generate_session(env, t_max=1000, video_file = None):\n", |
184 | 199 | " \"\"\" \n",
|
185 | 200 | " Play a full session with REINFORCE agent.\n",
|
186 | 201 | " Returns sequences of states, actions, and rewards.\n",
|
187 | 202 | " \"\"\"\n",
|
188 | 203 | " # arrays to record session\n",
|
189 | 204 | " states, actions, rewards = [], [], []\n",
|
190 |
| - " s = env.reset()\n", |
| 205 | + " \n", |
| 206 | + " s = env.reset()[0]\n", |
| 207 | + " \n", |
| 208 | + " if video_file is not None:\n", |
| 209 | + " video = VideoRecorder(env, video_file)\n", |
191 | 210 | "\n",
|
192 | 211 | " for t in range(t_max):\n",
|
193 | 212 | " # action probabilities array aka pi(a|s)\n",
|
194 | 213 | " action_probs = predict_probs(np.array([s]))[0]\n",
|
195 | 214 | "\n",
|
196 | 215 | " # Sample action with given probabilities.\n",
|
197 | 216 | " a = <YOUR CODE>\n",
|
198 |
| - " new_s, r, done, info = env.step(a)\n", |
| 217 | + " \n", |
| 218 | + " if video_file is not None:\n", |
| 219 | + " env.render()\n", |
| 220 | + " video.capture_frame()\n", |
| 221 | + " \n", |
| 222 | + " new_s, r, done, truncated, info = env.step(a)\n", |
199 | 223 | "\n",
|
200 | 224 | " # record session history to train later\n",
|
201 | 225 | " states.append(s)\n",
|
|
205 | 229 | " s = new_s\n",
|
206 | 230 | " if done:\n",
|
207 | 231 | " break\n",
|
| 232 | + " \n", |
| 233 | + " if video_file is not None:\n", |
| 234 | + " video.close()\n", |
208 | 235 | "\n",
|
209 | 236 | " return states, actions, rewards"
|
210 | 237 | ]
|
|
399 | 426 | "source": [
|
400 | 427 | "# Record sessions\n",
|
401 | 428 | "\n",
|
402 |
| - "import gym.wrappers\n", |
| 429 | + "from gym.wrappers.monitoring.video_recorder import VideoRecorder\n", |
| 430 | + " \n", |
| 431 | + "if not os.path.exists('videos'):\n", |
| 432 | + " os.makedirs('videos')\n", |
403 | 433 | "\n",
|
404 |
| - "with gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True) as env_monitor:\n", |
405 |
| - " sessions = [generate_session(env_monitor) for _ in range(100)]" |
| 434 | + "sessions = [generate_session(gym.make(\"CartPole-v0\", render_mode = \"rgb_array\"),\n", |
| 435 | + " video_file = f'videos/CartPole_session_{i}.mp4')\n", |
| 436 | + " for i in range(10)]" |
406 | 437 | ]
|
407 | 438 | },
|
408 | 439 | {
|
|
438 | 469 | }
|
439 | 470 | ],
|
440 | 471 | "metadata": {
|
| 472 | + "kernelspec": { |
| 473 | + "display_name": "Python 3 (ipykernel)", |
| 474 | + "language": "python", |
| 475 | + "name": "python3" |
| 476 | + }, |
441 | 477 | "language_info": {
|
| 478 | + "codemirror_mode": { |
| 479 | + "name": "ipython", |
| 480 | + "version": 3 |
| 481 | + }, |
| 482 | + "file_extension": ".py", |
| 483 | + "mimetype": "text/x-python", |
442 | 484 | "name": "python",
|
443 |
| - "pygments_lexer": "ipython3" |
| 485 | + "nbconvert_exporter": "python", |
| 486 | + "pygments_lexer": "ipython3", |
| 487 | + "version": "3.9.16" |
444 | 488 | }
|
445 | 489 | },
|
446 | 490 | "nbformat": 4,
|
|
0 commit comments