Skip to content

Commit

Permalink
Refactor weight normalization.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Dec 13, 2024
1 parent d4c9222 commit fcfa196
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 86 deletions.
13 changes: 8 additions & 5 deletions hbt/config/configs_hbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,10 @@ def if_era(
cfg,
f"jec_{jec_source}",
{
"normalized_btag_weight": "normalized_btag_weight_{name}",
"normalized_njet_btag_weight": "normalized_njet_btag_weight_{name}",
"normalized_btag_deepjet_weight": "normalized_btag_deepjet_weight_{name}",
"normalized_njet_btag_deepjet_weight": "normalized_njet_btag_deepjet_weight_{name}",
"normalized_btag_pnet_weight": "normalized_btag_pnet_weight_{name}",
"normalized_njet_btag_pnet_weight": "normalized_njet_btag_pnet_weight_{name}",
},
)

Expand Down Expand Up @@ -1000,8 +1002,9 @@ def if_era(
cfg,
f"btag_{unc}",
{
"normalized_btag_weight": f"normalized_btag_weight_{unc}_" + "{direction}",
"normalized_njet_btag_weight": f"normalized_njet_btag_weight_{unc}_" + "{direction}",
"normalized_btag_deepjet_weight": f"normalized_btag_deepjet_weight_{unc}_" + "{direction}",
"normalized_njet_btag_deepjet_weight": f"normalized_njet_btag_deepjet_weight_{unc}_" + "{direction}",
# TODO: pnet here, or is this another shift? probably the latter
},
)

Expand Down Expand Up @@ -1168,7 +1171,7 @@ def add_external(name, value):
"pdf_weight": get_shifts("pdf"),
"murmuf_weight": get_shifts("murmuf"),
"normalized_pu_weight": get_shifts("minbias_xs"),
"normalized_njet_btag_weight": get_shifts(*(f"btag_{unc}" for unc in cfg.x.btag_unc_names)),
# "normalized_njet_btag_deepjet_weight": get_shifts(*(f"btag_{unc}" for unc in cfg.x.btag_unc_names)),
"electron_weight": get_shifts("e"),
"muon_weight": get_shifts("mu"),
"tau_weight": get_shifts(*(f"tau_{unc}" for unc in cfg.x.tau_unc_names)),
Expand Down
58 changes: 25 additions & 33 deletions hbt/production/btag.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,54 +110,46 @@ def _normalized_btag_weights_setup(self: Producer, reqs: dict, inputs: dict, rea
)

# get the unique process ids in that dataset
key = f"sum_mc_weight_selected_nobjet_{self.tagger_name}_per_process_and_njet"
key = f"sum_mc_weight_selected_nob_{self.tagger_name}_per_process_and_njet"
self.unique_process_ids = list(map(int, selection_stats[key].keys()))

# get the maximum numbers of jets
max_n_jets = max(map(int, sum((list(d.keys()) for d in selection_stats[key].values()), [])))

# helper to get numerators and denominators
def numerator_per_pid(pid):
key = f"sum_mc_weight_selected_nobjet_{self.tagger_name}_per_process"
# helper to get sums of mc weights per pid and njet, with an optional weight name
def sum_per_pid(pid, weight_name="", /):
if weight_name:
weight_name += "_"
key = f"sum_mc_weight_{weight_name}selected_nob_{self.tagger_name}_per_process"
return selection_stats[key].get(str(pid), 0.0)

def denominator_per_pid(weight_name, pid):
key = f"sum_mc_weight_{weight_name}_selected_nobjet_{self.tagger_name}_per_process"
return selection_stats[key].get(str(pid), 0.0)

def numerator_per_pid_njet(pid, n_jets):
key = f"sum_mc_weight_selected_nobjet_{self.tagger_name}_per_process_and_njet"
d = selection_stats[key].get(str(pid), {})
return d.get(str(n_jets), 0.0)

def denominator_per_pid_njet(weight_name, pid, n_jets):
key = f"sum_mc_weight_{weight_name}_selected_nobjet_{self.tagger_name}_per_process_and_njet"
d = selection_stats[key].get(str(pid), {})
return d.get(str(n_jets), 0.0)
def sum_per_pid_njet(pid, n_jets, weight_name="", /):
if weight_name:
weight_name += "_"
key = f"sum_mc_weight_{weight_name}selected_nob_{self.tagger_name}_per_process_and_njet"
return selection_stats[key].get(str(pid), {}).get(str(n_jets), 0.0)

