diff --git a/mlprof/tasks/parameters.py b/mlprof/tasks/parameters.py index 88fb66a..7f6fa74 100644 --- a/mlprof/tasks/parameters.py +++ b/mlprof/tasks/parameters.py @@ -4,7 +4,7 @@ Collection of the recurrent luigi parameters for different tasks. """ -from __future__ import annotations +# from __future__ import annotations import os from dataclasses import dataclass @@ -144,7 +144,7 @@ class ModelParameters(BaseTask): model_label = luigi.Parameter( default=law.NO_STR, description="when set, use this label in plots; when empty, the 'network_name' field in the model json data is " - "used when existing, and full_model_name otherwise; default: empty", + "used when existing, and full_name otherwise; default: empty", ) def __init__(self, *args, **kwargs): @@ -152,8 +152,8 @@ def __init__(self, *args, **kwargs): self.model = Model( model_file=self.model_file, - name=self.name if self.name != law.NO_STR else None, - label=self.label if self.label != law.NO_STR else None, + name=self.model_name if self.model_name != law.NO_STR else None, + label=self.model_label if self.model_label != law.NO_STR else None, ) def store_parts(self): @@ -168,48 +168,69 @@ def store_parts(self): return parts -# class MultiModelParameters(BaseTask): -# """ -# General parameters for the model definition and the runtime measurement. -# """ - -# model_files = luigi.Parameter( -# default="$MLP_BASE/examples/simple_dnn/model.json", -# description="json file containing information of model to be tested; " -# "default: $MLP_BASE/examples/simple_dnn/model.json", -# ) -# model_names = luigi.Parameter( -# default=law.NO_STR, -# description="when set, use this name for storing outputs instead of a hashed version of " -# "--model-file; default: empty", -# ) -# model_labels = luigi.Parameter( -# default=law.NO_STR, -# description="when set, use this label in plots; when empty, the 'network_name' field in the " -# "model json data is used when existing, and full_model_name otherwise; default: empty", -# ) - -# def __init__(self, *args, **kwargs): -# super().__init__(*args, **kwargs) - -# # TODO: check that lengths match ... -# pass - -# self.models = [ -# Model() -# for x, y, z in zip(...) -# ] - -# def store_parts(self): -# parts = super().store_parts() - -# # build a combined string that represents the significant parameters -# params = [ -# f"model_{self.model.full_name}", -# ] -# parts.insert_before("version", "model_params", "__".join(params)) - -# return parts +class MultiModelParameters(BaseTask): + """ + General parameters for the model definition and the runtime measurement. + """ + + model_files = law.CSVParameter( + description="comma-separated list of json files containing information of models to be tested", + brace_expand=True, + ) + model_names = law.CSVParameter( + default=law.NO_STR, + description="comma-separated list of names of models defined in --model-files to use in output paths " + "instead of a hashed version of model_files; when set, the number of names must match the number of " + "model files; default: ()", + ) + model_labels = law.CSVParameter( + default=law.NO_STR, + description="when set, use this label in plots; when empty, the 'network_name' field in the " + "model json data is used when existing, and full_model_name otherwise; default: empty", + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # check that lengths match if initialized + from IPython import embed; embed() + if self.model_names[0] == law.NO_STR: + if self.model_labels[0] != law.NO_STR: + if len(self.model_files) != len(self.model_labels): + raise ValueError("the length of model_files and model_labels muss be the same") + elif self.model_labels[0] == law.NO_STR: + if len(self.model_files) != len(self.model_names): + raise ValueError("the length of model_files and model_names muss be the same") + else: + if len({len(self.model_files), len(self.model_names), len(self.model_labels)}) != 1: + raise ValueError("the length of model_names, model_files and model_labels muss be the same") + + # if not initialized, change size objects for them to match + if len(self.model_names) != len(self.model_files): + self.model_names = (law.NO_STR,) * len(self.model_files) + if len(self.model_labels) != len(self.model_files): + self.model_labels = (law.NO_STR,) * len(self.model_files) + + # define Model objects + self.models = [ + Model( + model_file=x, + name=y if y != law.NO_STR else None, + label=z if z != law.NO_STR else None, + ) + for x, y, z in zip(self.model_files, self.model_names, self.model_labels) + ] + + def store_parts(self): + parts = super().store_parts() + + # build a combined string that represents the significant parameters + params = [ + f"model_{model.full_name}" for model in self.models + ] + parts.insert_before("version", "model_params", "__".join(params)) + + return parts class BatchSizesParameters(BaseTask): diff --git a/mlprof/tasks/runtime.py b/mlprof/tasks/runtime.py index e13c79e..a8f49d0 100644 --- a/mlprof/tasks/runtime.py +++ b/mlprof/tasks/runtime.py @@ -12,7 +12,8 @@ from mlprof.tasks.base import CommandTask, PlotTask, view_output_plots from mlprof.tasks.parameters import ( - RuntimeParameters, ModelParameters, CMSSWParameters, BatchSizesParameters, CustomPlotParameters, + RuntimeParameters, ModelParameters, MultiModelParameters, CMSSWParameters, BatchSizesParameters, + CustomPlotParameters, ) from mlprof.tasks.sandboxes import CMSSWSandboxTask from mlprof.plotting.plotter import plot_batch_size_several_measurements @@ -176,16 +177,12 @@ def run(self): output = self.output() output.parent.touch() - # get name network for legend - model_data = self.model_data - network_name = model_data["network_name"] - # create the plot plot_batch_size_several_measurements( self.batch_sizes, [self.input().path], output.path, - [network_name], + [self.model.full_model_label], self.custom_plot_params, ) print("plot saved") @@ -193,6 +190,7 @@ def run(self): class PlotRuntimesMultipleParams( RuntimeParameters, + MultiModelParameters, BatchSizesParameters, PlotTask, CustomPlotParameters, @@ -205,15 +203,6 @@ class PlotRuntimesMultipleParams( sandbox = "bash::$MLP_BASE/sandboxes/plotting.sh" - model_files = law.CSVParameter( - description="comma-separated list of json files containing information of models to be tested", - brace_expand=True, - ) - model_names = law.CSVParameter( - default=(), - description="comma-separated list of names of models defined in --model-files to use in output paths; " - "when set, the number of names must match the number of model files; default: ()", - ) cmssw_versions = law.CSVParameter( default=(CMSSWParameters.cmssw_version._default,), description=f"comma-separated list of CMSSW versions; default: ({CMSSWParameters.cmssw_version._default},)", @@ -272,6 +261,7 @@ def run(self): # create the plot # TODO: maybe adjust labels + from IPython import embed; embed() plot_batch_size_several_measurements( self.batch_sizes, input_paths,