From 55af19fbfea1d600e57c68928f0edc533c58a86a Mon Sep 17 00:00:00 2001 From: drisspg Date: Sat, 22 Jun 2024 17:20:41 -0700 Subject: [PATCH] bigger sweep --- benchmarks/bench_padding.py | 46 ++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/benchmarks/bench_padding.py b/benchmarks/bench_padding.py index 3a10b8e..e86865f 100644 --- a/benchmarks/bench_padding.py +++ b/benchmarks/bench_padding.py @@ -4,9 +4,10 @@ import fire import torch -import torch.utils.benchmark as benchmark from float8_experimental.float8_utils import pad_tensor_for_matmul from tabulate import tabulate +from torch._inductor.utils import do_bench_using_profiling +from tqdm import tqdm # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N @@ -26,14 +27,9 @@ def benchmark_fn_in_usec(f, *args, **kwargs): - # Manual warmup - for _ in range(4): - f(*args, **kwargs) - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - measurement = t0.blocked_autorange() - return measurement.mean * 1e6 + no_args = lambda: f(*args, **kwargs) + time = do_bench_using_profiling(no_args) + return time * 1e3 def get_tops_info(tops, time, peak_tops): @@ -51,16 +47,17 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype): scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) A_pad = pad_tensor_for_matmul(A_fp8, dims=1) # mem copy - B_pad = pad_tensor_for_matmul(B_fp8, dims=[0, 1]).contiguous().t() # mem copy + B_pad = pad_tensor_for_matmul(B_fp8, dims=0).contiguous().t() # mem copy - return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[ - : A.shape[0], : B.shape[1] - ] + return torch._scaled_mm( + A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True + ) def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): + # We are only going to test the shape preserving A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy - B_pad = pad_tensor_for_matmul(B, dims=[0, 1]) # mem copy + B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) @@ -70,9 +67,9 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): B_pad = B_pad.t().contiguous().t() # mem copy - return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[ - : A.shape[0], : B.shape[1] - ] + return torch._scaled_mm( + A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True + ) def do_hp_matmul(A, B): @@ -92,7 +89,18 @@ def __iter__(self): def gen_configs(): - shapes = [(8192, 2500, 5000), (64, 255, 4096)] + shapes = shapes = [ + (8193, 2501, 5008), + (65, 253, 4096), + (1023, 1029, 2512), + (4095, 511, 10000), + (2047, 3073, 8192), + (511, 769, 7504), + (127, 4097, 12288), + (32769, 15, 15024), + (9217, 8191, 20480), + (16385, 1025, 25008), + ] output_dtype = torch.bfloat16 fp8_dtype = torch.float8_e4m3fn return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes] @@ -112,7 +120,7 @@ def run(compile: bool = False, n_limit: Optional[int] = None): "Ref % Peak", "FP8 % Peak", ] - for experiment in experiments: + for experiment in tqdm(experiments): M, K, N, output_dtype, fp8_dtype = experiment tops = 2 * M * N * K