Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get all tests to pass locally with no special configuration #1108

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MaxText/convert_gemma2_chkpt.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@
import orbax

import checkpointing
from train import save_checkpoint
from MaxText.train import save_checkpoint

Params = dict[str, Any]

2 changes: 1 addition & 1 deletion MaxText/convert_gemma_chkpt.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@
import orbax

import checkpointing
from train import save_checkpoint
from MaxText.train import save_checkpoint

Params = dict[str, Any]

4 changes: 2 additions & 2 deletions MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
--run-name=$RUN_NAME \
--base-output-directory=$BASE_OUTPUT_DIR
"""
import max_utils
from MaxText import max_utils
import optimizers
import pyconfig
import os
@@ -50,7 +50,7 @@
import gc
import max_logging
from psutil import Process
from train import save_checkpoint
from MaxText.train import save_checkpoint
import argparse


3 changes: 1 addition & 2 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
@@ -16,8 +16,7 @@

import jax

import max_utils
import maxengine
from MaxText import max_utils, maxengine

import os
import pyconfig
6 changes: 2 additions & 4 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
@@ -25,9 +25,7 @@

import checkpointing
import jax
import max_logging
import max_utils
import optimizers
from MaxText import max_logging, max_utils, optimizers
import pyconfig

from absl import app
@@ -36,7 +34,7 @@
from jax import random
from typing import Sequence
from layers import models, quantizations
from train import save_checkpoint
from MaxText.train import save_checkpoint

Transformer = models.Transformer

2 changes: 1 addition & 1 deletion MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@

from jetstream.engine import token_utils

import max_utils
from MaxText import max_utils
import maxengine
import maxtext_utils
import profiler
8 changes: 4 additions & 4 deletions MaxText/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
@@ -24,10 +24,10 @@
import tensorflow_datasets as tfds
import jax

import multihost_dataloading
import tokenizer
import sequence_packing
from input_pipeline import _input_pipeline_utils
from MaxText import multihost_dataloading
from MaxText import tokenizer
from MaxText import sequence_packing
from MaxText.input_pipeline import _input_pipeline_utils

AUTOTUNE = tf.data.experimental.AUTOTUNE

2 changes: 1 addition & 1 deletion MaxText/kernels/ragged_attention.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
import common_types
from MaxText import common_types

from jax.experimental import shard_map

14 changes: 7 additions & 7 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
@@ -29,13 +29,13 @@
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
import jax.numpy as jnp
import common_types
from kernels.ragged_attention import ragged_gqa
from kernels.ragged_attention import ragged_mha
from layers import embeddings
from layers import initializers
from layers import linears
from layers import quantizations
from MaxText import common_types
from MaxText.kernels.ragged_attention import ragged_gqa
from MaxText.kernels.ragged_attention import ragged_mha
from MaxText.layers import embeddings
from MaxText.layers import initializers
from MaxText.layers import linears
from MaxText.layers import quantizations


# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
2 changes: 1 addition & 1 deletion MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
import jax
from jax import lax
import jax.numpy as jnp
from layers import initializers
from MaxText.layers import initializers

Config = Any
Array = jnp.ndarray
2 changes: 1 addition & 1 deletion MaxText/layers/initializers.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

from flax import linen as nn
import jax
import common_types
from MaxText import common_types

Array = common_types.Array
DType = common_types.DType
14 changes: 7 additions & 7 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
@@ -23,18 +23,18 @@
import jax
from jax import lax
import jax.numpy as jnp
import common_types
from layers import initializers
from layers import normalizations
from layers import quantizations
from MaxText import common_types
from MaxText.layers import initializers
from MaxText.layers import normalizations
from MaxText.layers import quantizations
import numpy as np
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import shard_map
import math
import max_logging
import max_utils
from MaxText import max_logging
from MaxText import max_utils
from aqt.jax.v2 import aqt_tensor
from kernels import megablox as mblx
from MaxText.kernels import megablox as mblx


Array = common_types.Array
16 changes: 8 additions & 8 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
@@ -23,14 +23,14 @@
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
# from jax.experimental.pallas.ops.tpu import flash_attention
from layers import attentions
from layers import embeddings
from layers import linears
from layers import normalizations
from layers import models
from layers import quantizations

import common_types
from MaxText.layers import attentions
from MaxText.layers import embeddings
from MaxText.layers import linears
from MaxText.layers import normalizations
from MaxText.layers import models
from MaxText.layers import quantizations

from MaxText import common_types
from typing import Optional

Array = common_types.Array
12 changes: 6 additions & 6 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
@@ -24,12 +24,12 @@
import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
import common_types
from layers import attentions
from layers import embeddings
from layers import linears
from layers import normalizations, quantizations
from layers import pipeline
from MaxText import common_types
from MaxText.layers import attentions
from MaxText.layers import embeddings
from MaxText.layers import linears
from MaxText.layers import normalizations, quantizations
from MaxText.layers import pipeline

Array = common_types.Array
Config = common_types.Config
2 changes: 1 addition & 1 deletion MaxText/layers/normalizations.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
from flax import linen as nn
from jax import lax
import jax.numpy as jnp
from layers import initializers
from MaxText.layers import initializers

Initializer = initializers.Initializer

2 changes: 1 addition & 1 deletion MaxText/layers/pipeline.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
from jax import numpy as jnp
from flax.core import meta
from flax import linen as nn
import common_types
from MaxText import common_types
import functools
from typing import Any

2 changes: 1 addition & 1 deletion MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
from aqt.jax.v2.flax import aqt_flax
from aqt.jax.v2 import tiled_dot_general
from aqt.jax.v2 import calibration
import common_types
from MaxText import common_types
from dataclasses import dataclass
from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
2 changes: 1 addition & 1 deletion MaxText/llama_ckpt_conversion_inference_only.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@
import jax
from flax.training import train_state
import max_logging
from train import save_checkpoint
from MaxText.train import save_checkpoint
import torch
import sys
import os
4 changes: 2 additions & 2 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
@@ -38,14 +38,14 @@
from absl import app
import numpy as np
import pyconfig
import max_utils
from MaxText import max_utils
from jax.sharding import Mesh
import max_logging
import checkpointing
from generate_param_only_checkpoint import _read_train_checkpoint
import llama_or_mistral_ckpt
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig
from max_utils import unpermute_from_match_maxtext_rope
from MaxText.max_utils import unpermute_from_match_maxtext_rope


def reverse_scale(arr, scale):
4 changes: 2 additions & 2 deletions MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
@@ -49,10 +49,10 @@
from tqdm import tqdm

import max_logging
from train import save_checkpoint
from MaxText.train import save_checkpoint
import checkpointing
from safetensors import safe_open
import max_utils
from MaxText import max_utils

MODEL_PARAMS_DICT = {
"llama2-70b": {
3 changes: 1 addition & 2 deletions MaxText/load_and_quantize_checkpoint.py
Original file line number Diff line number Diff line change
@@ -16,8 +16,7 @@

import jax

import max_utils
import maxengine
from MaxText import max_utils, maxengine

import os
import pyconfig
3 changes: 1 addition & 2 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
@@ -20,8 +20,7 @@
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
import checkpointing
import common_types
from MaxText import checkpointing, common_types
import functools
import time
import optax
7 changes: 3 additions & 4 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
@@ -22,22 +22,21 @@
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning

from layers import models, quantizations
from MaxText.layers import models, quantizations

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.experimental import layout as jax_layout

import common_types
from MaxText import common_types
from jetstream.core import config_lib
from jetstream.engine import engine_api
from jetstream.engine import tokenizer_pb2
from jetstream.engine import tokenizer_api
from jetstream.engine import token_utils

import max_utils
import inference_utils
from MaxText import max_utils, inference_utils
import pyconfig

import warnings
2 changes: 1 addition & 1 deletion MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@

import jax
import optax
import max_utils
from MaxText import max_utils
from jax.sharding import PartitionSpec as P
from jax.experimental.serialize_executable import deserialize_and_load

3 changes: 1 addition & 2 deletions MaxText/metric_logger.py
Original file line number Diff line number Diff line change
@@ -22,8 +22,7 @@
import os
import numpy as np

import max_logging
import max_utils
from MaxText import max_utils, max_logging


def _prepare_metrics_for_json(metrics, step, run_name):
3 changes: 1 addition & 2 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
@@ -26,8 +26,7 @@
from jax.experimental.compilation_cache import compilation_cache
from layers.attentions import AttentionType
import accelerator_to_spec_map
import max_logging
import max_utils
from MaxText import max_utils, max_logging
import omegaconf

OmegaConf = omegaconf.OmegaConf
7 changes: 3 additions & 4 deletions MaxText/standalone_checkpointer.py
Original file line number Diff line number Diff line change
@@ -30,11 +30,10 @@
from jax import numpy as jnp
import numpy as np

import checkpointing
import max_utils
import max_logging
from MaxText import checkpointing
from MaxText import max_utils, max_logging
import pyconfig
from train import setup_mesh_and_model, get_first_step, validate_train_config, save_checkpoint
from MaxText.train import setup_mesh_and_model, get_first_step, validate_train_config, save_checkpoint

from layers import models

2 changes: 1 addition & 1 deletion MaxText/standalone_dataloader.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
import numpy as np

import pyconfig
from train import validate_train_config, get_first_step, load_next_batch, setup_train_loop
from MaxText.train import validate_train_config, get_first_step, load_next_batch, setup_train_loop


def data_load_loop(config, state=None):
6 changes: 3 additions & 3 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
@@ -20,17 +20,17 @@
import unittest
from absl.testing import parameterized

import common_types

from MaxText import common_types, max_utils
from flax.core import freeze
import jax
import jax.numpy as jnp

import max_utils
import pytest

import pyconfig

from layers import attentions
from MaxText.layers import attentions

Mesh = jax.sharding.Mesh
Attention = attentions.Attention
Loading