Skip to content

Commit

Permalink
Add dashboard metric support for sam and sam2 benchmarks (#1407)
Browse files Browse the repository at this point in the history
* Add dashboard metric support for sam and sam2 benchmarks

Summary:
Also makes the `.github/workflows/dashboard_perf_test.yml` more complete by adding runs for both compile and autoquant-all
for llama, sam, sam2

Test Plan:
tested locally:
```
SEGMENT_ANYTHING_FAST_USE_FLASH_4=0 python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path output

cat output.csv
{"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "torch.bfloat16", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "vit_h", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "memory(MiB)", "benchmark_values": [27740], "target_value": null}}
{"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "torch.bfloat16", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "vit_h", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "img_s(avg)", "benchmark_values": [42.65887304183299], "target_value": null}}

python server.py ~/checkpoints/sam2 large --port 4000 --host localhost --use_autoquant --benchmark --output_json_path output

{"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "autoquant", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "sam2-large", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "memory(MiB)", "benchmark_values": [4294], "target_value": null}}
{"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "autoquant", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "sam2-large", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "memory(%)", "benchmark_values": [4], "target_value": null}}
{"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "autoquant", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "sam2-large", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "time_s(avg)", "benchmark_values": [0.4720584869384766], "target_value": null}}
```

Also will check CI and click house

Reviewers:

Subscribers:

Tasks:

Tags:

* ruff

* add util file

* add scripts

* tiktoken

* dep

* lm_eval

* use default autoquant for now

* path for sam eval

* add conda_run

* more deps

* invalid requirement

* git dep

* include subgraph_utils folder in installation

* reduce num_workers to 12

* try sam again with 8 workers

* skipping sam

* fix sam2 path

* checking path

* add sam2 config for ci

* skip sam2

* remove config
  • Loading branch information
jerryzh168 authored Dec 17, 2024
1 parent ace7219 commit cbd7c29
Show file tree
Hide file tree
Showing 14 changed files with 227 additions and 85 deletions.
33 changes: 28 additions & 5 deletions .github/workflows/dashboard_perf_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ on:
- ciflow/benchmark/*
workflow_dispatch:
schedule:
- cron: 0 7 * * 0-6
- cron: 0 7 * * *

jobs:
benchmark:
runs-on: linux.aws.a100
strategy:
matrix:
torch-spec:
- '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124'
- '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124'
steps:
- uses: actions/checkout@v3

Expand All @@ -31,14 +31,37 @@ jobs:
${CONDA_RUN} pip install ${{ matrix.torch-spec }}
${CONDA_RUN} pip install -r dev-requirements.txt
${CONDA_RUN} pip install .
# SAM 2.1
${CONDA_RUN} pip install -r examples/sam2_amg_server/requirements.txt
# llama3
export CHECKPOINT_PATH=checkpoints
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
${CONDA_RUN} python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf --hf_token ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
${CONDA_RUN} python scripts/download.py --repo_id ${MODEL_REPO} --hf_token ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
${CONDA_RUN} python scripts/convert_hf_checkpoint.py --checkpoint_dir "${CHECKPOINT_PATH}/${MODEL_REPO}"
mkdir -p ${{ runner.temp }}/benchmark-results
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --output_json_path ${{ runner.temp }}/benchmark-results/benchmark-results.json
# llama3 - compile baseline
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
# llama3 - autoquant
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
# skipping SAM because of https://hud.pytorch.org/pr/pytorch/ao/1407
# # SAM
# ${CONDA_RUN} pip install git+https://github.com/pytorch-labs/segment-anything-fast.git@main
# # SAM compile baselilne
# ${CONDA_RUN} sh torchao/_models/sam/setup.sh
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
# SAM 2.1
# ${CONDA_RUN} sh scripts/download_sam2_ckpts.sh ${CHECKPOINT_PATH}/sam2
# cd examples/sam2_amg_server
# hydra.errors.MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'. Check that it's in your config search path.
# ${CONDA_RUN} python server.py ${CHECKPOINT_PATH}/sam2 large --port 4000 --host localhost --fast --benchmark --dry --output_json_path ${{ runner.temp }}/benchmark-results/sam2-benchmark-results.json
# ${CONDA_RUN} python server.py ${CHECKPOINT_PATH}/sam2 large --port 4000 --host localhost --fast --use_autoquant --benchmark --dry --output_json_path ${{ runner.temp }}/benchmark-results/sam2-benchmark-results.json
- name: Upload the benchmark results to OSS benchmark database for the dashboard
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
Expand Down
9 changes: 8 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@ sentencepiece # for gpt-fast tokenizer
expecttest

# For prototype features and benchmarks
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
bitsandbytes # needed for testing triton quant / dequant ops for 8-bit optimizers
matplotlib
pandas
fire # QOL for commandline scripts
tabulate # QOL for printing tables to stdout
tiktoken
blobfile
lm_eval
# sam
diskcache
pycocotools
tqdm

# Custom CUDA Extensions
ninja
Expand Down
35 changes: 28 additions & 7 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
import asyncio
from contextlib import asynccontextmanager
import contextlib
from torchao._models.utils import (
get_arch_name,
write_json_result,
)

from torch._inductor import config as inductorconfig
inductorconfig.triton.unique_kernel_names = True
Expand Down Expand Up @@ -269,8 +273,10 @@ def benchmark_fn(func, inp, mask_generator, warmup=3, runs=10):
t = time.time()
for _ in range(runs):
func(inp, mask_generator)
print(f"Benchmark took {(time.time() - t)/runs}s per iteration.")
max_memory_allocated()
avg_time_per_run = (time.time() - t)/runs
print(f"Benchmark took {avg_time_per_run}s per iteration.")
max_memory_allocated_bytes, max_memory_allocated_percentage = max_memory_allocated()
return avg_time_per_run, max_memory_allocated_bytes, max_memory_allocated_percentage


def max_memory_allocated():
Expand All @@ -279,6 +285,7 @@ def max_memory_allocated():
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%")
return max_memory_allocated_bytes, max_memory_allocated_percentage


def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
Expand Down Expand Up @@ -527,10 +534,10 @@ def set_furious(mask_generator):
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16

def set_autoquant(mask_generator):
import torchao
from torchao import autoquant
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
# NOTE: Not baseline feature
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')
# NOTE: this fails when we run
Expand All @@ -556,7 +563,8 @@ def main(checkpoint_path,
dry=False,
batch_size=1,
load_fast="",
save_fast=""):
save_fast="",
output_json_path=None):
if verbose:
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
Expand Down Expand Up @@ -626,9 +634,9 @@ def main(checkpoint_path,
if benchmark:
print(f"batch size {batch_size} dog benchmark")
if batch_size == 1:
benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator)
result = benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator)
else:
benchmark_fn(image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
result = benchmark_fn(image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)

for i, shapes in enumerate([example_shapes(), example_shapes_2()]):
print(f"batch size {batch_size} example shapes {i} benchmark")
Expand All @@ -644,6 +652,19 @@ def main(checkpoint_path,
print("len(random_images): ", len(random_images))
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)

if output_json_path:
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = "sam2-" + model_type
arch = get_arch_name()
dtype = "autoquant" if use_autoquant else ("compile" if fast else "base")
avg_time_per_run, max_memory_allocated_bytes, max_memory_allocated_percentage = result
memory_result = [name, dtype, device, arch, "memory(MiB)", max_memory_allocated_bytes, None]
memory_percent_result = [name, dtype, device, arch, "memory(%)", max_memory_allocated_percentage, None]
performance_result = [name, dtype, device, arch, "time_s(avg)", avg_time_per_run, None]
write_json_result(output_json_path, headers, memory_result)
write_json_result(output_json_path, headers, memory_percent_result)
write_json_result(output_json_path, headers, performance_result)

if profile is not None:
print(f"Saving profile under {profile}")
if batch_size == 1:
Expand Down
6 changes: 3 additions & 3 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def convert_hf_checkpoint(
model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
model_map_json = None

try:
assert model_map_json_safetensors.is_file()
model_map_json = model_map_json_safetensors
Expand All @@ -46,7 +46,7 @@ def convert_hf_checkpoint(
print(f"Found pytorch index at {model_map_json_pytorch}")
except AssertionError:
print(f"{model_map_json_pytorch} not found")

if model_map_json is None: raise Exception("No model map found!")

with open(model_map_json) as json_map:
Expand Down Expand Up @@ -85,7 +85,7 @@ def permute(w, n_head):
else:
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)

if config.tie_word_embeddings:
merged_result["lm_head.weight"] = merged_result["model.embed_tokens.weight"].clone()

Expand Down
68 changes: 68 additions & 0 deletions scripts/download_sam2_ckpts.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/bin/bash

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Use either wget or curl to download the checkpoints
if command -v wget &> /dev/null; then
CMD="wget -P"
elif command -v curl &> /dev/null; then
CMD="curl -L -O"
else
echo "Please install wget or curl to download the checkpoints."
exit 1
fi

# Define the URLs for SAM 2 checkpoints
# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"

# Download each of the four checkpoints using wget
# echo "Downloading sam2_hiera_tiny.pt checkpoint..."
# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }

# echo "Downloading sam2_hiera_small.pt checkpoint..."
# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }

# echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }

# echo "Downloading sam2_hiera_large.pt checkpoint..."
# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }

# Define the URLs for SAM 2.1 checkpoints
SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"

# $1 is the directory to store the checkpoint
DEFAULT_DIR=test
if [ -z "$1" ]; then
DIR_NAME=$DEFAULT_DIR
else
# Use provided directory name
DIR_NAME=$1
fi

# SAM 2.1 checkpoints
echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
$CMD $DIR_NAME $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }

echo "Downloading sam2.1_hiera_small.pt checkpoint..."
$CMD $DIR_NAME $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }

echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
$CMD $DIR_NAME $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }

echo "Downloading sam2.1_hiera_large.pt checkpoint..."
$CMD $DIR_NAME $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }

echo "All checkpoints are downloaded successfully."
6 changes: 6 additions & 0 deletions scripts/run_ruff_fix.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
ruff check . --fix
# --isolated is used to skip the allowlist at all so this applies to all files
# please be careful when using this large changes means everyone needs to rebase
ruff check --isolated --select F821,F823,W191 --fix
ruff check --select F,I --fix
ruff format .
70 changes: 11 additions & 59 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import platform
import sys
Expand All @@ -19,6 +18,10 @@
import torchao
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5
from torchao._models.utils import (
get_arch_name,
write_json_result,
)

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

Expand All @@ -37,14 +40,6 @@ def elapsed_time(self, other_event):
return abs(other_event.event_time - self.event_time) * 1000


def get_arch_name() -> str:
if torch.cuda.is_available():
return torch.cuda.get_device_name()
else:
# This returns x86_64 or arm64 (for aarch64)
return platform.machine()


def device_timer(device):
if "cuda" in device:
return torch.cuda.Event(enable_timing=True)
Expand All @@ -65,39 +60,6 @@ def device_sync(device):
print(f"device={device} is not yet suppported")


def write_json_result(output_json_path, headers, row):
"""
Write the result into JSON format, so that it can be uploaded to the benchmark database
to be displayed on OSS dashboard. The JSON format is defined at
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
mapping_headers = {headers[i]: v for i, v in enumerate(row)}
record = {
"benchmark": {
"name": "TorchAO benchmark",
"mode": "inference",
"dtype": mapping_headers["dtype"],
"extra_info": {
"device": mapping_headers["device"],
"arch": mapping_headers["arch"],
},
},
"model": {
"name": mapping_headers["name"],
"type": "model",
"origins": ["pytorch"],
},
"metric": {
"name": mapping_headers["metric"],
"benchmark_values": [mapping_headers["actual"]],
"target_value": mapping_headers["target"],
},
}

with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f:
print(json.dumps(record), file=f)


default_device = (
"cuda"
if torch.cuda.is_available()
Expand Down Expand Up @@ -728,20 +690,10 @@ def ffn_or_attn_only(mod, fqn):
example_input=inputs,
)
if "autoquant-all" == quantization:
all_qtensor_classes = (
torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST
+ torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST
+ torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
)
if torchao.utils.is_sm_89():
# this is fp8 related subclasses, should rename
all_qtensor_classes += (
torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST
)
model = autoquant(
model,
manual=True,
qtensor_class_list=all_qtensor_classes,
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
else:
Expand Down Expand Up @@ -978,13 +930,13 @@ def callback(x):
f.write(result_txt)
f.close()

headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = checkpoint_path.parent.name
arch = get_arch_name()
dtype = quantization or str(precision)
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
if output_json_path:
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = checkpoint_path.parent.name
arch = get_arch_name()
dtype = quantization or str(precision)
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
write_json_result(output_json_path, headers, memory_result)
write_json_result(output_json_path, headers, performance_result)

Expand Down
Loading

0 comments on commit cbd7c29

Please sign in to comment.