# extract the ratio per weight and pid
self.ratio_per_pid = {
weight_name: {
pid: safe_div(numerator_per_pid(pid), denominator_per_pid(weight_name, pid))
# ratio per weight and pid
# extract the ratio per weight, pid and also the jet multiplicity, using the latter as in index
self.ratio_per_pid = {}
self.ratio_per_pid_njet = {}
for route in self[self.btag_weights_cls].produced_columns:
weight_name = str(route)
if not weight_name.startswith(self.btag_weights_cls.weight_name):
continue
# normal ratio
self.ratio_per_pid[weight_name] = {
pid: safe_div(sum_per_pid(pid), sum_per_pid(pid, weight_name))
for pid in self.unique_process_ids
}
for weight_name in (str(route) for route in self[self.btag_weights_cls].produced_columns)
if weight_name.startswith(self.btag_weights_cls.weight_name)
}

# extract the ratio per weight, pid and also the jet multiplicity, using the latter as in index
# for a lookup table (since it naturally starts at 0)
self.ratio_per_pid_njet = {
weight_name: {
# per jet multiplicity ratio
self.ratio_per_pid_njet[weight_name] = {
pid: np.array([
safe_div(numerator_per_pid_njet(pid, n_jets), denominator_per_pid_njet(weight_name, pid, n_jets))
safe_div(sum_per_pid_njet(pid, n_jets), sum_per_pid_njet(pid, n_jets, weight_name))
for n_jets in range(max_n_jets + 1)
])
for pid in self.unique_process_ids
}
for weight_name in (str(route) for route in self[self.btag_weights_cls].produced_columns)
if weight_name.startswith(self.btag_weights_cls.weight_name)
}


# derive for btaggers
Expand Down
16 changes: 4 additions & 12 deletions hbt/production/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,8 @@ def denominator_per_pid(weight_name, pid):


@producer(
uses={
"pdf_weight", "pdf_weight_up", "pdf_weight_down",
},
produces={
"normalized_pdf_weight", "normalized_pdf_weight_up", "normalized_pdf_weight_down",
},
uses={"pdf_weight{,_up,_down}"},
produces={"normalized_pdf_weight{,_up,_down}"},
# only run on mc
mc_only=True,
)
Expand Down Expand Up @@ -158,12 +154,8 @@ def normalized_pdf_weight_setup(


@producer(
uses={
"murmuf_weight", "murmuf_weight_up", "murmuf_weight_down",
},
produces={
"normalized_murmuf_weight", "normalized_murmuf_weight_up", "normalized_murmuf_weight_down",
},
uses={"murmuf_weight{,_up,_down}"},
produces={"normalized_murmuf_weight{,_up,_down}"},
# only run on mc
mc_only=True,
)
Expand Down
65 changes: 30 additions & 35 deletions hbt/selection/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,14 @@ def default(
event_sel = reduce(and_, results.steps.values())
results.event = event_sel

# combined event seleciton after all but the bjet step
tagger_name = btag_weights_deepjet.tagger_name
event_sel_nob_deepjet = results.steps[f"all_but_bjet_{tagger_name}"] = reduce(and_, [
mask for step_name, mask in results.steps.items()
if step_name != f"bjet_{tagger_name}"
])
event_sel_nob_pnet = None
if self.has_dep(btag_weights_pnet):
tagger_name = btag_weights_pnet.tagger_name
event_sel_nob_pnet = results.steps[f"all_but_bjet_{tagger_name}"] = reduce(and_, [
# combined event selection after all but the bjet step
def event_sel_nob(btag_weight_cls):
tagger_name = btag_weights_deepjet.tagger_name
var_sel = results.steps[f"all_but_bjet_{tagger_name}"] = reduce(and_, [
mask for step_name, mask in results.steps.items()
if step_name != f"bjet_{tagger_name}"
])
return var_sel

# increment stats
events, results = setup_and_increment_stats(
Expand All @@ -157,8 +152,10 @@ def default(
results=results,
stats=stats,
event_sel=event_sel,
event_sel_nob_deepjet=event_sel_nob_deepjet,
event_sel_nob_pnet=event_sel_nob_pnet,
event_sel_variations={
"nob_deepjet": event_sel_nob(btag_weights_deepjet),
"nob_pnet": event_sel_nob(btag_weights_pnet) if self.has_dep(btag_weights_pnet) else None,
},
njets=results.x.n_central_jets,
)

Expand Down Expand Up @@ -287,8 +284,10 @@ def empty_call(
results=results,
stats=stats,
event_sel=results.event,
event_sel_nob_deepjet=results.event,
event_sel_nob_pnet=results.event if self.has_dep(btag_weights_pnet) else None,
event_sel_variations={
"nob_deepjet": results.event,
"nob_pnet": results.event if self.has_dep(btag_weights_pnet) else None,
},
njets=ak.num(events.Jet, axis=1),
)

Expand All @@ -302,8 +301,7 @@ def setup_and_increment_stats(
results: SelectionResult,
stats: defaultdict,
event_sel: np.ndarray | ak.Array,
event_sel_nob_deepjet: np.ndarray | ak.Array | None = None,
event_sel_nob_pnet: np.ndarray | ak.Array | None = None,
event_sel_variations: dict[str, np.ndarray | ak.Array] | None = None,
njets: np.ndarray | ak.Array | None = None,
**kwargs,
) -> tuple[ak.Array, SelectionResult]:
Expand All @@ -316,31 +314,31 @@ def setup_and_increment_stats(
:param results: The current selection results.
:param stats: The stats dictionary.
:param event_sel: The general event selection mask.
:param event_sel_nob_deepjet: The event selection mask without the bjet step for deepjet.
:param event_sel_variations: Named variations of the event selection mask for additional stats.
:param event_sel_nob_pnet: The event selection mask without the bjet step for pnet.
:param njets: The number of central jets.
:return: The updated events and results objects in a tuple.
"""
if event_sel_variations is None:
event_sel_variations = {}
event_sel_variations = {n: s for n, s in event_sel_variations.items() if s is not None}

# start creating a weight, group and group combination map
weight_map = {
"num_events": Ellipsis,
"num_events_selected": event_sel,
}
if event_sel_nob_deepjet is not None:
weight_map["num_events_selected_nobjet_deepjet"] = event_sel_nob_deepjet
if event_sel_nob_pnet is not None:
weight_map["num_events_selected_nobjet_pnet"] = event_sel_nob_pnet
for var_name, var_sel in event_sel_variations.items():
weight_map[f"num_events_selected_{var_name}"] = var_sel
group_map = {}
group_combinations = []

# add mc info
if self.dataset_inst.is_mc:
weight_map["sum_mc_weight"] = events.mc_weight
weight_map["sum_mc_weight_selected"] = (events.mc_weight, event_sel)
if event_sel_nob_deepjet is not None:
weight_map["sum_mc_weight_selected_nobjet_deepjet"] = (events.mc_weight, event_sel_nob_deepjet)
if event_sel_nob_pnet is not None:
weight_map["sum_mc_weight_selected_nobjet_pnet"] = (events.mc_weight, event_sel_nob_pnet)
for var_name, var_sel in event_sel_variations.items():
weight_map[f"sum_mc_weight_selected_{var_name}"] = (events.mc_weight, var_sel)

# pu weights with variations
for route in sorted(self[pu_weight].produced_columns):
Expand All @@ -364,17 +362,14 @@ def setup_and_increment_stats(
if not self.has_dep(prod):
continue
for route in sorted(self[prod].produced_columns):
name = str(route)
if not name.startswith(prod.weight_name):
weight_name = str(route)
if not weight_name.startswith(prod.weight_name):
continue
weight_map[f"sum_{name}"] = events[name]
weight_map[f"sum_{name}_selected"] = (events[name], event_sel)
if event_sel_nob_deepjet is not None:
weight_map[f"sum_{name}_selected_nobjet_deepjet"] = (events[name], event_sel_nob_deepjet)
weight_map[f"sum_mc_weight_{name}_selected_nobjet_deepjet"] = (events.mc_weight * events[name], event_sel_nob_deepjet) # noqa: E501
if event_sel_nob_pnet is not None:
weight_map[f"sum_{name}_selected_nobjet_pnet"] = (events[name], event_sel_nob_pnet)
weight_map[f"sum_mc_weight_{name}_selected_nobjet_pnet"] = (events.mc_weight * events[name], event_sel_nob_pnet) # noqa: E501
weight_map[f"sum_{weight_name}"] = events[weight_name]
weight_map[f"sum_{weight_name}_selected"] = (events[weight_name], event_sel)
for var_name, var_sel in event_sel_variations.items():
weight_map[f"sum_{weight_name}_selected_{var_name}"] = (events[weight_name], var_sel)
weight_map[f"sum_mc_weight_{weight_name}_selected_{var_name}"] = (events.mc_weight * events[weight_name], var_sel) # noqa: E501

# groups
group_map = {
Expand Down
2 changes: 1 addition & 1 deletion modules/columnflow

0 comments on commit fcfa196

Please sign in to comment.