Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow aggregated tasks within benchmarks #1771

Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1338736
fix: Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 11, 2025
f2920ff
feat: Update task filtering, fixing bug on MTEB
KennethEnevoldsen Jan 13, 2025
12aaa97
format
KennethEnevoldsen Jan 13, 2025
1be8ed8
remove "en-ext" from AmazonCounterfactualClassification
KennethEnevoldsen Jan 13, 2025
8aab5d0
fixed mteb(deu)
KennethEnevoldsen Jan 13, 2025
4dfe2ec
fix: simplify in a few areas
KennethEnevoldsen Jan 13, 2025
cd87ebb
wip
KennethEnevoldsen Jan 14, 2025
450953d
Merge branch 'correct-mteb-eng' into KennethEnevoldsen/issue-Allow-ag…
KennethEnevoldsen Jan 14, 2025
87816f1
tmp
KennethEnevoldsen Jan 15, 2025
f73ffb7
sav
KennethEnevoldsen Jan 16, 2025
33578ec
Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 17, 2025
54d16f9
Merge remote-tracking branch 'origin' into KennethEnevoldsen/issue-Al…
KennethEnevoldsen Jan 17, 2025
b11f6b1
ensure correct formatting of eval_langs
KennethEnevoldsen Jan 17, 2025
0718389
ignore aggregate dataset
KennethEnevoldsen Jan 17, 2025
2bc375c
clean up dummy cases
KennethEnevoldsen Jan 17, 2025
5a9bd8c
add to mteb(eng, classic)
KennethEnevoldsen Jan 17, 2025
36cee38
format
KennethEnevoldsen Jan 17, 2025
8bb9026
clean up
KennethEnevoldsen Jan 17, 2025
60a8f0f
Allow aggregated tasks within benchmarks
KennethEnevoldsen Jan 17, 2025
f65c68e
added fixed from comments
KennethEnevoldsen Jan 19, 2025
76d511c
Merge branch 'main' of https://github.com/embeddings-benchmark/mteb i…
KennethEnevoldsen Jan 19, 2025
14f3ae1
fix merge
KennethEnevoldsen Jan 19, 2025
66fb570
format
KennethEnevoldsen Jan 19, 2025
063e357
Updated task type
KennethEnevoldsen Jan 21, 2025
6b1e190
Merge branch 'main' of https://github.com/embeddings-benchmark/mteb i…
KennethEnevoldsen Jan 28, 2025
6e37e07
Added minor fix for dummy tasks
KennethEnevoldsen Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Allow aggregated tasks within benchmarks
Fixes #1231
KennethEnevoldsen committed Jan 17, 2025
commit 33578ece6a0ca28400ba1f1c9125cf4cb702f092
13 changes: 9 additions & 4 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
from typing_extensions import Literal, TypedDict

if TYPE_CHECKING:
from mteb.abstasks import AbsTask
pass

from ..encoder_interface import PromptType
from ..languages import (
@@ -85,6 +85,7 @@
"machine-translated and verified",
"machine-translated and localized",
"LM-generated and verified",
"multiple",
]

TASK_TYPE = Literal[
@@ -172,9 +173,10 @@
"gpl-3.0",
"cdla-sharing-1.0",
"mpl-2.0",
"multiple",
]
)

MODALITIES = Literal["text"]
METRIC_NAME = str
METRIC_VALUE = Union[int, float, dict[str, Any]]

