Skip to content

Batchjax is a library that allows jax's vmap to be used over list and Objax ModuleLists.

License

Notifications You must be signed in to change notification settings

defaultobject/batchjax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BatchJAX

Description

BatchJAX is a library that allow JAX vmap to be used over lists and objax.ModuleList.

Installation

pip install batchjax

Example

See batchjax_example.ipynb.

About

Batchjax is a library that allows jax's vmap to be used over list and Objax ModuleLists.

Resources

License

Stars

Watchers

Forks