Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions docs/how_to/tutorials/e2e_opt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,32 @@
# leverage MetaSchedule to tune the model and store the tuning logs to the database. We also
# apply the database to the model to get the best performance.
#

TOTAL_TRIALS = 8000 # Change to 20000 for better performance if needed
# The ResNet18 model will be divided into 20 independent tuning tasks during compilation.
# To ensure each task receives adequate tuning resources in one iteration while providing
# early feedback:
#
# - To quickly observe tuning progress, each task is allocated a maximum of 16 trials per
# iteration (controlled by ``max_trials_per_task=16``). Setting ``total_trials`` to at least
# ``320 (20 tasks * 16 trials)`` ensures every task receives one full iteration of tuning.
# - If ``max_trials_per_task`` is unspecified, the system defaults to ``min(max_trials_per_iter=64,
# total_trials)`` trials per task per iteration. This may lead to undersubscribed tuning when
# ``total_trials`` is insufficient (e.g., ``64 < total_trials < 20 * 64``), potentially skipping
# some tasks entirely, leaving critical operators unoptimized or missing thread binding for
# untuned tasks. Explicitly setting both parameters avoids this issue and provides deterministic
# resource allocation across all tasks.

TOTAL_TRIALS = 320 # Change to 20000 for better performance if needed
target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") # Change to your target device
work_dir = "tuning_logs"

if not IS_IN_CI:
mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod)
mod = relax.get_pipeline(
"static_shape_tuning",
target=target,
work_dir=work_dir,
total_trials=TOTAL_TRIALS,
max_trials_per_task=16,
)(mod)

# Only show the main function
mod["main"].show()
Expand All @@ -119,6 +138,6 @@
# Need to allocate data and params on GPU device
gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
gpu_params = [tvm.nd.array(p, dev) for p in params["main"]]
gpu_out = vm["main"](gpu_data, *gpu_params).numpy()
gpu_out = vm["main"](gpu_data, *gpu_params)[0].numpy()

print(gpu_out.shape)
14 changes: 12 additions & 2 deletions python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
as it is or serves as a basis to do further composition.
"""
# pylint: disable=unused-argument
from typing import Union
from typing import Union, Optional

import tvm
from tvm import meta_schedule as ms
Expand Down Expand Up @@ -111,6 +111,7 @@ def static_shape_tuning_pipeline(
target: Union[str, tvm.target.Target],
work_dir: str = "tuning_logs",
cpu_weight_prepack: bool = False,
max_trials_per_task: Optional[int] = None,
):
"""Tune the static shape model and store the log to database.

Expand All @@ -128,6 +129,9 @@ def static_shape_tuning_pipeline(
cpu_weight_prepack : bool
Whether to enable the cpu weight prepack feature.

max_trials_per_task : Optional[int]
The maximum number of trials to run per task.

Note
----
`cpu_weight_prepack` is expected to be `True` when running on CPU for
Expand All @@ -142,6 +146,7 @@ def static_shape_tuning_pipeline(
target="llvm -num-cores 16",
work_dir="tuning_logs",
cpu_weight_prepack=True,
max_trials_per_task=64,
)(mod)

ex = tvm.compile(mod, target=target)
Expand Down Expand Up @@ -177,7 +182,12 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
*pre_tuning_layout_rewrite,
# Skip tuning if total_trials is 0
(
transform.MetaScheduleTuneIRMod({}, work_dir, total_trials)
transform.MetaScheduleTuneIRMod(
params={},
work_dir=work_dir,
max_trials_global=total_trials,
max_trials_per_task=max_trials_per_task,
)
if total_trials > 0
else tvm.transform.Sequential([])
),
Expand Down