Skip to content

Commit

Permalink
Fix gemlite import (#2553)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Dec 23, 2024
1 parent 25e5d58 commit 23e5e50
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 39 deletions.
12 changes: 0 additions & 12 deletions python/sglang/bench_offline_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,6 @@ def throughput_test(
)
time.sleep(0.5)

try:
import os
import pwd

from gemlite.core import GemLiteLinearTriton

GemLiteLinearTriton.cache_config(
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
)
except ImportError:
pass

logging.info("\nBenchmark...")
result = throughput_test_once(
backend_name=bench_args.backend,
Expand Down
12 changes: 0 additions & 12 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,6 @@ def latency_test(
server_args.device,
)

try:
import os
import pwd

from gemlite.core import GemLiteLinearTriton

GemLiteLinearTriton.cache_config(
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
)
except ImportError:
pass

rank_print("Benchmark ...")

# Run the sweep
Expand Down
26 changes: 11 additions & 15 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
Common utilities for torchao.
"""

import logging
import os
import pwd

import torch

logger = logging.getLogger(__name__)


def apply_torchao_config_to_model(
model: torch.nn.Module, torchao_config: str, filter_fn=None
Expand Down Expand Up @@ -50,27 +56,17 @@ def filter_fn(module, fqn):
elif "gemlite" in torchao_config:
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
import os
import pwd

import gemlite
from gemlite.core import GemLiteLinearTriton, set_autotune

try:
from torchao.quantization import gemlite_uintx_weight_only
except:
print(
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
)
return model
from gemlite.core import GemLiteLinearTriton
from torchao.quantization import gemlite_uintx_weight_only

_quant_args = torchao_config.split("-")
bit_width = int(_quant_args[-2])
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])

try:
packing_bitwidth = int(_quant_args[-3])
except:
# if only 2 inputs found, use default value
except (ValueError, IndexError):
# if only 2 inputs found or conversion fails, use default value
packing_bitwidth = 32

quantize_(
Expand Down

0 comments on commit 23e5e50

Please sign in to comment.