Skip to content

Commit

Permalink
refactoring multi models
Browse files Browse the repository at this point in the history
  • Loading branch information
nprouvost authored Jan 24, 2024
1 parent 76c5795 commit 53c25d5
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 186 deletions.
160 changes: 131 additions & 29 deletions mlprof/tasks/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,60 @@
Collection of the recurrent luigi parameters for different tasks.
"""

from __future__ import annotations

import os
from dataclasses import dataclass

import luigi
import law

from mlprof.tasks.base import BaseTask


@dataclass
class Model:

model_file: str
name: str | None = None
label: str | None = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# cached data
self._data = None

@property
def data(self):
if self._data is None:
self._data = law.LocalFileTarget(self.model_file).load(formatter="json")
return self._data

@property
def full_name(self):
if self.name:
return self.name

# create a hash
model_file = os.path.expandvars(os.path.expanduser(self.model_file))
name = os.path.splitext(os.path.basename(model_file))[0]
return f"{name}{law.util.create_hash(model_file)}"

@property
def full_model_label(self):
if self.label:
return self.label

# get the network_name field in the model data
network_name = self.data.get("network_name")
if network_name:
return network_name

# fallback to the full model name
return self.full_name


class CMSSWParameters(BaseTask):
"""
Parameters related to the CMSSW environment
Expand Down Expand Up @@ -40,16 +86,6 @@ class RuntimeParameters(BaseTask):
General parameters for the model definition and the runtime measurement.
"""

model_file = luigi.Parameter(
default="$MLP_BASE/examples/simple_dnn/model.json",
description="json file containing information of model to be tested; "
"default: $MLP_BASE/examples/simple_dnn/model.json",
)
model_name = luigi.Parameter(
default=law.NO_STR,
description="when set, use this name for storing outputs instead of a hashed version of "
"--model-file; default: empty",
)
input_type = luigi.Parameter(
default="random",
description="either 'random', 'incremental', 'zeros', or a path to a root file; default: random",
Expand All @@ -76,40 +112,106 @@ def __init__(self, *args, **kwargs):
f"a path to an existing root file",
)

# cached model content
self._model_data = None
def store_parts(self):
parts = super().store_parts()

@property
def model_data(self):
if self._model_data is None:
self._model_data = law.LocalFileTarget(self.model_file).load(formatter="json")
return self._model_data
# build a combined string that represents the significant parameters
params = [
f"input_{law.util.create_hash(self.input_file) if self.input_file else self.input_type}",
f"nevents_{self.n_events}",
f"ncalls_{self.n_calls}",
]
parts.insert_before("version", "runtime_params", "__".join(params))

@property
def full_model_name(self):
if self.model_name not in (None, law.NO_STR):
return self.model_name
return parts

# create a hash
model_file = os.path.expandvars(os.path.expanduser(self.model_file))
model_name = os.path.splitext(os.path.basename(model_file))[0]
return f"{model_name}{law.util.create_hash(model_file)}"

class ModelParameters(BaseTask):
"""
General parameters for the model definition and the runtime measurement.
"""

model_file = luigi.Parameter(
default="$MLP_BASE/examples/simple_dnn/model.json",
description="json file containing information of model to be tested; "
"default: $MLP_BASE/examples/simple_dnn/model.json",
)
model_name = luigi.Parameter(
default=law.NO_STR,
description="when set, use this name for storing outputs instead of a hashed version of "
"--model-file; default: empty",
)
model_label = luigi.Parameter(
default=law.NO_STR,
description="when set, use this label in plots; when empty, the 'network_name' field in the model json data is "
"used when existing, and full_model_name otherwise; default: empty",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.model = Model(
model_file=self.model_file,
name=self.name if self.name != law.NO_STR else None,
label=self.label if self.label != law.NO_STR else None,
)

def store_parts(self):
parts = super().store_parts()

# build a combined string that represents the significant parameters
params = [
f"model_{self.full_model_name}",
f"input_{law.util.create_hash(self.input_file) if self.input_file else self.input_type}",
f"nevents_{self.n_events}",
f"ncalls_{self.n_calls}",
f"model_{self.model.full_name}",
]
parts.insert_before("version", "model_params", "__".join(params))

return parts


# class MultiModelParameters(BaseTask):
# """
# General parameters for the model definition and the runtime measurement.
# """

# model_files = luigi.Parameter(
# default="$MLP_BASE/examples/simple_dnn/model.json",
# description="json file containing information of model to be tested; "
# "default: $MLP_BASE/examples/simple_dnn/model.json",
# )
# model_names = luigi.Parameter(
# default=law.NO_STR,
# description="when set, use this name for storing outputs instead of a hashed version of "
# "--model-file; default: empty",
# )
# model_labels = luigi.Parameter(
# default=law.NO_STR,
# description="when set, use this label in plots; when empty, the 'network_name' field in the model json data is "
# "used when existing, and full_model_name otherwise; default: empty",
# )

# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)

# # TODO: check that lengths match ...
# pass

# self.models = [
# Model()
# for x, y, z in zip(...)
# ]

# def store_parts(self):
# parts = super().store_parts()

# # build a combined string that represents the significant parameters
# params = [
# f"model_{self.model.full_name}",
# ]
# parts.insert_before("version", "model_params", "__".join(params))

# return parts


class BatchSizesParameters(BaseTask):
"""
Parameter to add several batch sizes to perform the measurement on
Expand Down
Loading

0 comments on commit 53c25d5

Please sign in to comment.