diff --git a/.github/workflows/python-package-pip.yaml b/.github/workflows/python-package-pip.yaml index b8aaf46..e6b6ba8 100644 --- a/.github/workflows/python-package-pip.yaml +++ b/.github/workflows/python-package-pip.yaml @@ -10,7 +10,7 @@ on: env: CACHE_NUMBER: 0 # increase to reset cache manually - TMPDIR: '/tmp' + AAC_METRICS_TMP_PATH: '/tmp' # Cancel workflow if a new push occurs concurrency: @@ -23,7 +23,7 @@ jobs: strategy: matrix: - os: [ubuntu-latest,windows-latest] + os: [ubuntu-latest,windows-latest,macos-latest] python-version: ["3.9"] java-version: ["11"] diff --git a/.gitmodules b/.gitmodules index 206d686..2a01231 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,7 @@ [submodule "caption-evaluation-tools"] path = tests/caption-evaluation-tools url = https://github.com/audio-captioning/caption-evaluation-tools + ignore = dirty branch = master [submodule "fense"] diff --git a/CHANGELOG.md b/CHANGELOG.md index ccdc785..ae7fa7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,18 @@ All notable changes to this project will be documented in this file. +## [0.5.0] 2023-12-08 +### Added +- New `Vocab` metric to compute vocabulary size and vocabulary ratio. +- New `BERTScoreMRefs` metric wrapper to compute BERTScore with multiple references. + +### Changed +- Rename metric `FluErr` to `FER`. + +### Fixed +- `METEOR` localization issue. ([#9](https://github.com/Labbeti/aac-metrics/issues/9)) +- `SPIDErMax` output when `return_all_scores=False`. + ## [0.4.6] 2023-10-10 ### Added - Argument `clean_archives` for `SPICE` download. diff --git a/CITATION.cff b/CITATION.cff index 872db92..e9ecb6d 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -19,5 +19,5 @@ keywords: - captioning - audio-captioning license: MIT -version: 0.4.6 -date-released: '2023-10-10' +version: 0.5.0 +date-released: '2023-12-08' diff --git a/README.md b/README.md index 1d32314..dec1684 100644 --- a/README.md +++ b/README.md @@ -17,20 +17,22 @@ Metrics for evaluating Automated Audio Captioning systems, designed for PyTorch. ## Why using this package? -- **Easy installation and download** -- **Same results than [caption-evaluation-tools](https://github.com/audio-captioning/caption-evaluation-tools) and [fense](https://github.com/blmoistawinde/fense) repositories** -- **Provides the following metrics:** +- **Easy to install and download** +- **Produces same results than [caption-evaluation-tools](https://github.com/audio-captioning/caption-evaluation-tools) and [fense](https://github.com/blmoistawinde/fense) repositories** +- **Provides 12 different metrics:** - BLEU [[1]](#bleu) - ROUGE-L [[2]](#rouge-l) - METEOR [[3]](#meteor) - CIDEr-D [[4]](#cider) - SPICE [[5]](#spice) - SPIDEr [[6]](#spider) - - SPIDEr-max [[7]](#spider-max) - - SBERT-sim [[8]](#fense) - - Fluency Error [[8]](#fense) - - FENSE [[8]](#fense) - - SPIDEr-FL [[9]](#spider-fl) + - BERTScore [[7]](#bertscore) + - SPIDEr-max [[8]](#spider-max) + - SBERT-sim [[9]](#fense) + - FER [[9]](#fense) + - FENSE [[9]](#fense) + - SPIDEr-FL [[10]](#spider-fl) + - Vocab (unique word vocabulary) ## Installation Install the pip package: @@ -100,7 +102,7 @@ Each metrics also exists as a python class version, like `aac_metrics.classes.ci ## Metrics ### Legacy metrics -| Metric | Python Class | Origin | Range | Short description | +| Metric name | Python Class | Origin | Range | Short description | |:---|:---|:---|:---|:---| | BLEU [[1]](#bleu) | `BLEU` | machine translation | [0, 1] | Precision of n-grams | | ROUGE-L [[2]](#rouge-l) | `ROUGEL` | text summarization | [0, 1] | FScore of the longest common subsequence | @@ -108,20 +110,29 @@ Each metrics also exists as a python class version, like `aac_metrics.classes.ci | CIDEr-D [[4]](#cider) | `CIDErD` | image captioning | [0, 10] | Cosine-similarity of TF-IDF computed on n-grams | | SPICE [[5]](#spice) | `SPICE` | image captioning | [0, 1] | FScore of a semantic graph | | SPIDEr [[6]](#spider) | `SPIDEr` | image captioning | [0, 5.5] | Mean of CIDEr-D and SPICE | +| BERTScore [[7]](#bertscore) | `BERTScoreMRefs` | text generation | [0, 1] | Fscore of BERT embeddings. In contrast to torchmetrics, it supports multiple references per file. | ### AAC-specific metrics | Metric name | Python Class | Origin | Range | Short description | |:---|:---|:---|:---|:---| -| SPIDEr-max [[7]](#spider-max) | `SPIDErMax` | audio captioning | [0, 5.5] | Max of SPIDEr scores for multiples candidates | -| SBERT-sim [[8]](#spider-max) | `SBERTSim` | audio captioning | [-1, 1] | Cosine-similarity of **Sentence-BERT embeddings** | -| Fluency error rate [[8]](#spider-max) | `FluErr` | audio captioning | [0, 1] | Detect fluency errors in sentences with a pretrained model | -| FENSE [[8]](#fense) | `FENSE` | audio captioning | [-1, 1] | Combines SBERT-sim and Fluency Error rate | -| SPIDEr-FL [[9]](#spider-fl) | `SPIDErFL` | audio captioning | [0, 5.5] | Combines SPIDEr and Fluency Error rate | +| SPIDEr-max [[8]](#spider-max) | `SPIDErMax` | audio captioning | [0, 5.5] | Max of SPIDEr scores for multiples candidates | +| SBERT-sim [[9]](#fense) | `SBERTSim` | audio captioning | [-1, 1] | Cosine-similarity of **Sentence-BERT embeddings** | +| Fluency Error Rate [[9]](#fense) | `FER` | audio captioning | [0, 1] | Detect fluency errors in sentences with a pretrained model | +| FENSE [[9]](#fense) | `FENSE` | audio captioning | [-1, 1] | Combines SBERT-sim and Fluency Error rate | +| SPIDEr-FL [[10]](#spider-fl) | `SPIDErFL` | audio captioning | [0, 5.5] | Combines SPIDEr and Fluency Error rate | + +### Other metrics +| Metric name | Python Class | Origin | Range | Short description | +|:---|:---|:---|:---|:---| +| Vocabulary | `Vocab` | text generation | [0, +$\infty$[ | Number of unique words in candidates. | -### AAC metrics not implemented -- CB-Score [[10]](#cb-score) -- SPICE+ [[11]](#spice-plus) -- ACES [[12]](#aces) (can be found here: https://github.com/GlJS/ACES) +### Future directions +This package currently does not include all metrics dedicated to audio captioning. Feel free to do a pull request / or ask to me by email if you want to include them. Those metrics not included are listed here: +- CB-Score [[11]](#cb-score) +- SPICE+ [[12]](#spice-plus) +- ACES [[13]](#aces) (can be found here: https://github.com/GlJS/ACES) +- SBF [[14]](#sbf) +- s2v [[15]](#s2v) ## Requirements This package has been developped for Ubuntu 20.04, and it is expected to work on most Linux distributions. Windows is not officially supported. @@ -136,6 +147,7 @@ pyyaml >= 6.0 tqdm >= 4.64.0 sentence-transformers >= 2.2.2 transformers < 4.31.0 +torchmetrics >= 0.11.4 ``` ### External requirements @@ -154,64 +166,54 @@ No. Most of these metrics use numpy or external java programs to run, which prev ### Do metrics work on Windows/Mac OS? Maybe. Most of the metrics only need python to run, which can be done on Windows. However, you might expect errors with METEOR metric, SPICE-based metrics and PTB tokenizer, since they requires an external java program to run. -## SPIDEr-max metric +## About SPIDEr-max metric SPIDEr-max [[7]](#spider-max) is a metric based on SPIDEr that takes into account multiple candidates for the same audio. It computes the maximum of the SPIDEr scores for each candidate to balance the high sensitivity to the frequency of the words generated by the model. For more detail, please see the [documentation about SPIDEr-max](https://aac-metrics.readthedocs.io/en/stable/spider_max.html). ## References #### BLEU -[1] K. Papineni, S. Roukos, T. Ward, and W.-J. Zhu, “BLEU: a -method for automatic evaluation of machine translation,” in Proceed- -ings of the 40th Annual Meeting on Association for Computational -Linguistics - ACL ’02. Philadelphia, Pennsylvania: Association -for Computational Linguistics, 2001, p. 311. [Online]. Available: -http://portal.acm.org/citation.cfm?doid=1073083.1073135 +[1] K. Papineni, S. Roukos, T. Ward, and W.-J. Zhu, “BLEU: a method for automatic evaluation of machine translation,” in Proceedings of the 40th Annual Meeting on Association for Computational Linguistics - ACL ’02. Philadelphia, Pennsylvania: Association for Computational Linguistics, 2001, p. 311. [Online]. Available: http://portal.acm.org/citation.cfm?doid=1073083.1073135 #### ROUGE-L -[2] C.-Y. Lin, “ROUGE: A package for automatic evaluation of summaries,” -in Text Summarization Branches Out. Barcelona, Spain: Association -for Computational Linguistics, Jul. 2004, pp. 74–81. [Online]. Available: -https://aclanthology.org/W04-1013 +[2] C.-Y. Lin, “ROUGE: A package for automatic evaluation of summaries,” in Text Summarization Branches Out. Barcelona, Spain: Association for Computational Linguistics, Jul. 2004, pp. 74–81. [Online]. Available: https://aclanthology.org/W04-1013 #### METEOR -[3] M. Denkowski and A. Lavie, “Meteor Universal: Language Specific -Translation Evaluation for Any Target Language,” in Proceedings of the -Ninth Workshop on Statistical Machine Translation. Baltimore, Maryland, -USA: Association for Computational Linguistics, 2014, pp. 376–380. -[Online]. Available: http://aclweb.org/anthology/W14-3348 +[3] M. Denkowski and A. Lavie, “Meteor Universal: Language Specific Translation Evaluation for Any Target Language,” in Proceedings of the Ninth Workshop on Statistical Machine Translation. Baltimore, Maryland, USA: Association for Computational Linguistics, 2014, pp. 376–380. [Online]. Available: http://aclweb.org/anthology/W14-3348 #### CIDEr -[4] R. Vedantam, C. L. Zitnick, and D. Parikh, “CIDEr: Consensus-based -Image Description Evaluation,” arXiv:1411.5726 [cs], Jun. 2015, arXiv: -1411.5726. [Online]. Available: http://arxiv.org/abs/1411.5726 +[4] R. Vedantam, C. L. Zitnick, and D. Parikh, “CIDEr: Consensus-based Image Description Evaluation,” arXiv:1411.5726 [cs], Jun. 2015, [Online]. Available: http://arxiv.org/abs/1411.5726 #### SPICE -[5] P. Anderson, B. Fernando, M. Johnson, and S. Gould, “SPICE: Semantic -Propositional Image Caption Evaluation,” arXiv:1607.08822 [cs], Jul. 2016, -arXiv: 1607.08822. [Online]. Available: http://arxiv.org/abs/1607.08822 +[5] P. Anderson, B. Fernando, M. Johnson, and S. Gould, “SPICE: Semantic Propositional Image Caption Evaluation,” arXiv:1607.08822 [cs], Jul. 2016, [Online]. Available: http://arxiv.org/abs/1607.08822 #### SPIDEr -[6] S. Liu, Z. Zhu, N. Ye, S. Guadarrama, and K. Murphy, “Improved Image -Captioning via Policy Gradient optimization of SPIDEr,” 2017 IEEE Inter- -national Conference on Computer Vision (ICCV), pp. 873–881, Oct. 2017, -arXiv: 1612.00370. [Online]. Available: http://arxiv.org/abs/1612.00370 +[6] S. Liu, Z. Zhu, N. Ye, S. Guadarrama, and K. Murphy, “Improved Image Captioning via Policy Gradient optimization of SPIDEr,” 2017 IEEE International Conference on Computer Vision (ICCV), pp. 873–881, Oct. 2017, arXiv: 1612.00370. [Online]. Available: http://arxiv.org/abs/1612.00370 + +#### BERTScore +[7] T. Zhang*, V. Kishore*, F. Wu*, K. Q. Weinberger, and Y. Artzi, “BERTScore: Evaluating Text Generation with BERT,” 2020. [Online]. Available: https://openreview.net/forum?id=SkeHuCVFDr #### SPIDEr-max -[7] E. Labbé, T. Pellegrini, and J. Pinquier, “Is my automatic audio captioning system so bad? spider-max: a metric to consider several caption candidates,” Nov. 2022. [Online]. Available: https://hal.archives-ouvertes.fr/hal-03810396 +[8] E. Labbé, T. Pellegrini, and J. Pinquier, “Is my automatic audio captioning system so bad? spider-max: a metric to consider several caption candidates,” Nov. 2022. [Online]. Available: https://hal.archives-ouvertes.fr/hal-03810396 #### FENSE -[8] Z. Zhou, Z. Zhang, X. Xu, Z. Xie, M. Wu, and K. Q. Zhu, Can Audio Captions Be Evaluated with Image Caption Metrics? arXiv, 2022. [Online]. Available: http://arxiv.org/abs/2110.04684 +[9] Z. Zhou, Z. Zhang, X. Xu, Z. Xie, M. Wu, and K. Q. Zhu, Can Audio Captions Be Evaluated with Image Caption Metrics? arXiv, 2022. [Online]. Available: http://arxiv.org/abs/2110.04684 #### SPIDEr-FL -[9] DCASE website task6a description: https://dcase.community/challenge2023/task-automated-audio-captioning#evaluation +[10] DCASE website task6a description: https://dcase.community/challenge2023/task-automated-audio-captioning#evaluation #### CB-score [11] I. Martín-Morató, M. Harju, and A. Mesaros, “A Summarization Approach to Evaluating Audio Captioning,” Nov. 2022. [Online]. Available: https://dcase.community/documents/workshop2022/proceedings/DCASE2022Workshop_Martin-Morato_35.pdf #### SPICE-plus -[10] F. Gontier, R. Serizel, and C. Cerisara, “SPICE+: Evaluation of Automatic Audio Captioning Systems with Pre-Trained Language Models,” in ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2023, pp. 1–5. doi: 10.1109/ICASSP49357.2023.10097021. +[12] F. Gontier, R. Serizel, and C. Cerisara, “SPICE+: Evaluation of Automatic Audio Captioning Systems with Pre-Trained Language Models,” in ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2023, pp. 1–5. doi: 10.1109/ICASSP49357.2023.10097021. #### ACES -[12] G. Wijngaard, E. Formisano, B. L. Giordano, M. Dumontier, “ACES: Evaluating Automated Audio Captioning Models on the Semantics of Sounds”, in EUSIPCO 2023, 2023. +[13] G. Wijngaard, E. Formisano, B. L. Giordano, M. Dumontier, “ACES: Evaluating Automated Audio Captioning Models on the Semantics of Sounds”, in EUSIPCO 2023, 2023. + +#### SBF +[14] R. Mahfuz, Y. Guo, A. K. Sridhar, and E. Visser, Detecting False Alarms and Misses in Audio Captions. 2023. [Online]. Available: https://arxiv.org/pdf/2309.03326.pdf + +#### s2v +[15] S. Bhosale, R. Chakraborty, and S. K. Kopparapu, “A Novel Metric For Evaluating Audio Caption Similarity,” in ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2023, pp. 1–5. doi: 10.1109/ICASSP49357.2023.10096526. ## Citation If you use **SPIDEr-max**, you can cite the following paper using BibTex : @@ -227,20 +229,21 @@ If you use **SPIDEr-max**, you can cite the following paper using BibTex : } ``` -If you use this software, please consider cite it as below : +If you use this software, please consider cite it as "Labbe, E. (2013). aac-metrics: Metrics for evaluating Automated Audio Captioning systems for PyTorch.", or use the following BibTeX citation: + ``` @software{ - Labbe_aac-metrics_2023, + Labbe_aac_metrics_2023, author = {Labbé, Etienne}, license = {MIT}, - month = {10}, + month = {12}, title = {{aac-metrics}}, url = {https://github.com/Labbeti/aac-metrics/}, - version = {0.4.6}, + version = {0.5.0}, year = {2023}, } ``` ## Contact Maintainer: -- Etienne Labbé "Labbeti": labbeti.pub@gmail.com +- Étienne Labbé "Labbeti": labbeti.pub@gmail.com diff --git a/docs/aac_metrics.classes.bert_score_mrefs.rst b/docs/aac_metrics.classes.bert_score_mrefs.rst new file mode 100644 index 0000000..f567b02 --- /dev/null +++ b/docs/aac_metrics.classes.bert_score_mrefs.rst @@ -0,0 +1,7 @@ +aac\_metrics.classes.bert\_score\_mrefs module +============================================== + +.. automodule:: aac_metrics.classes.bert_score_mrefs + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/aac_metrics.classes.fer.rst b/docs/aac_metrics.classes.fer.rst new file mode 100644 index 0000000..1c5d23b --- /dev/null +++ b/docs/aac_metrics.classes.fer.rst @@ -0,0 +1,7 @@ +aac\_metrics.classes.fer module +=============================== + +.. automodule:: aac_metrics.classes.fer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/aac_metrics.classes.vocab.rst b/docs/aac_metrics.classes.vocab.rst new file mode 100644 index 0000000..88ea94a --- /dev/null +++ b/docs/aac_metrics.classes.vocab.rst @@ -0,0 +1,7 @@ +aac\_metrics.classes.vocab module +================================= + +.. automodule:: aac_metrics.classes.vocab + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/aac_metrics.functional.bert_score_mrefs.rst b/docs/aac_metrics.functional.bert_score_mrefs.rst new file mode 100644 index 0000000..8142197 --- /dev/null +++ b/docs/aac_metrics.functional.bert_score_mrefs.rst @@ -0,0 +1,7 @@ +aac\_metrics.functional.bert\_score\_mrefs module +================================================= + +.. automodule:: aac_metrics.functional.bert_score_mrefs + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/aac_metrics.functional.fer.rst b/docs/aac_metrics.functional.fer.rst new file mode 100644 index 0000000..0fec93f --- /dev/null +++ b/docs/aac_metrics.functional.fer.rst @@ -0,0 +1,7 @@ +aac\_metrics.functional.fer module +================================== + +.. automodule:: aac_metrics.functional.fer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/aac_metrics.functional.fluerr.rst b/docs/aac_metrics.functional.fluerr.rst deleted file mode 100644 index 000cef6..0000000 --- a/docs/aac_metrics.functional.fluerr.rst +++ /dev/null @@ -1,7 +0,0 @@ -aac\_metrics.functional.fluerr module -===================================== - -.. automodule:: aac_metrics.functional.fluerr - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/aac_metrics.functional.vocab.rst b/docs/aac_metrics.functional.vocab.rst new file mode 100644 index 0000000..0c9df8e --- /dev/null +++ b/docs/aac_metrics.functional.vocab.rst @@ -0,0 +1,7 @@ +aac\_metrics.functional.vocab module +==================================== + +.. automodule:: aac_metrics.functional.vocab + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/aac_metrics.utils.cmdline.rst b/docs/aac_metrics.utils.cmdline.rst new file mode 100644 index 0000000..c9a2505 --- /dev/null +++ b/docs/aac_metrics.utils.cmdline.rst @@ -0,0 +1,7 @@ +aac\_metrics.utils.cmdline module +================================= + +.. automodule:: aac_metrics.utils.cmdline + :members: + :undoc-members: + :show-inheritance: diff --git a/pyproject.toml b/pyproject.toml index 5ce6ac1..e344ac4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,15 +21,7 @@ classifiers = [ maintainers = [ {name = "Etienne Labbé (Labbeti)", email = "labbeti.pub@gmail.com"}, ] -dependencies = [ - "torch>=1.10.1", - "numpy>=1.21.2", - "pyyaml>=6.0", - "tqdm>=4.64.0", - "sentence-transformers>=2.2.2", - "transformers<4.31.0", -] -dynamic = ["version"] +dynamic = ["version", "dependencies", "optional-dependencies"] [project.urls] Homepage = "https://pypi.org/project/aac-metrics/" @@ -43,19 +35,11 @@ aac-metrics-download = "aac_metrics.download:_main_download" aac-metrics-eval = "aac_metrics.eval:_main_eval" aac-metrics-info = "aac_metrics.info:print_install_info" -[project.optional-dependencies] -dev = [ - "pytest==7.1.2", - "flake8==4.0.1", - "black==22.8.0", - "scikit-image==0.19.2", - "matplotlib==3.5.2", - "torchmetrics>=0.10", -] - [tool.setuptools.packages.find] where = ["src"] # list of folders that contain the packages (["."] by default) include = ["aac_metrics*"] # package names should match these glob patterns (["*"] by default) [tool.setuptools.dynamic] version = {attr = "aac_metrics.__version__"} +dependencies = {file = ["requirements.txt"]} +optional-dependencies = {dev = { file = ["requirements-dev.txt"] }} diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..a782ffb --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +pytest==7.1.2 +flake8==4.0.1 +black==22.8.0 +scikit-image==0.19.2 +matplotlib==3.5.2 +ipykernel==6.9.1 +twine==4.0.1 diff --git a/requirements.txt b/requirements.txt index cff75af..a87060f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ pyyaml>=6.0 tqdm>=4.64.0 sentence-transformers>=2.2.2 transformers<4.31.0 +torchmetrics>=0.11.4 diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index a4aa875..199bdbb 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -10,14 +10,14 @@ __maintainer__ = "Etienne Labbé (Labbeti)" __name__ = "aac-metrics" __status__ = "Development" -__version__ = "0.4.6" +__version__ = "0.5.0" from .classes.base import AACMetric from .classes.bleu import BLEU from .classes.cider_d import CIDErD from .classes.evaluate import Evaluate, DCASE2023Evaluate, _get_metric_factory_classes -from .classes.fluerr import FluErr +from .classes.fer import FER from .classes.fense import FENSE from .classes.meteor import METEOR from .classes.rouge_l import ROUGEL @@ -26,6 +26,7 @@ from .classes.spider import SPIDEr from .classes.spider_fl import SPIDErFL from .classes.spider_max import SPIDErMax +from .classes.vocab import Vocab from .functional.evaluate import evaluate, dcase2023_evaluate from .utils.paths import ( get_default_cache_path, @@ -44,7 +45,7 @@ "Evaluate", "DCASE2023Evaluate", "FENSE", - "FluErr", + "FER", "METEOR", "ROUGEL", "SBERTSim", @@ -52,6 +53,7 @@ "SPIDEr", "SPIDErFL", "SPIDErMax", + "Vocab", "evaluate", "dcase2023_evaluate", "get_default_cache_path", diff --git a/src/aac_metrics/classes/__init__.py b/src/aac_metrics/classes/__init__.py index c2614ed..1058fae 100644 --- a/src/aac_metrics/classes/__init__.py +++ b/src/aac_metrics/classes/__init__.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from .bleu import BLEU +from .bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4 from .cider_d import CIDErD from .evaluate import DCASE2023Evaluate, Evaluate from .fense import FENSE -from .fluerr import FluErr +from .fer import FER from .meteor import METEOR from .rouge_l import ROUGEL from .sbert_sim import SBERTSim @@ -13,15 +13,20 @@ from .spider import SPIDEr from .spider_fl import SPIDErFL from .spider_max import SPIDErMax +from .vocab import Vocab __all__ = [ "BLEU", + "BLEU1", + "BLEU2", + "BLEU3", + "BLEU4", "CIDErD", "DCASE2023Evaluate", "Evaluate", "FENSE", - "FluErr", + "FER", "METEOR", "ROUGEL", "SBERTSim", @@ -29,4 +34,5 @@ "SPIDEr", "SPIDErFL", "SPIDErMax", + "Vocab", ] diff --git a/src/aac_metrics/classes/base.py b/src/aac_metrics/classes/base.py index 0f0882b..3df7b33 100644 --- a/src/aac_metrics/classes/base.py +++ b/src/aac_metrics/classes/base.py @@ -1,10 +1,12 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Any, Generic, Optional, TypeVar +from typing import Any, ClassVar, Generic, Optional, TypeVar, Union -from torch import nn +from torch import nn, Tensor + +DefaultOutType = Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor] OutType = TypeVar("OutType") @@ -12,14 +14,14 @@ class AACMetric(nn.Module, Generic[OutType]): """Base Metric module for AAC metrics. Similar to torchmetrics.Metric.""" # Global values - full_state_update: Optional[bool] = False - higher_is_better: Optional[bool] = None - is_differentiable: Optional[bool] = False + full_state_update: ClassVar[Optional[bool]] = False + higher_is_better: ClassVar[Optional[bool]] = None + is_differentiable: ClassVar[Optional[bool]] = False # The theorical minimal value of the main global score of the metric. - min_value: Optional[float] = None + min_value: ClassVar[Optional[float]] = None # The theorical maximal value of the main global score of the metric. - max_value: Optional[float] = None + max_value: ClassVar[Optional[float]] = None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/aac_metrics/classes/bert_score_mrefs.py b/src/aac_metrics/classes/bert_score_mrefs.py new file mode 100644 index 0000000..4f07c06 --- /dev/null +++ b/src/aac_metrics/classes/bert_score_mrefs.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Union + +import torch + +from torch import nn, Tensor +from torchmetrics.text.bert import _DEFAULT_MODEL + +from aac_metrics.classes.base import AACMetric +from aac_metrics.functional.bert_score_mrefs import ( + bert_score_mrefs, + _load_model_and_tokenizer, +) + + +class BERTScoreMRefs(AACMetric): + """BERTScore metric which supports multiple references. + + The implementation is based on the bert_score implementation of torchmetrics. + + - Paper: https://arxiv.org/pdf/1904.09675.pdf + + For more information, see :func:`~aac_metrics.functional.bert_score.bert_score_mrefs`. + """ + + full_state_update = False + higher_is_better = True + is_differentiable = False + + min_value = 0.0 + max_value = 1.0 + + def __init__( + self, + return_all_scores: bool = True, + model: Union[str, nn.Module] = _DEFAULT_MODEL, + device: Union[str, torch.device, None] = "auto", + batch_size: int = 32, + num_threads: int = 0, + max_length: int = 64, + reset_state: bool = True, + idf: bool = False, + reduction: str = "max", + filter_nan: bool = True, + verbose: int = 0, + ) -> None: + model, tokenizer = _load_model_and_tokenizer( + model, None, device, reset_state, verbose + ) + + super().__init__() + self._return_all_scores = return_all_scores + self._model = model + self._tokenizer = tokenizer + self._device = device + self._batch_size = batch_size + self._num_threads = num_threads + self._max_length = max_length + self._reset_state = reset_state + self._idf = idf + self._reduction = reduction + self._filter_nan = filter_nan + self._verbose = verbose + + self._candidates = [] + self._mult_references = [] + + def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bert_score_mrefs( + candidates=self._candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, + model=self._model, + tokenizer=self._tokenizer, + device=self._device, + batch_size=self._batch_size, + num_threads=self._num_threads, + max_length=self._max_length, + reset_state=self._reset_state, + idf=self._idf, + reduction=self._reduction, + filter_nan=self._filter_nan, + verbose=self._verbose, + ) + + def extra_repr(self) -> str: + if isinstance(self._model, str): + model_name = self._model + else: + model_name = self._model.__class__.__name__ + + hparams = {"model": model_name, "idf": self._idf} + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ + + def get_output_names(self) -> tuple[str, ...]: + return ( + "bert_score.precision", + "bert_score.recalll", + "bert_score.f1", + ) + + def reset(self) -> None: + self._candidates = [] + self._mult_references = [] + return super().reset() + + def update( + self, + candidates: list[str], + mult_references: list[list[str]], + ) -> None: + self._candidates += candidates + self._mult_references += mult_references diff --git a/src/aac_metrics/classes/bleu.py b/src/aac_metrics/classes/bleu.py index 926945a..c6f1af2 100644 --- a/src/aac_metrics/classes/bleu.py +++ b/src/aac_metrics/classes/bleu.py @@ -63,7 +63,9 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: ) def extra_repr(self) -> str: - return f"n={self._n}" + hparams = {"n": self._n} + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: return (f"bleu_{self._n}",) @@ -96,7 +98,13 @@ def __init__( verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, ) -> None: - super().__init__(return_all_scores, 1, option, verbose, tokenizer) + super().__init__( + return_all_scores=return_all_scores, + n=1, + option=option, + verbose=verbose, + tokenizer=tokenizer, + ) class BLEU2(BLEU): @@ -107,7 +115,13 @@ def __init__( verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, ) -> None: - super().__init__(return_all_scores, 2, option, verbose, tokenizer) + super().__init__( + return_all_scores=return_all_scores, + n=2, + option=option, + verbose=verbose, + tokenizer=tokenizer, + ) class BLEU3(BLEU): @@ -118,7 +132,13 @@ def __init__( verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, ) -> None: - super().__init__(return_all_scores, 3, option, verbose, tokenizer) + super().__init__( + return_all_scores=return_all_scores, + n=3, + option=option, + verbose=verbose, + tokenizer=tokenizer, + ) class BLEU4(BLEU): @@ -129,4 +149,10 @@ def __init__( verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, ) -> None: - super().__init__(return_all_scores, 4, option, verbose, tokenizer) + super().__init__( + return_all_scores=return_all_scores, + n=4, + option=option, + verbose=verbose, + tokenizer=tokenizer, + ) diff --git a/src/aac_metrics/classes/cider_d.py b/src/aac_metrics/classes/cider_d.py index b22c7f5..77e5eaa 100644 --- a/src/aac_metrics/classes/cider_d.py +++ b/src/aac_metrics/classes/cider_d.py @@ -59,7 +59,9 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: ) def extra_repr(self) -> str: - return f"n={self._n}, sigma={self._sigma}" + hparams = {"n": self._n, "sigma": self._sigma} + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: return ("cider_d",) diff --git a/src/aac_metrics/classes/evaluate.py b/src/aac_metrics/classes/evaluate.py index 3df9836..a4eee40 100644 --- a/src/aac_metrics/classes/evaluate.py +++ b/src/aac_metrics/classes/evaluate.py @@ -5,23 +5,27 @@ import pickle import zlib -from typing import Any, Callable, Iterable, Union +from pathlib import Path +from typing import Any, Callable, Iterable, Optional, Union import torch from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.classes.bleu import BLEU +from aac_metrics.classes.bert_score_mrefs import BERTScoreMRefs +from aac_metrics.classes.bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4 from aac_metrics.classes.cider_d import CIDErD from aac_metrics.classes.fense import FENSE -from aac_metrics.classes.fluerr import FluErr +from aac_metrics.classes.fer import FER from aac_metrics.classes.meteor import METEOR from aac_metrics.classes.rouge_l import ROUGEL from aac_metrics.classes.sbert_sim import SBERTSim from aac_metrics.classes.spice import SPICE from aac_metrics.classes.spider import SPIDEr +from aac_metrics.classes.spider_max import SPIDErMax from aac_metrics.classes.spider_fl import SPIDErFL +from aac_metrics.classes.vocab import Vocab from aac_metrics.functional.evaluate import ( DEFAULT_METRICS_SET_NAME, METRICS_SETS, @@ -48,9 +52,9 @@ def __init__( metrics: Union[ str, Iterable[str], Iterable[AACMetric] ] = DEFAULT_METRICS_SET_NAME, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> None: @@ -120,9 +124,9 @@ class DCASE2023Evaluate(Evaluate): def __init__( self, preprocess: bool = True, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> None: @@ -139,9 +143,9 @@ def __init__( def _instantiate_metrics_classes( metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac", - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> list[AACMetric]: @@ -172,37 +176,50 @@ def _instantiate_metrics_classes( def _get_metric_factory_classes( return_all_scores: bool = True, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "auto", verbose: int = 0, - init_kwds: dict[str, Any] = ..., + init_kwds: Optional[dict[str, Any]] = None, ) -> dict[str, Callable[[], AACMetric]]: - if init_kwds is ...: + if init_kwds is None or init_kwds is ...: init_kwds = {} init_kwds = init_kwds | dict(return_all_scores=return_all_scores) factory = { + "bert_score": lambda: BERTScoreMRefs( + verbose=verbose, + **init_kwds, + ), "bleu": lambda: BLEU( **init_kwds, ), - "bleu_1": lambda: BLEU( - n=1, + "bleu_1": lambda: BLEU1( + **init_kwds, + ), + "bleu_2": lambda: BLEU2( **init_kwds, ), - "bleu_2": lambda: BLEU( - n=2, + "bleu_3": lambda: BLEU3( + **init_kwds, ), - "bleu_3": lambda: BLEU( - n=3, + "bleu_4": lambda: BLEU4( **init_kwds, ), - "bleu_4": lambda: BLEU( - n=4, + "cider_d": lambda: CIDErD( **init_kwds, ), + "fense": lambda: FENSE( + device=device, + verbose=verbose, + **init_kwds, + ), + "fer": lambda: FER( + device=device, + verbose=verbose, + ), "meteor": lambda: METEOR( cache_path=cache_path, java_path=java_path, @@ -212,7 +229,9 @@ def _get_metric_factory_classes( "rouge_l": lambda: ROUGEL( **init_kwds, ), - "cider_d": lambda: CIDErD( + "sbert_sim": lambda: SBERTSim( + device=device, + verbose=verbose, **init_kwds, ), "spice": lambda: SPICE( @@ -229,25 +248,22 @@ def _get_metric_factory_classes( verbose=verbose, **init_kwds, ), - "sbert_sim": lambda: SBERTSim( - device=device, - verbose=verbose, - **init_kwds, - ), - "fluerr": lambda: FluErr( - device=device, - verbose=verbose, - ), - "fense": lambda: FENSE( + "spider_fl": lambda: SPIDErFL( + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, device=device, verbose=verbose, **init_kwds, ), - "spider_fl": lambda: SPIDErFL( + "spider_max": lambda: SPIDErMax( cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, - device=device, + verbose=verbose, + **init_kwds, + ), + "vocab": lambda: Vocab( verbose=verbose, **init_kwds, ), diff --git a/src/aac_metrics/classes/fense.py b/src/aac_metrics/classes/fense.py index e0fc73d..1253b9b 100644 --- a/src/aac_metrics/classes/fense.py +++ b/src/aac_metrics/classes/fense.py @@ -11,7 +11,7 @@ from aac_metrics.classes.base import AACMetric from aac_metrics.functional.fense import fense, _load_models_and_tokenizer -from aac_metrics.functional.fluerr import ERROR_NAMES +from aac_metrics.functional.fer import ERROR_NAMES pylog = logging.getLogger(__name__) @@ -82,11 +82,18 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: ) def extra_repr(self) -> str: - return f"error_threshold={self._error_threshold}, penalty={self._penalty}, device={self._device}, batch_size={self._batch_size}" + hparams = { + "error_threshold": self._error_threshold, + "penalty": self._penalty, + "device": self._device, + "batch_size": self._batch_size, + } + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: - return ("sbert_sim", "fluerr", "fense") + tuple( - f"fluerr.{name}_prob" for name in ERROR_NAMES + return ("sbert_sim", "fer", "fense") + tuple( + f"fer.{name}_prob" for name in ERROR_NAMES ) def reset(self) -> None: diff --git a/src/aac_metrics/classes/fer.py b/src/aac_metrics/classes/fer.py new file mode 100644 index 0000000..4217a4f --- /dev/null +++ b/src/aac_metrics/classes/fer.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Union + +import torch + +from torch import Tensor + +from aac_metrics.classes.base import AACMetric +from aac_metrics.functional.fer import ( + fer, + _load_echecker_and_tokenizer, + ERROR_NAMES, +) + + +pylog = logging.getLogger(__name__) + + +class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): + """Return Fluency Error Rate (FER) detected by a pre-trained BERT model. + + - Paper: https://arxiv.org/abs/2110.04684 + - Original implementation: https://github.com/blmoistawinde/fense + + For more information, see :func:`~aac_metrics.functional.fer.fer`. + """ + + full_state_update = False + higher_is_better = False + is_differentiable = False + + min_value = -1.0 + max_value = 1.0 + + def __init__( + self, + return_all_scores: bool = True, + echecker: str = "echecker_clotho_audiocaps_base", + error_threshold: float = 0.9, + device: Union[str, torch.device, None] = "auto", + batch_size: int = 32, + reset_state: bool = True, + return_probs: bool = False, + verbose: int = 0, + ) -> None: + echecker, echecker_tokenizer = _load_echecker_and_tokenizer(echecker, None, device, reset_state, verbose) # type: ignore + + super().__init__() + self._return_all_scores = return_all_scores + self._echecker = echecker + self._echecker_tokenizer = echecker_tokenizer + self._error_threshold = error_threshold + self._device = device + self._batch_size = batch_size + self._reset_state = reset_state + self._return_probs = return_probs + self._verbose = verbose + + self._candidates = [] + + def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return fer( + candidates=self._candidates, + return_all_scores=self._return_all_scores, + echecker=self._echecker, + echecker_tokenizer=self._echecker_tokenizer, + error_threshold=self._error_threshold, + device=self._device, + batch_size=self._batch_size, + reset_state=self._reset_state, + return_probs=self._return_probs, + verbose=self._verbose, + ) + + def extra_repr(self) -> str: + hparams = {"device": self._device, "batch_size": self._batch_size} + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ + + def get_output_names(self) -> tuple[str, ...]: + return ("fer",) + tuple(f"fer.{name}_prob" for name in ERROR_NAMES) + + def reset(self) -> None: + self._candidates = [] + return super().reset() + + def update( + self, + candidates: list[str], + *args, + **kwargs, + ) -> None: + self._candidates += candidates diff --git a/src/aac_metrics/classes/meteor.py b/src/aac_metrics/classes/meteor.py index 98e517e..fe5e6cf 100644 --- a/src/aac_metrics/classes/meteor.py +++ b/src/aac_metrics/classes/meteor.py @@ -1,7 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Optional, Union +from pathlib import Path +from typing import Iterable, Optional, Union from torch import Tensor @@ -28,11 +29,13 @@ class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor def __init__( self, return_all_scores: bool = True, - cache_path: str = ..., - java_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, java_max_memory: str = "2G", language: str = "en", use_shell: Optional[bool] = None, + params: Optional[Iterable[float]] = None, + weights: Optional[Iterable[float]] = None, verbose: int = 0, ) -> None: super().__init__() @@ -42,6 +45,8 @@ def __init__( self._java_max_memory = java_max_memory self._language = language self._use_shell = use_shell + self._params = params + self._weights = weights self._verbose = verbose self._candidates = [] @@ -57,11 +62,15 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: java_max_memory=self._java_max_memory, language=self._language, use_shell=self._use_shell, + params=self._params, + weights=self._weights, verbose=self._verbose, ) def extra_repr(self) -> str: - return f"java_max_memory={self._java_max_memory}, language={self._language}" + hparams = {"java_max_memory": self._java_max_memory, "language": self._language} + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: return ("meteor",) diff --git a/src/aac_metrics/classes/rouge_l.py b/src/aac_metrics/classes/rouge_l.py index 503c912..5002d89 100644 --- a/src/aac_metrics/classes/rouge_l.py +++ b/src/aac_metrics/classes/rouge_l.py @@ -47,7 +47,9 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: ) def extra_repr(self) -> str: - return f"beta={self._beta}" + hparams = {"beta": self._beta} + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: return ("rouge_l",) diff --git a/src/aac_metrics/classes/sbert_sim.py b/src/aac_metrics/classes/sbert_sim.py index 99eaf52..b32d0d4 100644 --- a/src/aac_metrics/classes/sbert_sim.py +++ b/src/aac_metrics/classes/sbert_sim.py @@ -68,7 +68,9 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: ) def extra_repr(self) -> str: - return f"device={self._device}, batch_size={self._batch_size}" + hparams = {"device": self._device, "batch_size": self._batch_size} + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: return ("sbert_sim",) diff --git a/src/aac_metrics/classes/spice.py b/src/aac_metrics/classes/spice.py index 17f1f9c..2cdff11 100644 --- a/src/aac_metrics/classes/spice.py +++ b/src/aac_metrics/classes/spice.py @@ -3,6 +3,7 @@ import logging +from pathlib import Path from typing import Iterable, Optional, Union from torch import Tensor @@ -32,9 +33,9 @@ class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor] def __init__( self, return_all_scores: bool = True, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -74,7 +75,9 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: ) def extra_repr(self) -> str: - return f"java_max_memory={self._java_max_memory}" + hparams = {"java_max_memory": self._java_max_memory} + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: return ("spice",) diff --git a/src/aac_metrics/classes/spider.py b/src/aac_metrics/classes/spider.py index 0ecb30d..6605420 100644 --- a/src/aac_metrics/classes/spider.py +++ b/src/aac_metrics/classes/spider.py @@ -3,6 +3,7 @@ import logging +from pathlib import Path from typing import Iterable, Optional, Union from torch import Tensor @@ -36,9 +37,9 @@ def __init__( n: int = 4, sigma: float = 6.0, # SPICE args - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -78,9 +79,13 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: ) def extra_repr(self) -> str: - return ( - f"n={self._n}, sigma={self._sigma}, java_max_memory={self._java_max_memory}" - ) + hparams = { + "n": self._n, + "sigma": self._sigma, + "java_max_memory": self._java_max_memory, + } + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: return ("cider_d", "spice", "spider") diff --git a/src/aac_metrics/classes/spider_fl.py b/src/aac_metrics/classes/spider_fl.py index f11f078..070dc7a 100644 --- a/src/aac_metrics/classes/spider_fl.py +++ b/src/aac_metrics/classes/spider_fl.py @@ -3,6 +3,7 @@ import logging +from pathlib import Path from typing import Iterable, Optional, Union import torch @@ -11,7 +12,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.fluerr import ( +from aac_metrics.functional.fer import ( BERTFlatClassifier, _load_echecker_and_tokenizer, ) @@ -41,9 +42,9 @@ def __init__( n: int = 4, sigma: float = 6.0, # SPICE args - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -126,7 +127,7 @@ def extra_repr(self) -> str: return extra def get_output_names(self) -> tuple[str, ...]: - return ("cider_d", "spice", "spider", "spider_fl", "fluerr") + return ("cider_d", "spice", "spider", "spider_fl", "fer") def reset(self) -> None: self._candidates = [] diff --git a/src/aac_metrics/classes/spider_max.py b/src/aac_metrics/classes/spider_max.py index a43c730..4e60640 100644 --- a/src/aac_metrics/classes/spider_max.py +++ b/src/aac_metrics/classes/spider_max.py @@ -3,6 +3,7 @@ import logging +from pathlib import Path from typing import Iterable, Optional, Union from torch import Tensor @@ -37,9 +38,9 @@ def __init__( n: int = 4, sigma: float = 6.0, # SPICE args - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -79,9 +80,13 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: ) def extra_repr(self) -> str: - return ( - f"n={self._n}, sigma={self._sigma}, java_max_memory={self._java_max_memory}" - ) + hparams = { + "n": self._n, + "sigma": self._sigma, + "java_max_memory": self._java_max_memory, + } + repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items()) + return repr_ def get_output_names(self) -> tuple[str, ...]: return ("spider_max",) diff --git a/src/aac_metrics/classes/vocab.py b/src/aac_metrics/classes/vocab.py new file mode 100644 index 0000000..3f3f16c --- /dev/null +++ b/src/aac_metrics/classes/vocab.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import math + +from typing import Callable, Union + +import torch + +from torch import Tensor + +from aac_metrics.classes.base import AACMetric +from aac_metrics.functional.vocab import vocab + + +pylog = logging.getLogger(__name__) + + +class Vocab(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): + """VocabStats class. + + For more information, see :func:`~aac_metrics.functional.vocab.vocab`. + """ + + full_state_update = False + higher_is_better = True + is_differentiable = False + + min_value = 0.0 + max_value = math.inf + + def __init__( + self, + return_all_scores: bool = True, + seed: Union[None, int, torch.Generator] = 1234, + tokenizer: Callable[[str], list[str]] = str.split, + dtype: torch.dtype = torch.float64, + pop_strategy: str = "max", + verbose: int = 0, + ) -> None: + super().__init__() + self._return_all_scores = return_all_scores + self._seed = seed + self._tokenizer = tokenizer + self._dtype = dtype + self._pop_strategy = pop_strategy + self._verbose = verbose + + self._candidates = [] + self._mult_references = [] + + def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return vocab( + candidates=self._candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, + seed=self._seed, + tokenizer=self._tokenizer, + dtype=self._dtype, + pop_strategy=self._pop_strategy, + verbose=self._verbose, + ) + + def get_output_names(self) -> tuple[str, ...]: + return ( + "vocab", + "vocab.mrefs_full", + "vocab.ratio_full", + "vocab.mrefs_avg", + "vocab.mrefs_std", + "vocab.ratio_avg", + ) + + def reset(self) -> None: + self._candidates = [] + self._mult_references = [] + return super().reset() + + def update( + self, + candidates: list[str], + mult_references: Union[list[list[str]], None] = None, + ) -> None: + self._candidates += candidates + + if mult_references is not None: + if self._mult_references is None: + self._mult_references = [] + else: + self._mult_references += mult_references + else: + self._mult_references = None + + if self._mult_references is not None and len(self._candidates) != len( + self._mult_references + ): + raise ValueError( + f"Invalid number of sentences for {self.__class__.__name__}. (found {len(candidates)} candidates and {len(self._mult_references)} references)" + ) diff --git a/src/aac_metrics/download.py b/src/aac_metrics/download.py index 9aea94d..e1e9e38 100644 --- a/src/aac_metrics/download.py +++ b/src/aac_metrics/download.py @@ -5,21 +5,26 @@ import os import os.path as osp import shutil -import sys from argparse import ArgumentParser, Namespace +from pathlib import Path +from typing import Union from zipfile import ZipFile from torch.hub import download_url_to_file +import aac_metrics + +from aac_metrics.classes.bert_score_mrefs import BERTScoreMRefs from aac_metrics.classes.fense import FENSE from aac_metrics.functional.meteor import DNAME_METEOR_CACHE from aac_metrics.functional.spice import ( - FNAME_SPICE_JAR, - DNAME_SPICE_LOCAL_CACHE, DNAME_SPICE_CACHE, + DNAME_SPICE_LOCAL_CACHE, + FNAME_SPICE_JAR, check_spice_install, ) +from aac_metrics.utils.cmdline import _str_to_bool, _setup_logging from aac_metrics.utils.paths import ( _get_cache_path, _get_tmp_path, @@ -74,18 +79,17 @@ "fname": osp.join("SPICE-1.0", "stanford-corenlp-full-2015-12-09.zip"), }, } -_TRUE_VALUES = ("true", "1", "t") -_FALSE_VALUES = ("false", "0", "f") def download_metrics( - cache_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, clean_archives: bool = True, ptb_tokenizer: bool = True, meteor: bool = True, spice: bool = True, fense: bool = True, + bert_score: bool = True, verbose: int = 0, ) -> None: """Download the code needed for SPICE, METEOR, PTB Tokenizer and FENSE. @@ -97,6 +101,7 @@ def download_metrics( :param meteor: If True, downloads the METEOR code in cache directory. defaults to True. :param spice: If True, downloads the SPICE code in cache directory. defaults to True. :param fense: If True, downloads the FENSE models. defaults to True. + :param bert_score: If True, downloads the BERTScore model. defaults to True. :param verbose: The verbose level. defaults to 0. """ if verbose >= 1: @@ -125,6 +130,9 @@ def download_metrics( if fense: _download_fense(verbose) + if bert_score: + _download_bert_score(verbose) + if verbose >= 1: pylog.info("aac-metrics download finished.") @@ -226,7 +234,7 @@ def _download_spice( try: check_spice_install(cache_path) return None - except (FileNotFoundError, NotADirectoryError): + except (FileNotFoundError, NotADirectoryError, PermissionError): pass # Download JAR files for SPICE metric @@ -315,6 +323,15 @@ def _download_fense( _ = FENSE(device="cpu") +def _download_bert_score( + verbose: int = 0, +) -> None: + # Download models files for BERTScore metric + if verbose >= 1: + pylog.info("Downloading BERT model for BERTScore metric...") + _ = BERTScoreMRefs(device="cpu") + + def _get_main_download_args() -> Namespace: parser = ArgumentParser( description="Download models and external code to evaluate captions." @@ -368,33 +385,9 @@ def _get_main_download_args() -> Namespace: return args -def _setup_logging(verbose: int = 1) -> None: - format_ = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s" - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter(format_)) - pkg_logger = logging.getLogger("aac_metrics") - - found = False - for handler in pkg_logger.handlers: - if isinstance(handler, logging.StreamHandler) and handler.stream is sys.stdout: - found = True - break - if not found: - pkg_logger.addHandler(handler) - - if verbose <= 0: - level = logging.WARNING - elif verbose == 1: - level = logging.INFO - else: - level = logging.DEBUG - pkg_logger.setLevel(level) - - def _main_download() -> None: args = _get_main_download_args() - - _setup_logging(args.verbose) + _setup_logging(aac_metrics.__package__, args.verbose) download_metrics( cache_path=args.cache_path, @@ -408,17 +401,5 @@ def _main_download() -> None: ) -def _str_to_bool(s: str) -> bool: - s = str(s).strip().lower() - if s in _TRUE_VALUES: - return True - elif s in _FALSE_VALUES: - return False - else: - raise ValueError( - f"Invalid argument {s=}. (expected one of {_TRUE_VALUES + _FALSE_VALUES})" - ) - - if __name__ == "__main__": _main_download() diff --git a/src/aac_metrics/eval.py b/src/aac_metrics/eval.py index cd19be9..df733b8 100644 --- a/src/aac_metrics/eval.py +++ b/src/aac_metrics/eval.py @@ -10,18 +10,20 @@ import yaml +import aac_metrics + from aac_metrics.functional.evaluate import ( evaluate, - METRICS_SETS, DEFAULT_METRICS_SET_NAME, + METRICS_SETS, ) from aac_metrics.utils.checks import check_metric_inputs, check_java_path +from aac_metrics.utils.cmdline import _str_to_bool, _str_to_opt_str, _setup_logging from aac_metrics.utils.paths import ( get_default_cache_path, get_default_java_path, get_default_tmp_path, ) -from aac_metrics.download import _setup_logging pylog = logging.getLogger(__name__) @@ -125,10 +127,10 @@ def _load_columns(data: list[dict[str, str]], columns: list[str]) -> list[list[s def _get_main_evaluate_args() -> Namespace: - parser = ArgumentParser(description="Evaluate an output file.") + parser = ArgumentParser(description="Evaluate candidates from a file.") parser.add_argument( - "--input_file", + "--input", "-i", type=str, default="", @@ -137,7 +139,7 @@ def _get_main_evaluate_args() -> Namespace: ) parser.add_argument( "--cand_columns", - "-cc", + "-cands", type=str, nargs="+", default=("caption_predicted", "preds", "cands"), @@ -145,7 +147,7 @@ def _get_main_evaluate_args() -> Namespace: ) parser.add_argument( "--mrefs_columns", - "-rc", + "-mrefs", type=str, nargs="+", default=( @@ -159,7 +161,15 @@ def _get_main_evaluate_args() -> Namespace: help="The column names of the candidates in the CSV file. defaults to ('caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5', 'captions').", ) parser.add_argument( - "--metrics_set_name", + "--strict", + "-s", + type=_str_to_bool, + default=False, + help="If True, assume that all columns must be in CSV file. defaults to False.", + ) + parser.add_argument( + "--metrics", + "-m", type=str, default=DEFAULT_METRICS_SET_NAME, choices=tuple(METRICS_SETS.keys()), @@ -167,23 +177,52 @@ def _get_main_evaluate_args() -> Namespace: ) parser.add_argument( "--cache_path", + "-cache", type=str, default=get_default_cache_path(), help=f"Cache directory path. defaults to '{get_default_cache_path()}'.", ) parser.add_argument( "--java_path", + "-java", type=str, default=get_default_java_path(), help=f"Java executable path. defaults to '{get_default_java_path()}'.", ) parser.add_argument( "--tmp_path", + "-tmp", type=str, default=get_default_tmp_path(), help=f"Temporary directory path. defaults to '{get_default_tmp_path()}'.", ) - parser.add_argument("--verbose", type=int, default=0, help="Verbose level.") + parser.add_argument( + "--device", + type=str, + default="auto", + help="Device used for model-based metrics. defaults to 'auto'.", + ) + parser.add_argument( + "--verbose", + "-v", + type=int, + default=1, + help="Verbose level. defaults to 1.", + ) + parser.add_argument( + "--corpus_out", + "-co", + type=_str_to_opt_str, + default=None, + help="Output YAML path containing corpus scores. defaults to None.", + ) + parser.add_argument( + "--sentences_out", + "-so", + type=_str_to_opt_str, + default=None, + help="Output CSV path containing sentences scores. defaults to None.", + ) args = parser.parse_args() return args @@ -191,17 +230,19 @@ def _get_main_evaluate_args() -> Namespace: def _main_eval() -> None: args = _get_main_evaluate_args() - - _setup_logging(args.verbose) + _setup_logging(aac_metrics.__package__, args.verbose) if not check_java_path(args.java_path): raise RuntimeError(f"Invalid Java executable. ({args.java_path})") if args.verbose >= 1: - pylog.info(f"Load file {args.input_file}...") + pylog.info(f"Load file {args.input}...") candidates, mult_references = load_csv_file( - args.input_file, args.cand_columns, args.mrefs_columns + fpath=args.input, + cands_columns=args.cand_columns, + mrefs_columns=args.mrefs_columns, + strict=args.strict, ) check_metric_inputs(candidates, mult_references) @@ -211,20 +252,45 @@ def _main_eval() -> None: f"Found {len(candidates)} candidates, {len(mult_references)} references and [{min(refs_lens)}, {max(refs_lens)}] references per candidate." ) - corpus_scores, _sents_scores = evaluate( + corpus_scores, sents_scores = evaluate( candidates=candidates, mult_references=mult_references, preprocess=True, - metrics=args.metrics_set_name, + metrics=args.metrics, cache_path=args.cache_path, java_path=args.java_path, tmp_path=args.tmp_path, + device=args.device, verbose=args.verbose, ) corpus_scores = {k: v.item() for k, v in corpus_scores.items()} + sents_scores = {k: v.tolist() for k, v in sents_scores.items()} pylog.info(f"Global scores:\n{yaml.dump(corpus_scores, sort_keys=False)}") + if args.corpus_out is not None: + with open(args.corpus_out, "w") as file: + yaml.dump(corpus_scores, file, indent=4) + pylog.info(f"Corpus scores saved in '{args.corpus_out}'.") + + if args.sentences_out is not None: + fieldnames = ["index", "candidate"] + list(sents_scores.keys()) + + n_cands = len(next(iter(sents_scores.values()))) + rows = [ + ( + {"index": i, "candidate": candidates[i]} + | {k: sents_scores[k][i] for k in sents_scores.keys()} + ) + for i in range(n_cands) + ] + with open(args.sentences_out, "w") as file: + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + pylog.info(f"Sentences scores saved in '{args.sentences_out}'.") + if __name__ == "__main__": _main_eval() diff --git a/src/aac_metrics/functional/__init__.py b/src/aac_metrics/functional/__init__.py index 88d8298..04819ea 100644 --- a/src/aac_metrics/functional/__init__.py +++ b/src/aac_metrics/functional/__init__.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from .bleu import bleu +from .bleu import bleu, bleu_1, bleu_2, bleu_3, bleu_4 from .cider_d import cider_d from .evaluate import dcase2023_evaluate, evaluate from .fense import fense -from .fluerr import fluerr +from .fer import fer from .meteor import meteor from .rouge_l import rouge_l from .sbert_sim import sbert_sim @@ -13,15 +13,20 @@ from .spider import spider from .spider_fl import spider_fl from .spider_max import spider_max +from .vocab import vocab __all__ = [ "bleu", + "bleu_1", + "bleu_2", + "bleu_3", + "bleu_4", "cider_d", "dcase2023_evaluate", "evaluate", "fense", - "fluerr", + "fer", "meteor", "rouge_l", "sbert_sim", @@ -29,4 +34,5 @@ "spider", "spider_fl", "spider_max", + "vocab", ] diff --git a/src/aac_metrics/functional/bert_score_mrefs.py b/src/aac_metrics/functional/bert_score_mrefs.py new file mode 100644 index 0000000..d239652 --- /dev/null +++ b/src/aac_metrics/functional/bert_score_mrefs.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Callable, Optional, Union + +import torch + +from torch import nn, Tensor +from torchmetrics.functional.text.bert import bert_score, _DEFAULT_MODEL +from transformers.models.auto.modeling_auto import AutoModel +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers import logging as tfmers_logging + +from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list + + +def bert_score_mrefs( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + model: Union[str, nn.Module] = _DEFAULT_MODEL, + tokenizer: Optional[Callable] = None, + device: Union[str, torch.device, None] = "auto", + batch_size: int = 32, + num_threads: int = 0, + max_length: int = 64, + reset_state: bool = True, + idf: bool = False, + reduction: str = "max", + filter_nan: bool = True, + verbose: int = 0, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + """BERTScore metric which supports multiple references. + + The implementation is based on the bert_score implementation of torchmetrics. + + - Paper: https://arxiv.org/pdf/1904.09675.pdf + + :param candidates: The list of sentences to evaluate. + :param mult_references: The list of list of sentences used as target. + :param return_all_scores: If True, returns a tuple containing the globals and locals scores. + Otherwise returns a scalar tensor containing the main global score. + defaults to True. + :param model: The model name or the instantiated model to use to compute token embeddings. + defaults to "roberta-large". + :param tokenizer: The fast tokenizer used to split sentences into words. + If None, use the tokenizer corresponding to the model argument. + defaults to None. + :param device: The PyTorch device used to run the BERT model. defaults to "auto". + :param batch_size: The batch size used in the model forward. + :param num_threads: A number of threads to use for a dataloader. defaults to 0. + :param max_length: Max length when encoding sentences to tensor ids. defaults to 64. + :param idf: Whether or not using Inverse document frequency to ponderate the BERTScores. defaults to False. + :param reduction: The reduction function to apply between multiple references for each audio. defaults to "mean". + :param filter_nan: If True, replace NaN scores by 0.0. defaults to True. + :param verbose: The verbose level. defaults to 0. + :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. + """ + if isinstance(model, str): + if tokenizer is not None: + raise ValueError( + f"Invalid argument combinaison {model=} with {tokenizer=}." + ) + model, tokenizer = _load_model_and_tokenizer( + model, tokenizer, device, reset_state, verbose + ) + + elif isinstance(model, nn.Module): + if tokenizer is None: + raise ValueError( + f"Invalid argument combinaison {model=} with {tokenizer=}." + ) + + else: + raise ValueError( + f"Invalid argument type {type(model)=}. (expected str or nn.Module)" + ) + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + flat_mrefs, sizes = flat_list(mult_references) + duplicated_cands = duplicate_list(candidates, sizes) + + tfmers_verbosity = tfmers_logging.get_verbosity() + if verbose <= 1: + tfmers_logging.set_verbosity_error() + + sents_scores = bert_score( + duplicated_cands, + flat_mrefs, + model_name_or_path=None, + model=model, # type: ignore + user_tokenizer=tokenizer, + device=device, + batch_size=batch_size, + num_threads=num_threads, + verbose=verbose >= 3, + max_length=max_length, + idf=idf, + ) + if verbose <= 1: + # Restore previous verbosity level + tfmers_logging.set_verbosity(tfmers_verbosity) + + # sents_scores keys: "precision", "recall", "f1" + sents_scores = {k: unflat_list(v, sizes) for k, v in sents_scores.items()} # type: ignore + + if not return_all_scores: + sents_scores = {"f1": sents_scores["f1"]} + + dtype = torch.float32 + + if reduction == "mean": + reduction_fn = torch.mean + elif reduction == "max": + reduction_fn = max_reduce + elif reduction == "min": + reduction_fn = min_reduce + else: + REDUCTIONS = ("mean", "max", "min") + raise ValueError( + f"Invalid argument {reduction=}. (expected one of {REDUCTIONS})" + ) + + if len(sizes) > 0 and all(size == sizes[0] for size in sizes): + sents_scores = { + k: reduction_fn(torch.as_tensor(v, dtype=dtype), dim=1) + for k, v in sents_scores.items() + } + else: + sents_scores = { + k: torch.stack([reduction_fn(torch.as_tensor(vi, dtype=dtype)) for vi in v]) + for k, v in sents_scores.items() + } + + sents_scores = {f"bert_score.{k}": v for k, v in sents_scores.items()} + + if filter_nan: + # avoid NaN that can occur in some cases + sents_scores = { + k: v.masked_fill(v.isnan(), 0.0) for k, v in sents_scores.items() + } + + corpus_scores = {k: v.mean() for k, v in sents_scores.items()} + + if return_all_scores: + return corpus_scores, sents_scores + else: + return corpus_scores["bert_score.f1"] + + +def _load_model_and_tokenizer( + model: Union[str, nn.Module], + tokenizer: Optional[Callable], + device: Union[str, torch.device, None], + reset_state: bool, + verbose: int, +) -> tuple[nn.Module, Optional[Callable]]: + state = torch.random.get_rng_state() + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + if isinstance(model, str): + tfmers_verbosity = tfmers_logging.get_verbosity() + if verbose <= 1: + tfmers_logging.set_verbosity_error() + + # WARNING: tokenizer must be initialized BEFORE model to avoid connection errors + tokenizer = AutoTokenizer.from_pretrained(model) + model = AutoModel.from_pretrained(model) # type: ignore + + if verbose <= 1: + # Restore previous verbosity level + tfmers_logging.set_verbosity(tfmers_verbosity) + + model.eval() # type: ignore + model.to(device=device) # type: ignore + + if reset_state: + torch.random.set_rng_state(state) + + return model, tokenizer # type: ignore + + +def max_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor: + if dim is None: + return x.max() + else: + return x.max(dim=dim).values + + +def min_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor: + if dim is None: + return x.min() + else: + return x.min(dim=dim).values diff --git a/src/aac_metrics/functional/evaluate.py b/src/aac_metrics/functional/evaluate.py index 4d0212f..542ff80 100644 --- a/src/aac_metrics/functional/evaluate.py +++ b/src/aac_metrics/functional/evaluate.py @@ -5,22 +5,26 @@ import time from functools import partial -from typing import Any, Callable, Iterable, Union +from pathlib import Path +from typing import Any, Callable, Iterable, Optional, Union import torch from torch import Tensor -from aac_metrics.functional.bleu import bleu +from aac_metrics.functional.bert_score_mrefs import bert_score_mrefs +from aac_metrics.functional.bleu import bleu, bleu_1, bleu_2, bleu_3, bleu_4 from aac_metrics.functional.cider_d import cider_d from aac_metrics.functional.fense import fense -from aac_metrics.functional.fluerr import fluerr +from aac_metrics.functional.fer import fer from aac_metrics.functional.meteor import meteor from aac_metrics.functional.rouge_l import rouge_l from aac_metrics.functional.sbert_sim import sbert_sim from aac_metrics.functional.spice import spice from aac_metrics.functional.spider import spider from aac_metrics.functional.spider_fl import spider_fl +from aac_metrics.functional.spider_max import spider_max +from aac_metrics.functional.vocab import vocab from aac_metrics.utils.checks import check_metric_inputs from aac_metrics.utils.tokenization import preprocess_mono_sents, preprocess_mult_sents @@ -52,7 +56,7 @@ # DCASE challenge task6a metrics for 2023 "dcase2023": ( "meteor", - "spider_fl", # includes cider_d, spice, spider, fluerr + "spider_fl", # includes cider_d, spice, spider, fer ), # All metrics "all": ( @@ -62,8 +66,10 @@ "bleu_4", "meteor", "rouge_l", - "fense", # includes sbert, fluerr - "spider_fl", # includes cider_d, spice, spider, fluerr + "fense", # includes sbert, fer + "spider_fl", # includes cider_d, spice, spider, fer + "vocab", + "bert_score", ), } DEFAULT_METRICS_SET_NAME = "default" @@ -76,9 +82,9 @@ def evaluate( metrics: Union[ str, Iterable[str], Iterable[Callable[[list, list], tuple]] ] = DEFAULT_METRICS_SET_NAME, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: @@ -122,7 +128,9 @@ def evaluate( outs_sents = {} for i, metric in enumerate(metrics): - if hasattr(metric, "__qualname__"): + if isinstance(metric, partial): + name = metric.func.__qualname__ + elif hasattr(metric, "__qualname__"): name = metric.__qualname__ else: name = metric.__class__.__qualname__ @@ -161,9 +169,9 @@ def dcase2023_evaluate( candidates: list[str], mult_references: list[list[str]], preprocess: bool = True, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: @@ -196,9 +204,9 @@ def dcase2023_evaluate( def _instantiate_metrics_functions( metrics: Union[str, Iterable[str], Iterable[Callable[[list, list], tuple]]] = "all", - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> list[Callable]: @@ -234,41 +242,57 @@ def _instantiate_metrics_functions( def _get_metric_factory_functions( return_all_scores: bool = True, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "auto", verbose: int = 0, - init_kwds: dict[str, Any] = ..., + init_kwds: Optional[dict[str, Any]] = None, ) -> dict[str, Callable[[list[str], list[list[str]]], Any]]: - if init_kwds is ...: + if init_kwds is None or init_kwds is ...: init_kwds = {} init_kwds = init_kwds | dict(return_all_scores=return_all_scores) factory = { + "bert_score": partial( + bert_score_mrefs, + **init_kwds, + ), "bleu": partial( bleu, **init_kwds, ), "bleu_1": partial( - bleu, - n=1, + bleu_1, **init_kwds, ), "bleu_2": partial( - bleu, - n=2, + bleu_2, **init_kwds, ), "bleu_3": partial( - bleu, - n=3, + bleu_3, **init_kwds, ), "bleu_4": partial( - bleu, - n=4, + bleu_4, + **init_kwds, + ), + "cider_d": partial( + cider_d, + **init_kwds, + ), + "fer": partial( + fer, + device=device, + verbose=verbose, + **init_kwds, + ), + "fense": partial( + fense, + device=device, + verbose=verbose, **init_kwds, ), "meteor": partial( @@ -282,8 +306,10 @@ def _get_metric_factory_functions( rouge_l, **init_kwds, ), - "cider_d": partial( - cider_d, + "sbert_sim": partial( + sbert_sim, + device=device, + verbose=verbose, **init_kwds, ), "spice": partial( @@ -302,21 +328,11 @@ def _get_metric_factory_functions( verbose=verbose, **init_kwds, ), - "sbert_sim": partial( - sbert_sim, - device=device, - verbose=verbose, - **init_kwds, - ), - "fluerr": partial( - fluerr, - device=device, - verbose=verbose, - **init_kwds, - ), - "fense": partial( - fense, - device=device, + "spider_max": partial( + spider_max, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, verbose=verbose, **init_kwds, ), @@ -329,5 +345,10 @@ def _get_metric_factory_functions( verbose=verbose, **init_kwds, ), + "vocab": partial( + vocab, + verbose=verbose, + **init_kwds, + ), } return factory diff --git a/src/aac_metrics/functional/fense.py b/src/aac_metrics/functional/fense.py index 778b43b..b5070c5 100644 --- a/src/aac_metrics/functional/fense.py +++ b/src/aac_metrics/functional/fense.py @@ -16,8 +16,8 @@ from torch import Tensor from transformers.models.auto.tokenization_auto import AutoTokenizer -from aac_metrics.functional.fluerr import ( - fluerr, +from aac_metrics.functional.fer import ( + fer, _load_echecker_and_tokenizer, BERTFlatClassifier, ) @@ -74,12 +74,36 @@ def fense( # Init models sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer( - sbert_model, echecker, echecker_tokenizer, device, reset_state, verbose + sbert_model=sbert_model, + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + device=device, + reset_state=reset_state, + verbose=verbose, ) - - sbert_sim_outs: tuple = sbert_sim(candidates, mult_references, True, sbert_model, device, batch_size, reset_state, verbose) # type: ignore - fluerr_outs: tuple = fluerr(candidates, True, echecker, echecker_tokenizer, error_threshold, device, batch_size, reset_state, return_probs, verbose) # type: ignore - fense_outs = _fense_from_outputs(sbert_sim_outs, fluerr_outs, penalty) + sbert_sim_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = sbert_sim( # type: ignore + candidates=candidates, + mult_references=mult_references, + return_all_scores=True, + sbert_model=sbert_model, + device=device, + batch_size=batch_size, + reset_state=reset_state, + verbose=verbose, + ) + fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = fer( # type: ignore + candidates=candidates, + return_all_scores=True, + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + error_threshold=error_threshold, + device=device, + batch_size=batch_size, + reset_state=reset_state, + return_probs=return_probs, + verbose=verbose, + ) + fense_outs = _fense_from_outputs(sbert_sim_outs, fer_outs, penalty) if return_all_scores: return fense_outs @@ -89,30 +113,28 @@ def fense( def _fense_from_outputs( sbert_sim_outs: tuple[dict[str, Tensor], dict[str, Tensor]], - fluerr_outs: tuple[dict[str, Tensor], dict[str, Tensor]], - penalty: float, + fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]], + penalty: float = 0.9, ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: - """Combines SBERT and FluErr outputs. + """Combines SBERT and FER outputs. Based on https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L121 """ sbert_sim_outs_corpus, sbert_sim_outs_sents = sbert_sim_outs - fluerr_outs_corpus, fluerr_outs_sents = fluerr_outs + fer_outs_corpus, fer_outs_sents = fer_outs sbert_sims_scores = sbert_sim_outs_sents["sbert_sim"] - fluerr_scores = fluerr_outs_sents["fluerr"] - fense_scores = sbert_sims_scores * (1.0 - penalty * fluerr_scores) + fer_scores = fer_outs_sents["fer"] + fense_scores = sbert_sims_scores * (1.0 - penalty * fer_scores) fense_score = torch.as_tensor( - fense_scores.cpu().numpy().mean(), + fense_scores.cpu() + .numpy() + .mean(), # note: use numpy mean to keep the same values than the original fense device=fense_scores.device, ) - fense_outs_corpus = ( - sbert_sim_outs_corpus | fluerr_outs_corpus | {"fense": fense_score} - ) - fense_outs_sents = ( - sbert_sim_outs_sents | fluerr_outs_sents | {"fense": fense_scores} - ) + fense_outs_corpus = sbert_sim_outs_corpus | fer_outs_corpus | {"fense": fense_score} + fense_outs_sents = sbert_sim_outs_sents | fer_outs_sents | {"fense": fense_scores} fense_outs = fense_outs_corpus, fense_outs_sents return fense_outs diff --git a/src/aac_metrics/functional/fer.py b/src/aac_metrics/functional/fer.py new file mode 100644 index 0000000..b30fdaa --- /dev/null +++ b/src/aac_metrics/functional/fer.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +BASED ON https://github.com/blmoistawinde/fense/ +""" + +import hashlib +import logging +import os +import re +import requests + +from collections import namedtuple +from os import environ, makedirs +from os.path import exists, expanduser, join +from typing import Mapping, Optional, Union + +import numpy as np +import torch + +from torch import nn, Tensor +from tqdm import tqdm +from transformers import logging as tfmers_logging +from transformers.models.auto.modeling_auto import AutoModel +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + + +# config according to the settings on your computer, this should be default setting of shadowsocks +DEFAULT_PROXIES = { + "http": "socks5h://127.0.0.1:1080", + "https": "socks5h://127.0.0.1:1080", +} +PRETRAIN_ECHECKERS_DICT = { + "echecker_clotho_audiocaps_base": ( + "https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_base.ckpt", + "1a719f090af70614bbdb9f9437530b7e133c48cfa4a58d964de0d47fc974a2fa", + ), + "echecker_clotho_audiocaps_tiny": ( + "https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_tiny.ckpt", + "90ed0ac5033ec497ec66d4f68588053813e085671136dae312097c96c504f673", + ), +} + +RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"]) + +pylog = logging.getLogger(__name__) + + +ERROR_NAMES = ( + "add_tail", + "repeat_event", + "repeat_adv", + "remove_conj", + "remove_verb", + "error", +) + + +class BERTFlatClassifier(nn.Module): + def __init__(self, model_type: str, num_classes: int = 5) -> None: + super().__init__() + self.model_type = model_type + self.num_classes = num_classes + self.encoder = AutoModel.from_pretrained(model_type) + self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob) + self.clf = nn.Linear(self.encoder.config.hidden_size, num_classes) + + @classmethod + def from_pretrained( + cls, + model_name: str = "echecker_clotho_audiocaps_base", + device: Union[str, torch.device, None] = "auto", + use_proxy: bool = False, + proxies: Optional[dict[str, str]] = None, + verbose: int = 0, + ) -> "BERTFlatClassifier": + return __load_pretrain_echecker(model_name, device, use_proxy, proxies, verbose) + + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + token_type_ids: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + outputs = self.encoder(input_ids, attention_mask, token_type_ids) + x = outputs.last_hidden_state[:, 0, :] + x = self.dropout(x) + logits = self.clf(x) + return logits + + +def fer( + candidates: list[str], + return_all_scores: bool = True, + echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", + echecker_tokenizer: Optional[AutoTokenizer] = None, + error_threshold: float = 0.9, + device: Union[str, torch.device, None] = "auto", + batch_size: int = 32, + reset_state: bool = True, + return_probs: bool = False, + verbose: int = 0, +) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: + """Return Fluency Error Rate (FER) detected by a pre-trained BERT model. + + - Paper: https://arxiv.org/abs/2110.04684 + - Original implementation: https://github.com/blmoistawinde/fense + + :param candidates: The list of sentences to evaluate. + :param mult_references: The list of list of sentences used as target. + :param return_all_scores: If True, returns a tuple containing the globals and locals scores. + Otherwise returns a scalar tensor containing the main global score. + defaults to True. + :param echecker: The echecker model used to detect fluency errors. + Can be "echecker_clotho_audiocaps_base", "echecker_clotho_audiocaps_tiny", "none" or None. + defaults to "echecker_clotho_audiocaps_base". + :param echecker_tokenizer: The tokenizer of the echecker model. + If None and echecker is not None, this value will be inferred with `echecker.model_type`. + defaults to None. + :param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9. + :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". + :param batch_size: The batch size of the echecker models. defaults to 32. + :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. + :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to False. + :param verbose: The verbose level. defaults to 0. + :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. + """ + + # Init models + echecker, echecker_tokenizer = _load_echecker_and_tokenizer( + echecker, echecker_tokenizer, device, reset_state, verbose + ) + + # Compute and apply fluency error detection penalty + probs_outs_sents = __detect_error_sents( + echecker, + echecker_tokenizer, # type: ignore + candidates, + batch_size, + device, + ) + fer_scores = (probs_outs_sents["error"] > error_threshold).astype(float) + + fer_scores = torch.from_numpy(fer_scores) + fer_score = fer_scores.mean() + + if return_all_scores: + fer_outs_corpus = { + "fer": fer_score, + } + fer_outs_sents = { + "fer": fer_scores, + } + + if return_probs: + probs_outs_sents = {f"fer.{k}_prob": v for k, v in probs_outs_sents.items()} + probs_outs_sents = { + k: torch.from_numpy(v) for k, v in probs_outs_sents.items() + } + probs_outs_corpus = {k: v.mean() for k, v in probs_outs_sents.items()} + + fer_outs_corpus = probs_outs_corpus | fer_outs_corpus + fer_outs_sents = probs_outs_sents | fer_outs_sents + + fer_outs = fer_outs_corpus, fer_outs_sents + + return fer_outs + else: + return fer_score + + +# - Private functions +def _load_echecker_and_tokenizer( + echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", + echecker_tokenizer: Optional[AutoTokenizer] = None, + device: Union[str, torch.device, None] = "auto", + reset_state: bool = True, + verbose: int = 0, +) -> tuple[BERTFlatClassifier, AutoTokenizer]: + state = torch.random.get_rng_state() + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + if isinstance(echecker, str): + echecker = __load_pretrain_echecker(echecker, device, verbose=verbose) + + if echecker_tokenizer is None: + echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) # type: ignore + + echecker = echecker.eval() + for p in echecker.parameters(): + p.requires_grad_(False) + + if reset_state: + torch.random.set_rng_state(state) + + return echecker, echecker_tokenizer # type: ignore + + +def __detect_error_sents( + echecker: BERTFlatClassifier, + echecker_tokenizer: PreTrainedTokenizerFast, + sents: list[str], + batch_size: int, + device: Union[str, torch.device, None], + max_len: int = 64, +) -> dict[str, np.ndarray]: + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + if len(sents) <= batch_size: + batch = __infer_preprocess( + echecker_tokenizer, + sents, + max_len=max_len, + device=device, + dtype=torch.long, + ) + logits: Tensor = echecker(**batch) + assert not logits.requires_grad + # batch_logits: (bsize, num_classes=6) + # note: fix error in the original fense code: https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L69 + probs = logits.sigmoid().transpose(0, 1).cpu().numpy() + probs_dic: dict[str, np.ndarray] = dict(zip(ERROR_NAMES, probs)) + + else: + dic_lst_probs = {name: [] for name in ERROR_NAMES} + + for i in range(0, len(sents), batch_size): + batch = __infer_preprocess( + echecker_tokenizer, + sents[i : i + batch_size], + max_len=max_len, + device=device, + dtype=torch.long, + ) + + batch_logits: Tensor = echecker(**batch) + assert not batch_logits.requires_grad + # batch_logits: (bsize, num_classes=6) + # classes: add_tail, repeat_event, repeat_adv, remove_conj, remove_verb, error + probs = batch_logits.sigmoid().cpu().numpy() + + for j, name in enumerate(dic_lst_probs.keys()): + dic_lst_probs[name].append(probs[:, j]) + + probs_dic = { + name: np.concatenate(probs) for name, probs in dic_lst_probs.items() + } + + return probs_dic + + +def __check_download_resource( + remote: RemoteFileMetadata, + use_proxy: bool = False, + proxies: Optional[dict[str, str]] = None, +) -> str: + proxies = DEFAULT_PROXIES if use_proxy and proxies is None else proxies + data_home = __get_data_home() + file_path = os.path.join(data_home, remote.filename) + if not os.path.exists(file_path): + # currently don't capture error at this level, assume download success + file_path = __download(remote, data_home, use_proxy, proxies) + return file_path + + +def __infer_preprocess( + tokenizer: PreTrainedTokenizerFast, + texts: list[str], + max_len: int, + device: Union[str, torch.device, None], + dtype: torch.dtype, +) -> Mapping[str, Tensor]: + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + texts = __text_preprocess(texts) # type: ignore + batch = tokenizer(texts, truncation=True, padding="max_length", max_length=max_len) + for k in ("input_ids", "attention_mask", "token_type_ids"): + batch[k] = torch.as_tensor(batch[k], device=device, dtype=dtype) # type: ignore + return batch + + +def __download( + remote: RemoteFileMetadata, + file_path: Optional[str] = None, + use_proxy: bool = False, + proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, +) -> str: + data_home = __get_data_home() + file_path = __fetch_remote(remote, data_home, use_proxy, proxies) + return file_path + + +def __download_with_bar( + url: str, + file_path: str, + proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, +) -> str: + # Streaming, so we can iterate over the response. + response = requests.get(url, stream=True, proxies=proxies) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 KB + progress_bar = tqdm(total=total_size_in_bytes, unit="B", unit_scale=True) + with open(file_path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise Exception("ERROR, something went wrong with the downloading") + return file_path + + +def __fetch_remote( + remote: RemoteFileMetadata, + dirname: Optional[str] = None, + use_proxy: bool = False, + proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, +) -> str: + """Helper function to download a remote dataset into path + Fetch a dataset pointed by remote's url, save into path using remote's + filename and ensure its integrity based on the SHA256 Checksum of the + downloaded file. + Parameters + ---------- + remote : RemoteFileMetadata + Named tuple containing remote dataset meta information: url, filename + and checksum + dirname : string + Directory to save the file to. + Returns + ------- + file_path: string + Full path of the created file. + """ + + file_path = remote.filename if dirname is None else join(dirname, remote.filename) + proxies = None if not use_proxy else proxies + file_path = __download_with_bar(remote.url, file_path, proxies) + checksum = __sha256(file_path) + if remote.checksum != checksum: + raise IOError( + "{} has an SHA256 checksum ({}) " + "differing from expected ({}), " + "file may be corrupted.".format(file_path, checksum, remote.checksum) + ) + return file_path + + +def __get_data_home(data_home: Optional[str] = None) -> str: # type: ignore + """Return the path of the scikit-learn data dir. + This folder is used by some large dataset loaders to avoid downloading the + data several times. + By default the data dir is set to a folder named 'fense_data' in the + user home folder. + Alternatively, it can be set by the 'FENSE_DATA' environment + variable or programmatically by giving an explicit folder path. The '~' + symbol is expanded to the user home folder. + If the folder does not already exist, it is automatically created. + Parameters + ---------- + data_home : str | None + The path to data dir. + """ + if data_home is None: + data_home = environ.get("FENSE_DATA", join(torch.hub.get_dir(), "fense_data")) + + data_home: str + data_home = expanduser(data_home) + if not exists(data_home): + makedirs(data_home) + return data_home + + +def __load_pretrain_echecker( + echecker_model: str, + device: Union[str, torch.device, None] = "auto", + use_proxy: bool = False, + proxies: Optional[dict[str, str]] = None, + verbose: int = 0, +) -> BERTFlatClassifier: + if echecker_model not in PRETRAIN_ECHECKERS_DICT: + raise ValueError( + f"Invalid argument {echecker_model=}. (expected one of {tuple(PRETRAIN_ECHECKERS_DICT.keys())})" + ) + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(device, str): + device = torch.device(device) + + tfmers_logging.set_verbosity_error() # suppress loading warnings + url, checksum = PRETRAIN_ECHECKERS_DICT[echecker_model] + remote = RemoteFileMetadata( + filename=f"{echecker_model}.ckpt", url=url, checksum=checksum + ) + file_path = __check_download_resource(remote, use_proxy, proxies) + + if verbose >= 2: + pylog.debug(f"Loading echecker model from '{file_path}'.") + + model_states = torch.load(file_path) + + if verbose >= 2: + pylog.debug( + f"Loading echecker model type '{model_states['model_type']}' with '{model_states['num_classes']}' classes." + ) + + echecker = BERTFlatClassifier( + model_type=model_states["model_type"], + num_classes=model_states["num_classes"], + ) + echecker.load_state_dict(model_states["state_dict"]) + echecker.eval() + echecker.to(device=device) + return echecker + + +def __sha256(path: str) -> str: + """Calculate the sha256 hash of the file at path.""" + sha256hash = hashlib.sha256() + chunk_size = 8192 + with open(path, "rb") as f: + while True: + buffer = f.read(chunk_size) + if not buffer: + break + sha256hash.update(buffer) + return sha256hash.hexdigest() + + +def __text_preprocess(inp: Union[str, list[str]]) -> Union[str, list[str]]: + if isinstance(inp, str): + return re.sub(r"[^\w\s]", "", inp).lower() + else: + return [re.sub(r"[^\w\s]", "", x).lower() for x in inp] diff --git a/src/aac_metrics/functional/meteor.py b/src/aac_metrics/functional/meteor.py index 8ea5dd4..381c821 100644 --- a/src/aac_metrics/functional/meteor.py +++ b/src/aac_metrics/functional/meteor.py @@ -8,8 +8,9 @@ import platform import subprocess +from pathlib import Path from subprocess import Popen -from typing import Optional, Union +from typing import Iterable, Optional, Union import torch @@ -31,11 +32,13 @@ def meteor( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - cache_path: str = ..., - java_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, java_max_memory: str = "2G", language: str = "en", use_shell: Optional[bool] = None, + params: Optional[Iterable[float]] = None, + weights: Optional[Iterable[float]] = None, verbose: int = 0, ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: """Metric for Evaluation of Translation with Explicit ORdering function. @@ -57,6 +60,12 @@ def meteor( :param use_shell: Optional argument to force use os-specific shell for the java subprogram. If None, it will use shell only on Windows OS. defaults to None. + :param params: List of 4 parameters (alpha, beta gamma delta) used in METEOR metric. + If None, it will use the default of the java program, which is (0.85, 0.2, 0.6, 0.75). + defaults to None. + :param weights: List of 4 parameters (w1, w2, w3, w4) used in METEOR metric. + If None, it will use the default of the java program, which is (1.0 1.0 0.6 0.8). + defaults to None. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ @@ -88,8 +97,11 @@ def meteor( f"Invalid argument {language=}. (expected one of {SUPPORTED_LANGUAGES})" ) + # Note: override localization to avoid errors due to double conversion (https://github.com/Labbeti/aac-metrics/issues/9) meteor_cmd = [ java_path, + "-Duser.country=US", + "-Duser.language=en", "-jar", f"-Xmx{java_max_memory}", meteor_jar_fpath, @@ -101,6 +113,24 @@ def meteor( "-norm", ] + if params is not None: + params = list(params) + if len(params) != 4: + raise ValueError( + f"Invalid argument {params=}. (expected 4 params but found {len(params)})" + ) + params_arg = " ".join(map(str, params)) + meteor_cmd += ["-p", f"{params_arg}"] + + if weights is not None: + weights = list(weights) + if len(weights) != 4: + raise ValueError( + f"Invalid argument {weights=}. (expected 4 params but found {len(weights)})" + ) + weights_arg = " ".join(map(str, weights)) + meteor_cmd += ["-w", f"{weights_arg}"] + if verbose >= 2: pylog.debug( f"Run METEOR java code with: {' '.join(meteor_cmd)} and {use_shell=}" @@ -116,7 +146,7 @@ def meteor( n_candidates = len(candidates) encoded_cands_and_mrefs = [ - encore_cand_and_refs(cand, refs) + _encode_cand_and_refs(cand, refs) for cand, refs in zip(candidates, mult_references) ] del candidates, mult_references @@ -135,16 +165,33 @@ def meteor( assert meteor_process.stdin is not None, "INTERNAL METEOR process error" if verbose >= 3: pylog.debug(f"Write line {eval_line=}.") - meteor_process.stdin.write("{}\n".format(eval_line).encode()) + + process_inputs = "{}\n".format(eval_line).encode() + meteor_process.stdin.write(process_inputs) meteor_process.stdin.flush() # Read scores assert meteor_process.stdout is not None, "INTERNAL METEOR process error" meteor_scores = [] - for _ in range(n_candidates): - meteor_scores_i = float(meteor_process.stdout.readline().strip()) + for i in range(n_candidates): + process_out_i = meteor_process.stdout.readline().strip() + try: + meteor_scores_i = float(process_out_i) + except ValueError as err: + pylog.error( + f"Invalid METEOR stdout. (cannot convert sentence score to float {process_out_i=} with {i=})" + ) + raise err meteor_scores.append(meteor_scores_i) - meteor_score = float(meteor_process.stdout.readline().strip()) + + process_out = meteor_process.stdout.readline().strip() + try: + meteor_score = float(process_out) + except ValueError as err: + pylog.error( + f"Invalid METEOR stdout. (cannot convert global score to float {process_out=})" + ) + raise err meteor_process.stdin.close() meteor_process.kill() @@ -167,7 +214,7 @@ def meteor( return meteor_score -def encore_cand_and_refs(candidate: str, references: list[str]) -> bytes: +def _encode_cand_and_refs(candidate: str, references: list[str]) -> bytes: # SCORE ||| reference 1 words ||| ... ||| reference N words ||| candidate words candidate = candidate.replace("|||", "").replace(" ", " ") score_line = " ||| ".join(("SCORE", " ||| ".join(references), candidate)) diff --git a/src/aac_metrics/functional/mult_cands.py b/src/aac_metrics/functional/mult_cands.py index 98ae01c..851dba0 100644 --- a/src/aac_metrics/functional/mult_cands.py +++ b/src/aac_metrics/functional/mult_cands.py @@ -9,6 +9,9 @@ from torch import Tensor +SELECTIONS = ("max", "min", "mean") + + def mult_cands_metric( metric: Callable, metric_out_name: str, @@ -31,7 +34,6 @@ def mult_cands_metric( :param **kwargs: The keywords arguments given to the metric call. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ - SELECTIONS = ("max", "min", "mean") if selection not in SELECTIONS: raise ValueError( f"Invalid argument {selection=}. (expected one of {SELECTIONS})" @@ -106,4 +108,5 @@ def mult_cands_metric( if return_all_scores: return outs_corpus, outs_sents else: - return outs_corpus[metric_out_name] + out_key = f"{metric_out_name}_{selection}" + return outs_corpus[out_key] diff --git a/src/aac_metrics/functional/spice.py b/src/aac_metrics/functional/spice.py index 359f5e5..c72c1c1 100644 --- a/src/aac_metrics/functional/spice.py +++ b/src/aac_metrics/functional/spice.py @@ -13,6 +13,7 @@ import tempfile import time +from pathlib import Path from subprocess import CalledProcessError from tempfile import NamedTemporaryFile from typing import Any, Iterable, Optional, Union @@ -42,9 +43,9 @@ def spice( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, diff --git a/src/aac_metrics/functional/spider.py b/src/aac_metrics/functional/spider.py index b1f65ea..c2b0a0d 100644 --- a/src/aac_metrics/functional/spider.py +++ b/src/aac_metrics/functional/spider.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from pathlib import Path from typing import Callable, Iterable, Optional, Union from torch import Tensor @@ -19,9 +20,9 @@ def spider( tokenizer: Callable[[str], list[str]] = str.split, return_tfidf: bool = False, # SPICE args - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -64,12 +65,12 @@ def spider( f"Number of candidates and mult_references are different (found {len(candidates)} != {len(mult_references)})." ) - return_all_scores = True + sub_return_all_scores = True cider_d_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = cider_d( # type: ignore candidates=candidates, mult_references=mult_references, - return_all_scores=return_all_scores, + return_all_scores=sub_return_all_scores, n=n, sigma=sigma, tokenizer=tokenizer, @@ -78,7 +79,7 @@ def spider( spice_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = spice( # type: ignore candidates=candidates, mult_references=mult_references, - return_all_scores=return_all_scores, + return_all_scores=sub_return_all_scores, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, diff --git a/src/aac_metrics/functional/spider_fl.py b/src/aac_metrics/functional/spider_fl.py index 441df68..9c71686 100644 --- a/src/aac_metrics/functional/spider_fl.py +++ b/src/aac_metrics/functional/spider_fl.py @@ -2,11 +2,12 @@ # -*- coding: utf-8 -*- """ -BASED ON https://github.com/blmoistawinde/fense/ +Original based on https://github.com/blmoistawinde/fense/ """ import logging +from pathlib import Path from typing import Callable, Iterable, Optional, Union import torch @@ -14,8 +15,8 @@ from torch import Tensor from transformers.models.auto.tokenization_auto import AutoTokenizer -from aac_metrics.functional.fluerr import ( - fluerr, +from aac_metrics.functional.fer import ( + fer, _load_echecker_and_tokenizer, BERTFlatClassifier, ) @@ -35,9 +36,9 @@ def spider_fl( tokenizer: Callable[[str], list[str]] = str.split, return_tfidf: bool = False, # SPICE args - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -96,15 +97,43 @@ def spider_fl( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ - # Init models echecker, echecker_tokenizer = _load_echecker_and_tokenizer( - echecker, echecker_tokenizer, device, reset_state, verbose + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + device=device, + reset_state=reset_state, + verbose=verbose, ) - - spider_outs: tuple = spider(candidates, mult_references, True, n, sigma, tokenizer, return_tfidf, cache_path, java_path, tmp_path, n_threads, java_max_memory, timeout, verbose) # type: ignore - fluerr_outs: tuple = fluerr(candidates, True, echecker, echecker_tokenizer, error_threshold, device, batch_size, reset_state, return_probs, verbose) # type: ignore - spider_fl_outs = _spider_fl_from_outputs(spider_outs, fluerr_outs, penalty) + spider_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = spider( # type: ignore + candidates=candidates, + mult_references=mult_references, + return_all_scores=True, + n=n, + sigma=sigma, + tokenizer=tokenizer, + return_tfidf=return_tfidf, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + n_threads=n_threads, + java_max_memory=java_max_memory, + timeout=timeout, + verbose=verbose, + ) + fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = fer( # type: ignore + candidates=candidates, + return_all_scores=True, + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + error_threshold=error_threshold, + device=device, + batch_size=batch_size, + reset_state=reset_state, + return_probs=return_probs, + verbose=verbose, + ) + spider_fl_outs = _spider_fl_from_outputs(spider_outs, fer_outs, penalty) if return_all_scores: return spider_fl_outs @@ -114,26 +143,26 @@ def spider_fl( def _spider_fl_from_outputs( spider_outs: tuple[dict[str, Tensor], dict[str, Tensor]], - fluerr_outs: tuple[dict[str, Tensor], dict[str, Tensor]], + fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]], penalty: float = 0.9, ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: - """Combines SPIDEr and FluErr outputs. + """Combines SPIDEr and FER outputs. Based on https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48 """ spider_outs_corpus, spider_outs_sents = spider_outs - fluerr_outs_corpus, fluerr_outs_sents = fluerr_outs + fer_outs_corpus, fer_outs_sents = fer_outs spider_scores = spider_outs_sents["spider"] - fluerr_scores = fluerr_outs_sents["fluerr"] - spider_fl_scores = spider_scores * (1.0 - penalty * fluerr_scores) + fer_scores = fer_outs_sents["fer"] + spider_fl_scores = spider_scores * (1.0 - penalty * fer_scores) spider_fl_score = spider_fl_scores.mean() spider_fl_outs_corpus = ( - spider_outs_corpus | fluerr_outs_corpus | {"spider_fl": spider_fl_score} + spider_outs_corpus | fer_outs_corpus | {"spider_fl": spider_fl_score} ) spider_fl_outs_sents = ( - spider_outs_sents | fluerr_outs_sents | {"spider_fl": spider_fl_scores} + spider_outs_sents | fer_outs_sents | {"spider_fl": spider_fl_scores} ) spider_fl_outs = spider_fl_outs_corpus, spider_fl_outs_sents diff --git a/src/aac_metrics/functional/spider_max.py b/src/aac_metrics/functional/spider_max.py index e1f79c7..aefe6d3 100644 --- a/src/aac_metrics/functional/spider_max.py +++ b/src/aac_metrics/functional/spider_max.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from pathlib import Path from typing import Callable, Iterable, Optional, Union import torch @@ -22,9 +23,9 @@ def spider_max( tokenizer: Callable[[str], list[str]] = str.split, return_tfidf: bool = False, # SPICE args - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -66,14 +67,14 @@ def spider_max( :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ return mult_cands_metric( - spider, - "spider", - mult_candidates, - mult_references, - return_all_scores, - return_all_cands_scores, - "max", - torch.mean, + metric=spider, + metric_out_name="spider", + mult_candidates=mult_candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + return_all_cands_scores=return_all_cands_scores, + selection="max", + reduction=torch.mean, # CIDEr args n=n, sigma=sigma, diff --git a/src/aac_metrics/functional/vocab.py b/src/aac_metrics/functional/vocab.py new file mode 100644 index 0000000..760e02a --- /dev/null +++ b/src/aac_metrics/functional/vocab.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging + +from typing import Callable, Union + +import torch + +from torch import Tensor + + +pylog = logging.getLogger(__name__) + + +def vocab( + candidates: list[str], + mult_references: Union[list[list[str]], None], + return_all_scores: bool = True, + seed: Union[None, int, torch.Generator] = 1234, + tokenizer: Callable[[str], list[str]] = str.split, + dtype: torch.dtype = torch.float64, + pop_strategy: str = "max", + verbose: int = 0, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + """Compute vocabulary statistics. + + Returns the candidate corpus vocabulary length, the references vocabulary length, the average vocabulary length for single references, and the vocabulary ratios between candidates and references. + + :param candidates: The list of sentences to evaluate. + :param mult_references: The list of list of sentences used as target. Can also be None. + :param return_all_scores: If True, returns a tuple containing the globals and locals scores. + Otherwise returns a scalar tensor containing the main global score. + defaults to True. + :param seed: Random seed used to compute average vocabulary length for multiple references. defaults to 1234. + :param tokenizer: The function used to split a sentence into tokens. defaults to str.split. + :param dtype: Torch floating point dtype for numerical precision. defaults to torch.float64. + :param pop_strategy: Strategy to compute average reference vocab. defaults to "max". + :param verbose: The verbose level. defaults to 0. + :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. + """ + tok_cands = list(map(tokenizer, candidates)) + del candidates + + vocab_cands_len = _corpus_vocab(tok_cands, dtype) + if not return_all_scores: + return vocab_cands_len + + sents_scores = {} + corpus_scores = { + "vocab.cands": vocab_cands_len, + } + + if mult_references is not None: + if len(mult_references) <= 0: + raise ValueError( + f"Invalid number of references. (found {len(mult_references)} references)" + ) + tok_mrefs = [list(map(tokenizer, refs)) for refs in mult_references] + del mult_references + + vocab_mrefs_len_full = _corpus_vocab( + [ref for refs in tok_mrefs for ref in refs], dtype + ) + vocab_ratio_len_full = vocab_cands_len / vocab_mrefs_len_full + + if isinstance(seed, int): + generator = torch.Generator().manual_seed(seed) + else: + generator = seed + + if pop_strategy == "max": + n_samples = max(len(refs) for refs in tok_mrefs) + elif pop_strategy == "min": + n_samples = min(len(refs) for refs in tok_mrefs) + elif isinstance(pop_strategy, int): + n_samples = pop_strategy + else: + POP_STRATEGIES = ("max", "min") + raise ValueError( + f"Invalid argument {pop_strategy=}. (expected one of {POP_STRATEGIES} or an integer value)" + ) + + if verbose >= 2: + pylog.debug(f"Found {n_samples=} with {pop_strategy=}.") + + vocab_mrefs_lens = torch.empty((n_samples,), dtype=dtype) + + for i in range(n_samples): + indexes = [ + int(torch.randint(0, len(refs), (), generator=generator).item()) + for refs in tok_mrefs + ] + popped_refs = [refs[idx] for idx, refs in zip(indexes, tok_mrefs)] + vocab_mrefs_len_i = _corpus_vocab(popped_refs, dtype) + vocab_mrefs_lens[i] = vocab_mrefs_len_i + + vocab_mrefs_avg = vocab_mrefs_lens.mean() + vocab_len_ratio_avg = vocab_cands_len / vocab_mrefs_avg + + corpus_scores |= { + "vocab.mrefs_full": vocab_mrefs_len_full, + "vocab.ratio_full": vocab_ratio_len_full, + "vocab.mrefs_avg": vocab_mrefs_avg, + "vocab.ratio_avg": vocab_len_ratio_avg, + } + + return corpus_scores, sents_scores + + +def _corpus_vocab(tok_sents: list[list[str]], dtype: torch.dtype) -> Tensor: + corpus_cands_vocab = set(token for sent in tok_sents for token in sent) + vocab_len = torch.as_tensor(len(corpus_cands_vocab), dtype=dtype) + return vocab_len + + +def _sent_vocab( + tok_sents: list[list[str]], + dtype: torch.dtype, +) -> tuple[Tensor, Tensor]: + sents_cands_vocabs = [set(sent) for sent in tok_sents] + sent_cands_vocabs_lens = torch.as_tensor( + list(map(len, sents_cands_vocabs)), dtype=dtype + ) + sent_cands_vocab_len = sent_cands_vocabs_lens.mean() + return sent_cands_vocab_len, sent_cands_vocabs_lens diff --git a/src/aac_metrics/utils/checks.py b/src/aac_metrics/utils/checks.py index 6164266..061d36f 100644 --- a/src/aac_metrics/utils/checks.py +++ b/src/aac_metrics/utils/checks.py @@ -26,11 +26,11 @@ def check_metric_inputs( error_msgs = [] if not is_mono_sents(candidates): - error_msg = "Invalid candidates type. (expected list[str])" + error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})" error_msgs.append(error_msg) if not is_mult_sents(mult_references): - error_msg = "Invalid mult_references type. (expected list[list[str]])" + error_msg = f"Invalid mult_references type. (expected list[list[str]], found {mult_references.__class__.__name__})" error_msgs.append(error_msg) if len(error_msgs) > 0: diff --git a/src/aac_metrics/utils/cmdline.py b/src/aac_metrics/utils/cmdline.py new file mode 100644 index 0000000..c5dc1e2 --- /dev/null +++ b/src/aac_metrics/utils/cmdline.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import sys + +from typing import Optional + + +_TRUE_VALUES = ("true", "1", "t", "yes", "y") +_FALSE_VALUES = ("false", "0", "f", "no", "n") + + +def _str_to_bool(s: str) -> bool: + s = str(s).strip().lower() + if s in _TRUE_VALUES: + return True + elif s in _FALSE_VALUES: + return False + else: + raise ValueError( + f"Invalid argument {s=}. (expected one of {_TRUE_VALUES + _FALSE_VALUES})" + ) + + +def _str_to_opt_str(s: str) -> Optional[str]: + s = str(s) + if s.lower() == "none": + return None + else: + return s + + +def _setup_logging(pkg_name: str, verbose: int, set_format: bool = True) -> None: + handler = logging.StreamHandler(sys.stdout) + if set_format: + format_ = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s" + handler.setFormatter(logging.Formatter(format_)) + + pkg_logger = logging.getLogger(pkg_name) + + found = False + for handler in pkg_logger.handlers: + if isinstance(handler, logging.StreamHandler) and handler.stream is sys.stdout: + found = True + break + if not found: + pkg_logger.addHandler(handler) + + if verbose < 0: + level = logging.ERROR + elif verbose == 0: + level = logging.WARNING + elif verbose == 1: + level = logging.INFO + else: + level = logging.DEBUG + + pkg_logger.setLevel(level) diff --git a/src/aac_metrics/utils/collections.py b/src/aac_metrics/utils/collections.py index cff64d3..349aa39 100644 --- a/src/aac_metrics/utils/collections.py +++ b/src/aac_metrics/utils/collections.py @@ -23,3 +23,27 @@ def unflat_list(flatten_lst: list[T], sizes: list[int]) -> list[list[T]]: lst.append(flatten_lst[start:stop]) start = stop return lst + + +def duplicate_list(lst: list[T], sizes: list[int]) -> list[T]: + """Duplicate elements elements of a list with the corresponding sizes. + + Example 1 + ---------- + >>> lst = ["a", "b", "c", "d", "e"] + >>> sizes = [1, 0, 2, 1, 3] + >>> duplicate_list(lst, sizes) + ... ["a", "c", "c", "d", "e", "e", "e"] + """ + if len(lst) != len(sizes): + raise ValueError( + f"Invalid arguments lengths. (found {len(lst)=} != {len(sizes)=})" + ) + + out_size = sum(sizes) + out: list[T] = [None for _ in range(out_size)] # type: ignore + curidx = 0 + for size, elt in zip(sizes, lst): + out[curidx : curidx + size] = [elt] * size + curidx += size + return out diff --git a/src/aac_metrics/utils/imports.py b/src/aac_metrics/utils/imports.py index cd8bc28..98a93c4 100644 --- a/src/aac_metrics/utils/imports.py +++ b/src/aac_metrics/utils/imports.py @@ -7,7 +7,7 @@ @cache def _package_is_available(package_name: str) -> bool: - """Returns True if package is installed.""" + """Returns True if package is installed in the current python environment.""" try: return find_spec(package_name) is not None except AttributeError: diff --git a/src/aac_metrics/utils/paths.py b/src/aac_metrics/utils/paths.py index 04bf7c4..43af618 100644 --- a/src/aac_metrics/utils/paths.py +++ b/src/aac_metrics/utils/paths.py @@ -6,13 +6,14 @@ import os.path as osp import tempfile -from typing import Optional, Union +from pathlib import Path +from typing import Union, overload pylog = logging.getLogger(__name__) -__DEFAULT_PATHS: dict[str, dict[str, Optional[str]]] = { +__DEFAULT_GLOBALS: dict[str, dict[str, Union[str, None]]] = { "cache": { "user": None, "env": "AAC_METRICS_CACHE_PATH", @@ -39,7 +40,7 @@ def get_default_cache_path() -> str: Else if the environment variable AAC_METRICS_CACHE_PATH has been set to a string, it will return its value. Else it will be equal to "~/.cache" by default. """ - return __get_default_path("cache") + return __get_default_value("cache") def get_default_java_path() -> str: @@ -49,7 +50,7 @@ def get_default_java_path() -> str: Else if the environment variable AAC_METRICS_JAVA_PATH has been set to a string, it will return its value. Else it will be equal to "java" by default. """ - return __get_default_path("java") + return __get_default_value("java") def get_default_tmp_path() -> str: @@ -59,77 +60,87 @@ def get_default_tmp_path() -> str: Else if the environment variable AAC_METRICS_TMP_PATH has been set to a string, it will return its value. Else it will be equal to the value returned by :func:`~tempfile.gettempdir()` by default. """ - return __get_default_path("tmp") + return __get_default_value("tmp") -def set_default_cache_path(cache_path: Optional[str]) -> None: +def set_default_cache_path(cache_path: Union[str, Path, None]) -> None: """Override default cache directory path.""" - __set_default_path("cache", cache_path) + __set_default_value("cache", cache_path) -def set_default_java_path(java_path: Optional[str]) -> None: +def set_default_java_path(java_path: Union[str, Path, None]) -> None: """Override default java executable path.""" - __set_default_path("java", java_path) + __set_default_value("java", java_path) -def set_default_tmp_path(tmp_path: Optional[str]) -> None: +def set_default_tmp_path(tmp_path: Union[str, Path, None]) -> None: """Override default temporary directory path.""" - __set_default_path("tmp", tmp_path) + __set_default_value("tmp", tmp_path) # Private functions -def _get_cache_path(cache_path: Union[str, None] = ...) -> str: - return __get_path("cache", cache_path) +def _get_cache_path(cache_path: Union[str, Path, None] = None) -> str: + return __get_value("cache", cache_path) -def _get_java_path(java_path: Union[str, None] = ...) -> str: - return __get_path("java", java_path) +def _get_java_path(java_path: Union[str, Path, None] = None) -> str: + return __get_value("java", java_path) -def _get_tmp_path(tmp_path: Union[str, None] = ...) -> str: - return __get_path("tmp", tmp_path) +def _get_tmp_path(tmp_path: Union[str, Path, None] = None) -> str: + return __get_value("tmp", tmp_path) -def __get_default_path(path_name: str) -> str: - paths = __DEFAULT_PATHS[path_name] +def __get_default_value(value_name: str) -> str: + values = __DEFAULT_GLOBALS[value_name] - for name, path_or_var in paths.items(): - if path_or_var is None: + for source, value_or_env_varname in values.items(): + if value_or_env_varname is None: continue - if name.startswith("env"): - path = os.getenv(path_or_var, None) + if source.startswith("env"): + path = os.getenv(value_or_env_varname, None) else: - path = path_or_var + path = value_or_env_varname if path is not None: - path = __process_path(path) + path = __process_value(path) return path - pylog.error(f"Paths values: {paths}") + pylog.error(f"Paths values: {values}") raise RuntimeError( - f"Invalid default path for {path_name=}. (all default paths are None)" + f"Invalid default path for {value_name=}. (all default paths are None)" ) -def __set_default_path( - path_name: str, - path: Optional[str], +def __set_default_value( + value_name: str, + value: Union[str, Path, None], ) -> None: - if path is not ... and path is not None: - path = __process_path(path) - __DEFAULT_PATHS[path_name]["user"] = path + value = __process_value(value) + __DEFAULT_GLOBALS[value_name]["user"] = value -def __get_path(path_name: str, path: Union[str, None] = ...) -> str: - if path is ... or path is None: - return __get_default_path(path_name) +def __get_value(value_name: str, value: Union[str, Path, None] = None) -> str: + if value is ... or value is None: + return __get_default_value(value_name) else: - path = __process_path(path) - return path + value = __process_value(value) + return value -def __process_path(path: str) -> str: - path = osp.expanduser(path) - path = osp.expandvars(path) - return path +@overload +def __process_value(value: None) -> None: + ... + + +@overload +def __process_value(value: Union[str, Path]) -> str: + ... + + +def __process_value(value: Union[str, Path, None]) -> Union[str, None]: + value = str(value) + value = osp.expanduser(value) + value = osp.expandvars(value) + return value diff --git a/src/aac_metrics/utils/tokenization.py b/src/aac_metrics/utils/tokenization.py index bc16d14..0236895 100644 --- a/src/aac_metrics/utils/tokenization.py +++ b/src/aac_metrics/utils/tokenization.py @@ -8,7 +8,8 @@ import tempfile import time -from typing import Any, Hashable, Iterable, Optional +from pathlib import Path +from typing import Any, Hashable, Iterable, Optional, Union from aac_metrics.utils.checks import check_java_path, is_mono_sents from aac_metrics.utils.collections import flat_list, unflat_list @@ -24,7 +25,9 @@ # Path to the stanford corenlp jar FNAME_STANFORD_CORENLP_3_4_1_JAR = osp.join( - "aac-metrics", "stanford_nlp", "stanford-corenlp-3.4.1.jar" + "aac-metrics", + "stanford_nlp", + "stanford-corenlp-3.4.1.jar", ) # Punctuations to be removed from the sentences PTB_PUNCTUATIONS = ( @@ -51,9 +54,9 @@ def ptb_tokenize_batch( sentences: Iterable[str], audio_ids: Optional[Iterable[Hashable]] = None, - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, punctuations: Iterable[str] = PTB_PUNCTUATIONS, normalize_apostrophe: bool = False, verbose: int = 0, @@ -61,7 +64,7 @@ def ptb_tokenize_batch( """Use PTB Tokenizer to process sentences. Should be used only with all the sentences of a subset due to slow computation. :param sentences: The sentences to tokenize. - :param audio_ids: The optional audio names. None will use the audio index as name. defaults to None. + :param audio_ids: The optional audio names for the PTB Tokenizer program. None will use the audio index as name. defaults to None. :param cache_path: The path to the external directory containing the JAR program. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. :param tmp_path: The path to a temporary directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. @@ -191,9 +194,9 @@ def ptb_tokenize_batch( def preprocess_mono_sents( sentences: list[str], - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, punctuations: Iterable[str] = PTB_PUNCTUATIONS, normalize_apostrophe: bool = False, verbose: int = 0, @@ -229,9 +232,9 @@ def preprocess_mono_sents( def preprocess_mult_sents( mult_sentences: list[list[str]], - cache_path: str = ..., - java_path: str = ..., - tmp_path: str = ..., + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, punctuations: Iterable[str] = PTB_PUNCTUATIONS, normalize_apostrophe: bool = False, verbose: int = 0, @@ -246,8 +249,6 @@ def preprocess_mult_sents( :param verbose: The verbose level. defaults to 0. :returns: The multiple sentences processed by the tokenizer. """ - - # Flat list flatten_sents, sizes = flat_list(mult_sentences) flatten_sents = preprocess_mono_sents( sentences=flatten_sents, diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 0000000..5323f4e --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import platform +import unittest + +from unittest import TestCase + +from aac_metrics.functional.evaluate import evaluate + + +class TestAll(TestCase): + def test_example_1(self) -> None: + cands: list[str] = ["a man is speaking", "rain falls"] + mrefs: list[list[str]] = [ + [ + "a man speaks.", + "someone speaks.", + "a man is speaking while a bird is chirping in the background", + ], + ["rain is falling hard on a surface"], + ] + + _ = evaluate(cands, mrefs, metrics="all") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bleu_tchmet.py b/tests/test_bleu_tchmet.py index 7d951a5..0858548 100644 --- a/tests/test_bleu_tchmet.py +++ b/tests/test_bleu_tchmet.py @@ -33,7 +33,7 @@ def test_bleu(self) -> None: n = 2 bleu_v1 = BLEU(n=n, return_all_scores=False) - score_v1 = bleu_v1(cands, mrefs) + score_v1: Tensor = bleu_v1(cands, mrefs) # type: ignore bleu_v2 = BLEUScore(n_gram=n, smooth=False) score_v2 = bleu_v2(cands, mrefs) diff --git a/tests/test_compare_cet.py b/tests/test_compare_cet.py index 01e6794..7c838e6 100644 --- a/tests/test_compare_cet.py +++ b/tests/test_compare_cet.py @@ -4,6 +4,7 @@ import importlib import os.path as osp import platform +import shutil import subprocess import sys import unittest @@ -19,6 +20,7 @@ from aac_metrics.utils.paths import ( get_default_tmp_path, ) +from aac_metrics.download import _download_spice class TestCompareCaptionEvaluationTools(TestCase): @@ -37,25 +39,38 @@ def _import_cet_eval_func( Tuple[Dict[str, float], Dict[int, Dict[str, float]]], ]: cet_path = osp.join(osp.dirname(__file__), "caption-evaluation-tools") - use_shell = platform.system() == "Windows" + on_windows = platform.system() == "Windows" - stanford_fpath = osp.join( + cet_cache_path = Path( cet_path, "coco_caption", "pycocoevalcap", + ) + stanford_fpath = cet_cache_path.joinpath( "spice", "lib", "stanford-corenlp-3.6.0.jar", ) if not osp.isfile(stanford_fpath): - command = "bash get_stanford_models.sh" - subprocess.check_call( - command.split(), - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - cwd=osp.join(cet_path, "coco_caption"), - shell=use_shell, - ) + if not on_windows: + # Use CET installation + command = ["bash", "get_stanford_models.sh"] + subprocess.check_call( + command, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + cwd=osp.join(cet_path, "coco_caption"), + shell=on_windows, + ) + else: + # Use aac-metrics SPICE installation, but it requires to move some files after + _download_spice(str(cet_cache_path), clean_archives=True, verbose=2) + shutil.copytree( + cet_cache_path.joinpath("aac-metrics", "spice"), + cet_cache_path.joinpath("spice"), + dirs_exist_ok=True, + ) + shutil.rmtree(cet_cache_path.joinpath("aac-metrics")) # Append cet_path to allow imports of "caption" in eval_metrics.py. sys.path.append(cet_path) @@ -106,7 +121,7 @@ def _test_with_example(self, cands: list[str], mrefs: list[list[str]]) -> None: if platform.system() == "Windows": return None - corpus_scores, _ = evaluate(cands, mrefs, metrics="dcase2020") + corpus_scores, _ = evaluate(cands, mrefs, metrics="dcase2020", preprocess=True) self.assertIsInstance(corpus_scores, dict) diff --git a/tests/test_doc_examples.py b/tests/test_doc_examples.py index a52e640..eb4ee0f 100644 --- a/tests/test_doc_examples.py +++ b/tests/test_doc_examples.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import platform import unittest from unittest import TestCase @@ -18,9 +17,6 @@ class TestReadmeExamples(TestCase): def test_example_1(self) -> None: - if platform.system() == "Windows": - return None - candidates: list[str] = ["a man is speaking", "rain falls"] mult_references: list[list[str]] = [ [ @@ -58,9 +54,6 @@ def test_example_1(self) -> None: ) def test_example_2(self) -> None: - if platform.system() == "Windows": - return None - candidates: list[str] = ["a man is speaking", "rain falls"] mult_references: list[list[str]] = [ [ @@ -73,9 +66,9 @@ def test_example_2(self) -> None: corpus_scores, _ = evaluate(candidates, mult_references, metrics="dcase2023") # print(corpus_scores) - # dict containing the score of each metric: "meteor", "cider_d", "spice", "spider", "spider_fl", "fluerr" + # dict containing the score of each metric: "meteor", "cider_d", "spice", "spider", "spider_fl", "fer" - expected_keys = ["meteor", "cider_d", "spice", "spider", "spider_fl", "fluerr"] + expected_keys = ["meteor", "cider_d", "spice", "spider", "spider_fl", "fer"] self.assertTrue(set(corpus_scores.keys()).issuperset(expected_keys)) def test_example_3(self) -> None: diff --git a/tests/test_pickable.py b/tests/test_pickable.py index 2603b99..4759c3f 100644 --- a/tests/test_pickable.py +++ b/tests/test_pickable.py @@ -18,7 +18,7 @@ def test_pickle_dump(self) -> None: try: pickle.dumps(metric) except pickle.PicklingError: - self.assert_(False, f"Cannot pickle {metric.__class__.__name__}.") + self.assertTrue(False, f"Cannot pickle {metric.__class__.__name__}.") if __name__ == "__main__": diff --git a/tests/test_sdmax.py b/tests/test_sdmax.py new file mode 100644 index 0000000..d729de2 --- /dev/null +++ b/tests/test_sdmax.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import unittest + +from unittest import TestCase + +import torch + +from torch import Tensor + +from aac_metrics.classes.spider import SPIDEr +from aac_metrics.classes.spider_max import SPIDErMax + + +class TestSPIDErMax(TestCase): + # Tests methods + def test_sd_vs_sdmax(self) -> None: + sd = SPIDEr(return_all_scores=False) + sdmax = SPIDErMax(return_all_scores=False) + + cands, mrefs = self._get_example_0() + mcands = [[cand] for cand in cands] + + sd_score = sd(cands, mrefs) + sdmax_score = sdmax(mcands, mrefs) + + assert isinstance(sd_score, Tensor) + assert isinstance(sdmax_score, Tensor) + self.assertTrue( + torch.allclose(sd_score, sdmax_score), f"{sd_score=}, {sdmax_score=}" + ) + + def _get_example_0(self) -> tuple[list[str], list[list[str]]]: + cands = [ + "a man is speaking", + "birds chirping", + "rain is falling in the background", + ] + mrefs = [ + [ + "man speaks", + "man is speaking", + "a man speaks", + "man talks", + "someone is talking", + ], + ["a bird is chirping"] * 5, + ["heavy rain noise"] * 5, + ] + return cands, mrefs + + +if __name__ == "__main__": + unittest.main()