Skip to content

Commit

Permalink
Fixing secondary labels (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
lastmansleeping committed Oct 18, 2022
1 parent 3af0577 commit 276ded7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
13 changes: 13 additions & 0 deletions docs/source/misc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.1.13] - 2022-10-17

### Fixed

- Bug in metrics_helper when used without secondary_labels

### Added

- RankMatchFailure metric for evaluation
- RankMatchFailure auxiliary loss

## [0.1.12] - 2022-04-26

## [0.1.11] - 2021-01-18

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def get_grouped_stats(
sum_new_reciprocal_rank = df_grouped_batch.apply(lambda x: (1.0 / x[new_rank_col]).sum())

# Aggregate secondary label metrics by group keys
df_secondary_labels_metrics = df_secondary_labels_metrics.groupby(group_keys).sum()
if secondary_labels:
df_secondary_labels_metrics = df_secondary_labels_metrics.groupby(group_keys).sum()
else:
# Compute overall stats if group keys are not specified
query_count = [df_clicked.shape[0]]
Expand All @@ -307,7 +308,8 @@ def get_grouped_stats(
sum_new_reciprocal_rank = [(1.0 / df_clicked[new_rank_col]).sum()]

# Aggregate secondary label metrics
df_secondary_labels_metrics = df_secondary_labels_metrics.sum().to_frame().T
if secondary_labels:
df_secondary_labels_metrics = df_secondary_labels_metrics.sum().to_frame().T

df_label_stats = pd.DataFrame(
{
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def getReadMe():
setup(
name="ml4ir",
packages=find_namespace_packages(include=["ml4ir.*"]),
version="0.1.12",
version="0.1.13",
description="Machine Learning libraries for Information Retrieval",
long_description=getReadMe(),
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 276ded7

Please sign in to comment.