Skip to content

Commit

Permalink
Partial response to review comments and CI failure.
Browse files Browse the repository at this point in the history
Made Profiler classes to handle the V1/V2 switch more cleanly.
Made tuningRunner.py use Profiler to get consistent arguments.
Some places in tuningRunner.py use runPipeline, some don't yet.
  • Loading branch information
pcf000 committed Oct 10, 2024
1 parent 4a6ed1b commit 390b3fc
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 151 deletions.
274 changes: 145 additions & 129 deletions mlir/utils/performance/perfRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
BENCHMARKINGV1_RESULT_FILE_NAME = 'results.stats.csv'
BENCHMARKINGV1_METRICS_FILE_NAME = 'results.csv'
ROCPROFV1 = '/opt/rocm/bin/rocprof'
ROCPROFV1_OPTS = ['--stats', '-o', BENCHMARKINGV1_METRICS_FILE_NAME]
ROCPROFV1_OPTS = ['--stats', '-o', BENCHMARKINGV1_RESULT_FILE_NAME]
BENCHMARKINGV2_RESULT_FILE_NAME = './pmc_1/results_stats.csv'
BENCHMARKINGV2_METRICS_FILE_NAME = './pmc_1/results_stats.csv'
ROCPROFV2 = '/opt/rocm/bin/rocprofv2'
Expand All @@ -36,6 +36,123 @@
ROCPROF_OPTS = ROCPROFV2_OPTS
BENCHMARKING_RESULT_FILE_NAME = BENCHMARKINGV2_RESULT_FILE_NAME
BENCHMARKING_METRICS_FILE_NAME = BENCHMARKINGV2_METRICS_FILE_NAME
DEFAULT_PROFILER_VERSION = '2'

class ProfilerBase:
rocprof : str
rocprofOpts : list
benchmarkingResultFileName : str
benchmarkingMetricsFileName : str

def profilerCommand(self, command, arch):
return [self.rocprof] + getMetricArgsForRocprof(arch) + self.rocprofOpts + command

class ProfilerV1(ProfilerBase):
def getNanoSeconds(self, fileName):
if not os.path.exists(fileName):
return np.nan
with open(fileName, 'r') as csv_file:
reader = csv.DictReader(csv_file, delimiter = ',')
result = 0
for row in reader:
result += int(row['AverageNs'])
return result

def getBankConflict(self, fileName):
if not os.path.exists(fileName):
return np.nan
with open(fileName, 'r') as csv_file:
reader = csv.DictReader(csv_file, delimiter = ',')
header = reader.fieldnames
if 'LDSBankConflict' not in header:
return np.nan
sum = 0
count = 0
for row in reader:
sum += float(row['LDSBankConflict'])
count += 1
return sum / count

def __init__(self):
self.rocprof = ROCPROFV1
self.rocprofOpts = ROCPROFV1_OPTS
self.benchmarkingResultFileName = BENCHMARKINGV1_RESULT_FILE_NAME
self.benchmarkingMetricsFileName = BENCHMARKINGV1_METRICS_FILE_NAME

class ProfilerV2(ProfilerBase):
def getNanoSeconds(self, fileName):
if not os.path.exists(fileName):
return np.nan
with open(fileName) as csvfile:
reader = csv.reader(csvfile, delimiter=',')
headers = next(reader)
rows = [row for row in list(reader) if row]
if len(headers) == 2 + len(rows[0]):
# Correct the header by removing 'sig' and 'obj'. Remove once the
# bug (https://github.com/ROCm/rocprofiler/issues/144) is fixed.
headers.remove('sig')
headers.remove('obj')
end_index = headers.index('EndNs')
begin_index = headers.index('BeginNs')
name_index = headers.index('KernelName')
# rocprofv2 doesn't give us an AverageNs field, so we must compute
# it from the total and the number of rows. For bwd kernels, if
# not others, we'll have more than one kernel function, and will
# need to compute separate averages and then accumulate them.
results = defaultdict(int)
counts = defaultdict(int)
for row in rows:
assert len(row) == len(headers)
kernel_name = row[name_index]
results[kernel_name] += int(row[end_index]) - int(row[begin_index])
counts[kernel_name] += 1
result = 0
for kernel,total in results.items():
result += float(total) / float(counts[kernel])
return result

