Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

git cloneだけでなく、リポジトリからpip installしてmain.pyの処理が実行できるようにする #1

Merged
merged 8 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ poetry run pytest tests
The following command evaluate the specified model on the all the tasks in JMTEB.

```bash
poetry run python main.py \
poetry run python -m jmteb \
--embedder SentenceBertEmbedder \
--embedder.model_name_or_path "<model_name_or_path>" \
--save_dir "output/<model_name_or_path>"
Expand All @@ -26,7 +26,7 @@ By default, the evaluation tasks are read from `src/configs/jmteb.jsonnet`.
If you want to evaluate the model on a specific task, you can specify the task via `--evaluators` option with the task config.

```bash
poetry run python main.py \
poetry run python -m jmteb \
--evaluators "src/configs/tasks/jsts.jsonnet" \
--embedder SentenceBertEmbedder \
--embedder.model_name_or_path "<model_name_or_path>" \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "The evaluation scripts for JMTEB (Japanese Massive Text Embedding Benchmark)"
name = "JMTEB"
packages = [{from = ".", include = "src"}]
packages = [{from = "src", include = "jmteb"}]
readme = "README.md"
version = "1.0.0"

Expand Down
2 changes: 0 additions & 2 deletions src/embedders/__init__.py

This file was deleted.

File renamed without changes.
18 changes: 14 additions & 4 deletions main.py → src/jmteb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from jsonargparse import ActionConfigFile, ArgumentParser
from loguru import logger

from src.embedders import TextEmbedder
from src.evaluators import EmbeddingEvaluator
from src.utils.score_recorder import JsonScoreRecorder
from jmteb.embedders import TextEmbedder
from jmteb.evaluators import EmbeddingEvaluator
from jmteb.utils.score_recorder import JsonScoreRecorder


def main(
Expand Down Expand Up @@ -49,7 +49,10 @@ def main(

parser.add_subclass_arguments(TextEmbedder, nested_key="embedder", required=True)
parser.add_argument(
"--evaluators", type=dict[str, EmbeddingEvaluator], enable_path=True, default="src/configs/jmteb.jsonnet"
"--evaluators",
type=dict[str, EmbeddingEvaluator],
enable_path=True,
default=str(Path(__file__).parent / "configs" / "jmteb.jsonnet"),
)
parser.add_argument("--config", action=ActionConfigFile, help="Path to the config file.")
parser.add_argument("--save_dir", type=str, default=None, help="Directory to save the outputs")
Expand All @@ -66,6 +69,13 @@ def main(
args.evaluators.pop(key)

args = parser.instantiate_classes(args)
if isinstance(args.evaluators, str):
raise ValueError(
"Evaluators should be a dictionary, not a string.\n"
"Perhaps you provided a path to a config file, "
"but the path does not exist or the config format is broken.\n"
f"Please check {args.evaluators}"
)

main(
text_embedder=args.embedder,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions src/jmteb/embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from jmteb.embedders.base import TextEmbedder
from jmteb.embedders.sbert_embedder import SentenceBertEmbedder
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from sentence_transformers import SentenceTransformer

from src.embedders.base import TextEmbedder
from jmteb.embedders.base import TextEmbedder


class SentenceBertEmbedder(TextEmbedder):
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/evaluators/base.py → src/jmteb/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic.dataclasses import dataclass

from src.embedders import TextEmbedder
from jmteb.embedders import TextEmbedder


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from loguru import logger
from sklearn.metrics import accuracy_score, f1_score

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .classifiers import Classifier, KnnClassifier, LogRegClassifier
from .data import ClassificationDataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
)
from sklearn.metrics import homogeneity_completeness_v_measure

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import ClusteringDataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from os import PathLike
from pathlib import Path

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import PairClassificationDataset
from .threshold_accuracy import ThresholdAccuracyMetric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from loguru import logger
from torch import Tensor

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import RetrievalDocDataset, RetrievalQueryDataset

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from scipy.stats import pearsonr, spearmanr
from torch import Tensor

from src.embedders.base import TextEmbedder
from src.evaluators.base import EmbeddingEvaluator, EvaluationResults
from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults

from .data import STSDataset

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path
from typing import Any

from src.evaluators import EvaluationResults
from jmteb.evaluators import EvaluationResults


class AbstractScoreRecorder(ABC):
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

from jmteb.embedders.sbert_embedder import SentenceBertEmbedder


def pytest_addoption(parser: pytest.Parser):
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
Expand All @@ -17,3 +19,8 @@ def pytest_collection_modifyitems(config: pytest.Config, items: pytest.Parser):
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)


@pytest.fixture(scope="module")
def embedder(model_name_or_path: str = "prajjwal1/bert-tiny"):
return SentenceBertEmbedder(model_name_or_path=model_name_or_path)
2 changes: 1 addition & 1 deletion tests/embedders/test_sbert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from src.embedders.sbert_embedder import SentenceBertEmbedder
from jmteb.embedders.sbert_embedder import SentenceBertEmbedder

MODEL_NAME_OR_PATH = "prajjwal1/bert-tiny"
OUTPUT_DIM = 128
Expand Down
2 changes: 1 addition & 1 deletion tests/evaluator/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from src.embedders.sbert_embedder import SentenceBertEmbedder
from jmteb.embedders.sbert_embedder import SentenceBertEmbedder


@pytest.fixture(scope="module")
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_classification_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from src.evaluators.classification import (
from jmteb.evaluators.classification import (
ClassificationDataset,
ClassificationEvaluator,
ClassificationInstance,
KnnClassifier,
LogRegClassifier,
)
from src.evaluators.classification.data import JsonlClassificationDataset
from jmteb.evaluators.classification.data import JsonlClassificationDataset


class DummyClassificationDataset(ClassificationDataset):
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_clustering_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from src.evaluators.clustering import (
from jmteb.evaluators.clustering import (
ClusteringDataset,
ClusteringEvaluator,
ClusteringInstance,
)
from src.evaluators.clustering.data import JsonlClusteringDataset
from jmteb.evaluators.clustering.data import JsonlClusteringDataset


class DummyClusteringDataset(ClusteringDataset):
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_pair_classification_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from src.evaluators.pair_classification import (
from jmteb.evaluators.pair_classification import (
PairClassificationDataset,
PairClassificationEvaluator,
PairClassificationInstance,
)
from src.evaluators.pair_classification.data import JsonlPairClassificationDataset
from jmteb.evaluators.pair_classification.data import JsonlPairClassificationDataset


class DummyBinaryDataset(PairClassificationDataset):
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_retrieval_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from src.evaluators.retrieval import (
from jmteb.evaluators.retrieval import (
RetrievalDoc,
RetrievalDocDataset,
RetrievalEvaluator,
RetrievalQuery,
RetrievalQueryDataset,
)
from src.evaluators.retrieval.data import (
from jmteb.evaluators.retrieval.data import (
JsonlRetrievalDocDataset,
JsonlRetrievalQueryDataset,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluator/test_sts_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.evaluators.sts import STSDataset, STSEvaluator, STSInstance
from src.evaluators.sts.data import JsonlSTSDataset
from jmteb.evaluators.sts import STSDataset, STSEvaluator, STSInstance
from jmteb.evaluators.sts.data import JsonlSTSDataset


class DummySTSDataset(STSDataset):
Expand Down
Loading