Skip to content

Commit

Permalink
Update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 11, 2024
1 parent a27ca81 commit 81c4360
Show file tree
Hide file tree
Showing 27 changed files with 299 additions and 3,965 deletions.
5 changes: 1 addition & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/
USE_SABI 3
WITH_SOABI)

Expand All @@ -418,9 +418,6 @@ define_gpu_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

# include(nm_cutlass_c.cmake)
# build_nm_cutlass_c()

#
# _moe_C extension
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import argparse
import copy
import itertools
import pickle as pkl
import time
import dataclasses
import multiprocessing as mp
import os
Expand All @@ -10,9 +15,15 @@
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES
from vllm.utils import FlexibleArgumentParser
import vllm._custom_ops as ops
from utils import make_n_rand_sparse_tensors

import vllm._custom_ops as ops

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]


@dataclasses.dataclass
Expand Down Expand Up @@ -565,10 +576,202 @@ def bench_fp8(dtype: torch.dtype, with_cuda_graph: Optional[int],
return timers


def bench_v2(dtype: torch.dtype, with_cuda_graph: Optional[int],
def bench(dtype: torch.dtype, with_cuda_graph: Optional[int],
with_arg_pool: Optional[int], m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, with_cuda_graph, with_arg_pool, m, k, n, label,
sub_label)
raise ValueError("unsupported type")


# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()


def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
dtype = args.dtype

for m, k, n in MKNs:
label = f"scaled-sparse-{dtype}-gemm"
label = f"{label}-cugraph_{args.with_cuda_graph}" \
if args.with_cuda_graph else label
label = f"{label}-argpool_{args.with_arg_pool}" \
if args.with_arg_pool else label
timers = bench(args.dtype, args.with_cuda_graph,
args.with_arg_pool, m, k, n, label,
f"MKN=({m}x{k}x{n})")

print_timers(timers)
results.extend(timers)

return results


# output makers
def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None):
print(f"== All Results {base_description} ====")
print_timers(data)

# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)


# argparse runners


def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args, MKNs)

make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args, MKNs)

make_output(data, MKNs, f"range_bench-{args.dtype}")


def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")

def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
if tp_split_dim is not None:
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs

model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))

data = run(args, MKNs)
model_bench_data.append(data)

# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)

timestamp = int(time.time())

all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)


if __name__ == '__main__':

def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
if dt == "fp16":
return torch.float16
if dt == "bf16":
return torch.bfloat16
raise ValueError("unsupported dtype")

parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8', 'fp16', 'bf16']")
parser.add_argument(
'--with-cuda-graph',
type=int,
default=32,
help="Number of ops/matmuls in a cudagraph execution. When set"
"cuda-graphs is enabled")
parser.add_argument(
'--with-arg-pool',
type=int,
default=None,
help="Number of A and B tensors to use as arg-pool. When not set,"
"it defaults to 1")

subparsers = parser.add_subparsers(dest="cmd")

square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)

range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)

model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench)

args = parser.parse_args()
args.func(args)
Loading

0 comments on commit 81c4360

Please sign in to comment.