Skip to content

Commit

Permalink
using poetry, moving base path and change builder package
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Oct 23, 2024
1 parent d2a40c0 commit dda2760
Show file tree
Hide file tree
Showing 16 changed files with 376 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@
"FlashAttention",
"AttentionConfig",
]

__version__ = "0.0.1"
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
355 changes: 355 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

42 changes: 16 additions & 26 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,46 +1,37 @@
[project]
[tool.poetry]
name = "jax-flash-attn2"
version = "0.0.0"
authors = [{ name = "Erfan Zare Chavoshi", email = "[email protected]" }]
version = "0.0.1"
description = "Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX)."
authors = ["Erfan Zare Chavoshi <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
requires-python = ">=3.8"
license = { text = "Apache-2.0" }
dependencies = [
"jax>=0.4.33",
"jaxlib>=0.4.33",
"triton~=3.0.0",
"scipy==1.13.1",
"einops",
"chex",
]
homepage = "https://github.com/erfanzar/jax-flash-attn2"
repository = "https://github.com/erfanzar/jax-flash-attn2"
documentation = "https://erfanzar.github.io/jax-flash-attn2"
keywords = ["JAX", "Deep Learning", "Machine Learning", "XLA"]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
keywords = ["JAX", "Deep Learning", "Machine Learning", "XLA"]


[project.urls]
Homepage = "https://github.com/erfanzar/jax-flash-attn2"
Issues = "https://github.com/erfanzar/jax-flash-attn2/issues"
Documentation = "https://erfanzar.github.io/jax-flash-attn2"

[build-system]
requires = ["flit_core >=3.2,<4"]
build-backend = "flit_core.buildapi"
[tool.poetry.dependencies]
python = ">=3.10"
jax = ">=0.4.33"
jaxlib = ">=0.4.33"
triton = "~=3.0.0"
scipy = "1.13.1"
einops = "*"
chex = "*"

[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "B"]

ignore = ["E501", "B905", "B007", "E741"]
unfixable = ["B"]

Expand All @@ -56,7 +47,6 @@ quote-style = "double"
indent-style = "tab"
docstring-code-format = true


[tool.ruff]
target-version = "py311"
line-length = 88
Expand Down
Empty file added tests/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions tests/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["JAX_TRACEBACK_FILTERING"] = "off"

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../src"))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

import jax
from jax import numpy as jnp
Expand Down Expand Up @@ -109,8 +109,8 @@ def test_forward():


def test_backward():
"""Tests the backward pass of the attention mechanism."""
"""Tests the backward pass of the attention mechanism."""

q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
B, QH, KVH, QS, KS, D = 1, 32, 32, 1024, 1024, 128
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)
Expand Down

0 comments on commit dda2760

Please sign in to comment.