Skip to content

Commit 5408adb

Browse files
committed
create disable_validation
1 parent 09f3696 commit 5408adb

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

doc/dev/python_scheduling/autotune_outer_reduction.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
test_model,
2424
at_least_one_div,
2525
ceil_div,
26-
floor_div,
27-
round_down_pow2,
2826
round_up_pow2,
2927
round_up_multiple_of,
3028
round_down_pow2_or_multiple_of,
@@ -238,7 +236,6 @@ def get_grid_outer_reduction_configurations(
238236
bdimy = min(ceil_div(threads_per_cta, bdimx), num_reductions)
239237
bdimy = round_down_pow2_or_multiple_of(bdimy, 8)
240238

241-
242239
gidim = ceil_div(num_iterations, gidim * bdimx * vectorize_factor)
243240
num_reductions_available = ceil_div(
244241
num_reductions, grdim * bdimy * reduction_unroll_factor
@@ -284,9 +281,9 @@ def get_grid_outer_reduction_configurations(
284281
vectorization_factor_options,
285282
reduction_unroll_factor_options,
286283
):
287-
# yield from get_block_outer_reduction_configurations(
288-
# threads_per_cta, vectorize_factor, reduction_unroll_factor
289-
# )
284+
yield from get_block_outer_reduction_configurations(
285+
threads_per_cta, vectorize_factor, reduction_unroll_factor
286+
)
290287
yield from get_grid_outer_reduction_configurations(
291288
threads_per_cta, vectorize_factor, reduction_unroll_factor
292289
)

doc/dev/python_scheduling/autotune_utils.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import itertools
99
from nvfuser import FusionCache, FusionDefinition
1010
from dataclasses import dataclass, astuple
11-
from typing import Callable
1211

1312
# ================================ Description ================================
1413
# This file contains the utility function for autotuning scripts.
@@ -183,7 +182,14 @@ def separate_data(script_config, parameters, performance):
183182

184183

185184
# Apply schedule decorator, run fusion, and profile performance
186-
def run_profile(autotune_config, presched_fd, inputs, scheduler_config=None):
185+
def run_profile(
186+
autotune_config,
187+
presched_fd,
188+
inputs,
189+
scheduler_config=None,
190+
*,
191+
disable_validation=False,
192+
):
187193
scheduled_fd = autotune_config.custom_scheduler(presched_fd, scheduler_config)
188194
nvf_outputs = scheduled_fd.execute(inputs, profile=True)
189195

@@ -193,15 +199,14 @@ def run_profile(autotune_config, presched_fd, inputs, scheduler_config=None):
193199
inp.grad.data.zero_()
194200

195201
# validate correctness
196-
"""
197-
eager_output = autotune_config.eager_reference(inputs)
198-
assert torch.allclose(
199-
nvf_outputs[0].to(torch.double),
200-
eager_output.to(torch.double),
201-
atol=5e-1,
202-
rtol=5e-1,
203-
)
204-
"""
202+
if not disable_validation:
203+
eager_output = autotune_config.eager_reference(inputs)
204+
assert torch.allclose(
205+
nvf_outputs[0].to(torch.double),
206+
eager_output.to(torch.double),
207+
atol=5e-1,
208+
rtol=5e-1,
209+
)
205210

206211
prof = scheduled_fd.profile()
207212
bandwidth = prof.kernel_profiles[0].effective_bandwidth_gbs
@@ -326,7 +331,11 @@ def test_model(clf, script_config, autotune_config):
326331
autotune_config.create_fusion_func()(presched_fd)
327332

328333
_, est_time_ms = run_profile(
329-
autotune_config, presched_fd, inputs, estimate_config
334+
autotune_config,
335+
presched_fd,
336+
inputs,
337+
estimate_config,
338+
disable_validation=True,
330339
)
331340
est_perfs.append(est_time_ms)
332341
print(
@@ -344,7 +353,9 @@ def test_model(clf, script_config, autotune_config):
344353
with FusionDefinition() as presched_fd:
345354
autotune_config.create_fusion_func()(presched_fd)
346355

347-
_, nvf_time_ms = run_profile(autotune_config, presched_fd, inputs)
356+
_, nvf_time_ms = run_profile(
357+
autotune_config, presched_fd, inputs, disable_validation=True
358+
)
348359
nvf_perfs.append(nvf_time_ms)
349360
print(
350361
f"{script_config.empirical_batch_size}, {hidden_shape}, {nvf_time_ms: .3f}"

0 commit comments

Comments
 (0)