Skip to content
36 changes: 33 additions & 3 deletions .github/container/Dockerfile.axlearn
Original file line number Diff line number Diff line change
@@ -1,19 +1,49 @@
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
ARG URLREF_AXLEARN=https://github.com/Steboss/axlearn.git#main
ARG SRC_MANIFEST_FILE=manifest.yaml
ARG DEST_MANIFEST_DIR=/opt/manifest.d
ARG SRC_PATH_AXLEARN=/opt/axlearn
ARG GIT_USER_NAME=NVIDIA
ARG [email protected]

###############################################################################
## Download source and configure dependencies
###############################################################################
FROM ${BASE_IMAGE} AS mealkit
ARG URLREF_AXLEARN
ARG DEST_MANIFEST_DIR
ARG DEST_MANIFEST_DIR
ARG SRC_MANIFEST_FILE
ARG SRC_PATH_AXLEARN
ARG GIT_USER_NAME
ARG GIT_USER_EMAIL

RUN git-clone.sh "${URLREF_AXLEARN}" "${SRC_PATH_AXLEARN}"
# Use create distribution as we have a patch
# ADD --chmod=777 create-distribution.sh ${DEST_MANIFEST_DIR}/
# COPY ${SRC_MANIFEST_FILE} ${DEST_MANIFEST_DIR}/${SRC_MANIFEST_FILE}
# COPY patches/ ${DEST_MANIFEST_DIR}/patches/
# Run the patch with cloning
RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu
# move files from mount
cp -r /mnt/jax-toolbox/.github/container/patches ${MANIFEST_DIR}/
cp /mnt/jax-toolbox/.github/container/manifest.yaml ${MANIFEST_DIR}/manifest.yaml
cp /mnt/jax-toolbox/.github/container/create-distribution.sh ${MANIFEST_DIR}/create-distribution.sh
# TODO: remove
cp /mnt/jax-toolbox/.github/container/pip-finalize.sh /usr/local/bin/
# then set up the identity
git config --global user.email "${GIT_USER_EMAIL}"
git config --global user.name "${GIT_USER_NAME}"
# Apply the patch
bash ${DEST_MANIFEST_DIR}/create-distribution.sh \
--manifest ${DEST_MANIFEST_DIR}/manifest.yaml \
--package axlearn

# general clean up
rm -f ~/.gitconfig
EOF

# these packages are needed to run axlearn tests
# https://github.com/apple/axlearn/blob/main/pyproject.toml as reference
WORKDIR /opt/axlearn
RUN <<"EOF" bash -ex
echo "-e ${SRC_PATH_AXLEARN}" > /opt/pip-tools.d/requirements-axlearn.in
cat <<REQUIREMENTS >> /opt/pip-tools.d/requirements-axlearn.in
Expand Down
3 changes: 1 addition & 2 deletions .github/container/git-clone.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ while [ : ]; do
;;
--)
shift;
break
break
;;
esac
done
Expand Down Expand Up @@ -80,7 +80,6 @@ git submodule update --init --recursive
popd

## update the manifest file

