diff --git a/jetstream_pt/cli.py b/jetstream_pt/cli.py index 513ceae..cd27046 100644 --- a/jetstream_pt/cli.py +++ b/jetstream_pt/cli.py @@ -55,6 +55,7 @@ def create_engine(devices): """Create Pytorch engine from flags""" torch.set_default_dtype(torch.bfloat16) quant_config = config.create_quantization_config_from_flags() + config.set_jax_compilation_cache_config() env_data = fetch_models.construct_env_data_from_model_id( FLAGS.model_id, FLAGS.override_batch_size, diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index d1e21d9..c1dbce4 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -13,7 +13,9 @@ # limitations under the License. +import os from absl import flags +import jax from jetstream_pt.environment import QuantizationConfig FLAGS = flags.FLAGS @@ -154,17 +156,17 @@ "page size per page", ) flags.DEFINE_string( - "jax_compilation_cache_dir", + "internal_jax_compilation_cache_dir", "~/jax_cache", "Jax compilation cache directory", ) flags.DEFINE_integer( - "jax_persistent_cache_min_entry_size_bytes", + "internal_jax_persistent_cache_min_entry_size_bytes", 0, "Minimum size (in bytes) of an entry that will be cached in the persistent compilation cache", ) flags.DEFINE_integer( - "jax_persistent_cache_min_compile_time_secs", + "internal_jax_persistent_cache_min_compile_time_secs", 1, "Minimum compilation time for a computation to be written to persistent cache", ) @@ -190,3 +192,19 @@ def create_quantization_config_from_flags(): else FLAGS.quantize_weights ) return config + + +def set_jax_compilation_cache_config(): + """Sets the jax compilation cache configuration""" + jax.config.update( + "jax_compilation_cache_dir", + os.path.expanduser(FLAGS.internal_jax_compilation_cache_dir), + ) + jax.config.update( + "jax_persistent_cache_min_entry_size_bytes", + FLAGS.internal_jax_persistent_cache_min_entry_size_bytes, + ) + jax.config.update( + "jax_persistent_cache_min_compile_time_secs", + FLAGS.internal_jax_persistent_cache_min_compile_time_secs, + )