Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for Evolution Strategies? evosax
provides a comprehensive, high-performance library that implements Evolution Strategies (ES) in JAX. By leveraging XLA compilation and JAX's transformation primitives, evosax
enables researchers and practitioners to efficiently scale evolutionary algorithms to modern hardware accelerators without the traditional overhead of distributed implementations.
The API follows the classical ask
-eval
-tell
cycle of ES, with full support for JAX's transformations (jit
, vmap
, lax.scan
). The library includes 30+ evolution strategies, from classics like CMA-ES and Differential Evolution to modern approaches like OpenAI-ES and Diffusion Evolution.
import jax
from evosax.algorithms import CMA_ES
# Instantiate the search strategy
es = CMA_ES(population_size=32, solution=dummy_solution)
params = es.default_params
# Initialize state
key = jax.random.key(0)
state = es.init(key, params)
# Ask-Eval-Tell loop
for i in range(num_generations):
key, key_ask, key_eval = jax.random.split(key, 3)
# Generate a set of candidate solutions to evaluate
population, state = es.ask(key_ask, state, params)
# Evaluate the fitness of the population
fitness = ...
# Update the evolution strategy
state = es.tell(population, fitness, state, params)
# Get best solution
state.best_solution, state.best_fitness
You will need Python 3.10 or later, and a working JAX installation.
Then, install evosax
from PyPi:
pip install evosax
To upgrade to the latest version of evosax
, you can use:
pip install git+https://github.com/RobertTLange/evosax.git@main
- π Getting Started - Introduction to the library
- π Black Box Optimization Benchmark - Optimization of common test functions
- π Reinforcement Learning - Learning MLP control policies
- π Vision - Training CNNs for classification
- π Restart ES - Implementing restart strategies
- π Diffusion Evolution - Optimization with diffusion evolution.
- π Stein Variational ES - Using SV-ES on BBOB problems
- π Persistent/Noise-Reuse ES - ES for meta-learning problems
- Comprehensive Algorithm Collection: 30+ classic and modern evolution strategies with a unified API
- JAX Acceleration: Fully compatible with JAX transformations for speed and scalability
- Vectorization & Parallelization: Fast execution on CPUs, GPUs, and TPUs
- Production Ready: Well-tested, documented, and used in research environments
- Batteries Included: Comes with optimizers like ClipUp, fitness shaping, and restart strategies
- πΊ Rob's MLC Research Jam Talk - Overview at the ML Collective Research Jam
- π Rob's 02/2021 Blog - Blog post on implementing CMA-ES in JAX
- π» Evojax - Hardware-Accelerated Neuroevolution with great rollout wrappers.
- π» QDax: Quality-Diversity algorithms in JAX.
If you use evosax
in your research, please cite the following paper:
@article{evosax2022github,
author = {Robert Tjarko Lange},
title = {evosax: JAX-based Evolution Strategies},
journal = {arXiv preprint arXiv:2212.04180},
year = {2022},
}
We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.
Contributions are welcome! If you find a bug or are missing your favorite feature, please open an issue or submit a pull request following our contribution guidelines π€.
This repository contains independent reimplementations of LES and DES based and is unrelated to Google DeepMind. The implementation has been tested to reproduce the official results on a range of tasks.