mkdir -p $(dirname ${MANIFEST})
touch ${MANIFEST}
PACKAGE=$(basename "${DESTINATION}")
Expand Down
5 changes: 4 additions & 1 deletion .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ pathwaysutils:
latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9
mode: pip-vcs
axlearn:
url: https://github.com/Steboss/axlearn.git
url: https://github.com/apple/axlearn.git
mirror_url: https://github.com/nvjax-svc-0/axlearn.git
tracking_ref: main
mode: git-clone
patches:
pull/1339/head: file://patches/axlearn/PR-1339.patch
qwix:
url: https://github.com/google/qwix.git
tracking_ref: main
Expand Down
280 changes: 280 additions & 0 deletions .github/container/patches/axlearn/PR-1339.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py
index 369926e..5bd271e 100644
--- a/axlearn/common/array_serialization.py
+++ b/axlearn/common/array_serialization.py
@@ -306,10 +306,12 @@ async def _async_serialize(
and arr_inp.is_fully_addressable
)
# pylint: disable=protected-access
- spec_has_metadata = {
- "0.6.2": lambda: serialization.ts_impl._spec_has_metadata,
- "0.5.3": lambda: serialization._spec_has_metadata,
- }[jax.__version__]()
+ if jax.__version__.startswith("0.8.0") or jax.__version__ == "0.6.2":
+ spec_has_metadata = serialization.ts_impl._spec_has_metadata
+ elif jax.__version__ == "0.5.3":
+ spec_has_metadata = serialization._spec_has_metadata
+ else:
+ raise ValueError(f"Unsupported JAX version for spec_has_metadata: {jax.__version__}")
if not spec_has_metadata(tensorstore_spec):
# pylint: disable-next=protected-access
tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp)
diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py
index 9026f3a..1d47a26 100644
--- a/axlearn/common/flash_attention/gpu_attention.py
+++ b/axlearn/common/flash_attention/gpu_attention.py
@@ -44,7 +44,6 @@ from jax._src.cudnn.fused_attention_stablehlo import (
)
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import pallas as pl
-from jax.experimental.pallas.triton import TritonCompilerParams

