jaxKAN is a Python package designed to enable the training of Kolmogorov-Arnold Networks (KANs) using the JAX framework. Built on Flax's NNX module, jaxKAN provides a collection of KAN layers that serve as foundational building blocks for various KAN architectures, such as the EfficientKAN and the ChebyKAN. While it includes standard features like initialization and forward pass methods, the KAN class in jaxKAN introduces an extend_grids
method, which facilitates the extension of the grids for all layers in the network, irrespective of how those grids are defined. For instance, in the case of ChebyKAN, where a traditional grid concept doesn't exist, the method extends the order of the Chebyshev polynomials utilized in the model.
Extensive documentation on jaxKAN, including installation & contributing guidelines, API reference and tutorials, can be found here.
We warmly welcome community contributions to jaxKAN! For details on the types of contributions that will help jaxKAN evolve, as well as guidelines on how to contribute, visit this page of our documentation.
If you utilized jaxKAN
for your own academic work, please use the following citation:
@article{Rigas2025,
author = {Rigas, Spyros and Papachristou, Michalis},
title = {jax{KAN}: A unified {JAX} framework for {K}olmogorov-{A}rnold Networks},
journal = {Journal of Open Source Software},
year = {2025},
volume = {10},
number = {108},
pages = {7830},
doi = {10.21105/joss.07830}
}
If you have used jaxKAN in your research for PIKAN-related applications or theoretical developments, please consider also citing the paper that originally introduced jaxKAN for these tasks:
@article{10763509,
author = {Rigas, Spyros and Papachristou, Michalis and Papadopoulos, Theofilos and Anagnostopoulos, Fotios and Alexandridis, Georgios},
title = {Adaptive Training of Grid-Dependent Physics-Informed {K}olmogorov-{A}rnold Networks},
journal = {IEEE Access},
year = {2024},
volume = {12},
pages = {176982-176998},
doi = {10.1109/ACCESS.2024.3504962}
}