Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit d2da1ad

Browse files
committed
update profile
1 parent d03d16b commit d2da1ad

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ class LinearParams:
8787
torch_compile: Optional[bool] = False
8888

8989

90-
def main(profile_path: Path, compile: bool, linear_type: str):
90+
def main(
91+
profile_path: Path, compile: bool, linear_type: str, recompute_weight_cast: bool
92+
):
9193
profile_path = Path(profile_path)
9294
assert profile_path.is_dir(), f"Path {profile_path} must be a directory"
9395
params = LinearParams(
@@ -110,7 +112,9 @@ def main(profile_path: Path, compile: bool, linear_type: str):
110112
dtype=params.ref_dtype,
111113
)
112114
linear_type = LinearType[linear_type.upper()]
113-
linear_float8 = get_float8_linear(linear_type, linear_ref)
115+
linear_float8 = get_float8_linear(
116+
linear_type, linear_ref, recompute_weight_cast=recompute_weight_cast
117+
)
114118

115119
input_tensor = torch.randn(
116120
params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True

0 commit comments

Comments
 (0)