Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semi structured v2 #32

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
96ae0ea
[doc] fix location of runllm widget (#10266)
youkaichao Nov 12, 2024
1808145
[doc] improve debugging doc (#10270)
youkaichao Nov 12, 2024
377b74f
Revert "[ci][build] limit cmake version" (#10271)
youkaichao Nov 12, 2024
112fa0b
[V1] Fix CI tests on V1 engine (#10272)
WoosukKwon Nov 13, 2024
0d4ea3f
[core][distributed] use tcp store directly (#10275)
youkaichao Nov 13, 2024
bbd3e86
[V1] Support VLMs with fine-grained scheduling (#9871)
WoosukKwon Nov 13, 2024
56a955e
Bump to compressed-tensors v0.8.0 (#10279)
dsikka Nov 13, 2024
032fcf1
[Doc] Fix typo in arg_utils.py (#10264)
xyang16 Nov 13, 2024
3945c82
[Model] Add support for Qwen2-VL video embeddings input & multiple im…
imkero Nov 13, 2024
1b886aa
[Model] Adding Support for Qwen2VL as an Embedding Model. Using MrLig…
FurtherAI Nov 13, 2024
b6dde33
[Core] Flashinfer - Remove advance step size restriction (#10282)
pavanimajety Nov 13, 2024
d909acf
[Model][LoRA]LoRA support added for idefics3 (#10281)
B-201 Nov 13, 2024
bb7991a
[V1] Add missing tokenizer options for `Detokenizer` (#10288)
ywang96 Nov 13, 2024
78eea7b
semi_structured for fp16 and bf16 and int8
ilmarkov Oct 1, 2024
331e9c5
Fix A100 int8 tests
ilmarkov Oct 2, 2024
381a6b4
Add fp8 cusparseLt
ilmarkov Oct 9, 2024
b146a79
wip
ilmarkov Oct 9, 2024
0ac01cc
Fix signatures
ilmarkov Oct 9, 2024
7472af2
Fix compilation and tests
ilmarkov Oct 13, 2024
3fe8bd4
Update for older platforms
ilmarkov Oct 15, 2024
e736027
Add benchmarks
ilmarkov Oct 16, 2024
ae66f77
Fix typo
ilmarkov Oct 23, 2024
59ee24d
Added scaled_mm for fp8.
ilmarkov Oct 24, 2024
3367704
Add docstrings
ilmarkov Oct 28, 2024
368beec
Update for torch 2.5
ilmarkov Oct 30, 2024
5be53f3
Add handling contiguous dense input for int8 and fp8
ilmarkov Oct 30, 2024
f9546a8
Add fp8 cusparseLt
ilmarkov Oct 9, 2024
2187236
Fix compilation and tests
ilmarkov Oct 13, 2024
f45a83b
Add caching of cusparseLT meta
ilmarkov Oct 23, 2024
b1aaea5
Cached cusparseLt
ilmarkov Oct 25, 2024
c36401c
Fix destroy function
ilmarkov Oct 25, 2024
9f6a469
Prepare for reproduce
ilmarkov Oct 25, 2024
2e56de9
Fix cusparseLt caching
ilmarkov Oct 30, 2024
9ad83cb
Make cached version default function
ilmarkov Nov 5, 2024
1f6a05b
Fixes and polishing after rebase
ilmarkov Nov 6, 2024
3a2c258
Add output_dtype option, fix non-padded inputs case
ilmarkov Nov 12, 2024
31cf482
Fix and polish
ilmarkov Nov 13, 2024
68512d4
Formatting
ilmarkov Nov 13, 2024
72d6cd3
Minor test and benchmarks updates
ilmarkov Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/quantization/fp8_semi_structured/cusparseLt.cpp"
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down Expand Up @@ -400,6 +401,15 @@ define_gpu_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

# If cuSparseLt is not installed we skip 2:4 optimizations
find_path(CUSPARSELT_INCLUDE_PATH cusparseLt.h
HINTS ${CUSPARSELT_INCLUDE_DIR}
PATH_SUFFIXES cuda/include cuda include)

if(CUSPARSELT_INCLUDE_PATH)
message(STATUS "CuSparseLt header file found ${CUSPARSELT_INCLUDE_PATH}")
target_compile_definitions(_C PRIVATE VLLM_CUSPARSELT_ENABLED=1)
endif()
#
# _moe_C extension
#
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.neuron
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi

RUN python3 -m pip install -U \
'cmake>=3.26,<=3.30' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-neuron.txt

ENV VLLM_TARGET_DEVICE neuron
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.ppc64le
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \
# These packages will be in rocketce eventually
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
'cmake>=3.26,<=3.30' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
torch==2.3.1 \
-r requirements-cpu.txt \
xformers uvloop==0.20.0
Expand Down
281 changes: 281 additions & 0 deletions benchmarks/cusparseLt_benchmarks/benchmark_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
import argparse
import copy
import itertools
import pickle
import time
from typing import Callable, Iterable, List, Tuple

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

from vllm.model_executor.layers.sparsity.utils.cusparse_2_4_utils import (
compress_to_torch_sparse_semi_structured_mat, dense_matmul, get_random_mat,
is_semi_structured_supported, semi_structured_sparse_dense_gemm,
semi_structured_sparse_dense_gemm_scaled)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]


# helpers
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = get_random_mat(n, k, dtype)
b = get_random_mat(m, k, dtype).t()
return a, b


# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
min_run_time = 1

globals = {
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)


def bench(m: int, k: int, n: int, label: str, sub_label: str,
use_fp8: bool) -> Iterable[TMeasurement]:
a, b = make_rand_tensors(torch.float16, m, n, k)

timers = []
# pytorch float16
timers.append(
bench_fn(label, sub_label, "pytorch_fp16_fp16_matmul", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))

# cusparseLt fp16
timers.append(
bench_fn(label, sub_label, "cusparseLt_fp16_fp16_2_4",
semi_structured_sparse_dense_gemm,
compress_to_torch_sparse_semi_structured_mat(a), b))

# timers.append(
# bench_fn(label,
# sub_label,
# "cusparseLt_fp16_fp16_2_4_noncached",
# semi_structured_sparse_dense_gemm,
# compress_to_torch_sparse_semi_structured_mat(a),
# b,
# cached=False))

# pytorch bf16
a, b = make_rand_tensors(torch.bfloat16, m, n, k)
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_matmul", torch.mm, a, b))

# cusparseLt bf16
a_compressed = compress_to_torch_sparse_semi_structured_mat(a)

timers.append(
bench_fn(label, sub_label, "cusparseLt_bf16_bf16_2_4",
semi_structured_sparse_dense_gemm, a_compressed, b))

# timers.append(
# bench_fn(label,
# sub_label,
# "cusparseLt_bf16_bf16_2_4_noncached",
# semi_structured_sparse_dense_gemm,
# a_compressed,
# b,
# cached=False))

a, b = make_rand_tensors(torch.int8, m, n, k)
# # cutlass i8
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_matmul_scaled", dense_matmul,
a, b, torch.int8))

# cusparseLt i8
a_compressed = compress_to_torch_sparse_semi_structured_mat(a)
# warmup
scale = torch.tensor(1.0, dtype=torch.float32, device='cuda')

semi_structured_sparse_dense_gemm(a_compressed, b)
timers.append(
bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4",
semi_structured_sparse_dense_gemm, a_compressed, b))

semi_structured_sparse_dense_gemm_scaled(a_compressed,
b,
scale_a=scale,
scale_b=scale)
timers.append(
bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4_scaled",
semi_structured_sparse_dense_gemm_scaled, a_compressed, b,
scale, scale))

# scale_vec = scale.repeat(a_compressed.shape[0])
# semi_structured_sparse_dense_gemm_scaled(a_compressed,
# b,
# scale_a=scale_vec,
# scale_b=scale)
# timers.append(
# bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4_scaled_channel",
# semi_structured_sparse_dense_gemm_scaled, a_compressed, b,
# scale_vec, scale))

# timers.append(
# bench_fn(label,
# sub_label,
# "cusparseLt_i8_i8_2_4_noncached",
# semi_structured_sparse_dense_gemm,
# a_compressed,
# b,
# cached=False))

if use_fp8:
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
# cutlass fp8
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_matmul_scaled",
dense_matmul, a, b, torch.float8_e4m3fn))

# cusparseLt fp8
a_compressed = compress_to_torch_sparse_semi_structured_mat(a)

# warmup
semi_structured_sparse_dense_gemm(a_compressed, b)

timers.append(
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4",
semi_structured_sparse_dense_gemm, a_compressed, b))

semi_structured_sparse_dense_gemm_scaled(a_compressed, b, scale, scale)
timers.append(
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4_scaled",
semi_structured_sparse_dense_gemm_scaled, a_compressed, b,
scale, scale))

# timers.append(
# bench_fn(label,
# sub_label,
# "cusparseLt_fp8_fp8_2_4_noncached",
# semi_structured_sparse_dense_gemm,
# a_compressed,
# b,
# cached=False))

return timers


# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()


def run(MKNs: Iterable[Tuple[int, int, int]],
use_fp8: bool) -> Iterable[TMeasurement]:
results = []

for m, k, n in MKNs:
timers = bench(m, k, n, "gemm", f"MKN=({m}x{k}x{n})", use_fp8)
print_timers(timers)
results.extend(timers)

return results


def run_model_bench(args):
if not is_semi_structured_supported():
raise ValueError("Device does not support semi-structured sparsity")

print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")

def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs

model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
assert m % 16 == 0, "Batch size has to be a multiple of 16"
for k, n in KNs:
if k % 32 or n % 32:
continue
MKNs.append((m, k, n))

data = run(MKNs, args.use_fp8)
model_bench_data.append(data)

# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results cuSparseLt {model}-TP{tp_size} ====")
print_timers(data)

if args.save_results:
timestamp = int(time.time())

all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{timestamp}.pkl", "wb") as f:
pickle.dump(all_data, f)


if __name__ == '__main__':

parser = FlexibleArgumentParser(
description="""
Benchmark cuSparseLt 2:4 GEMMs.

To run dimensions from a model:
python3 ./benchmarks/cusparseLt_benchmarks/benchmark_24.py --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1


Output if --save-results:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch, cutlass and cusparseLt implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument("--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument(
'--use-fp8',
action='store_true',
help='Add benchmarking fp8 matmul (on supporting fp8 platforms)')

parser.add_argument(
'--save-results',
action='store_true',
help='Save results to a pickle file named model_bench_{timestamp}.pkl')

args = parser.parse_args()
run_model_bench(args)
43 changes: 43 additions & 0 deletions benchmarks/cusparseLt_benchmarks/weight_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536

# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-3-8b": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}
19 changes: 19 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,22 @@ void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif

#ifndef USE_ROCM
torch::Tensor cslt_compress_fp8_semi_structured(const torch::Tensor& input);

torch::Tensor cslt_mm_semi_structured(
const torch::Tensor& compressed_A, const torch::Tensor& dense_B,
const c10::optional<torch::Tensor>& scale_opt,
const c10::optional<torch::Tensor>& bias_opt,
const std::optional<torch::ScalarType> out_dtype_opt);

torch::Tensor cslt_mm_semi_structured2(
const torch::Tensor& compressed_A, const torch::Tensor& dense_B,
const c10::optional<torch::Tensor>& scale_opt,
const c10::optional<torch::Tensor>& bias_opt,
const std::optional<torch::ScalarType> out_dtype_opt);

void cslt_clear_cache();

#endif
Loading