8
8
import itertools
9
9
from nvfuser import FusionCache , FusionDefinition
10
10
from dataclasses import dataclass , astuple
11
- from typing import Callable
12
11
13
12
# ================================ Description ================================
14
13
# This file contains the utility function for autotuning scripts.
@@ -183,7 +182,14 @@ def separate_data(script_config, parameters, performance):
183
182
184
183
185
184
# Apply schedule decorator, run fusion, and profile performance
186
- def run_profile (autotune_config , presched_fd , inputs , scheduler_config = None ):
185
+ def run_profile (
186
+ autotune_config ,
187
+ presched_fd ,
188
+ inputs ,
189
+ scheduler_config = None ,
190
+ * ,
191
+ disable_validation = False ,
192
+ ):
187
193
scheduled_fd = autotune_config .custom_scheduler (presched_fd , scheduler_config )
188
194
nvf_outputs = scheduled_fd .execute (inputs , profile = True )
189
195
@@ -193,15 +199,14 @@ def run_profile(autotune_config, presched_fd, inputs, scheduler_config=None):
193
199
inp .grad .data .zero_ ()
194
200
195
201
# validate correctness
196
- """
197
- eager_output = autotune_config.eager_reference(inputs)
198
- assert torch.allclose(
199
- nvf_outputs[0].to(torch.double),
200
- eager_output.to(torch.double),
201
- atol=5e-1,
202
- rtol=5e-1,
203
- )
204
- """
202
+ if not disable_validation :
203
+ eager_output = autotune_config .eager_reference (inputs )
204
+ assert torch .allclose (
205
+ nvf_outputs [0 ].to (torch .double ),
206
+ eager_output .to (torch .double ),
207
+ atol = 5e-1 ,
208
+ rtol = 5e-1 ,
209
+ )
205
210
206
211
prof = scheduled_fd .profile ()
207
212
bandwidth = prof .kernel_profiles [0 ].effective_bandwidth_gbs
@@ -326,7 +331,11 @@ def test_model(clf, script_config, autotune_config):
326
331
autotune_config .create_fusion_func ()(presched_fd )
327
332
328
333
_ , est_time_ms = run_profile (
329
- autotune_config , presched_fd , inputs , estimate_config
334
+ autotune_config ,
335
+ presched_fd ,
336
+ inputs ,
337
+ estimate_config ,
338
+ disable_validation = True ,
330
339
)
331
340
est_perfs .append (est_time_ms )
332
341
print (
@@ -344,7 +353,9 @@ def test_model(clf, script_config, autotune_config):
344
353
with FusionDefinition () as presched_fd :
345
354
autotune_config .create_fusion_func ()(presched_fd )
346
355
347
- _ , nvf_time_ms = run_profile (autotune_config , presched_fd , inputs )
356
+ _ , nvf_time_ms = run_profile (
357
+ autotune_config , presched_fd , inputs , disable_validation = True
358
+ )
348
359
nvf_perfs .append (nvf_time_ms )
349
360
print (
350
361
f"{ script_config .empirical_batch_size } , { hidden_shape } , { nvf_time_ms : .3f} "
0 commit comments