Skip to content

Commit 9137d39

Browse files
authored
Support gym 0.26 in week6 seminar (#519)
* Support gym 0.26 There are changes in gym API between 0.25 and 0.26 These changes break backward compatibility. In this commit we updated the code of week6 seminar to use gym 0.26 What is changed: - env initializer; - env.render; - env.reset; - env.step; - video recording Note: didn't test in Colab, only locally * week6 seminar: Uppdate video dependencies for Colab * week6 seminar: minor fix in installing
1 parent 9e0017f commit 9137d39

File tree

1 file changed

+55
-11
lines changed

1 file changed

+55
-11
lines changed

week06_policy_based/reinforce_pytorch.ipynb

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
"import sys, os\n",
2121
"if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
2222
" !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
23-
" !pip install -q gym[classic_control]==0.18.0\n",
23+
" !pip install -q gym[classic_control]==0.26.0\n",
24+
" !pip install moviepy\n",
25+
" !apt install ffmpeg\n",
26+
" !pip install imageio-ffmpeg\n",
2427
" !touch .setup_complete\n",
2528
"\n",
2629
"# This code creates a virtual display to draw game images on.\n",
@@ -37,11 +40,23 @@
3740
"outputs": [],
3841
"source": [
3942
"import gym\n",
43+
"print(f'gym version: {gym.__version__}')\n",
44+
"assert tuple(map(int, gym.__version__.split('.')))[:2] >= (0, 26)\n",
4045
"import numpy as np\n",
4146
"import matplotlib.pyplot as plt\n",
4247
"%matplotlib inline"
4348
]
4449
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"# also you need to install ffmpeg if not installed\n",
57+
"# for MacOS: ! brew install ffmpeg"
58+
]
59+
},
4560
{
4661
"cell_type": "markdown",
4762
"metadata": {},
@@ -55,7 +70,7 @@
5570
"metadata": {},
5671
"outputs": [],
5772
"source": [
58-
"env = gym.make(\"CartPole-v0\")\n",
73+
"env = gym.make(\"CartPole-v0\", render_mode = \"rgb_array\")\n",
5974
"\n",
6075
"# gym compatibility: unwrap TimeLimit\n",
6176
"if hasattr(env, '_max_episode_steps'):\n",
@@ -65,7 +80,7 @@
6580
"n_actions = env.action_space.n\n",
6681
"state_dim = env.observation_space.shape\n",
6782
"\n",
68-
"plt.imshow(env.render(\"rgb_array\"))"
83+
"plt.imshow(env.render())"
6984
]
7085
},
7186
{
@@ -156,7 +171,7 @@
156171
"metadata": {},
157172
"outputs": [],
158173
"source": [
159-
"test_states = np.array([env.reset() for _ in range(5)])\n",
174+
"test_states = np.array([env.reset()[0] for _ in range(5)])\n",
160175
"test_probas = predict_probs(test_states)\n",
161176
"assert isinstance(test_probas, np.ndarray), \\\n",
162177
" \"you must return np array and not %s\" % type(test_probas)\n",
@@ -180,22 +195,31 @@
180195
"metadata": {},
181196
"outputs": [],
182197
"source": [
183-
"def generate_session(env, t_max=1000):\n",
198+
"def generate_session(env, t_max=1000, video_file = None):\n",
184199
" \"\"\" \n",
185200
" Play a full session with REINFORCE agent.\n",
186201
" Returns sequences of states, actions, and rewards.\n",
187202
" \"\"\"\n",
188203
" # arrays to record session\n",
189204
" states, actions, rewards = [], [], []\n",
190-
" s = env.reset()\n",
205+
" \n",
206+
" s = env.reset()[0]\n",
207+
" \n",
208+
" if video_file is not None:\n",
209+
" video = VideoRecorder(env, video_file)\n",
191210
"\n",
192211
" for t in range(t_max):\n",
193212
" # action probabilities array aka pi(a|s)\n",
194213
" action_probs = predict_probs(np.array([s]))[0]\n",
195214
"\n",
196215
" # Sample action with given probabilities.\n",
197216
" a = <YOUR CODE>\n",
198-
" new_s, r, done, info = env.step(a)\n",
217+
" \n",
218+
" if video_file is not None:\n",
219+
" env.render()\n",
220+
" video.capture_frame()\n",
221+
" \n",
222+
" new_s, r, done, truncated, info = env.step(a)\n",
199223
"\n",
200224
" # record session history to train later\n",
201225
" states.append(s)\n",
@@ -205,6 +229,9 @@
205229
" s = new_s\n",
206230
" if done:\n",
207231
" break\n",
232+
" \n",
233+
" if video_file is not None:\n",
234+
" video.close()\n",
208235
"\n",
209236
" return states, actions, rewards"
210237
]
@@ -399,10 +426,14 @@
399426
"source": [
400427
"# Record sessions\n",
401428
"\n",
402-
"import gym.wrappers\n",
429+
"from gym.wrappers.monitoring.video_recorder import VideoRecorder\n",
430+
" \n",
431+
"if not os.path.exists('videos'):\n",
432+
" os.makedirs('videos')\n",
403433
"\n",
404-
"with gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True) as env_monitor:\n",
405-
" sessions = [generate_session(env_monitor) for _ in range(100)]"
434+
"sessions = [generate_session(gym.make(\"CartPole-v0\", render_mode = \"rgb_array\"),\n",
435+
" video_file = f'videos/CartPole_session_{i}.mp4')\n",
436+
" for i in range(10)]"
406437
]
407438
},
408439
{
@@ -438,9 +469,22 @@
438469
}
439470
],
440471
"metadata": {
472+
"kernelspec": {
473+
"display_name": "Python 3 (ipykernel)",
474+
"language": "python",
475+
"name": "python3"
476+
},
441477
"language_info": {
478+
"codemirror_mode": {
479+
"name": "ipython",
480+
"version": 3
481+
},
482+
"file_extension": ".py",
483+
"mimetype": "text/x-python",
442484
"name": "python",
443-
"pygments_lexer": "ipython3"
485+
"nbconvert_exporter": "python",
486+
"pygments_lexer": "ipython3",
487+
"version": "3.9.16"
444488
}
445489
},
446490
"nbformat": 4,

0 commit comments

Comments
 (0)