Skip to content

Commit

Permalink
Support gym 0.26 in week6 seminar (#519)
Browse files Browse the repository at this point in the history
* Support gym 0.26

There are changes in gym API between 0.25 and 0.26
These changes break backward compatibility.

In this commit we updated the code of week6 seminar to use gym 0.26

What is changed:
- env initializer;
- env.render;
- env.reset;
- env.step;
- video recording

Note: didn't test in Colab, only locally

* week6 seminar: Uppdate video dependencies for Colab

* week6 seminar: minor fix in installing
  • Loading branch information
q0o0p authored Feb 22, 2023
1 parent 9e0017f commit 9137d39
Showing 1 changed file with 55 additions and 11 deletions.
66 changes: 55 additions & 11 deletions week06_policy_based/reinforce_pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
"import sys, os\n",
"if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
" !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
" !pip install -q gym[classic_control]==0.18.0\n",
" !pip install -q gym[classic_control]==0.26.0\n",
" !pip install moviepy\n",
" !apt install ffmpeg\n",
" !pip install imageio-ffmpeg\n",
" !touch .setup_complete\n",
"\n",
"# This code creates a virtual display to draw game images on.\n",
Expand All @@ -37,11 +40,23 @@
"outputs": [],
"source": [
"import gym\n",
"print(f'gym version: {gym.__version__}')\n",
"assert tuple(map(int, gym.__version__.split('.')))[:2] >= (0, 26)\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# also you need to install ffmpeg if not installed\n",
"# for MacOS: ! brew install ffmpeg"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -55,7 +70,7 @@
"metadata": {},
"outputs": [],
"source": [
"env = gym.make(\"CartPole-v0\")\n",
"env = gym.make(\"CartPole-v0\", render_mode = \"rgb_array\")\n",
"\n",
"# gym compatibility: unwrap TimeLimit\n",
"if hasattr(env, '_max_episode_steps'):\n",
Expand All @@ -65,7 +80,7 @@
"n_actions = env.action_space.n\n",
"state_dim = env.observation_space.shape\n",
"\n",
"plt.imshow(env.render(\"rgb_array\"))"
"plt.imshow(env.render())"
]
},
{
Expand Down Expand Up @@ -156,7 +171,7 @@
"metadata": {},
"outputs": [],
"source": [
"test_states = np.array([env.reset() for _ in range(5)])\n",
"test_states = np.array([env.reset()[0] for _ in range(5)])\n",
"test_probas = predict_probs(test_states)\n",
"assert isinstance(test_probas, np.ndarray), \\\n",
" \"you must return np array and not %s\" % type(test_probas)\n",
Expand All @@ -180,22 +195,31 @@
"metadata": {},
"outputs": [],
"source": [
"def generate_session(env, t_max=1000):\n",
"def generate_session(env, t_max=1000, video_file = None):\n",
" \"\"\" \n",
" Play a full session with REINFORCE agent.\n",
" Returns sequences of states, actions, and rewards.\n",
" \"\"\"\n",
" # arrays to record session\n",
" states, actions, rewards = [], [], []\n",
" s = env.reset()\n",
" \n",
" s = env.reset()[0]\n",
" \n",
" if video_file is not None:\n",
" video = VideoRecorder(env, video_file)\n",
"\n",
" for t in range(t_max):\n",
" # action probabilities array aka pi(a|s)\n",
" action_probs = predict_probs(np.array([s]))[0]\n",
"\n",
" # Sample action with given probabilities.\n",
" a = <YOUR CODE>\n",
" new_s, r, done, info = env.step(a)\n",
" \n",
" if video_file is not None:\n",
" env.render()\n",
" video.capture_frame()\n",
" \n",
" new_s, r, done, truncated, info = env.step(a)\n",
"\n",
" # record session history to train later\n",
" states.append(s)\n",
Expand All @@ -205,6 +229,9 @@
" s = new_s\n",
" if done:\n",
" break\n",
" \n",
" if video_file is not None:\n",
" video.close()\n",
"\n",
" return states, actions, rewards"
]
Expand Down Expand Up @@ -399,10 +426,14 @@
"source": [
"# Record sessions\n",
"\n",
"import gym.wrappers\n",
"from gym.wrappers.monitoring.video_recorder import VideoRecorder\n",
" \n",
"if not os.path.exists('videos'):\n",
" os.makedirs('videos')\n",
"\n",
"with gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True) as env_monitor:\n",
" sessions = [generate_session(env_monitor) for _ in range(100)]"
"sessions = [generate_session(gym.make(\"CartPole-v0\", render_mode = \"rgb_array\"),\n",
" video_file = f'videos/CartPole_session_{i}.mp4')\n",
" for i in range(10)]"
]
},
{
Expand Down Expand Up @@ -438,9 +469,22 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"pygments_lexer": "ipython3"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 9137d39

Please sign in to comment.