Skip to content

Commit

Permalink
Add jax compilation cache config (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivianrwu authored Nov 7, 2024
1 parent fe22a9f commit a12698d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
1 change: 1 addition & 0 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def create_engine(devices):
"""Create Pytorch engine from flags"""
torch.set_default_dtype(torch.bfloat16)
quant_config = config.create_quantization_config_from_flags()
config.set_jax_compilation_cache_config()
env_data = fetch_models.construct_env_data_from_model_id(
FLAGS.model_id,
FLAGS.override_batch_size,
Expand Down
24 changes: 21 additions & 3 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.


import os
from absl import flags
import jax
from jetstream_pt.environment import QuantizationConfig

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -154,17 +156,17 @@
"page size per page",
)
flags.DEFINE_string(
"jax_compilation_cache_dir",
"internal_jax_compilation_cache_dir",
"~/jax_cache",
"Jax compilation cache directory",
)
flags.DEFINE_integer(
"jax_persistent_cache_min_entry_size_bytes",
"internal_jax_persistent_cache_min_entry_size_bytes",
0,
"Minimum size (in bytes) of an entry that will be cached in the persistent compilation cache",
)
flags.DEFINE_integer(
"jax_persistent_cache_min_compile_time_secs",
"internal_jax_persistent_cache_min_compile_time_secs",
1,
"Minimum compilation time for a computation to be written to persistent cache",
)
Expand All @@ -190,3 +192,19 @@ def create_quantization_config_from_flags():
else FLAGS.quantize_weights
)
return config


def set_jax_compilation_cache_config():
"""Sets the jax compilation cache configuration"""
jax.config.update(
"jax_compilation_cache_dir",
os.path.expanduser(FLAGS.internal_jax_compilation_cache_dir),
)
jax.config.update(
"jax_persistent_cache_min_entry_size_bytes",
FLAGS.internal_jax_persistent_cache_min_entry_size_bytes,
)
jax.config.update(
"jax_persistent_cache_min_compile_time_secs",
FLAGS.internal_jax_persistent_cache_min_compile_time_secs,
)

0 comments on commit a12698d

Please sign in to comment.