diff --git a/src/jax_flash_attn2/__init__.py b/jax_flash_attn2/__init__.py similarity index 97% rename from src/jax_flash_attn2/__init__.py rename to jax_flash_attn2/__init__.py index ca1a95c..397e566 100644 --- a/src/jax_flash_attn2/__init__.py +++ b/jax_flash_attn2/__init__.py @@ -25,3 +25,5 @@ "FlashAttention", "AttentionConfig", ] + +__version__ = "0.0.1" diff --git a/src/jax_flash_attn2/_custom_call_lib/__init__.py b/jax_flash_attn2/_custom_call_lib/__init__.py similarity index 100% rename from src/jax_flash_attn2/_custom_call_lib/__init__.py rename to jax_flash_attn2/_custom_call_lib/__init__.py diff --git a/src/jax_flash_attn2/_custom_call_lib/lib.py b/jax_flash_attn2/_custom_call_lib/lib.py similarity index 100% rename from src/jax_flash_attn2/_custom_call_lib/lib.py rename to jax_flash_attn2/_custom_call_lib/lib.py diff --git a/src/jax_flash_attn2/cpu_calls/__init__.py b/jax_flash_attn2/cpu_calls/__init__.py similarity index 100% rename from src/jax_flash_attn2/cpu_calls/__init__.py rename to jax_flash_attn2/cpu_calls/__init__.py diff --git a/src/jax_flash_attn2/cpu_calls/mha.py b/jax_flash_attn2/cpu_calls/mha.py similarity index 100% rename from src/jax_flash_attn2/cpu_calls/mha.py rename to jax_flash_attn2/cpu_calls/mha.py diff --git a/src/jax_flash_attn2/flash_attention.py b/jax_flash_attn2/flash_attention.py similarity index 100% rename from src/jax_flash_attn2/flash_attention.py rename to jax_flash_attn2/flash_attention.py diff --git a/src/jax_flash_attn2/pallas_kernels/__init__.py b/jax_flash_attn2/pallas_kernels/__init__.py similarity index 100% rename from src/jax_flash_attn2/pallas_kernels/__init__.py rename to jax_flash_attn2/pallas_kernels/__init__.py diff --git a/src/jax_flash_attn2/pallas_kernels/gpu_mha_kernel.py b/jax_flash_attn2/pallas_kernels/gpu_mha_kernel.py similarity index 100% rename from src/jax_flash_attn2/pallas_kernels/gpu_mha_kernel.py rename to jax_flash_attn2/pallas_kernels/gpu_mha_kernel.py diff --git a/src/jax_flash_attn2/triton_kernels/__init__.py b/jax_flash_attn2/triton_kernels/__init__.py similarity index 100% rename from src/jax_flash_attn2/triton_kernels/__init__.py rename to jax_flash_attn2/triton_kernels/__init__.py diff --git a/src/jax_flash_attn2/triton_kernels/gqa_kernel.py b/jax_flash_attn2/triton_kernels/gqa_kernel.py similarity index 100% rename from src/jax_flash_attn2/triton_kernels/gqa_kernel.py rename to jax_flash_attn2/triton_kernels/gqa_kernel.py diff --git a/src/jax_flash_attn2/triton_kernels/mha_kernel.py b/jax_flash_attn2/triton_kernels/mha_kernel.py similarity index 100% rename from src/jax_flash_attn2/triton_kernels/mha_kernel.py rename to jax_flash_attn2/triton_kernels/mha_kernel.py diff --git a/src/jax_flash_attn2/utils.py b/jax_flash_attn2/utils.py similarity index 100% rename from src/jax_flash_attn2/utils.py rename to jax_flash_attn2/utils.py diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..0fa6ea4 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,355 @@ +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. + +[[package]] +name = "absl-py" +version = "2.1.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = false +python-versions = ">=3.7" +files = [ + {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, + {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, +] + +[[package]] +name = "chex" +version = "0.1.87" +description = "Chex: Testing made fun, in JAX!" +optional = false +python-versions = ">=3.9" +files = [ + {file = "chex-0.1.87-py3-none-any.whl", hash = "sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16"}, + {file = "chex-0.1.87.tar.gz", hash = "sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d"}, +] + +[package.dependencies] +absl-py = ">=0.9.0" +jax = ">=0.4.27" +jaxlib = ">=0.4.27" +numpy = ">=1.24.1" +setuptools = {version = "*", markers = "python_version >= \"3.12\""} +toolz = ">=0.9.0" +typing-extensions = ">=4.2.0" + +[[package]] +name = "einops" +version = "0.8.0" +description = "A new flavour of deep learning operations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "einops-0.8.0-py3-none-any.whl", hash = "sha256:9572fb63046264a862693b0a87088af3bdc8c068fde03de63453cbbde245465f"}, + {file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"}, +] + +[[package]] +name = "filelock" +version = "3.16.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, + {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] +typing = ["typing-extensions (>=4.12.2)"] + +[[package]] +name = "jax" +version = "0.4.35" +description = "Differentiate, compile, and transform Numpy code." +optional = false +python-versions = ">=3.10" +files = [ + {file = "jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325"}, + {file = "jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e"}, +] + +[package.dependencies] +jaxlib = ">=0.4.34,<=0.4.35" +ml-dtypes = ">=0.4.0" +numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.24", markers = "python_version < \"3.12\""}, +] +opt-einsum = "*" +scipy = [ + {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, + {version = ">=1.10", markers = "python_version < \"3.12\""}, +] + +[package.extras] +ci = ["jaxlib (==0.4.34)"] +cuda = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] +cuda12 = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] +cuda12-local = ["jax-cuda12-plugin (==0.4.34)", "jaxlib (==0.4.34)"] +cuda12-pip = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] +k8s = ["kubernetes"] +minimum-jaxlib = ["jaxlib (==0.4.34)"] +tpu = ["jaxlib (>=0.4.34,<=0.4.35)", "libtpu (==0.0.2)", "libtpu-nightly (==0.1.dev20241010+nightly.cleanup)", "requests"] + +[[package]] +name = "jaxlib" +version = "0.4.35" +description = "XLA library for JAX" +optional = false +python-versions = ">=3.10" +files = [ + {file = "jaxlib-0.4.35-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:907e548ad6ce53b242a55c5f36c2a2a4c37d38f6cd8c356fc550a2f18ab0e82f"}, + {file = "jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f8c499644660aefd0ae2ee31039da6d4df0f26d0ee67ba9fb316183a5304288"}, + {file = "jaxlib-0.4.35-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5d2d8a5b89d334b875ede98d7fcee946bebef1a1b5abd118ff543bcef4ab09f5"}, + {file = "jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:91a283a72263feebe0d110d1136df96950744e47530f12df42c03f36888c971e"}, + {file = "jaxlib-0.4.35-cp310-cp310-win_amd64.whl", hash = "sha256:d210bab7e1ce0b2f2e568548b3903ea6aec349019fc1398cd2a0c069e8342e62"}, + {file = "jaxlib-0.4.35-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:7f8bfc90f68857b223b7e38a9bdf466a4f1cb405c9a4aa11698dc9ab7b35c29b"}, + {file = "jaxlib-0.4.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261570c94b169dc90f3af903282eeec856b52736c0944d243504ced93d19b217"}, + {file = "jaxlib-0.4.35-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e1cee6dc291251f3fb6b0127fdd96c0439ac1ea97e01571d06910df72d6ac6e1"}, + {file = "jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc9eafba001ff8569cfa252fe7f04ba553622702b4b473b656dd0866edf6b8d4"}, + {file = "jaxlib-0.4.35-cp311-cp311-win_amd64.whl", hash = "sha256:0fd990354d5623d3a34493fcd7213493390dbf5039bea19b62e2aaee1049eda9"}, + {file = "jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f"}, + {file = "jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad"}, + {file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74"}, + {file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a"}, + {file = "jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d"}, + {file = "jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18"}, + {file = "jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb"}, + {file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5"}, + {file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309"}, + {file = "jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43"}, +] + +[package.dependencies] +ml-dtypes = ">=0.2.0" +numpy = ">=1.24" +scipy = [ + {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, + {version = ">=1.10", markers = "python_version < \"3.12\""}, +] + +[[package]] +name = "ml-dtypes" +version = "0.5.0" +description = "" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ml_dtypes-0.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4"}, + {file = "ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8"}, + {file = "ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856"}, + {file = "ml_dtypes-0.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8"}, + {file = "ml_dtypes-0.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215"}, + {file = "ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f"}, + {file = "ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f"}, + {file = "ml_dtypes-0.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7"}, + {file = "ml_dtypes-0.5.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0"}, + {file = "ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff"}, + {file = "ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a"}, + {file = "ml_dtypes-0.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02"}, + {file = "ml_dtypes-0.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f"}, + {file = "ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599"}, + {file = "ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab"}, + {file = "ml_dtypes-0.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07"}, + {file = "ml_dtypes-0.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15"}, + {file = "ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe"}, + {file = "ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859"}, + {file = "ml_dtypes-0.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf"}, + {file = "ml_dtypes-0.5.0.tar.gz", hash = "sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128"}, +] + +[package.dependencies] +numpy = [ + {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\" and python_version < \"3.13\""}, + {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + +[[package]] +name = "numpy" +version = "2.1.2" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.10" +files = [ + {file = "numpy-2.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:30d53720b726ec36a7f88dc873f0eec8447fbc93d93a8f079dfac2629598d6ee"}, + {file = "numpy-2.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d3ca0a72dd8846eb6f7dfe8f19088060fcb76931ed592d29128e0219652884"}, + {file = "numpy-2.1.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648"}, + {file = "numpy-2.1.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:7c1c60328bd964b53f8b835df69ae8198659e2b9302ff9ebb7de4e5a5994db3d"}, + {file = "numpy-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cdb606a7478f9ad91c6283e238544451e3a95f30fb5467fbf715964341a8a86"}, + {file = "numpy-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d666cb72687559689e9906197e3bec7b736764df6a2e58ee265e360663e9baf7"}, + {file = "numpy-2.1.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6eef7a2dbd0abfb0d9eaf78b73017dbfd0b54051102ff4e6a7b2980d5ac1a03"}, + {file = "numpy-2.1.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:12edb90831ff481f7ef5f6bc6431a9d74dc0e5ff401559a71e5e4611d4f2d466"}, + {file = "numpy-2.1.2-cp310-cp310-win32.whl", hash = "sha256:a65acfdb9c6ebb8368490dbafe83c03c7e277b37e6857f0caeadbbc56e12f4fb"}, + {file = "numpy-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:860ec6e63e2c5c2ee5e9121808145c7bf86c96cca9ad396c0bd3e0f2798ccbe2"}, + {file = "numpy-2.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b42a1a511c81cc78cbc4539675713bbcf9d9c3913386243ceff0e9429ca892fe"}, + {file = "numpy-2.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:faa88bc527d0f097abdc2c663cddf37c05a1c2f113716601555249805cf573f1"}, + {file = "numpy-2.1.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:c82af4b2ddd2ee72d1fc0c6695048d457e00b3582ccde72d8a1c991b808bb20f"}, + {file = "numpy-2.1.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:13602b3174432a35b16c4cfb5de9a12d229727c3dd47a6ce35111f2ebdf66ff4"}, + {file = "numpy-2.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ebec5fd716c5a5b3d8dfcc439be82a8407b7b24b230d0ad28a81b61c2f4659a"}, + {file = "numpy-2.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2b49c3c0804e8ecb05d59af8386ec2f74877f7ca8fd9c1e00be2672e4d399b1"}, + {file = "numpy-2.1.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cbba4b30bf31ddbe97f1c7205ef976909a93a66bb1583e983adbd155ba72ac2"}, + {file = "numpy-2.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8e00ea6fc82e8a804433d3e9cedaa1051a1422cb6e443011590c14d2dea59146"}, + {file = "numpy-2.1.2-cp311-cp311-win32.whl", hash = "sha256:5006b13a06e0b38d561fab5ccc37581f23c9511879be7693bd33c7cd15ca227c"}, + {file = "numpy-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:f1eb068ead09f4994dec71c24b2844f1e4e4e013b9629f812f292f04bd1510d9"}, + {file = "numpy-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7bf0a4f9f15b32b5ba53147369e94296f5fffb783db5aacc1be15b4bf72f43b"}, + {file = "numpy-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b1d0fcae4f0949f215d4632be684a539859b295e2d0cb14f78ec231915d644db"}, + {file = "numpy-2.1.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f751ed0a2f250541e19dfca9f1eafa31a392c71c832b6bb9e113b10d050cb0f1"}, + {file = "numpy-2.1.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:bd33f82e95ba7ad632bc57837ee99dba3d7e006536200c4e9124089e1bf42426"}, + {file = "numpy-2.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b8cde4f11f0a975d1fd59373b32e2f5a562ade7cde4f85b7137f3de8fbb29a0"}, + {file = "numpy-2.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d95f286b8244b3649b477ac066c6906fbb2905f8ac19b170e2175d3d799f4df"}, + {file = "numpy-2.1.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ab4754d432e3ac42d33a269c8567413bdb541689b02d93788af4131018cbf366"}, + {file = "numpy-2.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e585c8ae871fd38ac50598f4763d73ec5497b0de9a0ab4ef5b69f01c6a046142"}, + {file = "numpy-2.1.2-cp312-cp312-win32.whl", hash = "sha256:9c6c754df29ce6a89ed23afb25550d1c2d5fdb9901d9c67a16e0b16eaf7e2550"}, + {file = "numpy-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:456e3b11cb79ac9946c822a56346ec80275eaf2950314b249b512896c0d2505e"}, + {file = "numpy-2.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a84498e0d0a1174f2b3ed769b67b656aa5460c92c9554039e11f20a05650f00d"}, + {file = "numpy-2.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4d6ec0d4222e8ffdab1744da2560f07856421b367928026fb540e1945f2eeeaf"}, + {file = "numpy-2.1.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:259ec80d54999cc34cd1eb8ded513cb053c3bf4829152a2e00de2371bd406f5e"}, + {file = "numpy-2.1.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:675c741d4739af2dc20cd6c6a5c4b7355c728167845e3c6b0e824e4e5d36a6c3"}, + {file = "numpy-2.1.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b2d4e667895cc55e3ff2b56077e4c8a5604361fc21a042845ea3ad67465aa8"}, + {file = "numpy-2.1.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43cca367bf94a14aca50b89e9bc2061683116cfe864e56740e083392f533ce7a"}, + {file = "numpy-2.1.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:76322dcdb16fccf2ac56f99048af32259dcc488d9b7e25b51e5eca5147a3fb98"}, + {file = "numpy-2.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:32e16a03138cabe0cb28e1007ee82264296ac0983714094380b408097a418cfe"}, + {file = "numpy-2.1.2-cp313-cp313-win32.whl", hash = "sha256:242b39d00e4944431a3cd2db2f5377e15b5785920421993770cddb89992c3f3a"}, + {file = "numpy-2.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:f2ded8d9b6f68cc26f8425eda5d3877b47343e68ca23d0d0846f4d312ecaa445"}, + {file = "numpy-2.1.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2ffef621c14ebb0188a8633348504a35c13680d6da93ab5cb86f4e54b7e922b5"}, + {file = "numpy-2.1.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad369ed238b1959dfbade9018a740fb9392c5ac4f9b5173f420bd4f37ba1f7a0"}, + {file = "numpy-2.1.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d82075752f40c0ddf57e6e02673a17f6cb0f8eb3f587f63ca1eaab5594da5b17"}, + {file = "numpy-2.1.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:1600068c262af1ca9580a527d43dc9d959b0b1d8e56f8a05d830eea39b7c8af6"}, + {file = "numpy-2.1.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a26ae94658d3ba3781d5e103ac07a876b3e9b29db53f68ed7df432fd033358a8"}, + {file = "numpy-2.1.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13311c2db4c5f7609b462bc0f43d3c465424d25c626d95040f073e30f7570e35"}, + {file = "numpy-2.1.2-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:2abbf905a0b568706391ec6fa15161fad0fb5d8b68d73c461b3c1bab6064dd62"}, + {file = "numpy-2.1.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ef444c57d664d35cac4e18c298c47d7b504c66b17c2ea91312e979fcfbdfb08a"}, + {file = "numpy-2.1.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:bdd407c40483463898b84490770199d5714dcc9dd9b792f6c6caccc523c00952"}, + {file = "numpy-2.1.2-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:da65fb46d4cbb75cb417cddf6ba5e7582eb7bb0b47db4b99c9fe5787ce5d91f5"}, + {file = "numpy-2.1.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c193d0b0238638e6fc5f10f1b074a6993cb13b0b431f64079a509d63d3aa8b7"}, + {file = "numpy-2.1.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a7d80b2e904faa63068ead63107189164ca443b42dd1930299e0d1cb041cec2e"}, + {file = "numpy-2.1.2.tar.gz", hash = "sha256:13532a088217fa624c99b843eeb54640de23b3414b14aa66d023805eb731066c"}, +] + +[[package]] +name = "opt-einsum" +version = "3.4.0" +description = "Path optimization of einsum functions." +optional = false +python-versions = ">=3.8" +files = [ + {file = "opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd"}, + {file = "opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac"}, +] + +[[package]] +name = "scipy" +version = "1.13.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, + {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, + {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, + {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, + {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, + {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, + {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, + {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, + {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, + {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, + {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, +] + +[package.dependencies] +numpy = ">=1.22.4,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + +[[package]] +name = "setuptools" +version = "75.2.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"}, + {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"] + +[[package]] +name = "toolz" +version = "1.0.0" +description = "List processing tools and functional utilities" +optional = false +python-versions = ">=3.8" +files = [ + {file = "toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236"}, + {file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"}, +] + +[[package]] +name = "triton" +version = "3.0.0" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +files = [ + {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, + {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, + {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, + {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, + {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + +[[package]] +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] + +[metadata] +lock-version = "2.0" +python-versions = ">=3.10" +content-hash = "3e16c751bf056781ad6406640f06b72c8d32e81104cb9bd20a300dd34aeb77a5" diff --git a/pyproject.toml b/pyproject.toml index f2bd053..2f4daba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,46 +1,37 @@ -[project] +[tool.poetry] name = "jax-flash-attn2" -version = "0.0.0" -authors = [{ name = "Erfan Zare Chavoshi", email = "erfanzare810@gmail.com" }] +version = "0.0.1" description = "Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX)." +authors = ["Erfan Zare Chavoshi "] +license = "Apache-2.0" readme = "README.md" -requires-python = ">=3.8" -license = { text = "Apache-2.0" } -dependencies = [ - "jax>=0.4.33", - "jaxlib>=0.4.33", - "triton~=3.0.0", - "scipy==1.13.1", - "einops", - "chex", -] +homepage = "https://github.com/erfanzar/jax-flash-attn2" +repository = "https://github.com/erfanzar/jax-flash-attn2" +documentation = "https://erfanzar.github.io/jax-flash-attn2" +keywords = ["JAX", "Deep Learning", "Machine Learning", "XLA"] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Topic :: Scientific/Engineering :: Artificial Intelligence", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ] -keywords = ["JAX", "Deep Learning", "Machine Learning", "XLA"] - -[project.urls] -Homepage = "https://github.com/erfanzar/jax-flash-attn2" -Issues = "https://github.com/erfanzar/jax-flash-attn2/issues" -Documentation = "https://erfanzar.github.io/jax-flash-attn2" - -[build-system] -requires = ["flit_core >=3.2,<4"] -build-backend = "flit_core.buildapi" +[tool.poetry.dependencies] +python = ">=3.10" +jax = ">=0.4.33" +jaxlib = ">=0.4.33" +triton = "~=3.0.0" +scipy = "1.13.1" +einops = "*" +chex = "*" [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "B"] - ignore = ["E501", "B905", "B007", "E741"] unfixable = ["B"] @@ -56,7 +47,6 @@ quote-style = "double" indent-style = "tab" docstring-code-format = true - [tool.ruff] target-version = "py311" line-length = 88 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_triton.py b/tests/test_triton.py index 59a2f8f..6ac3a30 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -6,7 +6,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["JAX_TRACEBACK_FILTERING"] = "off" -sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../src")) +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) import jax from jax import numpy as jnp @@ -109,8 +109,8 @@ def test_forward(): def test_backward(): - """Tests the backward pass of the attention mechanism.""" - + """Tests the backward pass of the attention mechanism.""" + q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3) B, QH, KVH, QS, KS, D = 1, 32, 32, 1024, 1024, 128 q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)