# Bank conflict functions.The percentage of GPUTime LDS is stalled by bank
# conflicts. Value range: 0% (optimal) to 100% (bad).
def getBankConflict(self, fileName):
if not os.path.exists(fileName):
return np.nan
with open(fileName) as csvfile:
reader = csv.reader(csvfile, delimiter=',')
headers = next(reader)
if 'LDSBankConflict' not in headers:
return np.nan
rows = [row for row in list(reader) if row]
if len(headers) == 2 + len(rows[0]):
# Correct the header by removing 'sig' and 'obj'. Remove once the
# bug (https://github.com/ROCm/rocprofiler/issues/144) is fixed.
headers.remove('sig')
headers.remove('obj')
ldsbc_index = headers.index('LDSBankConflict')
sum = 0
count = 0
for row in rows:
assert len(row) == len(headers)
sum += float(row[ldsbc_index])
count += 1
return sum / count

def __init__(self):
self.rocprof = ROCPROFV2
self.rocprofOpts = ROCPROFV2_OPTS
self.benchmarkingResultFileName = BENCHMARKINGV2_RESULT_FILE_NAME
self.benchmarkingMetricsFileName = BENCHMARKINGV2_METRICS_FILE_NAME

profiler = None

def makeProfiler(version):
if version == '1':
return ProfilerV1()
elif version == '2':
return ProfilerV2()
else:
print(f"impossible rocprof version: {version}")
sys.exit(1)

ROCMLIR_INPUT_METRICS_FILE_NAME = 'rocmlir_metrics.txt'
MIOPENDRIVER = '/opt/rocm/bin/MIOpenDriver'
DIRECTIONS = ['-F 1', '-F 2', '-F 4']
Expand Down Expand Up @@ -75,6 +192,10 @@ class Paths:
configuration_file_path : str
mlir_paths: Optional[MLIRPaths] = None

