Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ryokan0123 committed Apr 1, 2024
2 parents 2e629ff + d41b888 commit 9278a8e
Show file tree
Hide file tree
Showing 56 changed files with 114 additions and 94 deletions.
68 changes: 2 additions & 66 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,7 @@
from __future__ import annotations

import json
from pathlib import Path

from jsonargparse import ActionConfigFile, ArgumentParser
from loguru import logger

from src.embedders import TextEmbedder
from src.evaluators import EmbeddingEvaluator
from src.utils.score_recorder import JsonScoreRecorder


def main(
text_embedder: TextEmbedder,
evaluators: dict[str, EmbeddingEvaluator],
save_dir: str | None = None,
overwrite_cache: bool = False,
):
logger.info(f"Start evaluating the following tasks\n{list(evaluators.keys())}")

if save_dir:
Path(save_dir).mkdir(parents=True, exist_ok=True)

score_recorder = JsonScoreRecorder(save_dir)

for eval_name, evaluator in evaluators.items():
logger.info(f"Evaluating {eval_name}")

cache_dir = None
if save_dir is not None:
cache_dir = Path(save_dir) / "cache" / eval_name

metrics = evaluator(text_embedder, cache_dir=cache_dir, overwrite_cache=overwrite_cache)
score_recorder.record_task_scores(
scores=metrics,
dataset_name=eval_name,
task_name=evaluator.__class__.__name__.replace("Evaluator", ""),
)

logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}")

logger.info(f"Saving result summary to {Path(save_dir) / 'summary.json'}")
score_recorder.record_summary()

from jmteb.main import main, parse_args

if __name__ == "__main__":
parser = ArgumentParser(parser_mode="jsonnet")

parser.add_subclass_arguments(TextEmbedder, nested_key="embedder", required=True)
parser.add_argument(
"--evaluators", type=dict[str, EmbeddingEvaluator], enable_path=True, default="src/configs/jmteb.jsonnet"
)
parser.add_argument("--config", action=ActionConfigFile, help="Path to the config file.")
parser.add_argument("--save_dir", type=str, default=None, help="Directory to save the outputs")
parser.add_argument("--overwrite_cache", type=bool, default=False, help="Overwrite the save_dir if it exists")
parser.add_argument("--eval_exclude", type=list[str], default=None, help="Evaluators to exclude.")

args = parser.parse_args()

if args.eval_exclude is not None:
evaluator_keys = list(args.evaluators.keys())
# remove evaluators in eval_exclude
for key in evaluator_keys:
if key in args.eval_exclude:
args.evaluators.pop(key)

args = parser.instantiate_classes(args)
args = parse_args()

main(
text_embedder=args.embedder,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding Benchmark)"
name = "JMTEB"
packages = [{from = ".", include = "src"}]
packages = [{from = "src", include = "jmteb"}]
readme = "README.md"
version = "1.0.0"

Expand Down
2 changes: 0 additions & 2 deletions src/embedders/__init__.py

This file was deleted.

File renamed without changes.
9 changes: 9 additions & 0 deletions src/jmteb/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from jmteb.main import main, parse_args

args = parse_args()
main(
text_embedder=args.embedder,
evaluators=args.evaluators,
save_dir=args.save_dir,
overwrite_cache=args.overwrite_cache,
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions src/jmteb/embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from jmteb.embedders.base import TextEmbedder
from jmteb.embedders.sbert_embedder import SentenceBertEmbedder
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from sentence_transformers import SentenceTransformer

from src.embedders.base import TextEmbedder
from jmteb.embedders.base import TextEmbedder


class SentenceBertEmbedder(TextEmbedder):
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/evaluators/base.py → src/jmteb/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic.dataclasses import dataclass

from src.embedders import TextEmbedder
from jmteb.embedders import TextEmbedder


@dataclass
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from loguru import logger
from sklearn.metrics import accuracy_score, f1_score

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .classifiers import Classifier, KnnClassifier, LogRegClassifier
from .data import ClassificationDataset
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
)
from sklearn.metrics import homogeneity_completeness_v_measure

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import ClusteringDataset

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from os import PathLike
from pathlib import Path

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import PairClassificationDataset
from .threshold_accuracy import ThresholdAccuracyMetric
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from loguru import logger
from torch import Tensor

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import RetrievalDocDataset, RetrievalQueryDataset

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from scipy.stats import pearsonr, spearmanr
from torch import Tensor

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import STSDataset

Expand Down
68 changes: 68 additions & 0 deletions src/jmteb/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

import json
from pathlib import Path

from jsonargparse import ActionConfigFile, ArgumentParser
from loguru import logger

from jmteb.embedders import TextEmbedder
from jmteb.evaluators import EmbeddingEvaluator
from jmteb.utils.score_recorder import JsonScoreRecorder


def main(
text_embedder: TextEmbedder,
evaluators: dict[str, EmbeddingEvaluator],
save_dir: str | None = None,
overwrite_cache: bool = False,
):
logger.info(f"Start evaluating the following tasks\n{list(evaluators.keys())}")

