Skip to content

Commit

Permalink
Fix notebook.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716037590
Change-Id: I28a90b1c0a048941ff26dc17d8d51d4955ef6508
  • Loading branch information
btaba authored and copybara-github committed Jan 16, 2025
1 parent beca9ce commit 806be2e
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions learning/notebooks/locomotion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
"id": "_UbO9uhtBSX5"
},
"source": [
"\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eCopyright 2025 DeepMind Technologies Limited.\u003c/small\u003e\u003c/p\u003e\n",
"\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at \u003ca href=\"http://www.apache.org/licenses/LICENSE-2.0\"\u003ehttp://www.apache.org/licenses/LICENSE-2.0\u003c/a\u003e.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e\n",
"\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e"
"> <p><small><small>Copyright 2025 DeepMind Technologies Limited.</small></p>\n",
"> <p><small><small>Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at <a href=\"http://www.apache.org/licenses/LICENSE-2.0\">http://www.apache.org/licenses/LICENSE-2.0</a>.</small></small></p>\n",
"> <p><small><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></small></p>"
]
},
{
Expand All @@ -40,11 +40,11 @@
"id": "dNIJkb_FM2Ux"
},
"source": [
"# Locomotion in The Playground! \u003ca href=\"https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/locomotion.ipynb\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/\u003e\u003c/a\u003e\n",
"# Locomotion in The Playground! <a href=\"https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/locomotion.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/></a>\n",
"\n",
"In this notebook, we'll walk through a few locomotion environments available in MuJoCo Playground.\n",
"\n",
"**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu \"Runtime \u003e Change runtime type\".\n"
"**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu \"Runtime > Change runtime type\".\n"
]
},
{
Expand Down Expand Up @@ -107,7 +107,7 @@
" print('Checking that the installation succeeded:')\n",
" import mujoco\n",
"\n",
" mujoco.MjModel.from_xml_string('\u003cmujoco/\u003e')\n",
" mujoco.MjModel.from_xml_string('<mujoco/>')\n",
"except Exception as e:\n",
" raise e from RuntimeError(\n",
" 'Something went wrong during installation. Check the shell output above '\n",
Expand Down Expand Up @@ -142,7 +142,7 @@
"\n",
"# Graphics and plotting.\n",
"print(\"Installing mediapy:\")\n",
"!command -v ffmpeg \u003e/dev/null || (apt update \u0026\u0026 apt install -y ffmpeg)\n",
"!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
"!pip install -q mediapy\n",
"import mediapy as media\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -476,11 +476,11 @@
"command = jp.array([x_vel, y_vel, yaw_vel])\n",
"\n",
"state = jit_reset(rng)\n",
"if state.info[\"steps_since_last_pert\"] \u003c state.info[\"steps_until_next_pert\"]:\n",
"if state.info[\"steps_since_last_pert\"] < state.info[\"steps_until_next_pert\"]:\n",
" rng = sample_pert(rng)\n",
"state.info[\"command\"] = command\n",
"for i in range(env_cfg.episode_length):\n",
" if state.info[\"steps_since_last_pert\"] \u003c state.info[\"steps_until_next_pert\"]:\n",
" if state.info[\"steps_since_last_pert\"] < state.info[\"steps_until_next_pert\"]:\n",
" rng = sample_pert(rng)\n",
" act_rng, rng = jax.random.split(rng)\n",
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
Expand Down Expand Up @@ -651,7 +651,7 @@
" print(f\"Setting x to {x}\")\n",
" command = jp.array([x, 0, 0])\n",
" state.info[\"command\"] = command\n",
" if state.info[\"steps_since_last_pert\"] \u003c state.info[\"steps_until_next_pert\"]:\n",
" if state.info[\"steps_since_last_pert\"] < state.info[\"steps_until_next_pert\"]:\n",
" rng = sample_pert(rng)\n",
" act_rng, rng = jax.random.split(rng)\n",
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
Expand Down Expand Up @@ -1103,7 +1103,7 @@
"id": "yCyibqGMiAca"
},
"source": [
"The final policy should exhibit smoother behavior and have less power output! Feel free to finetune the policy some more using different reward terms to demonstrate even smooth behavior."
"The final policy should exhibit smoother behavior and have less power output! Feel free to finetune the policy some more using different reward terms to get the best behavior."
]
},
{
Expand Down Expand Up @@ -1185,6 +1185,9 @@
},
"outputs": [],
"source": [
"#@title Rollout and Render\n",
"from mujoco_playground._src.gait import draw_joystick_command\n",
"\n",
"env = registry.load(env_name)\n",
"eval_env = registry.load(env_name)\n",
"jit_reset = jax.jit(eval_env.reset)\n",
Expand All @@ -1209,18 +1212,13 @@
" state = jit_reset(rng)\n",
" state.info[\"phase_dt\"] = phase_dt\n",
" state.info[\"phase\"] = phase\n",
" ep_rews = []\n",
" for i in range(env_cfg.episode_length):\n",
" act_rng, rng = jax.random.split(rng)\n",
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
" state = jit_step(state, ctrl)\n",
" ep_rews.append(state.reward)\n",
" if state.done:\n",
" break\n",
" state.info[\"command\"] = command\n",
" rews.append(\n",
" {k: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
" )\n",
" rollout.append(state)\n",
"\n",
" xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso\").id])\n",
Expand All @@ -1236,7 +1234,6 @@
" scl=np.linalg.norm(state.info[\"command\"]),\n",
" )\n",
" )\n",
" rews_ep.append(ep_rews)\n",
"\n",
"render_every = 1\n",
"fps = 1.0 / eval_env.dt / render_every\n",
Expand Down Expand Up @@ -1301,4 +1298,3 @@
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 806be2e

Please sign in to comment.