From fe22a9f7054558f87df5e6738f9935dea2bdc8a1 Mon Sep 17 00:00:00 2001 From: vivianrwu Date: Wed, 6 Nov 2024 11:54:52 -0800 Subject: [PATCH] Add model warmup flag into cli (#197) add model warmup flag into cli --- jetstream_pt/cli.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jetstream_pt/cli.py b/jetstream_pt/cli.py index 110d32c..513ceae 100644 --- a/jetstream_pt/cli.py +++ b/jetstream_pt/cli.py @@ -36,6 +36,7 @@ flags.DEFINE_bool( "internal_use_local_tokenizer", 0, "Use local tokenizer if set to True" ) +flags.DEFINE_bool("enable_model_warmup", False, "enable model warmup") def shard_weights(env, weights, weight_shardings): @@ -111,6 +112,7 @@ def serve(): config=server_config, devices=devices, metrics_server_config=metrics_server_config, + enable_model_warmup=FLAGS.enable_model_warmup, ) print("Started jetstream_server....") jetstream_server.wait_for_termination()