Skip to content

Commit

Permalink
fix mixins of custom tasks (I hope)
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm committed Nov 25, 2024
1 parent 09a4ab4 commit f62312a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 42 deletions.
17 changes: 4 additions & 13 deletions hbw/tasks/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
CalibratorsMixin,
ProducersMixin,
MLModelTrainingMixin,
MLModelMixin,
MLModelsMixin,
MLModelDataMixin,
SelectorStepsMixin,
Expand Down Expand Up @@ -403,13 +402,9 @@ def workflow_run(self):


class MLEvaluationSingleFold(
# NOTE: this should probably be a MLModelTrainingMixin, but I'll postpone this until the MultiConfigTask
# is implemented
# NOTE: mixins might need fixing, needs to be checked
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand Down Expand Up @@ -541,13 +536,9 @@ def run(self):


class PlotMLResultsSingleFold(
# NOTE: this should probably be a MLModelTrainingMixin, but I'll postpone this until the MultiConfigTask
# is implemented
# NOTE: mixins might need fixing, needs to be checked
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand Down
29 changes: 13 additions & 16 deletions hbw/tasks/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

from columnflow.tasks.framework.base import Requirements
from columnflow.tasks.framework.mixins import (
SelectorMixin,
CalibratorsMixin,
ProducersMixin,
MLModelMixin,
# SelectorMixin,
# CalibratorsMixin,
# ProducersMixin,
# MLModelMixin,
MLModelTrainingMixin,
)
# from columnflow.tasks.framework.remote import RemoteWorkflow
from columnflow.util import DotDict
Expand Down Expand Up @@ -97,12 +98,11 @@ def run(self):


class Optimizer(
# NOTE: mixins might need fixing, needs to be tested
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
# RemoteWorkflow,
):
"""
Workflow that runs optimization. Needs to be run from within the sandbox
Expand Down Expand Up @@ -191,12 +191,11 @@ def run(self):


class Objective(
# NOTE: mixins might need fixing, needs to be tested
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
# RemoteWorkflow,
):
"""
Objective to optimize.
Expand Down Expand Up @@ -281,11 +280,9 @@ def run(self):


class DummyObjective(
# NOTE: mixins might need fixing, needs to be tested
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
# RemoteWorkflow,
):
Expand Down
6 changes: 4 additions & 2 deletions hbw/tasks/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import law
import order as od

from columnflow.tasks.framework.base import Requirements, ShiftTask
from columnflow.tasks.framework.base import Requirements, MultiConfigTask
from columnflow.tasks.framework.mixins import (
CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin,
CategoriesMixin,
)
from columnflow.tasks.framework.plotting import (
PlotBase, PlotBase1D, ProcessPlotSettingMixin, VariablePlotSettingMixin,
PlotShiftMixin,
)
from columnflow.tasks.framework.decorators import view_output_plots
from columnflow.tasks.framework.remote import RemoteWorkflow
Expand Down Expand Up @@ -108,7 +109,7 @@ def plot_multi_weight_producer(

class PlotVariablesMultiWeightProducer(
HBWTask,
ShiftTask,
PlotShiftMixin,
VariablePlotSettingMixin,
ProcessPlotSettingMixin,
PlotBase1D,
Expand All @@ -117,6 +118,7 @@ class PlotVariablesMultiWeightProducer(
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
MultiConfigTask,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand Down
3 changes: 3 additions & 0 deletions hbw/tasks/postfit_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import luigi
import order as od

from columnflow.tasks.framework.base import ConfigTask
from columnflow.tasks.framework.mixins import (
InferenceModelMixin, MLModelsMixin, ProducersMixin, SelectorStepsMixin,
CalibratorsMixin,
Expand Down Expand Up @@ -107,6 +108,7 @@ def plot_postfit_shapes(


class PlotPostfitShapes(
# NOTE: mixins might be wrong and could (should?) be extended to MultiConfigTask
HBWTask,
PlotBase1D,
# to correctly setup our InferenceModel, we need all these mixins, but hopefully, all these
Expand All @@ -116,6 +118,7 @@ class PlotPostfitShapes(
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
ConfigTask,
):
"""
Task that creates Postfit shape plots based on a fit_diagnostics file.
Expand Down
40 changes: 29 additions & 11 deletions hbw/tasks/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import law
import luigi

from columnflow.tasks.framework.base import Requirements
from columnflow.tasks.framework.base import Requirements, MultiConfigTask, ConfigTask
from columnflow.tasks.framework.mixins import (
InferenceModelMixin, MLModelsMixin, ProducersMixin, SelectorStepsMixin,
CalibratorsMixin,
Expand All @@ -31,16 +31,24 @@


class ControlPlots(
law.WrapperTask,
HBWTask,
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
MultiConfigTask,
):
split_resolved_boosted = False

"""
Helper task to produce default set of control plots
"""

split_resolved_boosted = False
output_collection_cls = law.NestedSiblingFileCollection

@property
def config_inst(self):
return self.config_insts[0]

def requires(self):
lepton_tag = self.config_inst.x.lepton_tag
lepton_channels = self.config_inst.x.lepton_channels
Expand Down Expand Up @@ -85,22 +93,30 @@ def requires(self):

return reqs

def output(self):
return self.requires()
# def output(self):
# return self.requires()

def run(self):
pass
# def run(self):
# pass


class MLInputPlots(
HBWTask,
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
MultiConfigTask,
):
"""
Helper task to produce default set of control plots
"""

output_collection_cls = law.NestedSiblingFileCollection

@property
def config_inst(self):
return self.config_insts[0]

def requires(self):
lepton_tag = self.config_inst.x.lepton_tag
lepton_channels = self.config_inst.x.lepton_channels
Expand Down Expand Up @@ -129,11 +145,11 @@ def requires(self):

return reqs

def output(self):
return self.requires()
# def output(self):
# return self.requires()

def run(self):
pass
# def run(self):
# pass


class InferencePlots(
Expand All @@ -148,6 +164,7 @@ class InferencePlots(
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
ConfigTask,
# law.LocalWorkflow,
# RemoteWorkflow,
):
Expand Down Expand Up @@ -218,6 +235,7 @@ class ShiftedInferencePlots(
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
ConfigTask,
):
sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox"))

Expand Down

0 comments on commit f62312a

Please sign in to comment.