diff --git a/MaxText/convert_gemma2_chkpt.py b/MaxText/convert_gemma2_chkpt.py index 33ddf73e2..a89746694 100644 --- a/MaxText/convert_gemma2_chkpt.py +++ b/MaxText/convert_gemma2_chkpt.py @@ -33,7 +33,7 @@ import orbax import checkpointing -from train import save_checkpoint +from MaxText.train import save_checkpoint Params = dict[str, Any] diff --git a/MaxText/convert_gemma_chkpt.py b/MaxText/convert_gemma_chkpt.py index 38881ac43..34f95fc2c 100644 --- a/MaxText/convert_gemma_chkpt.py +++ b/MaxText/convert_gemma_chkpt.py @@ -33,7 +33,7 @@ import orbax import checkpointing -from train import save_checkpoint +from MaxText.train import save_checkpoint Params = dict[str, Any] diff --git a/MaxText/convert_gpt3_ckpt_from_paxml.py b/MaxText/convert_gpt3_ckpt_from_paxml.py index 4363874be..b9423dd09 100644 --- a/MaxText/convert_gpt3_ckpt_from_paxml.py +++ b/MaxText/convert_gpt3_ckpt_from_paxml.py @@ -32,7 +32,7 @@ --run-name=$RUN_NAME \ --base-output-directory=$BASE_OUTPUT_DIR """ -import max_utils +from MaxText import max_utils import optimizers import pyconfig import os @@ -50,7 +50,7 @@ import gc import max_logging from psutil import Process -from train import save_checkpoint +from MaxText.train import save_checkpoint import argparse diff --git a/MaxText/decode.py b/MaxText/decode.py index 8b114d16d..6dc70537f 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -16,8 +16,7 @@ import jax -import max_utils -import maxengine +from MaxText import max_utils, maxengine import os import pyconfig diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index b73f94be7..fbf9d34b2 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -25,9 +25,7 @@ import checkpointing import jax -import max_logging -import max_utils -import optimizers +from MaxText import max_logging, max_utils, optimizers import pyconfig from absl import app @@ -36,7 +34,7 @@ from jax import random from typing import Sequence from layers import models, quantizations -from train import save_checkpoint +from MaxText.train import save_checkpoint Transformer = models.Transformer diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index a30da7e7e..6c77b1302 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -24,7 +24,7 @@ from jetstream.engine import token_utils -import max_utils +from MaxText import max_utils import maxengine import maxtext_utils import profiler diff --git a/MaxText/input_pipeline/_tfds_data_processing.py b/MaxText/input_pipeline/_tfds_data_processing.py index a39629ade..54bf66e81 100644 --- a/MaxText/input_pipeline/_tfds_data_processing.py +++ b/MaxText/input_pipeline/_tfds_data_processing.py @@ -24,10 +24,10 @@ import tensorflow_datasets as tfds import jax -import multihost_dataloading -import tokenizer -import sequence_packing -from input_pipeline import _input_pipeline_utils +from MaxText import multihost_dataloading +from MaxText import tokenizer +from MaxText import sequence_packing +from MaxText.input_pipeline import _input_pipeline_utils AUTOTUNE = tf.data.experimental.AUTOTUNE diff --git a/MaxText/kernels/ragged_attention.py b/MaxText/kernels/ragged_attention.py index 8ddeb7214..381f416dd 100644 --- a/MaxText/kernels/ragged_attention.py +++ b/MaxText/kernels/ragged_attention.py @@ -24,7 +24,7 @@ from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np -import common_types +from MaxText import common_types from jax.experimental import shard_map diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 560825644..dacfad0ee 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -29,13 +29,13 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask import jax.numpy as jnp -import common_types -from kernels.ragged_attention import ragged_gqa -from kernels.ragged_attention import ragged_mha -from layers import embeddings -from layers import initializers -from layers import linears -from layers import quantizations +from MaxText import common_types +from MaxText.kernels.ragged_attention import ragged_gqa +from MaxText.kernels.ragged_attention import ragged_mha +from MaxText.layers import embeddings +from MaxText.layers import initializers +from MaxText.layers import linears +from MaxText.layers import quantizations # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 1793e8cb3..fa34812ba 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -21,7 +21,7 @@ import jax from jax import lax import jax.numpy as jnp -from layers import initializers +from MaxText.layers import initializers Config = Any Array = jnp.ndarray diff --git a/MaxText/layers/initializers.py b/MaxText/layers/initializers.py index 5916ecb0c..bf915e757 100644 --- a/MaxText/layers/initializers.py +++ b/MaxText/layers/initializers.py @@ -18,7 +18,7 @@ from flax import linen as nn import jax -import common_types +from MaxText import common_types Array = common_types.Array DType = common_types.DType diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 27b9fb032..0f9e9fa0d 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -23,18 +23,18 @@ import jax from jax import lax import jax.numpy as jnp -import common_types -from layers import initializers -from layers import normalizations -from layers import quantizations +from MaxText import common_types +from MaxText.layers import initializers +from MaxText.layers import normalizations +from MaxText.layers import quantizations import numpy as np from jax.ad_checkpoint import checkpoint_name from jax.experimental import shard_map import math -import max_logging -import max_utils +from MaxText import max_logging +from MaxText import max_utils from aqt.jax.v2 import aqt_tensor -from kernels import megablox as mblx +from MaxText.kernels import megablox as mblx Array = common_types.Array diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 9769edace..edc6c6c07 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -23,14 +23,14 @@ import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name # from jax.experimental.pallas.ops.tpu import flash_attention -from layers import attentions -from layers import embeddings -from layers import linears -from layers import normalizations -from layers import models -from layers import quantizations - -import common_types +from MaxText.layers import attentions +from MaxText.layers import embeddings +from MaxText.layers import linears +from MaxText.layers import normalizations +from MaxText.layers import models +from MaxText.layers import quantizations + +from MaxText import common_types from typing import Optional Array = common_types.Array diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 8f423e712..1f35b75a7 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -24,12 +24,12 @@ import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name -import common_types -from layers import attentions -from layers import embeddings -from layers import linears -from layers import normalizations, quantizations -from layers import pipeline +from MaxText import common_types +from MaxText.layers import attentions +from MaxText.layers import embeddings +from MaxText.layers import linears +from MaxText.layers import normalizations, quantizations +from MaxText.layers import pipeline Array = common_types.Array Config = common_types.Config diff --git a/MaxText/layers/normalizations.py b/MaxText/layers/normalizations.py index fcb8bf0e5..92e541a82 100644 --- a/MaxText/layers/normalizations.py +++ b/MaxText/layers/normalizations.py @@ -19,7 +19,7 @@ from flax import linen as nn from jax import lax import jax.numpy as jnp -from layers import initializers +from MaxText.layers import initializers Initializer = initializers.Initializer diff --git a/MaxText/layers/pipeline.py b/MaxText/layers/pipeline.py index 577340d4c..3cf55e328 100644 --- a/MaxText/layers/pipeline.py +++ b/MaxText/layers/pipeline.py @@ -20,7 +20,7 @@ from jax import numpy as jnp from flax.core import meta from flax import linen as nn -import common_types +from MaxText import common_types import functools from typing import Any diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 7bff6e6e4..ab355e235 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -23,7 +23,7 @@ from aqt.jax.v2.flax import aqt_flax from aqt.jax.v2 import tiled_dot_general from aqt.jax.v2 import calibration -import common_types +from MaxText import common_types from dataclasses import dataclass from flax.linen import fp8_ops from flax.linen import initializers as flax_initializers diff --git a/MaxText/llama_ckpt_conversion_inference_only.py b/MaxText/llama_ckpt_conversion_inference_only.py index 19334c4fb..a76e95219 100644 --- a/MaxText/llama_ckpt_conversion_inference_only.py +++ b/MaxText/llama_ckpt_conversion_inference_only.py @@ -37,7 +37,7 @@ import jax from flax.training import train_state import max_logging -from train import save_checkpoint +from MaxText.train import save_checkpoint import torch import sys import os diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index cc1844114..3c2fa2cce 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -38,14 +38,14 @@ from absl import app import numpy as np import pyconfig -import max_utils +from MaxText import max_utils from jax.sharding import Mesh import max_logging import checkpointing from generate_param_only_checkpoint import _read_train_checkpoint import llama_or_mistral_ckpt from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig -from max_utils import unpermute_from_match_maxtext_rope +from MaxText.max_utils import unpermute_from_match_maxtext_rope def reverse_scale(arr, scale): diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index f2a40d3b0..5648855e0 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -49,10 +49,10 @@ from tqdm import tqdm import max_logging -from train import save_checkpoint +from MaxText.train import save_checkpoint import checkpointing from safetensors import safe_open -import max_utils +from MaxText import max_utils MODEL_PARAMS_DICT = { "llama2-70b": { diff --git a/MaxText/load_and_quantize_checkpoint.py b/MaxText/load_and_quantize_checkpoint.py index e0ed9dc52..739442326 100644 --- a/MaxText/load_and_quantize_checkpoint.py +++ b/MaxText/load_and_quantize_checkpoint.py @@ -16,8 +16,7 @@ import jax -import max_utils -import maxengine +from MaxText import max_utils, maxengine import os import pyconfig diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index d8a538ac9..f2c8a613e 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -20,8 +20,7 @@ import jax import jax.numpy as jnp from jax.experimental import mesh_utils -import checkpointing -import common_types +from MaxText import checkpointing, common_types import functools import time import optax diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index f7702737f..bb7e84eaa 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -22,22 +22,21 @@ from flax import linen as nn from flax.linen import partitioning as nn_partitioning -from layers import models, quantizations +from MaxText.layers import models, quantizations import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P from jax.experimental import layout as jax_layout -import common_types +from MaxText import common_types from jetstream.core import config_lib from jetstream.engine import engine_api from jetstream.engine import tokenizer_pb2 from jetstream.engine import tokenizer_api from jetstream.engine import token_utils -import max_utils -import inference_utils +from MaxText import max_utils, inference_utils import pyconfig import warnings diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index c748ea981..45a18eb8b 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -19,7 +19,7 @@ import jax import optax -import max_utils +from MaxText import max_utils from jax.sharding import PartitionSpec as P from jax.experimental.serialize_executable import deserialize_and_load diff --git a/MaxText/metric_logger.py b/MaxText/metric_logger.py index 9be1e5eb7..9b1f425bc 100644 --- a/MaxText/metric_logger.py +++ b/MaxText/metric_logger.py @@ -22,8 +22,7 @@ import os import numpy as np -import max_logging -import max_utils +from MaxText import max_utils, max_logging def _prepare_metrics_for_json(metrics, step, run_name): diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5c0d5c007..5c877b97f 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -26,8 +26,7 @@ from jax.experimental.compilation_cache import compilation_cache from layers.attentions import AttentionType import accelerator_to_spec_map -import max_logging -import max_utils +from MaxText import max_utils, max_logging import omegaconf OmegaConf = omegaconf.OmegaConf diff --git a/MaxText/standalone_checkpointer.py b/MaxText/standalone_checkpointer.py index 4759cb50a..0441934fe 100644 --- a/MaxText/standalone_checkpointer.py +++ b/MaxText/standalone_checkpointer.py @@ -30,11 +30,10 @@ from jax import numpy as jnp import numpy as np -import checkpointing -import max_utils -import max_logging +from MaxText import checkpointing +from MaxText import max_utils, max_logging import pyconfig -from train import setup_mesh_and_model, get_first_step, validate_train_config, save_checkpoint +from MaxText.train import setup_mesh_and_model, get_first_step, validate_train_config, save_checkpoint from layers import models diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index a0c0f558c..c7d0ae461 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -27,7 +27,7 @@ import numpy as np import pyconfig -from train import validate_train_config, get_first_step, load_next_batch, setup_train_loop +from MaxText.train import validate_train_config, get_first_step, load_next_batch, setup_train_loop def data_load_loop(config, state=None): diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index 5938ac219..671149278 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -20,17 +20,17 @@ import unittest from absl.testing import parameterized -import common_types - +from MaxText import common_types, max_utils from flax.core import freeze import jax import jax.numpy as jnp + import max_utils import pytest import pyconfig -from layers import attentions +from MaxText.layers import attentions Mesh = jax.sharding.Mesh Attention = attentions.Attention diff --git a/MaxText/tests/forward_pass_logit_checker.py b/MaxText/tests/forward_pass_logit_checker.py index d41b934f1..69bc4d7df 100644 --- a/MaxText/tests/forward_pass_logit_checker.py +++ b/MaxText/tests/forward_pass_logit_checker.py @@ -47,7 +47,7 @@ import numpy as np import pyconfig import jsonlines -import max_utils +from MaxText import max_utils from layers import models from layers import quantizations diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index d0f87457a..a73a0ec41 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -18,11 +18,9 @@ import sys import jax import unittest -import max_utils +from MaxText import max_utils from jax.sharding import Mesh -from layers import models -from layers import embeddings -from layers import quantizations +from MaxText.layers import models, embeddings, quantizations import jax.numpy as jnp diff --git a/MaxText/tests/gradient_accumulation_test.py b/MaxText/tests/gradient_accumulation_test.py index 29c0ab087..9152213e0 100644 --- a/MaxText/tests/gradient_accumulation_test.py +++ b/MaxText/tests/gradient_accumulation_test.py @@ -18,7 +18,7 @@ import pytest import string import random -from train import main as train_main +from MaxText.train import main as train_main def generate_random_string(length=10): diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index 54b516bca..ca5c83a11 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -23,8 +23,8 @@ import unittest import pyconfig -from input_pipeline import _grain_data_processing -from input_pipeline import input_pipeline_interface +from MaxText.input_pipeline import _grain_data_processing +from MaxText.input_pipeline import input_pipeline_interface class GrainDataProcessingTest(unittest.TestCase): diff --git a/MaxText/tests/hf_checkpoint_conversion_test.py b/MaxText/tests/hf_checkpoint_conversion_test.py index 493096199..8a562b800 100644 --- a/MaxText/tests/hf_checkpoint_conversion_test.py +++ b/MaxText/tests/hf_checkpoint_conversion_test.py @@ -17,7 +17,7 @@ """ Tests for kernels """ import numpy as np -from max_utils import permute_to_match_maxtext_rope, unpermute_from_match_maxtext_rope +from MaxText.max_utils import permute_to_match_maxtext_rope, unpermute_from_match_maxtext_rope import unittest diff --git a/MaxText/tests/hf_data_processing_test.py b/MaxText/tests/hf_data_processing_test.py index 55a80ee46..7d0dc2a28 100644 --- a/MaxText/tests/hf_data_processing_test.py +++ b/MaxText/tests/hf_data_processing_test.py @@ -22,6 +22,8 @@ import unittest import pyconfig +from MaxText.input_pipeline import _hf_data_processing +from MaxText.input_pipeline import input_pipeline_interface import pytest from input_pipeline import _hf_data_processing from input_pipeline import input_pipeline_interface diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index 139003993..e0fd6240a 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -18,7 +18,7 @@ import pytest import unittest from absl.testing import absltest -from inference_microbenchmark import run_benchmarks +from MaxText.inference_microbenchmark import run_benchmarks class Inference_Microbenchmark(unittest.TestCase): diff --git a/MaxText/tests/kernels_test.py b/MaxText/tests/kernels_test.py index 3c73ca10d..e3ad539d8 100644 --- a/MaxText/tests/kernels_test.py +++ b/MaxText/tests/kernels_test.py @@ -21,7 +21,7 @@ import unittest import jax import jax.numpy as jnp -from kernels.ragged_attention import ragged_mqa, reference_mqa, ragged_mha, reference_mha, ragged_gqa, reference_gqa +from MaxText.kernels.ragged_attention import ragged_mqa, reference_mqa, ragged_mha, reference_mha, ragged_gqa, reference_gqa class RaggedAttentionTest(unittest.TestCase): diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index 9c9fa08b4..dca535708 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -19,7 +19,7 @@ import unittest import jax.numpy as jnp from typing import Tuple -from layers.llama2 import embeddings +from MaxText.layers.llama2 import embeddings import numpy as np diff --git a/MaxText/tests/max_utils_test.py b/MaxText/tests/max_utils_test.py index df0770563..5d8dc1d5e 100644 --- a/MaxText/tests/max_utils_test.py +++ b/MaxText/tests/max_utils_test.py @@ -16,7 +16,7 @@ """ Tests for the common Max Utils """ import jax -import max_utils +from MaxText import max_utils from flax import linen as nn from flax.training import train_state from jax import numpy as jnp @@ -25,8 +25,8 @@ import optax import pyconfig import unittest -from layers import models -from layers import quantizations +from MaxText.layers import models +from MaxText.layers import quantizations Transformer = models.Transformer diff --git a/MaxText/tests/maxengine_test.py b/MaxText/tests/maxengine_test.py index aad9cbeed..17a741482 100644 --- a/MaxText/tests/maxengine_test.py +++ b/MaxText/tests/maxengine_test.py @@ -24,6 +24,7 @@ import numpy as np import unittest import pyconfig +from MaxText.maxengine import MaxEngine import max_utils from maxengine import MaxEngine from layers import quantizations diff --git a/MaxText/tests/maxtext_utils_test.py b/MaxText/tests/maxtext_utils_test.py index 9362a5326..03e45d7c0 100644 --- a/MaxText/tests/maxtext_utils_test.py +++ b/MaxText/tests/maxtext_utils_test.py @@ -18,7 +18,7 @@ import unittest import jax.numpy as jnp -import maxtext_utils +from MaxText import maxtext_utils class TestGradientClipping(unittest.TestCase): diff --git a/MaxText/tests/model_test.py b/MaxText/tests/model_test.py index 053ae8665..3c7591a0c 100644 --- a/MaxText/tests/model_test.py +++ b/MaxText/tests/model_test.py @@ -15,12 +15,12 @@ import sys import unittest -import common_types +from MaxText import common_types from flax.core import freeze import jax import jax.numpy as jnp -import max_utils +from MaxText import max_utils import numpy as np import pytest diff --git a/MaxText/tests/moe_test.py b/MaxText/tests/moe_test.py index e146671eb..60290f7d5 100644 --- a/MaxText/tests/moe_test.py +++ b/MaxText/tests/moe_test.py @@ -20,7 +20,7 @@ import jax.numpy as jnp import pyconfig -import max_utils +from MaxText import max_utils from jax.sharding import Mesh import flax.linen as nn from typing import Tuple diff --git a/MaxText/tests/multihost_dataloading_test.py b/MaxText/tests/multihost_dataloading_test.py index 78b22d812..d6fb492c7 100644 --- a/MaxText/tests/multihost_dataloading_test.py +++ b/MaxText/tests/multihost_dataloading_test.py @@ -27,7 +27,7 @@ import pytest import pyconfig -import multihost_dataloading +from MaxText import multihost_dataloading class MultihostDataloadingTest(unittest.TestCase): diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index d76052632..f83c2967d 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -24,20 +24,20 @@ import pyconfig -from layers import pipeline +from MaxText.layers import pipeline import jax from jax import numpy as jnp from jax.sharding import Mesh -import common_types +from MaxText import common_types import pyconfig -import max_utils +from MaxText import max_utils from flax.core import meta import jax.numpy as jnp from flax import linen as nn from layers import simple_layer -from train import main as train_main +from MaxText.train import main as train_main def assert_same_output_and_grad(f1, f2, *inputs): diff --git a/MaxText/tests/quantizations_test.py b/MaxText/tests/quantizations_test.py index 1d24e420a..78b23bf8d 100644 --- a/MaxText/tests/quantizations_test.py +++ b/MaxText/tests/quantizations_test.py @@ -21,7 +21,7 @@ import functools import numpy as np import pyconfig -from layers import quantizations +from MaxText.layers import quantizations import unittest from aqt.jax.v2 import aqt_tensor from aqt.jax.v2 import calibration diff --git a/MaxText/tests/simple_decoder_layer_test.py b/MaxText/tests/simple_decoder_layer_test.py index ba6fa7c3c..a4d18d19d 100644 --- a/MaxText/tests/simple_decoder_layer_test.py +++ b/MaxText/tests/simple_decoder_layer_test.py @@ -13,7 +13,7 @@ import unittest import pytest -from train import main as train_main +from MaxText.train import main as train_main class SimpleDecoderLayerTest(unittest.TestCase): diff --git a/MaxText/tests/standalone_dl_ckpt_test.py b/MaxText/tests/standalone_dl_ckpt_test.py index 652b652e1..cbb8d955f 100644 --- a/MaxText/tests/standalone_dl_ckpt_test.py +++ b/MaxText/tests/standalone_dl_ckpt_test.py @@ -17,8 +17,8 @@ """ Tests for the standalone_checkpointer.py """ import unittest import pytest -from standalone_checkpointer import main as sckpt_main -from standalone_dataloader import main as sdl_main +from MaxText.standalone_checkpointer import main as sckpt_main +from MaxText.standalone_dataloader import main as sdl_main from datetime import datetime import random import string diff --git a/MaxText/tests/tfds_data_processing_test.py b/MaxText/tests/tfds_data_processing_test.py index 998bacee3..0f204987e 100644 --- a/MaxText/tests/tfds_data_processing_test.py +++ b/MaxText/tests/tfds_data_processing_test.py @@ -26,8 +26,8 @@ import tensorflow_datasets as tfds import pyconfig -from input_pipeline import _tfds_data_processing -from input_pipeline import input_pipeline_interface +from MaxText.input_pipeline import _tfds_data_processing +from MaxText.input_pipeline import input_pipeline_interface class TfdsDataProcessingTest(unittest.TestCase): diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index 3a2030674..12b26e113 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -18,8 +18,8 @@ """ import numpy as np -import train_tokenizer -from input_pipeline import _input_pipeline_utils +from MaxText import train_tokenizer +from MaxText.input_pipeline import _input_pipeline_utils import unittest import pytest import tensorflow_datasets as tfds diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 80139be45..82cd00341 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -17,8 +17,7 @@ """ Tests for the common Max Utils """ import unittest import pytest -from train_compile import main as train_compile_main -from train import main as train_main +from MaxText.train_compile import main as train_compile_main class TrainCompile(unittest.TestCase): diff --git a/MaxText/tests/train_gpu_smoke_test.py b/MaxText/tests/train_gpu_smoke_test.py index d54831c78..cc92ed9f5 100644 --- a/MaxText/tests/train_gpu_smoke_test.py +++ b/MaxText/tests/train_gpu_smoke_test.py @@ -17,7 +17,7 @@ import os import unittest from absl.testing import absltest -from train import main as train_main +from MaxText.train import main as train_main class Train(unittest.TestCase): diff --git a/MaxText/tests/train_int8_smoke_test.py b/MaxText/tests/train_int8_smoke_test.py index 3bc4c31e7..7eab6c08c 100644 --- a/MaxText/tests/train_int8_smoke_test.py +++ b/MaxText/tests/train_int8_smoke_test.py @@ -17,7 +17,7 @@ """Smoke test for int8""" import os import unittest -from train import main as train_main +from MaxText.train import main as train_main from absl.testing import absltest diff --git a/MaxText/tests/train_smoke_test.py b/MaxText/tests/train_smoke_test.py index 74da43509..65d4e3f56 100644 --- a/MaxText/tests/train_smoke_test.py +++ b/MaxText/tests/train_smoke_test.py @@ -17,7 +17,7 @@ """ Smoke test """ import os import unittest -from train import main as train_main +from MaxText.train import main as train_main from absl.testing import absltest diff --git a/MaxText/tests/train_tests.py b/MaxText/tests/train_tests.py index 8e70d8de1..f68687407 100644 --- a/MaxText/tests/train_tests.py +++ b/MaxText/tests/train_tests.py @@ -18,7 +18,7 @@ import os import unittest import pytest -from train import main as train_main +from MaxText.train import main as train_main from absl.testing import absltest diff --git a/MaxText/tests/train_using_ragged_dot_smoke_test.py b/MaxText/tests/train_using_ragged_dot_smoke_test.py index 7f2d6457c..e2ab3a7e9 100644 --- a/MaxText/tests/train_using_ragged_dot_smoke_test.py +++ b/MaxText/tests/train_using_ragged_dot_smoke_test.py @@ -18,7 +18,7 @@ import unittest from absl.testing import absltest -from train import main as train_main +from MaxText.train import main as train_main class Train(unittest.TestCase): diff --git a/MaxText/tests/weight_dtypes_test.py b/MaxText/tests/weight_dtypes_test.py index ac4525e7c..a285f29ae 100644 --- a/MaxText/tests/weight_dtypes_test.py +++ b/MaxText/tests/weight_dtypes_test.py @@ -20,10 +20,9 @@ import pyconfig -import optimizers -from layers import models -from layers import quantizations -import max_utils +from MaxText import optimizers +from MaxText.layers import models, quantizations +from MaxText import max_utils import jax from jax.sharding import Mesh import jax.numpy as jnp diff --git a/MaxText/train.py b/MaxText/train.py index 58898329f..5a2a1da05 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -39,11 +39,7 @@ import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager import checkpointing -import max_utils -import maxtext_utils -import max_logging -import optimizers -import profiler +from MaxText import max_utils, maxtext_utils, max_logging, optimizers, profiler import pyconfig import pathwaysutils # pylint: disable=unused-import import tensorflow as tf diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index ee9c3d627..f9cb9dd7b 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -28,12 +28,9 @@ from jax.sharding import Mesh from jax.experimental.serialize_executable import serialize from flax.linen import partitioning as nn_partitioning -import maxtext_utils -import optimizers -import max_utils +from MaxText import maxtext_utils, optimizers, max_utils import pyconfig -from layers import models -from layers import quantizations +from MaxText.layers import models, quantizations from typing import Sequence from absl import app import os diff --git a/MaxText/vertex_tensorboard.py b/MaxText/vertex_tensorboard.py index 35c8ecc5e..f48481865 100644 --- a/MaxText/vertex_tensorboard.py +++ b/MaxText/vertex_tensorboard.py @@ -20,8 +20,7 @@ import jax -import max_logging -import max_utils +from MaxText import max_logging, max_utils from cloud_accelerator_diagnostics import tensorboard from cloud_accelerator_diagnostics import uploader diff --git a/benchmarks/mmlu/mmlu_eval.py b/benchmarks/mmlu/mmlu_eval.py index 257608289..6fab164b3 100644 --- a/benchmarks/mmlu/mmlu_eval.py +++ b/benchmarks/mmlu/mmlu_eval.py @@ -45,9 +45,7 @@ from absl import flags import datasets import jax -import max_logging -import max_utils -import maxengine +from MaxText import max_logging, max_utils, maxengine from mmlu_categories import categories from mmlu_categories import subcategories import pyconfig