Skip to content

Commit

Permalink
refactoring python 3.7 and more, to be adapted to python 3.6
Browse files Browse the repository at this point in the history
  • Loading branch information
nprouvost committed Jan 31, 2024
1 parent 4a85e12 commit 0620f4f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 61 deletions.
113 changes: 67 additions & 46 deletions mlprof/tasks/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Collection of the recurrent luigi parameters for different tasks.
"""

from __future__ import annotations
# from __future__ import annotations

import os
from dataclasses import dataclass
Expand Down Expand Up @@ -144,16 +144,16 @@ class ModelParameters(BaseTask):
model_label = luigi.Parameter(
default=law.NO_STR,
description="when set, use this label in plots; when empty, the 'network_name' field in the model json data is "
"used when existing, and full_model_name otherwise; default: empty",
"used when existing, and full_name otherwise; default: empty",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.model = Model(
model_file=self.model_file,
name=self.name if self.name != law.NO_STR else None,
label=self.label if self.label != law.NO_STR else None,
name=self.model_name if self.model_name != law.NO_STR else None,
label=self.model_label if self.model_label != law.NO_STR else None,
)

def store_parts(self):
Expand All @@ -168,48 +168,69 @@ def store_parts(self):
return parts


# class MultiModelParameters(BaseTask):
# """
# General parameters for the model definition and the runtime measurement.
# """

# model_files = luigi.Parameter(
# default="$MLP_BASE/examples/simple_dnn/model.json",
# description="json file containing information of model to be tested; "
# "default: $MLP_BASE/examples/simple_dnn/model.json",
# )
# model_names = luigi.Parameter(
# default=law.NO_STR,
# description="when set, use this name for storing outputs instead of a hashed version of "
# "--model-file; default: empty",
# )
# model_labels = luigi.Parameter(
# default=law.NO_STR,
# description="when set, use this label in plots; when empty, the 'network_name' field in the "
# "model json data is used when existing, and full_model_name otherwise; default: empty",
# )

# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)

# # TODO: check that lengths match ...
# pass

# self.models = [
# Model()
# for x, y, z in zip(...)
# ]

# def store_parts(self):
# parts = super().store_parts()

# # build a combined string that represents the significant parameters
# params = [
# f"model_{self.model.full_name}",
# ]
# parts.insert_before("version", "model_params", "__".join(params))

# return parts
class MultiModelParameters(BaseTask):
"""
General parameters for the model definition and the runtime measurement.
"""

model_files = law.CSVParameter(
description="comma-separated list of json files containing information of models to be tested",
brace_expand=True,
)
model_names = law.CSVParameter(
default=law.NO_STR,
description="comma-separated list of names of models defined in --model-files to use in output paths "
"instead of a hashed version of model_files; when set, the number of names must match the number of "
"model files; default: ()",
)
model_labels = law.CSVParameter(
default=law.NO_STR,
description="when set, use this label in plots; when empty, the 'network_name' field in the "
"model json data is used when existing, and full_model_name otherwise; default: empty",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# check that lengths match if initialized
from IPython import embed; embed()
if self.model_names[0] == law.NO_STR:
if self.model_labels[0] != law.NO_STR:
if len(self.model_files) != len(self.model_labels):
raise ValueError("the length of model_files and model_labels muss be the same")
elif self.model_labels[0] == law.NO_STR:
if len(self.model_files) != len(self.model_names):
raise ValueError("the length of model_files and model_names muss be the same")
else:
if len({len(self.model_files), len(self.model_names), len(self.model_labels)}) != 1:
raise ValueError("the length of model_names, model_files and model_labels muss be the same")

# if not initialized, change size objects for them to match
if len(self.model_names) != len(self.model_files):
self.model_names = (law.NO_STR,) * len(self.model_files)
if len(self.model_labels) != len(self.model_files):
self.model_labels = (law.NO_STR,) * len(self.model_files)

# define Model objects
self.models = [
Model(
model_file=x,
name=y if y != law.NO_STR else None,
label=z if z != law.NO_STR else None,
)
for x, y, z in zip(self.model_files, self.model_names, self.model_labels)
]

def store_parts(self):
parts = super().store_parts()

# build a combined string that represents the significant parameters
params = [
f"model_{model.full_name}" for model in self.models
]
parts.insert_before("version", "model_params", "__".join(params))

return parts


class BatchSizesParameters(BaseTask):
Expand Down
20 changes: 5 additions & 15 deletions mlprof/tasks/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from mlprof.tasks.base import CommandTask, PlotTask, view_output_plots
from mlprof.tasks.parameters import (
RuntimeParameters, ModelParameters, CMSSWParameters, BatchSizesParameters, CustomPlotParameters,
RuntimeParameters, ModelParameters, MultiModelParameters, CMSSWParameters, BatchSizesParameters,
CustomPlotParameters,
)
from mlprof.tasks.sandboxes import CMSSWSandboxTask
from mlprof.plotting.plotter import plot_batch_size_several_measurements
Expand Down Expand Up @@ -176,23 +177,20 @@ def run(self):
output = self.output()
output.parent.touch()

# get name network for legend
model_data = self.model_data
network_name = model_data["network_name"]

# create the plot
plot_batch_size_several_measurements(
self.batch_sizes,
[self.input().path],
output.path,
[network_name],
[self.model.full_model_label],
self.custom_plot_params,
)
print("plot saved")


class PlotRuntimesMultipleParams(
RuntimeParameters,
MultiModelParameters,
BatchSizesParameters,
PlotTask,
CustomPlotParameters,
Expand All @@ -205,15 +203,6 @@ class PlotRuntimesMultipleParams(

sandbox = "bash::$MLP_BASE/sandboxes/plotting.sh"

model_files = law.CSVParameter(
description="comma-separated list of json files containing information of models to be tested",
brace_expand=True,
)
model_names = law.CSVParameter(
default=(),
description="comma-separated list of names of models defined in --model-files to use in output paths; "
"when set, the number of names must match the number of model files; default: ()",
)
cmssw_versions = law.CSVParameter(
default=(CMSSWParameters.cmssw_version._default,),
description=f"comma-separated list of CMSSW versions; default: ({CMSSWParameters.cmssw_version._default},)",
Expand Down Expand Up @@ -272,6 +261,7 @@ def run(self):

# create the plot
# TODO: maybe adjust labels
from IPython import embed; embed()
plot_batch_size_several_measurements(
self.batch_sizes,
input_paths,
Expand Down

0 comments on commit 0620f4f

Please sign in to comment.