Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
bigger sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 24, 2024
1 parent 017e858 commit 55af19f
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import fire

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate
from torch._inductor.utils import do_bench_using_profiling
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
Expand All @@ -26,14 +27,9 @@


def benchmark_fn_in_usec(f, *args, **kwargs):
# Manual warmup
for _ in range(4):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
measurement = t0.blocked_autorange()
return measurement.mean * 1e6
no_args = lambda: f(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3


def get_tops_info(tops, time, peak_tops):
Expand All @@ -51,16 +47,17 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)

A_pad = pad_tensor_for_matmul(A_fp8, dims=1) # mem copy
B_pad = pad_tensor_for_matmul(B_fp8, dims=[0, 1]).contiguous().t() # mem copy
B_pad = pad_tensor_for_matmul(B_fp8, dims=0).contiguous().t() # mem copy

return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
: A.shape[0], : B.shape[1]
]
return torch._scaled_mm(
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
)


def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
# We are only going to test the shape preserving
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
B_pad = pad_tensor_for_matmul(B, dims=[0, 1]) # mem copy
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy

scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
Expand All @@ -70,9 +67,9 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):

B_pad = B_pad.t().contiguous().t() # mem copy

return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
: A.shape[0], : B.shape[1]
]
return torch._scaled_mm(
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
)


def do_hp_matmul(A, B):
Expand All @@ -92,7 +89,18 @@ def __iter__(self):


def gen_configs():
shapes = [(8192, 2500, 5000), (64, 255, 4096)]
shapes = shapes = [
(8193, 2501, 5008),
(65, 253, 4096),
(1023, 1029, 2512),
(4095, 511, 10000),
(2047, 3073, 8192),
(511, 769, 7504),
(127, 4097, 12288),
(32769, 15, 15024),
(9217, 8191, 20480),
(16385, 1025, 25008),
]
output_dtype = torch.bfloat16
fp8_dtype = torch.float8_e4m3fn
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]
Expand All @@ -112,7 +120,7 @@ def run(compile: bool = False, n_limit: Optional[int] = None):
"Ref % Peak",
"FP8 % Peak",
]
for experiment in experiments:
for experiment in tqdm(experiments):
M, K, N, output_dtype, fp8_dtype = experiment
tops = 2 * M * N * K

Expand Down

0 comments on commit 55af19f

Please sign in to comment.