diff --git a/parea/evals/general/answer_matches_target_recall.py b/parea/evals/general/answer_matches_target_recall.py index 86c22140..c53e167b 100644 --- a/parea/evals/general/answer_matches_target_recall.py +++ b/parea/evals/general/answer_matches_target_recall.py @@ -1,5 +1,6 @@ from collections import Counter +from parea.evals.utils import get_tokens from parea.schemas.log import Log @@ -7,18 +8,10 @@ def answer_matches_target_recall(log: Log) -> float: """Prop. of tokens in target/reference answer which are also in model generation.""" target = log.target output = log.output - - provider = log.configuration.provider model = log.configuration.model - if provider == "openai": - import tiktoken - - encoding = tiktoken.encoding_for_model(model) - target_tokens = encoding.encode(target) - output_tokens = encoding.encode(output) - else: - raise NotImplementedError + target_tokens = get_tokens(model, target) + output_tokens = get_tokens(model, output) if len(target_tokens) == 0: return 1.0 diff --git a/parea/evals/rag/answer_context_faithfulness_precision.py b/parea/evals/rag/answer_context_faithfulness_precision.py index bfc3c4be..dc1c0868 100644 --- a/parea/evals/rag/answer_context_faithfulness_precision.py +++ b/parea/evals/rag/answer_context_faithfulness_precision.py @@ -2,6 +2,7 @@ from collections import Counter +from parea.evals.utils import get_tokens from parea.schemas.log import Log @@ -22,22 +23,10 @@ def answer_context_faithfulness_precision_factory(context_field: Optional[str] = def answer_context_faithfulness_precision(log: Log) -> float: """Prop. of tokens in model generation which are also present in the retrieved context.""" context = log.inputs[context_field] - - provider = log.configuration.provider model = log.configuration.model - if provider == "openai": - import tiktoken - - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - print("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - context_tokens = encoding.encode(context) - output_tokens = encoding.encode(log.output) - else: - raise NotImplementedError + context_tokens = get_tokens(model, context) + output_tokens = get_tokens(model, log.output) if len(context_tokens) == 0: return 1.0 diff --git a/parea/evals/utils.py b/parea/evals/utils.py index 814f9c09..51be3903 100644 --- a/parea/evals/utils.py +++ b/parea/evals/utils.py @@ -1,10 +1,13 @@ from typing import Callable, Union import json +import re +import string import warnings import openai import pysbd +import tiktoken from attrs import define from openai import __version__ as openai_version @@ -118,3 +121,19 @@ def run_evals_in_thread_and_log(trace_id: str, log: Log, eval_funcs: list[EvalFu kwargs={"trace_id": trace_id, "log": log, "eval_funcs": eval_funcs, "verbose": verbose}, ) logging_thread.start() + + +def get_tokens(model: str, text: str) -> Union[str, list[int]]: + if not text: + return [] + try: + encoding = tiktoken.encoding_for_model(model) + tokens = encoding.encode(text) + except KeyError: + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + text = text.lower() + text = "".join(char for char in text if char not in set(string.punctuation)) + text = re.sub(regex, " ", text) + text = " ".join(text.split()) + tokens = text.split() + return tokens diff --git a/pyproject.toml b/pyproject.toml index 402f9fe6..96699798 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.35" +version = "0.2.36" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]