Molecular active learning with JAX - a lightweight framework for active learning in molecular property prediction.
git clone https://github.com/yourusername/molax
cd molax
pip install -r requirements.txt
Required dependencies:
jax
flax
optax
rdkit
pandas
numpy
Basic usage with SMILES data:
from molax.utils.data import MolecularDataset
from molax.models import UncertaintyGCN
from molax.acquisition import combined_acquisition
# Load your data
dataset = MolecularDataset('data/molecules.csv',
smiles_col='smiles',
label_col='property')
# Split dataset
train_data, test_data = dataset.split_train_test(test_size=0.2)
# Initialize model
model = UncertaintyGCN(
hidden_features=(64, 64),
output_features=1,
dropout_rate=0.1
)
# Run active learning loop
# See examples/active_learning.py for complete implementation
- Graph neural networks implemented in JAX/Flax
- Uncertainty estimation via MC dropout
- Multiple acquisition functions
- Efficient batch selection
- RDKit-based molecular processing
Check examples/active_learning.py
for a complete active learning pipeline.
- Fork the repository
- Create your feature branch (
git checkout -b feature/amazing-feature
) - Commit changes (
git commit -m 'Add amazing feature'
) - Push to branch (
git push origin feature/amazing-feature
) - Open a Pull Request
MIT License
@software{molax2025,
title={molax: Molecular Active Learning with JAX},
author={Hosein Fooladi},
year={2025},
url={https://github.com/hfooladi/molax}
}