From 3ee836491cf30bf3bab9967722f48b6b38a37bb9 Mon Sep 17 00:00:00 2001 From: Paul Fuqua Date: Fri, 2 Aug 2024 16:38:29 -0500 Subject: [PATCH] Use rocprofv2 instead of rocprof. Abstract the boilerplate for collecting results from a process. Account for .MLIR_N_REPEATS in rocprofv2 results, which don't include it. Account for nrepeats in a smarter way -- count the rows, while verifying. Don't do attention perfRunner.py on gfx110x. Don't run the CK benchmarking for gfx110x, because ck-benchmark-driver won't compile. getFusionTestInfo and runFusionKernel turn out to be mostly the same. Invent --rocprof-version to switch between rocprof and rocprofv2. Change default to rocprofv2. --- mlir/utils/performance/perfRunner.py | 337 +++++++++++++++------------ 1 file changed, 192 insertions(+), 145 deletions(-) diff --git a/mlir/utils/performance/perfRunner.py b/mlir/utils/performance/perfRunner.py index a9e4c589bfd7..bc719ae28d69 100644 --- a/mlir/utils/performance/perfRunner.py +++ b/mlir/utils/performance/perfRunner.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import csv -from collections import OrderedDict +from collections import OrderedDict, defaultdict import getopt import os import subprocess @@ -13,6 +13,7 @@ import glob import argparse import re +import tempfile from dataclasses import dataclass from typing import Optional, Dict, Tuple @@ -23,11 +24,20 @@ from perfCommonUtils import Operation, GEMMLibrary # global variables. -ROCPROF = '/opt/rocm/bin/rocprof' -MIOPENDRIVER = '/opt/rocm/bin/MIOpenDriver' -BENCHMARKING_RESULT_FILE_NAME = 'results.stats.csv' -BENCHMARKING_METRICS_FILE_NAME = 'results.csv' +BENCHMARKINGV1_RESULT_FILE_NAME = 'results.stats.csv' +BENCHMARKINGV1_METRICS_FILE_NAME = 'results.csv' +ROCPROFV1 = '/opt/rocm/bin/rocprof' +ROCPROFV1_OPTS = ['--stats', '-o', BENCHMARKINGV1_METRICS_FILE_NAME] +BENCHMARKINGV2_RESULT_FILE_NAME = './pmc_1/results_stats.csv' +BENCHMARKINGV2_METRICS_FILE_NAME = './pmc_1/results_stats.csv' +ROCPROFV2 = '/opt/rocm/bin/rocprofv2' +ROCPROFV2_OPTS = ['--plugin', 'file', '--plugin-version', '1', '--kernel-trace', '-o', 'stats', '-fi', '100'] +ROCPROF = ROCPROFV2 +ROCPROF_OPTS = ROCPROFV2_OPTS +BENCHMARKING_RESULT_FILE_NAME = BENCHMARKINGV2_RESULT_FILE_NAME +BENCHMARKING_METRICS_FILE_NAME = BENCHMARKINGV2_METRICS_FILE_NAME ROCMLIR_INPUT_METRICS_FILE_NAME = 'rocmlir_metrics.txt' +MIOPENDRIVER = '/opt/rocm/bin/MIOpenDriver' DIRECTIONS = ['-F 1', '-F 2', '-F 4'] DATA_TYPES = ['conv', 'convfp16', 'convint8'] LAYOUTS = ['NHWC', 'NCHW'] @@ -129,6 +139,9 @@ def create_paths(config_file_path, mlir_build_dir_path) -> Paths: # utility functions. def getNanoSeconds(fileName): + pass + +def getNanoSecondsV1(fileName): if not os.path.exists(fileName): return np.nan with open(fileName, 'r') as csv_file: @@ -137,9 +150,40 @@ def getNanoSeconds(fileName): result = 0 for row in reader: result += int(row['AverageNs']) - csv_file.close() return result +def getNanoSecondsV2(fileName): + if not os.path.exists(fileName): + return np.nan + with open(fileName) as csvfile: + reader = csv.reader(csvfile, delimiter=',') + headers = next(reader) + rows = [row for row in list(reader) if row] + if len(headers) == 2 + len(rows[0]): + # Correct the header by removing 'sig' and 'obj'. Remove once the + # bug (https://github.com/ROCm/rocprofiler/issues/144) is fixed. + headers.remove('sig') + headers.remove('obj') + end_index = headers.index('EndNs') + begin_index = headers.index('BeginNs') + name_index = headers.index('KernelName') + # rocprofv2 doesn't give us an AverageNs field, so we must compute + # it from the total and the number of rows. For bwd kernels, if + # not others, we'll have more than one kernel function, and will + # need to compute separate averages and then accumulate them. + results = defaultdict(int) + counts = defaultdict(int) + for row in rows: + assert len(row) == len(headers) + kernel_name = row[name_index] + results[kernel_name] += int(row[end_index]) - int(row[begin_index]) + counts[kernel_name] += 1 + result = 0 + for kernel,total in results.items(): + result += float(total) / float(counts[kernel]) + return result + + def getMetricArgsForRocprof(arch): chip = GFX_CHIP_RE.search(arch).group(0) current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -153,21 +197,46 @@ def getMetricArgsForRocprof(arch): # Bank conflict functions.The percentage of GPUTime LDS is stalled by bank # conflicts. Value range: 0% (optimal) to 100% (bad). def getBankConflict(fileName): + pass + +def getBankConflictV1(fileName): if not os.path.exists(fileName): - result = "NaN" - return result + return np.nan with open(fileName, 'r') as csv_file: reader = csv.DictReader(csv_file, delimiter = ',') header = reader.fieldnames if 'LDSBankConflict' not in header: return np.nan - - result = [] + sum = 0 + count = 0 for row in reader: - result.append(float(row['LDSBankConflict'])) - csv_file.close() - result_average = sum(result) / len(result) - return result_average + sum += float(row['LDSBankConflict']) + count += 1 + return sum / count + +def getBankConflictV2(fileName): + if not os.path.exists(fileName): + return np.nan + with open(fileName) as csvfile: + reader = csv.reader(csvfile, delimiter=',') + headers = next(reader) + if 'LDSBankConflict' not in headers: + return np.nan + rows = [row for row in list(reader) if row] + if len(headers) == 2 + len(rows[0]): + # Correct the header by removing 'sig' and 'obj'. Remove once the + # bug (https://github.com/ROCm/rocprofiler/issues/144) is fixed. + headers.remove('sig') + headers.remove('obj') + ldsbc_index = headers.index('LDSBankConflict') + sum = 0 + count = 0 + for row in rows: + assert len(row) == len(headers) + sum += float(row[ldsbc_index]) + count += 1 + return sum / count + # Tuning databases MaybeTuningDb = Optional[Dict[Tuple[str, str], str]] @@ -209,10 +278,10 @@ def getMilliseconds(output): return float(result.group(1)) -def runPipeline(proc_specs): +def runPipeline(proc_specs, initial_stdin=subprocess.DEVNULL): procs = [] for proc in proc_specs: - prev_stdout = procs[-1].stdout if procs else subprocess.DEVNULL + prev_stdout = procs[-1].stdout if procs else initial_stdin po = subprocess.Popen(proc, stdin=prev_stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE) procs.append(po) try: @@ -525,6 +594,7 @@ def __init__(self, dtype: str, direction: str, filterLayout: str, inputLayout:st @classmethod def benchmarkExternal(cls, commandLine, paths: Paths, arch, numCU): + os.system("rm -f "+BENCHMARKING_RESULT_FILE_NAME) os.system("rm -f "+BENCHMARKING_METRICS_FILE_NAME) config = cls.fromCommandLine(commandLine, arch, numCU) MIOpenDriverCommand = [MIOPENDRIVER, *commandLine, '-V', '0', '-t', '1'] @@ -626,7 +696,7 @@ class GemmConfiguration(PerfConfiguration): TABLE_COLUMNS = reportUtils.GEMM_TEST_PARAMETERS + ['LDSBankConflict'] + ['TFlops'] def computeTFlops(self, ns): # NaN will propagate as expected - # Repeats are handled by the fact that we're using avarageNs + # Repeats are handled by the fact that we're using averageNs return (2.0 * self.g * self.m * self.k * self.n) / (float(ns) * 1e-9) / 1e12 def tableEntry(self, nanoSeconds): @@ -894,6 +964,7 @@ def benchmarkExternal(cls, commandLine, paths: Paths, arch, numCU): benchmarkArgs = config.generateMlirDriverCommandLine("") # remove the result file generated by rocprof in previous benchmarking os.system("rm -f "+BENCHMARKING_RESULT_FILE_NAME) + os.system("rm -f "+BENCHMARKING_METRICS_FILE_NAME) print(f"Running rocBLAS benchmark {config!r}") profilerCommand = [paths.mlir_paths.rocblas_benchmark_driver_path] + \ benchmarkArgs.split() @@ -926,14 +997,14 @@ def benchmarkExternal(cls, commandLine, paths: Paths, arch, numCU): def runConfigWithMLIR(config: PerfConfiguration, paths: Paths, arch, rocmlir_gen_flags, debug=True): # remove the result file generated by rocprof in previous benchmarking os.system("rm -f "+BENCHMARKING_RESULT_FILE_NAME) + os.system("rm -f "+BENCHMARKING_METRICS_FILE_NAME) commandLineOptions = config.generateMlirDriverCommandLine(rocmlir_gen_flags) if debug: print("Running MLIR Benchmark: ", repr(config)) rocmlirGenCommand = paths.mlir_paths.rocmlir_gen_path + ' -ph ' + commandLineOptions rocmlirDriverCommand = [paths.mlir_paths.rocmlir_driver_path, '-c'] mlir_cpu_runner_args = [f'--shared-libs={paths.mlir_paths.libmlir_rocm_runtime_path},{paths.mlir_paths.libconv_validation_wrappers_path},{paths.mlir_paths.libmlir_runtime_utils_path},{paths.mlir_paths.libmlir_c_runner_utils_path}', '--entry-point-result=void'] - profilerCommand = [ROCPROF] + getMetricArgsForRocprof(arch) + ['--stats', '-o', BENCHMARKING_METRICS_FILE_NAME, paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args - + profilerCommand = [ROCPROF] + getMetricArgsForRocprof(arch) + ROCPROF_OPTS + [paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args runPipeline([rocmlirGenCommand.split(), rocmlirDriverCommand, profilerCommand]) # Benchmarking function. @@ -973,8 +1044,6 @@ def generatePerformanceResults(configs, confClass, paths: Paths, arch, numCU, tu suffixes=('', f" ({externalName})")) externalTFlopsCol = f"{externalName} TFlops (no MLIR Kernels)" df.rename(columns={'TFlops': 'MLIR TFlops', f"TFlops ({externalName})": externalTFlopsCol}, inplace=True) -# if tuned_df is None and quick_tuned_df is None: -# df.drop(columns=['PerfConfig'], inplace=True) if tuned_df is not None: # No need for suffixes, the conflicting columns have been renamed # Also note that we're ignoring PerfConfig with the -3 @@ -1048,50 +1117,7 @@ def findRunCommand(filename): print("WARNING: cannot find valid RUN command in ", filename) return None, None -# Extract testVector and test function name from the test file -def getFusionTestInfo(filename, paths: Paths): - testEntry = {} - rocmlirCommand, futName = findRunCommand(filename) - if not rocmlirCommand: - return testEntry - # rocmlir-gen -fut test -arch gfx90a --clone-harness - rocmlirgenCommand = [paths.mlir_paths.rocmlir_gen_path, '-fut', futName, '-arch', getChip(), '--clone-harness', filename] - p0 = subprocess.Popen(rocmlirgenCommand, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) - if "-migraphx-to-tosa" in rocmlirCommand: - rocmlirOptCommand = [paths.mlir_paths.rocmlir_opt_path, '-migraphx-to-tosa'] - rocmlirDriverCommand = [paths.mlir_paths.rocmlir_driver_path, '-host-pipeline', 'highlevel', '-targets', getChip()] - # rocmlir-opt -migraphx-to-tosa ../mlir/test/fusion/resnet50-e2e/mixr-resnet-fusion-case-1.mlir - p1 = subprocess.Popen(rocmlirOptCommand, stdin=p0.stdout, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) - # pipe to rocmlir-driver -host-pipeline partition,highlevel -targets gfx90a - p2 = subprocess.Popen(rocmlirDriverCommand, stdin=p1.stdout, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) - p1.stdout.close() - elif "migraphx" in rocmlirCommand: - rocmlirMigraphxCommand = [paths.mlir_paths.rocmlir_driver_path, '-kernel-pipeline', 'migraphx'] - rocmlirDriverCommand = [paths.mlir_paths.rocmlir_driver_path, '-host-pipeline', 'migraphx,highlevel', '-targets', getChip()] - # rocmlir-driver -kernel-pipeline migraphx ../mlir/test/fusion/resnet50-e2e/mixr-resnet-fusion-case-1.mlir - p1 = subprocess.Popen(rocmlirMigraphxCommand, stdin=p0.stdout, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) - # pipe to rocmlir-driver -host-pipeline partition,highlevel -targets gfx90a - p2 = subprocess.Popen(rocmlirDriverCommand, stdin=p1.stdout, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) - p1.stdout.close() - else: - rocmlirDriverCommand = [paths.mlir_paths.rocmlir_driver_path, '-host-pipeline', 'highlevel', '-targets', getChip()] - # rocmlir-driver -host-pipeline partition,highlevel -targets gfx90a - p2 = subprocess.Popen(rocmlirDriverCommand, stdin=p0.stdout, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) - - # pipe to rocmlir_gen --emit-tuning-key - tuningKey = subprocess.Popen([paths.mlir_paths.rocmlir_gen_path, '--emit-tuning-key', '-'], stdin=p2.stdout, - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - p2.stdout.close() - output, _ = tuningKey.communicate() - result = output.decode('utf-8').strip().split('\t') - testEntry = {'filename' : filename, 'testVector' : result[2], 'futName' : futName} - return testEntry - -def runFusionKernel(filename, rocmlirGenArgs, paths: Paths): - os.system("rm -f "+BENCHMARKING_RESULT_FILE_NAME) - - rocmlirCommand, futName = findRunCommand(filename) - +def makeBasicFusionPipeline(filename, rocmlirCommand, futName, paths: Paths): # rocmlir-gen -fut test -arch gfx90a --clone-harness rocmlirgenCommand = [paths.mlir_paths.rocmlir_gen_path, '-fut', futName, '-arch', getChip(), '--clone-harness', filename] commands = [rocmlirgenCommand] @@ -1108,15 +1134,20 @@ def runFusionKernel(filename, rocmlirGenArgs, paths: Paths): else: rocmlirDriverCommand = [paths.mlir_paths.rocmlir_driver_path, '-host-pipeline', 'highlevel', '-targets', getChip()] commands.append(rocmlirDriverCommand) + return commands +def runFusionKernel(mlirfile, rocmlirGenArgs, paths: Paths): + os.system("rm -f "+BENCHMARKING_RESULT_FILE_NAME) + os.system("rm -f "+BENCHMARKING_METRICS_FILE_NAME) + commands = [] rocmlirGenCommand = [paths.mlir_paths.rocmlir_gen_path] + rocmlirGenArgs commands.append(rocmlirGenCommand) kernelPipelineCommand = [paths.mlir_paths.rocmlir_driver_path, '-host-pipeline', 'mhal,runner', '-kernel-pipeline','full'] commands.append(kernelPipelineCommand) mlir_cpu_runner_args = [f'--shared-libs={paths.mlir_paths.libmlir_rocm_runtime_path},{paths.mlir_paths.libconv_validation_wrappers_path},{paths.mlir_paths.libmlir_runtime_utils_path},{paths.mlir_paths.libmlir_c_runner_utils_path}', '--entry-point-result=void'] - profilerCommand = [ROCPROF] + getMetricArgsForRocprof(getChip()) + ['--stats', '-o', BENCHMARKING_METRICS_FILE_NAME] + [paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args + profilerCommand = [ROCPROF] + getMetricArgsForRocprof(getChip()) + ROCPROF_OPTS + [paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args commands.append(profilerCommand) - runPipeline(commands) + runPipeline(commands, initial_stdin=mlirfile) # Generate fusion vs. gemm/conv performance results def benchmarkFusionKernels(test_dir, paths: Paths, arch, numCU, tuningDb: MaybeTuningDb): @@ -1124,12 +1155,6 @@ def benchmarkFusionKernels(test_dir, paths: Paths, arch, numCU, tuningDb: MaybeT perfResults = {} #associate testVector to config and performances chip = GFX_CHIP_RE.search(arch).group(0) - # Prepare test cases - for filename in glob.glob(test_dir+'/*.mlir'): - testEntry = getFusionTestInfo(filename, paths) - if testEntry: - allTests.append(testEntry) - if tuningDb: # Force all split-K factors to 1, to avoid trouble because fusion # and split-K aren't compatible. Crude parser approximating @@ -1140,63 +1165,75 @@ def benchmarkFusionKernels(test_dir, paths: Paths, arch, numCU, tuningDb: MaybeT splitPerf[6] = '1' tuningDb[arch,config] = ','.join(splitPerf) - # Profile each test case - for test in allTests: - filename = test['filename'] - testVector = test['testVector'] - futName = test['futName'] - - print("Profiling:", filename) - # Sanity check - if not testVector: - print("\tCannot find a test vector") - continue - if not futName: - print("\tCannot find rocmlir-gen with -fut") - continue + # Prepare test cases + for filename in sorted(glob.glob(test_dir+'/*.mlir')): + with tempfile.TemporaryFile() as mlirfile: + # Extract testVector and test function name from the test file + rocmlirCommand, futName = findRunCommand(filename) + if not rocmlirCommand: + continue + commands = makeBasicFusionPipeline(filename, rocmlirCommand, futName, paths) + partialCode, _ = runPipeline(commands) + mlirfile.seek(0) + mlirfile.write(bytes(partialCode)) + mlirfile.seek(0) + tuningKeyCommand = [paths.mlir_paths.rocmlir_gen_path, '--emit-tuning-key', '-'] + output, _ = runPipeline([tuningKeyCommand], initial_stdin=mlirfile) + result = output.decode('utf-8').strip().split('\t') + testVector = result[2] + + print("Profiling:", filename) + # Sanity check + if not testVector: + print("\tCannot find a test vector") + continue + if not futName: + print("\tCannot find rocmlir-gen with -fut") + continue - commandLine = testVector.split(sep=' ') - if commandLine[0].startswith('conv'): - op = 'conv' - config = ConvConfiguration.fromCommandLine(commandLine, arch, numCU) - else: - op = 'gemm' - config = GemmConfiguration.fromCommandLine(commandLine, arch, numCU) - - # Find the best perf_config - bestPerf ="" - if tuningDb: - configStr = config.toCommandLine() - if (arch, configStr) in tuningDb: - bestPerf = tuningDb[arch, configStr] - config.setPerfConfig(bestPerf) - else: # Tuning DB present but doesn't contain config, add a NaN entry - if not testVector in perfResults: - oneEntry = config.tableEntry(np.nan) - oneEntry['MLIR TFlops'] = np.nan - oneEntry['Fusion/MLIR'] = np.nan - oneEntry['FileName'] = filename - perfResults[testVector] = oneEntry + commandLine = testVector.split(sep=' ') + if commandLine[0].startswith('conv'): + op = 'conv' + config = ConvConfiguration.fromCommandLine(commandLine, arch, numCU) + else: + op = 'gemm' + config = GemmConfiguration.fromCommandLine(commandLine, arch, numCU) + + # Find the best perf_config + bestPerf ="" + if tuningDb: + configStr = config.toCommandLine() + if (arch, configStr) in tuningDb: + bestPerf = tuningDb[arch, configStr] + config.setPerfConfig(bestPerf) + else: # Tuning DB present but doesn't contain config, add a NaN entry + if not testVector in perfResults: + oneEntry = config.tableEntry(np.nan) + oneEntry['MLIR TFlops'] = np.nan + oneEntry['Fusion/MLIR'] = np.nan + oneEntry['FileName'] = filename + perfResults[testVector] = oneEntry + continue + + # Run fusion test + rocmlirGenArgs = ['-ph', '-fut='+futName+'_wrapper', '--perf_config='+bestPerf, '-'] + mlirfile.seek(0) + runFusionKernel(mlirfile, rocmlirGenArgs, paths) + # Get nanoseconds of fusion test + nanoSeconds = getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME) + oneEntry = config.tableEntry(nanoSeconds) + # Keep the best performance + if testVector in perfResults and oneEntry['TFlops'] <= perfResults[testVector]['TFlops']: continue - # Run fusion test - rocmlirGenArgs = ['-ph', '-fut='+futName+'_wrapper', '--perf_config='+bestPerf, '-'] - runFusionKernel(filename, rocmlirGenArgs, paths) - # Get nanoseconds of fusion test - nanoSeconds = getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME) - oneEntry = config.tableEntry(nanoSeconds) - # Keep the best performance - if testVector in perfResults and oneEntry['TFlops'] <= perfResults[testVector]['TFlops']: - continue - - # Run gemm or conv op with the same configuration - runConfigWithMLIR(config, paths, arch, '') - # Get nanoseconds of gemm/conv - nanoSeconds = getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME) - oneEntry['MLIR TFlops'] = config.computeTFlops(nanoSeconds) - oneEntry['Fusion/MLIR'] = oneEntry['TFlops']/oneEntry['MLIR TFlops'] - oneEntry['FileName'] = filename - perfResults[testVector] = oneEntry + # Run gemm or conv op with the same configuration + runConfigWithMLIR(config, paths, arch, '') + # Get nanoseconds of gemm/conv + nanoSeconds = getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME) + oneEntry['MLIR TFlops'] = config.computeTFlops(nanoSeconds) + oneEntry['Fusion/MLIR'] = oneEntry['TFlops']/oneEntry['MLIR TFlops'] + oneEntry['FileName'] = filename + perfResults[testVector] = oneEntry df = pd.DataFrame(perfResults.values()) df.fillna('NaN', inplace=True) @@ -1220,26 +1257,12 @@ def tuneMLIRKernels(configs, arch, numCU): if config.inputLayout == 'nchw': MIOpenDriverCommand = [MIOPENDRIVER, *commandLine, '-V', '0'] print(' '.join(MIOpenDriverCommand)) - p1 = subprocess.Popen(MIOpenDriverCommand, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=envs) - # get output. - try: - _, errs = p1.communicate(timeout=300) - if len(errs) > 0 and p1.returncode != 0: - raise OSError(errs.decode('utf-8')) - except subprocess.TimeoutExpired: - p1.kill() - print("MIOpen tuning timed out") - _, errs = p1.communicate() + runPipeline([MIOpenDriverCommand]) def is_xdlops_present() -> bool: """This function checks whether a GPU with xdlops support is present""" xdlop_supported_gpus = ['gfx908', 'gfx90a', 'gfx942'] - xdlop_supported_gpus_str = xdlop_supported_gpus[0] - for gpu in xdlop_supported_gpus[1:]: - xdlop_supported_gpus_str += '|' + gpu + xdlop_supported_gpus_str = '|'.join(xdlop_supported_gpus) r = subprocess.run(f"/opt/rocm/bin/rocm_agent_enumerator -t GPU | grep -q -E '{xdlop_supported_gpus_str}'", check=True, shell=True) if r.returncode == 0: @@ -1285,7 +1308,7 @@ def getNumCU(chip): rocminfo = subprocess.check_output("/opt/rocm/bin/rocminfo", stderr=subprocess.PIPE) except subprocess.CalledProcessError as e: - print(e.stderr.decode('utf-8')) + print(f"Process error: {e.stderr.decode('utf-8')}") raise except Exception as e: print(f"Exception: {e}") @@ -1371,7 +1394,7 @@ def main(args=None): mutex_arg_group.add_argument( "--batch_all", action="store_true", - help="CSV batch benchmarking with MLIR and external reference (defalut on no args)" + help="CSV batch benchmarking with MLIR and external reference (default on no args)" ) mutex_arg_group.add_argument( "--external", @@ -1447,8 +1470,32 @@ def main(args=None): help='Force a set of datatypes' ) + parser.add_argument("--rv", "--rocprof-version", choices=['1', '2'], + default='2', + help="Use rocprof (V1) or rocprofv2.") + parsed_args = parser.parse_args(args) + global getNanoSeconds, getBankConflict, ROCPROF, ROCPROF_OPTS + global BENCHMARKING_RESULT_FILE_NAME, BENCHMARKING_METRICS_FILE_NAME + if parsed_args.rv == '1': + ROCPROF = ROCPROFV1 + ROCPROF_OPTS = ROCPROFV1_OPTS + BENCHMARKING_RESULT_FILE_NAME = BENCHMARKINGV1_RESULT_FILE_NAME + BENCHMARKING_METRICS_FILE_NAME = BENCHMARKINGV1_METRICS_FILE_NAME + getNanoSeconds = getNanoSecondsV1 + getBankConflict = getBankConflictV1 + elif parsed_args.rv == '2': + ROCPROF = ROCPROFV2 + ROCPROF_OPTS = ROCPROFV2_OPTS + BENCHMARKING_RESULT_FILE_NAME = BENCHMARKINGV2_RESULT_FILE_NAME + BENCHMARKING_METRICS_FILE_NAME = BENCHMARKINGV2_METRICS_FILE_NAME + getNanoSeconds = getNanoSecondsV2 + getBankConflict = getBankConflictV2 + else: + print(f"impossible rocprof version: {parsed_args.rv}") + sys.exit(1) + rocmlir_gen_flags = '' if 'rocmlir_gen_flags' in parsed_args: rocmlir_gen_flags = parsed_args.rocmlir_gen_flags