Skip to content

Commit

Permalink
working version without docu
Browse files Browse the repository at this point in the history
  • Loading branch information
nprouvost committed Feb 1, 2024
1 parent 0620f4f commit efe74f6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 33 deletions.
16 changes: 5 additions & 11 deletions mlprof/tasks/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,21 @@
Collection of the recurrent luigi parameters for different tasks.
"""

# from __future__ import annotations

import os
from dataclasses import dataclass

import luigi
import law

from mlprof.tasks.base import BaseTask


@dataclass
class Model:
class Model():

model_file: str
name: str | None = None
label: str | None = None
def __init__(self, model_file: str, name, label, **kwargs):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_file = model_file
self.name = name
self.label = label

# cached data
self._data = None
Expand Down Expand Up @@ -193,7 +188,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# check that lengths match if initialized
from IPython import embed; embed()
if self.model_names[0] == law.NO_STR:
if self.model_labels[0] != law.NO_STR:
if len(self.model_files) != len(self.model_labels):
Expand Down
95 changes: 73 additions & 22 deletions mlprof/tasks/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run(self):
output.parent.touch()

# get model data
model_data = self.model_data
model_data = self.model.data

# resolve the graph path relative to the model file
graph_path = os.path.expandvars(os.path.expanduser(model_data["file"]))
Expand Down Expand Up @@ -78,9 +78,9 @@ def run(self):
}

# load the template content
if self.model_data["inference_engine"] == "tf":
if model_data["inference_engine"] == "tf":
template = "$MLP_BASE/cmssw/MLProf/RuntimeMeasurement/test/tf_runtime_template_cfg.py"
elif self.model_data["inference_engine"] == "onnx":
elif model_data["inference_engine"] == "onnx":
template = "$MLP_BASE/cmssw/MLProf/ONNXRuntimeModule/test/onnx_runtime_template_cfg.py"
else:
raise Exception("The only inference_engine supported are 'tf' and 'onnx'")
Expand Down Expand Up @@ -222,34 +222,86 @@ def __init__(self, *args, **kwargs):
if len(self.model_names) not in (n_models, 0):
raise ValueError("the number of model names does not match the number of model files")

# list of sequences over which the product is performed
self.product_names = ["model_file", "cmssw_version", "scram_arch"]
self.product_sequences = [
# list of sequences over which the product is performed for the requirements
self.product_names_req = ["model_file", "model_name", "cmssw_version", "scram_arch"]
self.product_sequences_req = [
list(zip(self.model_files, self.model_names or (n_models * [None]))),
self.cmssw_versions,
self.scram_archs,
]

params_to_write = []
# list of sequences over which the product is performed for the output file name
self.product_names_out = ["model_name", "cmssw_version", "scram_arch"]
self.product_sequences_out = [
tuple([model.full_name for model in self.models]),
self.cmssw_versions,
self.scram_archs,
]

# list of sequences over which the product is performed for the labels in plot
self.product_names_labels = ["model_label", "cmssw_version", "scram_arch"]
self.product_sequences_labels = [
tuple([model.full_model_label for model in self.models]),
self.cmssw_versions,
self.scram_archs,
]

# define output product
self.output_product = list(itertools.product(*self.product_sequences_out))
self.output_product_dict = [dict(zip(self.product_names_out, values)) for values in self.output_product]

# retrieve the names of the params to be put in output
self.params_to_write_outputs = []
for iparam, param in enumerate(self.product_names_out):
if len(self.product_sequences_out[iparam]) > 1:
self.params_to_write_outputs += [param]

# create output representation to be used in output file name
self.output_product_params_to_write = [
combination_dict[key_to_write]
for combination_dict in self.output_product_dict
for key_to_write in self.params_to_write_outputs
]

# create params_to_write for labels if model_files or cmssw_versions is None? ->
# gets difficult with itertools product if only one param is changed
self.out_params_repr = "_".join(self.output_product_params_to_write)

# define label product
self.labels_products = list(itertools.product(*self.product_sequences_labels))
self.labels_products_dict = [dict(zip(self.product_names_labels, values)) for values in self.labels_products]

# retrieve the names of the params to be put in labels
self.params_to_write_labels = []
for iparam, param in enumerate(self.product_names_labels):
if len(self.product_sequences_labels[iparam]) > 1:
self.params_to_write_labels += [param]

# create list of labels to plot
self.params_product_params_to_write = [
combination_dict[key_to_write]
for combination_dict in self.labels_products_dict
for key_to_write in self.params_to_write_labels
]

def flatten_tuple(self, value):
for x in value:
if isinstance(x, tuple):
yield from self.flatten_tuple(x)
else:
yield x

def requires(self):
flattened_product = [
tuple(self.flatten_tuple(tuple_of_args)) for tuple_of_args in itertools.product(*self.product_sequences_req)
]
return [
MergeRuntimes.req(self, **dict(zip(self.product_names, values)))
for values in itertools.product(*self.product_sequences)
MergeRuntimes.req(self, **dict(zip(self.product_names_req, values)))
for values in flattened_product
]

def output(self):
# TODO: encode all important params in a human-readable way
# TODO: also check which parameters should go into store parts
return self.local_target("test.pdf")
# self.fill_params_to_write()
# all_params = self.factorize_params()
# all_params_list = ["_".join(all_params_item) for all_params_item in all_params]
# all_params_repr = "_".join(all_params_list)
# return self.local_target(f"runtime_plot_params_{all_params_repr}_different_batch_sizes_{self.batch_sizes_repr}.pdf") # noqa
return self.local_target(
f"runtime_plot_{self.out_params_repr}_different_batch_sizes_{self.batch_sizes_repr}.pdf"
)

@view_output_plots
def run(self):
Expand All @@ -260,12 +312,11 @@ def run(self):
input_paths = [inp.path for inp in self.input()]

# create the plot
# TODO: maybe adjust labels
from IPython import embed; embed()

plot_batch_size_several_measurements(
self.batch_sizes,
input_paths,
output.path,
list(itertools.product(self.product_sequences)),
self.params_product_params_to_write,
self.custom_plot_params,
)

0 comments on commit efe74f6

Please sign in to comment.