diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py index 861d8d6d73..5c99bf2908 100644 --- a/tools/run_text_generation_server.py +++ b/tools/run_text_generation_server.py @@ -122,6 +122,8 @@ def add_text_generate_args(parser): assert len(model) == 1, "Above condition should have caught this" model = model[0] + model.eval() + if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: server = MegatronServer(model) server.run("0.0.0.0",port=args.port)