@@ -233,13 +235,13 @@ class TaskMetadata(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

dataset: dict
dataset: dict[str, Any]

name: str
description: str
prompt: str | PromptDict | None = None
type: TASK_TYPE
modalities: list[Literal["text"]] = ["text"]
modalities: list[MODALITIES] = ["text"]
category: TASK_CATEGORY | None = None
reference: STR_URL | None = None

@@ -432,3 +434,6 @@ def n_samples(self) -> dict[str, int] | None:
def __hash__(self) -> int:
return hash(self.model_dump_json())

@property
def revision(self) -> str:
return self.dataset["revision"]
2 changes: 1 addition & 1 deletion mteb/abstasks/__init__.py
Original file line number Diff line number Diff line change
@@ -13,4 +13,4 @@
from .AbsTaskSpeedTask import *
from .AbsTaskSTS import *
from .AbsTaskSummarization import *
from .MultilingualTask import *
from .MultilingualTask import *
171 changes: 171 additions & 0 deletions mteb/abstasks/aggregate_task_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from __future__ import annotations

import logging
from datetime import datetime
from typing import Any, Literal

from pydantic import ConfigDict, model_validator

from mteb.abstasks.AbsTask import AbsTask
from mteb.abstasks.TaskMetadata import (
ANNOTATOR_TYPE,
LANGUAGES,
LICENSES,
MODALITIES,
SAMPLE_CREATION_METHOD,
STR_DATE,
TASK_DOMAIN,
TASK_SUBTYPE,
HFSubset,
TaskMetadata,
)
from mteb.languages import ISO_LANGUAGE_SCRIPT

logger = logging.getLogger(__name__)


class AggregateTaskMetadata(TaskMetadata):
"""Metadata for an aggregation of tasks. This description only covers exceptions to the TaskMetadata. Many of the field if not filled out will be
autofilled from its tasks.

Attributes:
name: The name of the aggregated task.
description: A description of the task. Should explain the aggregation.
prompt: An aggregate task does not have a prompt, thus this value is always None.
dataset: The dataset for the aggregated task is specified in its tasks. The aggregate task thus only specified the revision and uses a
placeholder path.
tasks: A list of tasks, the majority of the metadata is described within its tasks.
eval_splits: The splits of the tasks used for evaluation.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

name: str
description: str
dataset: dict[str, Any] = {
"path": "aggregate tasks do not have a path", # just a place holder
"revision": "1",
}

tasks: list[AbsTask]
main_score: str
type: Literal["aggregate-task"] = "aggregate-task"
eval_splits: list[str]
eval_langs: LANGUAGES = []
prompt: None = None
reference: str | None = None
bibtex_citation: str | None = None

@property
def hf_subsets_to_langscripts(self) -> dict[HFSubset, list[ISO_LANGUAGE_SCRIPT]]:
"""Return a dictionary mapping huggingface subsets to languages."""
return {"default": self.eval_langs} # type: ignore

@model_validator(mode="after") # type: ignore
def compute_unfilled_cases(self) -> AggregateTaskMetadata:
if not self.eval_langs:
self.eval_langs = self.compute_eval_langs()
if not self.date:
self.date = self.compute_date()
if not self.domains:
self.domains = self.compute_domains()
if not self.task_subtypes:
self.task_subtypes = self.compute_task_subtypes()
if not self.license:
self.license = self.compute_license()
if not self.annotations_creators:
self.annotations_creators = self.compute_annotations_creators()
if not self.dialect:
self.dialect = self.compute_dialect()
if not self.sample_creation:
self.sample_creation = self.compute_sample_creation()
if not self.modalities:
self.modalities = self.compute_modalities()

return self

def compute_eval_langs(self) -> list[ISO_LANGUAGE_SCRIPT]:
langs = set()
for task in self.tasks:
langs.update(set(task.metadata.languages))
return list(langs)

def compute_date(self) -> tuple[STR_DATE, STR_DATE] | None:
# get min max date from tasks
dates = []
for task in self.tasks:
if task.metadata.date:
dates.append(datetime.fromisoformat(task.metadata.date[0]))
dates.append(datetime.fromisoformat(task.metadata.date[1]))

if not dates:
return None

min_date = min(dates)
max_date = max(dates)
return min_date.isoformat(), max_date.isoformat()

def compute_domains(self) -> list[TASK_DOMAIN] | None:
domains = set()
for task in self.tasks:
if task.metadata.domains:
domains.update(set(task.metadata.domains))
if domains:
return list(domains)
return None

def compute_task_subtypes(self) -> list[TASK_SUBTYPE] | None:
subtypes = set()
for task in self.tasks:
if task.metadata.task_subtypes:
subtypes.update(set(task.metadata.task_subtypes))
if subtypes:
return list(subtypes)
return None

def compute_license(self) -> LICENSES | None:
licenses = set()
for task in self.tasks:
if task.metadata.license:
licenses.add(task.metadata.license)
if len(licenses) > 1:
return "multiple"
return None

def compute_annotations_creators(self) -> ANNOTATOR_TYPE | None:
creators = set()
for task in self.tasks:
if task.metadata.annotations_creators:
creators.add(task.metadata.annotations_creators)
if len(creators) > 1:
logger.warning(
f"Multiple annotations_creators found for tasks in {self.name}. Using None as annotations_creators."
)
return None

def compute_dialect(self) -> list[str] | None:
dialects = set()
for task in self.tasks:
if task.metadata.dialect:
dialects.update(set(task.metadata.dialect))
if dialects:
return list(dialects)
return None

def compute_sample_creation(self) -> SAMPLE_CREATION_METHOD | None:
sample_creations = set()
for task in self.tasks:
if task.metadata.sample_creation:
sample_creations.add(task.metadata.sample_creation)
if len(sample_creations) > 1:
return "multiple"
return None

def compute_modalities(self) -> list[MODALITIES]:
modalities = set()
for task in self.tasks:
if task.metadata.modalities:
modalities.update(set(task.metadata.modalities))
if modalities:
return list(modalities)
return None
Loading