From 23e5e50fd5fba7f315e04294f55060a8171fcc69 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 22 Dec 2024 20:21:17 -0800 Subject: [PATCH] Fix gemlite import (#2553) --- python/sglang/bench_offline_throughput.py | 12 ----------- python/sglang/bench_one_batch.py | 12 ----------- python/sglang/srt/layers/torchao_utils.py | 26 ++++++++++------------- 3 files changed, 11 insertions(+), 39 deletions(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index fbb5a3fb30..f840ee878a 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -322,18 +322,6 @@ def throughput_test( ) time.sleep(0.5) - try: - import os - import pwd - - from gemlite.core import GemLiteLinearTriton - - GemLiteLinearTriton.cache_config( - f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" - ) - except ImportError: - pass - logging.info("\nBenchmark...") result = throughput_test_once( backend_name=bench_args.backend, diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 1a4dab7400..8ab3e45243 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -386,18 +386,6 @@ def latency_test( server_args.device, ) - try: - import os - import pwd - - from gemlite.core import GemLiteLinearTriton - - GemLiteLinearTriton.cache_config( - f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" - ) - except ImportError: - pass - rank_print("Benchmark ...") # Run the sweep diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index f911fe0a7c..e2d557fb73 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -2,8 +2,14 @@ Common utilities for torchao. """ +import logging +import os +import pwd + import torch +logger = logging.getLogger(__name__) + def apply_torchao_config_to_model( model: torch.nn.Module, torchao_config: str, filter_fn=None @@ -50,27 +56,17 @@ def filter_fn(module, fqn): elif "gemlite" in torchao_config: # gemlite--- or # gemlite-- (packing_bitwidth defaults to 32) - import os - import pwd - - import gemlite - from gemlite.core import GemLiteLinearTriton, set_autotune - - try: - from torchao.quantization import gemlite_uintx_weight_only - except: - print( - f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization" - ) - return model + from gemlite.core import GemLiteLinearTriton + from torchao.quantization import gemlite_uintx_weight_only _quant_args = torchao_config.split("-") bit_width = int(_quant_args[-2]) group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1]) + try: packing_bitwidth = int(_quant_args[-3]) - except: - # if only 2 inputs found, use default value + except (ValueError, IndexError): + # if only 2 inputs found or conversion fails, use default value packing_bitwidth = 32 quantize_(