Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
update linear bench
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 18, 2024
1 parent 2ffcbe9 commit 48f21f6
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Experiment:
dtype: torch.dtype
compiled: bool = False
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
recompute_weight_cast: bool = False

# 3 Times since we are calculating forward backward
@property
Expand Down Expand Up @@ -95,9 +96,14 @@ def main(
}
input_bias = False
ref_dtypes = [torch.bfloat16, torch.float16]
recompute_weight_casts = [True, False]
experiment_list: List[Experiment] = []
for idx, (dtype, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items())))
for idx, (dtype, (name, (K, N)), recompute_weight_cast) in enumerate(
tqdm(
list(
product(ref_dtypes, name_to_shapes_70b.items(), recompute_weight_casts)
)
)
):
if n_limit is not None and idx >= n_limit:
break
Expand All @@ -106,7 +112,9 @@ def main(
)

linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref), emulate=False
copy.deepcopy(linear_ref),
emulate=False,
recompute_weight_cast=recompute_weight_cast,
)

bsz, seq_len = 4, 4096
Expand Down Expand Up @@ -155,6 +163,7 @@ def wrapper(*args, **kwargs):
float8_time,
dtype,
compile,
recompute_weight_cast=recompute_weight_cast,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
Expand All @@ -169,6 +178,7 @@ def wrapper(*args, **kwargs):
"ref_dtype",
"compiled",
"fp8_dtype",
"recompute_weight_cast",
"ref_time_sec",
"pt_fp8_time_sec",
"ref_tops_sec",
Expand All @@ -187,6 +197,7 @@ def wrapper(*args, **kwargs):
experiment.dtype,
experiment.compiled,
experiment.float_8_dtype,
experiment.recompute_weight_cast,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.ref_tops_sec,
Expand Down Expand Up @@ -214,6 +225,7 @@ def wrapper(*args, **kwargs):
"shape",
"ref_dtype",
"compiled",
"recompute_weight_cast",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_speedup",
Expand Down

0 comments on commit 48f21f6

Please sign in to comment.