diff --git a/learning/README.md b/learning/README.md index 0a24ad8..4244539 100644 --- a/learning/README.md +++ b/learning/README.md @@ -10,6 +10,21 @@ For more detailed tutorials on using MuJoCo Playground for RL, see: 4. Training CartPole from Vision [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1.ipynb) 5. Robotic Manipulation from Vision [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_2.ipynb) +## Training with brax PPO + +To train with brax PPO, you can use the `train_jax_ppo.py` script. This script uses the brax PPO algorithm to train an agent on a given environment. + +```bash +python train_jax_ppo.py --env_name=CartpoleBalance +``` + +To train a vision-based policy using pixel observations: +```bash +python train_jax_ppo.py --env_name=CartpoleBalance --vision +``` + +Use `python train_jax_ppo.py --help` to see possible options and usage. Logs and checkpoints are saved in `logs` directory. + ## Training with RSL-RL To train with RSL-RL, you can use the `train_rsl_rl.py` script. This script uses the RSL-RL algorithm to train an agent on a given environment. @@ -18,7 +33,7 @@ To train with RSL-RL, you can use the `train_rsl_rl.py` script. This script uses python train_rsl_rl.py --env_name=LeapCubeReorient ``` -to render the behaviour from the resulting policy: +To render the behaviour from the resulting policy: ```bash python learning/train_rsl_rl.py --env_name LeapCubeReorient --play_only --load_run_name ``` diff --git a/learning/train_jax_ppo.py b/learning/train_jax_ppo.py index a251427..ef4a44a 100644 --- a/learning/train_jax_ppo.py +++ b/learning/train_jax_ppo.py @@ -198,12 +198,12 @@ def main(argv): if _CLIPPING_EPSILON.present: ppo_params.clipping_epsilon = _CLIPPING_EPSILON.value if _POLICY_HIDDEN_LAYER_SIZES.present: - ppo_params.network_factory.policy_hidden_layer_sizes = tuple( - _POLICY_HIDDEN_LAYER_SIZES.value + ppo_params.network_factory.policy_hidden_layer_sizes = list( + map(int, _POLICY_HIDDEN_LAYER_SIZES.value) ) if _VALUE_HIDDEN_LAYER_SIZES.present: - ppo_params.network_factory.value_hidden_layer_sizes = tuple( - _VALUE_HIDDEN_LAYER_SIZES.value + ppo_params.network_factory.value_hidden_layer_sizes = list( + map(int, _VALUE_HIDDEN_LAYER_SIZES.value) ) if _POLICY_OBS_KEY.present: ppo_params.network_factory.policy_obs_key = _POLICY_OBS_KEY.value