From 642ac1e854cd00f47340254f98bedb2258635ab8 Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Mon, 9 Dec 2024 15:12:57 -0800 Subject: [PATCH] Disable donated_buffer for all ops's backward benchmarking (#104) Summary: It is still a temporary fix for backward benchmarking. Related discussion https://github.com/pytorch-labs/tritonbench/issues/40 Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/104 Reviewed By: xuzhao9 Differential Revision: D66911331 Pulled By: FindHao fbshipit-source-id: 6b3e5188fb6c929d6fe34aaf3a141bafa92c33f3 --- tritonbench/operators/geglu/operator.py | 7 +++++++ tritonbench/operators/layer_norm/operator.py | 1 + tritonbench/operators/swiglu/operator.py | 7 +++++++ 3 files changed, 15 insertions(+) diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index 640e4185..c744fe88 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -8,6 +8,7 @@ from tritonbench.utils.triton_op import ( BenchmarkOperator, + Mode, register_benchmark, register_x_val, ) @@ -59,6 +60,12 @@ def liger_geglu(self, input) -> Callable: @register_benchmark() def inductor_geglu(self, input) -> Callable: + # TODO: remove this once we have a better way to handle backward benchmarking + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + compiled = torch.compile(self.baseline_model) return lambda: compiled(input) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index ecfc4444..db48a03d 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -34,6 +34,7 @@ def torch_layer_norm(self, *args): @register_benchmark() def torch_compile_layer_norm(self, *args): + # TODO: remove this once we have a better way to handle backward benchmarking # We need to run backward multiple times for proper benchmarking # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index ab53ab76..b414513f 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -7,6 +7,7 @@ from tritonbench.utils.triton_op import ( BenchmarkOperator, + Mode, register_benchmark, register_x_val, ) @@ -59,6 +60,12 @@ def liger_swiglu(self, input) -> Callable: @register_benchmark() def inductor_swiglu(self, input) -> Callable: + # TODO: remove this once we have a better way to handle backward benchmarking + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + compiled = torch.compile(self.baseline_op) return lambda: compiled(input)