def cpuRunnerCommand(paths: Paths):
mlir_cpu_runner_args = ['-O2', f'--shared-libs={paths.mlir_paths.libmlir_rocm_runtime_path},{paths.mlir_paths.libconv_validation_wrappers_path},{paths.mlir_paths.libmlir_runtime_utils_path},{paths.mlir_paths.libmlir_c_runner_utils_path}', '--entry-point-result=void']
return [paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args

def find_mlir_build_dir() -> str:
"""
Finds mlir build dir searching either WORKSPACE dir
Expand Down Expand Up @@ -138,51 +259,6 @@ def create_paths(config_file_path, mlir_build_dir_path) -> Paths:
return Paths(config_file_path, mlir_paths)

# utility functions.
def getNanoSeconds(fileName):
pass

def getNanoSecondsV1(fileName):
if not os.path.exists(fileName):
return np.nan
with open(fileName, 'r') as csv_file:
reader = csv.DictReader(csv_file, delimiter = ',')

result = 0
for row in reader:
result += int(row['AverageNs'])
return result

def getNanoSecondsV2(fileName):
if not os.path.exists(fileName):
return np.nan
with open(fileName) as csvfile:
reader = csv.reader(csvfile, delimiter=',')
headers = next(reader)
rows = [row for row in list(reader) if row]
if len(headers) == 2 + len(rows[0]):
# Correct the header by removing 'sig' and 'obj'. Remove once the
# bug (https://github.com/ROCm/rocprofiler/issues/144) is fixed.
headers.remove('sig')
headers.remove('obj')
end_index = headers.index('EndNs')
begin_index = headers.index('BeginNs')
name_index = headers.index('KernelName')
# rocprofv2 doesn't give us an AverageNs field, so we must compute
# it from the total and the number of rows. For bwd kernels, if
# not others, we'll have more than one kernel function, and will
# need to compute separate averages and then accumulate them.
results = defaultdict(int)
counts = defaultdict(int)
for row in rows:
assert len(row) == len(headers)
kernel_name = row[name_index]
results[kernel_name] += int(row[end_index]) - int(row[begin_index])
counts[kernel_name] += 1
result = 0
for kernel,total in results.items():
result += float(total) / float(counts[kernel])
return result


def getMetricArgsForRocprof(arch):
chip = GFX_CHIP_RE.search(arch).group(0)
Expand All @@ -193,51 +269,6 @@ def getMetricArgsForRocprof(arch):
metrics = ['-i', metrics_path]
return metrics


# Bank conflict functions.The percentage of GPUTime LDS is stalled by bank
# conflicts. Value range: 0% (optimal) to 100% (bad).
def getBankConflict(fileName):
pass

def getBankConflictV1(fileName):
if not os.path.exists(fileName):
return np.nan
with open(fileName, 'r') as csv_file:
reader = csv.DictReader(csv_file, delimiter = ',')
header = reader.fieldnames
if 'LDSBankConflict' not in header:
return np.nan
sum = 0
count = 0
for row in reader:
sum += float(row['LDSBankConflict'])
count += 1
return sum / count

def getBankConflictV2(fileName):
if not os.path.exists(fileName):
return np.nan
with open(fileName) as csvfile:
reader = csv.reader(csvfile, delimiter=',')
headers = next(reader)
if 'LDSBankConflict' not in headers:
return np.nan
rows = [row for row in list(reader) if row]
if len(headers) == 2 + len(rows[0]):
# Correct the header by removing 'sig' and 'obj'. Remove once the
# bug (https://github.com/ROCm/rocprofiler/issues/144) is fixed.
headers.remove('sig')
headers.remove('obj')
ldsbc_index = headers.index('LDSBankConflict')
sum = 0
count = 0
for row in rows:
assert len(row) == len(headers)
sum += float(row[ldsbc_index])
count += 1
return sum / count


# Tuning databases
MaybeTuningDb = Optional[Dict[Tuple[str, str], str]]
def read_tuning_db(path: Optional[str]) -> MaybeTuningDb:
Expand Down Expand Up @@ -386,7 +417,7 @@ def computeTFlops(self, ns):

def tableEntry(self, nanoSeconds):
# Future(kdrewnia): This can just be a dict literal on Python 3.7+
bankConflict = getBankConflict(BENCHMARKING_METRICS_FILE_NAME)
bankConflict = profiler.getBankConflict(BENCHMARKING_METRICS_FILE_NAME)
result = OrderedDict()
values = [self.direction, self.dataType, self.chip, self.numCU, self.filterLayout, self.inputLayout, self.outputLayout,
self.n, self.c, self.hi, self.wi, self.k, self.y, self.x, self.dilationH, self.dilationW,
Expand Down Expand Up @@ -701,7 +732,7 @@ def computeTFlops(self, ns):

def tableEntry(self, nanoSeconds):
# Future(kdrewnia): This can just be a dict literal on Python 3.7+
bankConflict = getBankConflict(BENCHMARKING_METRICS_FILE_NAME)
bankConflict = profiler.getBankConflict(BENCHMARKING_METRICS_FILE_NAME)
result = OrderedDict()
values = [self.dataType, self.outDataType, self.chip, self.numCU, self.transA, self.transB, \
self.g, self.m, self.k, self.n, self.perfConfig, bankConflict, self.computeTFlops(nanoSeconds)]
Expand Down Expand Up @@ -1003,8 +1034,7 @@ def runConfigWithMLIR(config: PerfConfiguration, paths: Paths, arch, rocmlir_gen
print("Running MLIR Benchmark: ", repr(config))
rocmlirGenCommand = paths.mlir_paths.rocmlir_gen_path + ' -ph ' + commandLineOptions
rocmlirDriverCommand = [paths.mlir_paths.rocmlir_driver_path, '-c']
mlir_cpu_runner_args = [f'--shared-libs={paths.mlir_paths.libmlir_rocm_runtime_path},{paths.mlir_paths.libconv_validation_wrappers_path},{paths.mlir_paths.libmlir_runtime_utils_path},{paths.mlir_paths.libmlir_c_runner_utils_path}', '--entry-point-result=void']
profilerCommand = [ROCPROF] + getMetricArgsForRocprof(arch) + ROCPROF_OPTS + [paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args
profilerCommand = profiler.profilerCommand(cpuRunnerCommand(paths), arch)
runPipeline([rocmlirGenCommand.split(), rocmlirDriverCommand, profilerCommand])

# Benchmarking function.
Expand All @@ -1019,25 +1049,29 @@ def benchmarkMLIR(commandLine, confClass, paths: Paths, arch, numCU, tuningDb: M

runConfigWithMLIR(config, paths, arch, rocmlir_gen_flags)
# get nanoseconds from rocprof output.
nanoSeconds = getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME)
nanoSeconds = profiler.getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME)
return config.tableEntry(nanoSeconds)

#Generate MLIR vs. MIOpen or rocBLAS performance results
def generatePerformanceResults(configs, confClass, paths: Paths, arch, numCU, tuningDb: MaybeTuningDb, quickTuningDb: MaybeTuningDb, rocmlir_gen_flags):
# Generate MLIR vs. MIOpen or rocBLAS performance results
def generatePerformanceResults(configs, confClass, paths: Paths, arch, numCU,
tuningDb: MaybeTuningDb, quickTuningDb: MaybeTuningDb, rocmlir_gen_flags):
# Never pass tuning DB to this run
mlir_df = pd.DataFrame(benchmarkMLIR(testVector.split(sep=' '), confClass, paths, arch, numCU, None, rocmlir_gen_flags)
for testVector in configs)
mlir_df = pd.DataFrame(benchmarkMLIR(testVector.split(sep=' '), confClass, paths, arch, numCU,
None, rocmlir_gen_flags)
for testVector in configs)
tuned_df = None
if tuningDb:
tuned_df = pd.DataFrame(benchmarkMLIR(testVector.split(sep=' '), confClass, paths, arch, numCU, tuningDb, rocmlir_gen_flags)
for testVector in configs)
tuned_df = pd.DataFrame(benchmarkMLIR(testVector.split(sep=' '), confClass, paths, arch, numCU,
tuningDb, rocmlir_gen_flags)
for testVector in configs)
quick_tuned_df = None
if quickTuningDb:
quick_tuned_df = pd.DataFrame(benchmarkMLIR(testVector.split(sep=' '), confClass, paths, arch, numCU, quickTuningDb, rocmlir_gen_flags)
for testVector in configs)
quick_tuned_df = pd.DataFrame(benchmarkMLIR(testVector.split(sep=' '), confClass, paths, arch, numCU,
quickTuningDb, rocmlir_gen_flags)
for testVector in configs)

external_df = pd.DataFrame(confClass.benchmarkExternal(testVector.split(sep=' '), paths, arch, numCU)
for testVector in configs)
for testVector in configs)

externalName = confClass.EXTERNAL_NAME
df = mlir_df.merge(external_df, on=confClass.TABLE_COLUMNS[:-2],
Expand Down Expand Up @@ -1144,8 +1178,7 @@ def runFusionKernel(mlirfile, rocmlirGenArgs, paths: Paths):
commands.append(rocmlirGenCommand)
kernelPipelineCommand = [paths.mlir_paths.rocmlir_driver_path, '-host-pipeline', 'mhal,runner', '-kernel-pipeline','full']
commands.append(kernelPipelineCommand)
mlir_cpu_runner_args = [f'--shared-libs={paths.mlir_paths.libmlir_rocm_runtime_path},{paths.mlir_paths.libconv_validation_wrappers_path},{paths.mlir_paths.libmlir_runtime_utils_path},{paths.mlir_paths.libmlir_c_runner_utils_path}', '--entry-point-result=void']
profilerCommand = [ROCPROF] + getMetricArgsForRocprof(getChip()) + ROCPROF_OPTS + [paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args
profilerCommand = profiler.profilerCommand(cpuRunnerCommand(paths), getChip())
commands.append(profilerCommand)
runPipeline(commands, initial_stdin=mlirfile)

Expand Down Expand Up @@ -1220,7 +1253,7 @@ def benchmarkFusionKernels(test_dir, paths: Paths, arch, numCU, tuningDb: MaybeT
mlirfile.seek(0)
runFusionKernel(mlirfile, rocmlirGenArgs, paths)
# Get nanoseconds of fusion test
nanoSeconds = getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME)
nanoSeconds = profiler.getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME)
oneEntry = config.tableEntry(nanoSeconds)
# Keep the best performance
if testVector in perfResults and oneEntry['TFlops'] <= perfResults[testVector]['TFlops']:
Expand All @@ -1229,7 +1262,7 @@ def benchmarkFusionKernels(test_dir, paths: Paths, arch, numCU, tuningDb: MaybeT
# Run gemm or conv op with the same configuration
runConfigWithMLIR(config, paths, arch, '')
# Get nanoseconds of gemm/conv
nanoSeconds = getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME)
nanoSeconds = profiler.getNanoSeconds(BENCHMARKING_RESULT_FILE_NAME)
oneEntry['MLIR TFlops'] = config.computeTFlops(nanoSeconds)
oneEntry['Fusion/MLIR'] = oneEntry['TFlops']/oneEntry['MLIR TFlops']
oneEntry['FileName'] = filename
Expand Down Expand Up @@ -1471,30 +1504,13 @@ def main(args=None):
)

parser.add_argument("--rv", "--rocprof-version", choices=['1', '2'],
default='2',
default=DEFAULT_PROFILER_VERSION,
help="Use rocprof (V1) or rocprofv2.")

parsed_args = parser.parse_args(args)

global getNanoSeconds, getBankConflict, ROCPROF, ROCPROF_OPTS
global BENCHMARKING_RESULT_FILE_NAME, BENCHMARKING_METRICS_FILE_NAME
if parsed_args.rv == '1':
ROCPROF = ROCPROFV1
ROCPROF_OPTS = ROCPROFV1_OPTS
BENCHMARKING_RESULT_FILE_NAME = BENCHMARKINGV1_RESULT_FILE_NAME
BENCHMARKING_METRICS_FILE_NAME = BENCHMARKINGV1_METRICS_FILE_NAME
getNanoSeconds = getNanoSecondsV1
getBankConflict = getBankConflictV1
elif parsed_args.rv == '2':
ROCPROF = ROCPROFV2
ROCPROF_OPTS = ROCPROFV2_OPTS
BENCHMARKING_RESULT_FILE_NAME = BENCHMARKINGV2_RESULT_FILE_NAME
BENCHMARKING_METRICS_FILE_NAME = BENCHMARKINGV2_METRICS_FILE_NAME
getNanoSeconds = getNanoSecondsV2
getBankConflict = getBankConflictV2
else:
print(f"impossible rocprof version: {parsed_args.rv}")
sys.exit(1)
global profiler
profiler = makeProfiler(parsed_args.rv)

rocmlir_gen_flags = ''
if 'rocmlir_gen_flags' in parsed_args:
Expand Down
Loading

0 comments on commit 390b3fc

Please sign in to comment.