Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
btaba authored Jan 21, 2025
2 parents 4075c5c + 7dc538b commit 0c6812c
Show file tree
Hide file tree
Showing 16 changed files with 449 additions and 58 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand All @@ -24,9 +24,9 @@ jobs:
python -m pip install --upgrade pip
pip install uv
uv pip install --system -e ".[test]"
# - name: Install submodules
# run: |
# git submodule update --init --recursive
# - name: Test with pytest
# run: |
# pytest mujoco_playground
- name: Trigger git clone
run: |
python -c "import mujoco_playground"
- name: Test with pytest
run: |
pytest -n auto mujoco_playground
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Changelog

All notable changes to this project will be documented in this file.

## Unreleased

### Added

- Added ALOHA handover task (thanks to @Andrew-Luo1).

### Changed

## [0.0.3] - 2025-01-18

### Changed

- Updated supported Python versions to 3.10-3.12.

## [0.0.2] - 2025-01-16

Initial release.
File renamed without changes.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# MuJoCo Playground

<h1>
<a href="#"><img alt="MuJoCo Playground" src="assets/banner.png" width="100%"></a>
</h1>
[![Build](https://img.shields.io/github/actions/workflow/status/google-deepmind/mujoco_playground/ci.yml?branch=main)](https://github.com/google-deepmind/mujoco_playground/actions)
[![PyPI version](https://img.shields.io/pypi/v/playground)](https://pypi.org/project/playground/)
![Banner for playground](https://github.com/google-deepmind/mujoco_playground/blob/main/assets/banner.png?raw=true)

A comprehensive suite of GPU-accelerated environments for robot learning research and sim-to-real, built with [MuJoCo MJX](https://github.com/google-deepmind/mujoco/tree/main/mjx).

Expand All @@ -26,7 +26,7 @@ pip install playground
### From Source

> [!IMPORTANT]
> Requires Python 3.9 or later.
> Requires Python 3.10 or later.
1. `pip install -U "jax[cuda12]"`
* Verify GPU backend: python -c "import jax; print(jax.default_backend())" should print gpu
Expand Down Expand Up @@ -71,7 +71,7 @@ Two additional colabs require local runtimes with Madrona-MJX installed locally

## How can I contribute?

Get started by installing the library and exploring its features! Found a bug? Report it in the issue tracker. Interested in contributing? If you’re a developer with robotics experience, we’d love your help—check out the [contribution guidelines](CONTRIBUTING) for more details.
Get started by installing the library and exploring its features! Found a bug? Report it in the issue tracker. Interested in contributing? If you’re a developer with robotics experience, we’d love your help—check out the [contribution guidelines](CONTRIBUTING.md) for more details.

## Citation

Expand Down
17 changes: 16 additions & 1 deletion learning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ For more detailed tutorials on using MuJoCo Playground for RL, see:
4. Training CartPole from Vision [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1.ipynb)
5. Robotic Manipulation from Vision [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_2.ipynb)

## Training with brax PPO

To train with brax PPO, you can use the `train_jax_ppo.py` script. This script uses the brax PPO algorithm to train an agent on a given environment.

```bash
python train_jax_ppo.py --env_name=CartpoleBalance
```

To train a vision-based policy using pixel observations:
```bash
python train_jax_ppo.py --env_name=CartpoleBalance --vision
```

Use `python train_jax_ppo.py --help` to see possible options and usage. Logs and checkpoints are saved in `logs` directory.

## Training with RSL-RL

To train with RSL-RL, you can use the `train_rsl_rl.py` script. This script uses the RSL-RL algorithm to train an agent on a given environment.
Expand All @@ -18,7 +33,7 @@ To train with RSL-RL, you can use the `train_rsl_rl.py` script. This script uses
python train_rsl_rl.py --env_name=LeapCubeReorient
```

to render the behaviour from the resulting policy:
To render the behaviour from the resulting policy:
```bash
python learning/train_rsl_rl.py --env_name LeapCubeReorient --play_only --load_run_name <run_name>
```
Expand Down
8 changes: 4 additions & 4 deletions learning/train_jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ def main(argv):
if _CLIPPING_EPSILON.present:
ppo_params.clipping_epsilon = _CLIPPING_EPSILON.value
if _POLICY_HIDDEN_LAYER_SIZES.present:
ppo_params.network_factory.policy_hidden_layer_sizes = tuple(
_POLICY_HIDDEN_LAYER_SIZES.value
ppo_params.network_factory.policy_hidden_layer_sizes = list(
map(int, _POLICY_HIDDEN_LAYER_SIZES.value)
)
if _VALUE_HIDDEN_LAYER_SIZES.present:
ppo_params.network_factory.value_hidden_layer_sizes = tuple(
_VALUE_HIDDEN_LAYER_SIZES.value
ppo_params.network_factory.value_hidden_layer_sizes = list(
map(int, _VALUE_HIDDEN_LAYER_SIZES.value)
)
if _POLICY_OBS_KEY.present:
ppo_params.network_factory.policy_obs_key = _POLICY_OBS_KEY.value
Expand Down
3 changes: 3 additions & 0 deletions mujoco_playground/_src/manipulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mujoco import mjx

from mujoco_playground._src import mjx_env
from mujoco_playground._src.manipulation.aloha import handover as aloha_handover
from mujoco_playground._src.manipulation.aloha import single_peg_insertion as aloha_peg
from mujoco_playground._src.manipulation.franka_emika_panda import open_cabinet as panda_open_cabinet
from mujoco_playground._src.manipulation.franka_emika_panda import pick as panda_pick
Expand All @@ -29,6 +30,7 @@
from mujoco_playground._src.manipulation.leap_hand import rotate_z as leap_rotate_z

_envs = {
"AlohaHandOver": aloha_handover.HandOver,
"AlohaSinglePegInsertion": aloha_peg.SinglePegInsertion,
"PandaPickCube": panda_pick.PandaPickCube,
"PandaPickCubeOrientation": panda_pick.PandaPickCubeOrientation,
Expand All @@ -40,6 +42,7 @@
}

_cfgs = {
"AlohaHandOver": aloha_handover.default_config,
"AlohaSinglePegInsertion": aloha_peg.default_config,
"PandaPickCube": panda_pick.default_config,
"PandaPickCubeOrientation": panda_pick.default_config,
Expand Down
15 changes: 8 additions & 7 deletions mujoco_playground/_src/manipulation/aloha/aloha_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@

from mujoco_playground._src import mjx_env

XML_PATH = (
mjx_env.ROOT_PATH
/ "manipulation"
/ "aloha"
/ "xmls"
/ "mjx_single_peg_insertion.xml"
)
XML_PATH = mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls"

ARM_JOINTS = [
"left/waist",
Expand All @@ -49,3 +43,10 @@
"right/right_finger_top",
"right/right_finger_bottom",
]

FINGER_JOINTS = [
"left/left_finger",
"left/right_finger",
"right/left_finger",
"right/right_finger",
]
32 changes: 32 additions & 0 deletions mujoco_playground/_src/manipulation/aloha/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
from typing import Any, Dict, Optional, Union

from etils import epath
import jax.numpy as jp
from ml_collections import config_dict
import mujoco
from mujoco import mjx
import numpy as np

from mujoco_playground._src import collision
from mujoco_playground._src import mjx_env
from mujoco_playground._src.manipulation.aloha import aloha_constants as consts


def get_assets() -> Dict[str, bytes]:
Expand Down Expand Up @@ -58,6 +62,26 @@ def __init__(
self._mjx_model = mjx.put_model(self._mj_model)
self._xml_path = xml_path

def _post_init_aloha(self, keyframe: str = "home"):
"""Initializes helpful robot properties."""
self._left_gripper_site = self._mj_model.site("left/gripper").id
self._right_gripper_site = self._mj_model.site("right/gripper").id
self._table_geom = self._mj_model.geom("table").id
self._finger_geoms = [
self._mj_model.geom(geom_id).id for geom_id in consts.FINGER_GEOMS
]
self._init_q = jp.array(self._mj_model.keyframe(keyframe).qpos)
self._init_ctrl = jp.array(self._mj_model.keyframe(keyframe).ctrl)
self._lowers, self._uppers = self.mj_model.actuator_ctrlrange.T
arm_joint_ids = [self._mj_model.joint(j).id for j in consts.ARM_JOINTS]
self._arm_qadr = jp.array(
[self._mj_model.jnt_qposadr[joint_id] for joint_id in arm_joint_ids]
)
self._finger_qposadr = np.array([
self._mj_model.jnt_qposadr[self._mj_model.joint(j).id]
for j in consts.FINGER_JOINTS
])

@property
def xml_path(self) -> str:
return self._xml_path
Expand All @@ -73,3 +97,11 @@ def mj_model(self) -> mujoco.MjModel:
@property
def mjx_model(self) -> mjx.Model:
return self._mjx_model

def hand_table_collision(self, data) -> jp.ndarray:
# Check for collisions with the floor.
hand_table_collisions = [
collision.geoms_colliding(data, self._table_geom, g)
for g in self._finger_geoms
]
return (sum(hand_table_collisions) > 0).astype(float)
Loading

0 comments on commit 0c6812c

Please sign in to comment.