Skip to content

Commit

Permalink
fix using the correct ml-model inst when configuring categories
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm committed Oct 10, 2023
1 parent 8844cb3 commit be638b7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
4 changes: 4 additions & 0 deletions hbw/config/categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import law

from columnflow.config_util import create_category_combinations
from columnflow.ml import MLModel
from hbw.util import call_once_on_config

import order as od
Expand Down Expand Up @@ -154,6 +155,9 @@ def add_categories_production(config: od.Config) -> None:

@call_once_on_config()
def add_categories_ml(config, ml_model_inst):
# if not already done, get the ml_model instance
if isinstance(ml_model_inst, str):
ml_model_inst = MLModel.get_cls(ml_model_inst)(config)

# add ml categories directly to the config
ml_categories = []
Expand Down
3 changes: 1 addition & 2 deletions hbw/production/categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from columnflow.production.categories import category_ids

from hbw.config.categories import add_categories_production, add_categories_ml
from hbw.ml.dense_classifier import dense_test

np = maybe_import("numpy")
ak = maybe_import("awkward")
Expand Down Expand Up @@ -76,7 +75,7 @@ def ml_cats_init(self: Producer) -> None:

# add categories to config inst
add_categories_production(self.config_inst)
add_categories_ml(self.config_inst, dense_test)
add_categories_ml(self.config_inst, self.ml_model_name)


# get all the derived DenseClassifier models and instantiate a corresponding producer
Expand Down

0 comments on commit be638b7

Please sign in to comment.