from axlearn.common.attention_bias import (
NEG_INF,
@@ -69,7 +68,14 @@ from axlearn.common.flash_attention.remat import FLASH_ATTN_RESIDUAL_NAME
from axlearn.common.kv_cache.base_kv_cache import BaseKVCache
from axlearn.common.kv_cache.kv_cache import KVCache
from axlearn.common.layers import get_dropout_mask
-from axlearn.common.utils import Nested, Tensor
+from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor
+
+# pylint: disable=ungrouped-imports
+if _JAX_MEMORY_SPACE_SUPPORT:
+ from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams
+else:
+ from jax.experimental.pallas.triton import TritonCompilerParams
+# pylint: disable=ungrouped-imports


def _segment_mask(
diff --git a/axlearn/common/flash_attention/gpu_decoding.py b/axlearn/common/flash_attention/gpu_decoding.py
index a29bdcc..1b5b07b 100644
--- a/axlearn/common/flash_attention/gpu_decoding.py
+++ b/axlearn/common/flash_attention/gpu_decoding.py
@@ -49,7 +49,6 @@ from absl import logging
from jax import lax
from jax._src.cudnn.fused_attention_stablehlo import check_compute_capability
from jax.experimental import pallas as pl
-from jax.experimental.pallas.triton import TritonCompilerParams

from axlearn.common.attention_bias import (
NEG_INF,
@@ -61,7 +60,14 @@ from axlearn.common.attention_bias import (
from axlearn.common.flash_attention.common import BaseSingleStepDecoding, get_gpu_dot_precision
from axlearn.common.kv_cache.base_kv_cache import BaseKVCache
from axlearn.common.kv_cache.kv_cache import KVCache
-from axlearn.common.utils import Nested, Tensor
+from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor
+
+# pylint: disable=ungrouped-imports
+if _JAX_MEMORY_SPACE_SUPPORT:
+ from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams
+else:
+ from jax.experimental.pallas.triton import TritonCompilerParams
+# pylint: enable=ungrouped-imports


# Note: split_k_seq_len must be a multiple of block_k.
diff --git a/axlearn/common/flash_attention/gpu_paged_attention.py b/axlearn/common/flash_attention/gpu_paged_attention.py
index a2600de..9a3dd9d 100644
--- a/axlearn/common/flash_attention/gpu_paged_attention.py
+++ b/axlearn/common/flash_attention/gpu_paged_attention.py
@@ -19,7 +19,6 @@ import jax
import jax.numpy as jnp
from jax import lax
from jax.experimental import pallas as pl
-from jax.experimental.pallas.triton import TritonCompilerParams

from axlearn.common.attention_bias import (
NEG_INF,
@@ -31,7 +30,16 @@ from axlearn.common.attention_bias import (
from axlearn.common.flash_attention.common import BasePagedAttention, get_gpu_dot_precision
from axlearn.common.flash_attention.gpu_decoding import _get_sm_count as get_sm_count
from axlearn.common.kv_cache.base_kv_cache import BaseKVCache
-from axlearn.common.utils import Nested, Tensor
+from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor
+
+# pylint: disable=ungrouped-imports
+if _JAX_MEMORY_SPACE_SUPPORT:
+ from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams
+else:
+ from jax.experimental.pallas.triton import ( # isort: skip
+ TritonCompilerParams,
+ )
+# pylint: enable=ungrouped-imports


def _paged_attention_kernel(
diff --git a/axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py b/axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py
index cbddb27..5358d90 100644
--- a/axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py
+++ b/axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py
@@ -8,9 +8,15 @@ This kernel is a temporary workaround of occasional performance problems with
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
-from jax.experimental.pallas.triton import TritonCompilerParams

-from axlearn.common.utils import Tensor
+from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Tensor
+
+# pylint: disable=ungrouped-imports
+if _JAX_MEMORY_SPACE_SUPPORT:
+ from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams
+else:
+ from jax.experimental.pallas.triton import TritonCompilerParams
+# pylint: enable=ungrouped-imports


def _scatter_pages_kernel(
diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py
index 54e4f0b..c5c1b7b 100644
--- a/axlearn/common/optimizers.py
+++ b/axlearn/common/optimizers.py
@@ -37,7 +37,6 @@ import optax
import typing_extensions
from absl import logging
from jax import numpy as jnp
-from jax._src.sharding_impls import TransferToMemoryKind
from optax._src import numerics

from axlearn.common import flax_struct, schedule
@@ -53,6 +52,8 @@ from axlearn.common.optimizer_base import (
TransformPartitionSpecFn,
)
from axlearn.common.utils import (
+ DEVICE_MEMORY,
+ HOST_MEMORY,
MemoryKind,
Nested,
NestedTensor,
@@ -62,6 +63,7 @@ from axlearn.common.utils import (
expand_vdicts,
flatten_items,
register_per_param_settings,
+ transfer_to_memory_kind,
tree_paths,
vectorized_tree_map,
)
@@ -2072,8 +2074,8 @@ def offload_optimizer(
optimizer: ConfigOr[PartitionedGradientTransformation],
*,
pattern: Union[str, re.Pattern] = ".*",
- offload_src: MemoryKind = "device",
- offload_dst: MemoryKind = "pinned_host",
+ offload_src: MemoryKind = DEVICE_MEMORY,
+ offload_dst: MemoryKind = HOST_MEMORY,
) -> PartitionedGradientTransformation:
"""Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`.

@@ -2145,9 +2147,7 @@ def offload_optimizer(
# released, so we have less memory pressure at that point in time.
return jax.tree.map(
lambda path, tensor: (
- jax.device_put(tensor, TransferToMemoryKind(dst))
- if re.fullmatch(pattern, path)
- else tensor
+ transfer_to_memory_kind(tensor, dst) if re.fullmatch(pattern, path) else tensor
),
tree_paths(state),
state,
diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py
index 6fe4082..761ddee 100644
--- a/axlearn/common/optimizers_test.py
+++ b/axlearn/common/optimizers_test.py
@@ -60,6 +60,7 @@ from axlearn.common.optimizers import (
from axlearn.common.schedule import Schedule, adafactor_decay_rate, decay_bias_correction
from axlearn.common.test_utils import TestCase, assert_allclose
from axlearn.common.utils import (
+ _JAX_MEMORY_SPACE_SUPPORT,
NestedPartitionSpec,
PartitionSpec,
Tensor,
@@ -427,10 +428,17 @@ class OptimizerTest(TestCase):
return loss, compute_loss(updated_params)

if offload:
- self.assertIn(
- "TransferToMemoryKind(memory_kind='pinned_host')",
- str(jax.make_jaxpr(jit_fn)(params, state)),
- )
+ jaxpr_str = str(jax.make_jaxpr(jit_fn)(params, state))
+ if _JAX_MEMORY_SPACE_SUPPORT:
+ self.assertIn(
+ "memory_kind=host",
+ jaxpr_str,
+ )
+ else:
+ self.assertIn(
+ "TransferToMemoryKind(memory_kind='pinned_host')",
+ str(jax.make_jaxpr(jit_fn)(params, state)),
+ )
loss, new_loss = jit_fn(params, state)
self.assertLess(new_loss, loss)

diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py
index 16a5fe3..776340a 100644
--- a/axlearn/common/utils.py
+++ b/axlearn/common/utils.py
@@ -54,6 +54,7 @@ from jax.ad_checkpoint import Offloadable, Recompute, Saveable
from jax.experimental import mesh_utils, multihost_utils
from jax.extend.core import Primitive
from jax.sharding import PartitionSpec
+from packaging import version

from axlearn.common import serialization
from axlearn.common.config import (
@@ -66,6 +67,9 @@ from axlearn.common.config import (
register_validator,
)

+# Define the version of JAX for compatibility on MemKind
+_JAX_MEMORY_SPACE_SUPPORT = version.parse(jax.__version__) >= version.parse("0.7.0")
+
# New code should use Nested[XX] instead of NestedXX.
# Old definitions are provided for backwards compatibility.
_NestedT = TypeVar("_NestedT")
@@ -118,7 +122,23 @@ class HybridMeshShape:
# "pinned_host" = Page locked memory on CPU, which can be address directly by accelerators by
# direct memory access (DMA). For TPU, "pinned_host" memory layout follows TPU device tile
# layout and usually cannot be zero-copy converted to a CPU-tensor.
-MemoryKind = Literal["device", "pinned_host"]
+if _JAX_MEMORY_SPACE_SUPPORT:
+ MemoryKind = [jax.memory.Space.Device, jax.memory.Space.Host]
+ DEVICE_MEMORY = jax.memory.Space.Device
+ HOST_MEMORY = jax.memory.Space.Host
+
+ def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor:
+ return jax.device_put(tensor, memory_kind)
+
+else:
+ from jax._src.sharding_impls import TransferToMemoryKind # pylint: disable=ungrouped-imports
+
+ MemoryKind = Literal["device", "pinned_host"]
+ DEVICE_MEMORY = "device"
+ HOST_MEMORY = "pinned_host"
+
+ def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor:
+ return jax.device_put(tensor, TransferToMemoryKind(memory_kind))


@dataclasses.dataclass
diff --git a/axlearn/experiments/text/gpt/c4_trainer.py b/axlearn/experiments/text/gpt/c4_trainer.py
index b9103cb..b8e8b22 100644
--- a/axlearn/experiments/text/gpt/c4_trainer.py
+++ b/axlearn/experiments/text/gpt/c4_trainer.py
@@ -49,7 +49,7 @@ from axlearn.common.config import (
from axlearn.common.input_lm import lm_text_preprocessor
from axlearn.common.utils import get_data_dir
from axlearn.experiments.text.common import DataMixtureComponent, vocab
-from axlearn.experiments.text.gpt import envy, fuji, gspmd
+from axlearn.experiments.text.gpt import fuji, gspmd
from axlearn.experiments.text.gpt.common import mixture_train_input_source, tfds_input
from axlearn.experiments.text.gpt.vocabulary_fuji_v3 import FujiV3Vocabulary

@@ -109,5 +109,4 @@ def named_trainer_configs() -> dict[str, TrainerConfigFn]:
config_map = {}
config_map.update(fuji.trainer_configs(_train_input_source, _eval_input_sources))
config_map.update(gspmd.trainer_configs(_train_input_source, _eval_input_sources))
- config_map.update(envy.trainer_configs(_train_input_source, _eval_input_sources))
return config_map
Loading