Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
fixes to matmul and linear benchmarks (#320)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #320

for matmul benchmarks, unbreaks them - we need the scales to be fp32, not integers

for linear benchmarks, aligns default settings to current best supported path (compile on, dynamic scaling)

Reviewed By: awgu

Differential Revision: D59877198

fbshipit-source-id: 092daaffeb0096f9fbd12ca407701bc3aa80c97c
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 18, 2024
1 parent e6bb1eb commit ec8b46c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def float8_pct_top_peak(self):

def main(
sweep_path: Optional[Path] = None,
compile: bool = False,
compile: bool = True,
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
scaling_type_x: str = "delayed",
scaling_type_w: str = "delayed",
scaling_type_dL_dY: str = "delayed",
scaling_type_x: str = "dynamic",
scaling_type_w: str = "dynamic",
scaling_type_dL_dY: str = "dynamic",
):
device = "cuda"
print(f"Compile is set to | {compile}")
Expand Down Expand Up @@ -274,7 +274,7 @@ def wrapper(*args, **kwargs):
def invoke_main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_path", type=str, required=False)
parser.add_argument("--compile", action="store_true")
parser.add_argument("--disable_compile", action="store_true")
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
Expand All @@ -292,7 +292,7 @@ def invoke_main() -> None:
kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY
main(
output_path,
args.compile,
not args.disable_compile,
args.n_limit,
args.fast_accum_filter,
args.shape_name_filter,
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def run(n_limit: Optional[int] = None):
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()

def do_matmul(A, B):
scale_a = torch.tensor([1], device=device)
scale_b = torch.tensor([1], device=device)
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
)
Expand Down

0 comments on commit ec8b46c

Please sign in to comment.