diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 4736cd29..265fa359 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -56,6 +56,7 @@ class Experiment: dtype: torch.dtype compiled: bool = False float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn + recompute_weight_cast: bool = False # 3 Times since we are calculating forward backward @property @@ -95,9 +96,14 @@ def main( } input_bias = False ref_dtypes = [torch.bfloat16, torch.float16] + recompute_weight_casts = [True, False] experiment_list: List[Experiment] = [] - for idx, (dtype, (name, (K, N))) in enumerate( - tqdm(list(product(ref_dtypes, name_to_shapes_70b.items()))) + for idx, (dtype, (name, (K, N)), recompute_weight_cast) in enumerate( + tqdm( + list( + product(ref_dtypes, name_to_shapes_70b.items(), recompute_weight_casts) + ) + ) ): if n_limit is not None and idx >= n_limit: break @@ -106,7 +112,9 @@ def main( ) linear_float8 = Float8Linear.from_float( - copy.deepcopy(linear_ref), emulate=False + copy.deepcopy(linear_ref), + emulate=False, + recompute_weight_cast=recompute_weight_cast, ) bsz, seq_len = 4, 4096 @@ -155,6 +163,7 @@ def wrapper(*args, **kwargs): float8_time, dtype, compile, + recompute_weight_cast=recompute_weight_cast, ) print(experiment) print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) @@ -169,6 +178,7 @@ def wrapper(*args, **kwargs): "ref_dtype", "compiled", "fp8_dtype", + "recompute_weight_cast", "ref_time_sec", "pt_fp8_time_sec", "ref_tops_sec", @@ -187,6 +197,7 @@ def wrapper(*args, **kwargs): experiment.dtype, experiment.compiled, experiment.float_8_dtype, + experiment.recompute_weight_cast, experiment.ref_time_sec, experiment.float8_time_sec, experiment.ref_tops_sec, @@ -214,6 +225,7 @@ def wrapper(*args, **kwargs): "shape", "ref_dtype", "compiled", + "recompute_weight_cast", "ref_time_sec", "pt_fp8_time_sec", "pt_fp8_speedup",