Skip to content

ShaneFlandermeyer/tdmpc2-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tdmpc2-jax

A re-implementation of TD-MPC2 in Jax/Flax. JIT'ing the planning/update steps makes training 5-10x faster than the original PyTorch implementation while maintaining similar or better performance in challenging continuous control environments.

This repository also supports vectorized environments (see the env field of config.yaml) and finite-horizon environments (see world_model.predict_continues and tdmpc.continue_coef in config.yaml).

Usage

To install the dependencies for this project (tested on Ubuntu 22.04), run

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

pip install --upgrade tqdm numpy flax optax jaxtyping einops "gymnasium[mujoco]" hydra-core tensorflow orbax-checkpoint dm_control

Then, edit config.yaml and run train.py in the main project directory. Some examples:

# gymnasium 
python train.py env.backend=gymnasium env.env_id=HalfCheetah-v4 
# dmcs
python train.py env.backend=dmc env.env_id=cheetah-run   

Installation

Install the package from the base directory with

pip install -e .

Contributing

If you enjoy this project and would like to help improve it, feel free to put in an issue or pull request! While the core algorithm is fully implemented, the following features still need to be added:

  • Multi-task operation through task embeddings and replay buffer

About

Jax/Flax Implementation of TD-MPC2

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages