Skip to content

Commit

Permalink
Add proton profiling (#102)
Browse files Browse the repository at this point in the history
Summary:
Add Proton profiler support. The profiler will instrument into the `run()` function to instrument the benchmark process.

Pull Request resolved: #102

Test Plan:
```
$ python run.py --op softmax --num-inputs 3 --metrics proton
Entering scope x_val_[4096, 256]
Exiting scope x_val_[4096, 256]
Entering scope x_val_[4096, 384]
Exiting scope x_val_[4096, 384]
Entering scope x_val_[4096, 512]
Exiting scope x_val_[4096, 512]
      x_val    naive_softmax-proton    triton_softmax-proton
-----------  ----------------------  -----------------------
[4096, 256]
[4096, 384]
[4096, 512]
```

```
$ proton-viewer  proton.hatchet -m time
```

<img width="1548" alt="image" src="https://github.com/user-attachments/assets/eb450119-08bc-48cb-b148-e84127f9c7b1">

Reviewed By: adamomainz

Differential Revision: D66909978

Pulled By: xuzhao9

fbshipit-source-id: 3f5226c7abb3732c86dc9f908f7ce9dfdd529d29
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 7, 2024
1 parent 62d311e commit 87dffcc
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 25 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ __pycache__/
*.egg-info/
torch_compile_debug/
build/
/*.csv
*.hatchet
1 change: 1 addition & 0 deletions tritonbench/components/do_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .run import do_bench_wrapper
34 changes: 34 additions & 0 deletions tritonbench/components/do_bench/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import triton


def do_bench_wrapper(
fn,
warmup,
rep,
grad_to_none,
use_cuda_graphs: bool = False,
bypass_fail: bool = False,
):
"""Wrapper to triton's do_bench to gain latency."""
if use_cuda_graphs:
with torch.cuda.stream(torch.cuda.Stream()):
return triton.testing.do_bench_cudagraph(
fn,
rep=rep,
return_mode="median",
grad_to_none=grad_to_none,
)
else:
try:
return triton.testing.do_bench(
fn,
warmup=warmup,
rep=rep,
return_mode="median",
grad_to_none=grad_to_none,
)
except Exception as e:
if not bypass_fail:
raise e
return None
1 change: 1 addition & 0 deletions tritonbench/components/proton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .trace import proton_trace
25 changes: 25 additions & 0 deletions tritonbench/components/proton/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Callable, Optional

import triton.profiler as proton


def proton_trace(
session_id: int,
scope_name: str,
fn: Callable,
warmup: int,
flops: Optional[int] = None,
bytes: Optional[int] = None,
):
# warmup
for _ in range(warmup):
fn()
metrics_dict = {}
if flops:
metrics_dict["flops"] = flops
if bytes:
metrics_dict["bytes"] = bytes
proton.activate(session_id)
with proton.scope(scope_name, metrics=metrics_dict):
fn()
proton.deactivate(session_id)
73 changes: 48 additions & 25 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import triton

from tritonbench.components.do_bench import do_bench_wrapper
from tritonbench.components.ncu import ncu_analyzer, nsys_analyzer
from tritonbench.utils.env_utils import (
apply_precision,
Expand Down Expand Up @@ -704,15 +705,26 @@ def run(
"""Benchmarking the operator and returning its metrics."""
metrics = []
try:
if "proton" in self.required_metrics:
import triton.profiler as proton

self._proton_session_id = proton.start()
proton.enter_scope(f"tritonbench_run_op_{self.name}")
proton.deactivate(self._proton_session_id)
input_id_range = range(self._input_id, self._input_id + self._num_inputs)
if tqdm is not None:
input_id_range = tqdm(input_id_range)
if self._input_id:
for _dryrun_input_id in range(self._input_id):
self.example_inputs = self.get_example_inputs()
for input_id in input_id_range:
self._cur_input_id = input_id
self.example_inputs = self.get_example_inputs()
x_val = self.get_x_val(self.example_inputs)
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.enter_scope(f"x_val_{x_val}")
proton.deactivate(self._proton_session_id)
self._cur_input_id = input_id
if self.example_inputs is None:
logger.warn(
f"The input generator get_input_iter() has depleted at id {input_id}. Available number of "
Expand All @@ -731,7 +743,6 @@ def run(
self.baseline_fn = None
self.baseline_metrics = None
self._op_flops = {}
x_val = self.get_x_val(self.example_inputs)
if self._only:
benchmarks = self._only
else:
Expand Down Expand Up @@ -772,8 +783,15 @@ def _reduce_benchmarks(acc, bm_name: str):
_reduce_benchmarks, benchmarks, {}
)
metrics.append((x_val, y_vals))
del self.example_inputs
gc.collect()
del self.example_inputs # save some memory
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.exit_scope()
proton.deactivate(self._proton_session_id)
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.exit_scope()
proton.finalize()
except (KeyboardInterrupt, Exception):
logger.warning(
"Caught exception, terminating early with partial results",
Expand Down Expand Up @@ -966,27 +984,14 @@ def _init_extra_metrics() -> Dict[str, Any]:
if {"latency", "tflops", "speedup", "compile_time"} & set(
self.required_metrics
):
if self.use_cuda_graphs:
with torch.cuda.stream(torch.cuda.Stream()):
metrics.latency = triton.testing.do_bench_cudagraph(
fn,
rep=rep,
return_mode="median",
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
else:
try:
metrics.latency = triton.testing.do_bench(
fn,
warmup=warmup,
rep=rep,
return_mode="median",
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
except Exception as e:
if not self.tb_args.bypass_fail:
raise e
metrics.latency = None
metrics.latency = do_bench_wrapper(
fn,
warmup,
rep,
grad_to_none=self.get_grad_to_none(self.example_inputs),
use_cuda_graphs=self.use_cuda_graphs,
bypass_fail=self.tb_args.bypass_fail,
)
if {
"gpu_peak_mem",
"gpu_mem_footprint_compression_ratio",
Expand Down Expand Up @@ -1116,6 +1121,20 @@ def _init_extra_metrics() -> Dict[str, Any]:
)
if "kineto_trace" in self.required_metrics:
metrics.kineto_trace = self.kineto_trace(input_id, fn)
if "proton" in self.required_metrics:
from tritonbench.components.proton import proton_trace

scope_name = fn_name
flops = self.flops() if self.has_metric("flops") else None
num_bytes = self.bytes() if self.has_metric("bytes") else None
proton_trace(
self._proton_session_id,
scope_name,
fn,
warmup=warmup,
flops=flops,
bytes=num_bytes,
)
if "best_config" in self.required_metrics:
metrics.best_config = self.best_config(fn)
# run the hidden metric "_compile_time_in_task"
Expand Down Expand Up @@ -1590,3 +1609,7 @@ def run_and_capture(self, *args, **kwargs):
@classmethod
def has_bwd(cls) -> bool:
return cls.get_bwd_fn is not BenchmarkOperator.get_bwd_fn

@classmethod
def has_metric(cls, metric_name: str) -> bool:
return bool(getattr(cls, metric_name, None))

0 comments on commit 87dffcc

Please sign in to comment.