if save_dir:
Path(save_dir).mkdir(parents=True, exist_ok=True)

score_recorder = JsonScoreRecorder(save_dir)

for eval_name, evaluator in evaluators.items():
logger.info(f"Evaluating {eval_name}")

cache_dir = None
if save_dir is not None:
cache_dir = Path(save_dir) / "cache" / eval_name

metrics = evaluator(text_embedder, cache_dir=cache_dir, overwrite_cache=overwrite_cache)
score_recorder.record_task_scores(
scores=metrics,
dataset_name=eval_name,
task_name=evaluator.__class__.__name__.replace("Evaluator", ""),
)

logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}")

logger.info(f"Saving result summary to {Path(save_dir) / 'summary.json'}")
score_recorder.record_summary()


def parse_args():
parser = ArgumentParser(parser_mode="jsonnet")

parser.add_subclass_arguments(TextEmbedder, nested_key="embedder", required=True)
parser.add_argument(
"--evaluators", type=dict[str, EmbeddingEvaluator], enable_path=True, default="src/configs/jmteb.jsonnet"
)
parser.add_argument("--config", action=ActionConfigFile, help="Path to the config file.")
parser.add_argument("--save_dir", type=str, default=None, help="Directory to save the outputs")
parser.add_argument("--overwrite_cache", type=bool, default=False, help="Overwrite the save_dir if it exists")
parser.add_argument("--eval_exclude", type=list[str], default=None, help="Evaluators to exclude.")

args = parser.parse_args()

if args.eval_exclude is not None:
evaluator_keys = list(args.evaluators.keys())
# remove evaluators in eval_exclude
for key in evaluator_keys:
if key in args.eval_exclude:
args.evaluators.pop(key)

args = parser.instantiate_classes(args)
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path
from typing import Any

from src.evaluators import EvaluationResults
from jmteb.evaluators import EvaluationResults


class AbstractScoreRecorder(ABC):
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

from jmteb.embedders.sbert_embedder import SentenceBertEmbedder


def pytest_addoption(parser: pytest.Parser):
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
Expand All @@ -17,3 +19,8 @@ def pytest_collection_modifyitems(config: pytest.Config, items: pytest.Parser):
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)


@pytest.fixture(scope="module")
def embedder(model_name_or_path: str = "prajjwal1/bert-tiny"):
return SentenceBertEmbedder(model_name_or_path=model_name_or_path)
2 changes: 1 addition & 1 deletion tests/embedders/test_sbert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from src.embedders.sbert_embedder import SentenceBertEmbedder
from jmteb.embedders.sbert_embedder import SentenceBertEmbedder

MODEL_NAME_OR_PATH = "prajjwal1/bert-tiny"
OUTPUT_DIM = 128
Expand Down
2 changes: 1 addition & 1 deletion tests/evaluator/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from src.embedders.sbert_embedder import SentenceBertEmbedder
from jmteb.embedders.sbert_embedder import SentenceBertEmbedder


@pytest.fixture(scope="module")
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_classification_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from src.evaluators.classification import (
from jmteb.evaluators.classification import (
ClassificationDataset,
ClassificationEvaluator,
ClassificationInstance,
KnnClassifier,
LogRegClassifier,
)
from src.evaluators.classification.data import JsonlClassificationDataset
from jmteb.evaluators.classification.data import JsonlClassificationDataset


class DummyClassificationDataset(ClassificationDataset):
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_clustering_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from src.evaluators.clustering import (
from jmteb.evaluators.clustering import (
ClusteringDataset,
ClusteringEvaluator,
ClusteringInstance,
)
from src.evaluators.clustering.data import JsonlClusteringDataset
from jmteb.evaluators.clustering.data import JsonlClusteringDataset


class DummyClusteringDataset(ClusteringDataset):
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_pair_classification_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from src.evaluators.pair_classification import (
from jmteb.evaluators.pair_classification import (
PairClassificationDataset,
PairClassificationEvaluator,
PairClassificationInstance,
)
from src.evaluators.pair_classification.data import JsonlPairClassificationDataset
from jmteb.evaluators.pair_classification.data import JsonlPairClassificationDataset


class DummyBinaryDataset(PairClassificationDataset):
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_retrieval_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from src.evaluators.retrieval import (
from jmteb.evaluators.retrieval import (
RetrievalDoc,
RetrievalDocDataset,
RetrievalEvaluator,
RetrievalQuery,
RetrievalQueryDataset,
)
from src.evaluators.retrieval.data import (
from jmteb.evaluators.retrieval.data import (
JsonlRetrievalDocDataset,
JsonlRetrievalQueryDataset,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_sts_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.evaluators.sts import STSDataset, STSEvaluator, STSInstance
from src.evaluators.sts.data import JsonlSTSDataset
from jmteb.evaluators.sts import STSDataset, STSEvaluator, STSInstance
from jmteb.evaluators.sts.data import JsonlSTSDataset


class DummySTSDataset(STSDataset):
Expand Down

0 comments on commit 9278a8e

Please sign in to comment.