diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index 640e4185..c744fe88 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -8,6 +8,7 @@ from tritonbench.utils.triton_op import ( BenchmarkOperator, + Mode, register_benchmark, register_x_val, ) @@ -59,6 +60,12 @@ def liger_geglu(self, input) -> Callable: @register_benchmark() def inductor_geglu(self, input) -> Callable: + # TODO: remove this once we have a better way to handle backward benchmarking + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + compiled = torch.compile(self.baseline_model) return lambda: compiled(input) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index ecfc4444..db48a03d 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -34,6 +34,7 @@ def torch_layer_norm(self, *args): @register_benchmark() def torch_compile_layer_norm(self, *args): + # TODO: remove this once we have a better way to handle backward benchmarking # We need to run backward multiple times for proper benchmarking # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index ab53ab76..b414513f 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -7,6 +7,7 @@ from tritonbench.utils.triton_op import ( BenchmarkOperator, + Mode, register_benchmark, register_x_val, ) @@ -59,6 +60,12 @@ def liger_swiglu(self, input) -> Callable: @register_benchmark() def inductor_swiglu(self, input) -> Callable: + # TODO: remove this once we have a better way to handle backward benchmarking + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + compiled = torch.compile(self.baseline_op) return lambda: compiled(input)