Skip to content

Commit

Permalink
Merge pull request #345 from parea-ai/fix-answer-context
Browse files Browse the repository at this point in the history
fix: provide default for non openai models
  • Loading branch information
joschkabraun committed Jan 26, 2024
2 parents 3f38f32 + d0fa8ba commit 4e48c36
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 25 deletions.
13 changes: 3 additions & 10 deletions parea/evals/general/answer_matches_target_recall.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
from collections import Counter

from parea.evals.utils import get_tokens
from parea.schemas.log import Log


def answer_matches_target_recall(log: Log) -> float:
"""Prop. of tokens in target/reference answer which are also in model generation."""
target = log.target
output = log.output

provider = log.configuration.provider
model = log.configuration.model

if provider == "openai":
import tiktoken

encoding = tiktoken.encoding_for_model(model)
target_tokens = encoding.encode(target)
output_tokens = encoding.encode(output)
else:
raise NotImplementedError
target_tokens = get_tokens(model, target)
output_tokens = get_tokens(model, output)

if len(target_tokens) == 0:
return 1.0
Expand Down
17 changes: 3 additions & 14 deletions parea/evals/rag/answer_context_faithfulness_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import Counter

from parea.evals.utils import get_tokens
from parea.schemas.log import Log


Expand All @@ -22,22 +23,10 @@ def answer_context_faithfulness_precision_factory(context_field: Optional[str] =
def answer_context_faithfulness_precision(log: Log) -> float:
"""Prop. of tokens in model generation which are also present in the retrieved context."""
context = log.inputs[context_field]

provider = log.configuration.provider
model = log.configuration.model

if provider == "openai":
import tiktoken

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
context_tokens = encoding.encode(context)
output_tokens = encoding.encode(log.output)
else:
raise NotImplementedError
context_tokens = get_tokens(model, context)
output_tokens = get_tokens(model, log.output)

if len(context_tokens) == 0:
return 1.0
Expand Down
19 changes: 19 additions & 0 deletions parea/evals/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Callable, Union

import json
import re
import string
import warnings

import openai
import pysbd
import tiktoken
from attrs import define
from openai import __version__ as openai_version

Expand Down Expand Up @@ -118,3 +121,19 @@ def run_evals_in_thread_and_log(trace_id: str, log: Log, eval_funcs: list[EvalFu
kwargs={"trace_id": trace_id, "log": log, "eval_funcs": eval_funcs, "verbose": verbose},
)
logging_thread.start()


def get_tokens(model: str, text: str) -> Union[str, list[int]]:
if not text:
return []
try:
encoding = tiktoken.encoding_for_model(model)
tokens = encoding.encode(text)
except KeyError:
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
text = text.lower()
text = "".join(char for char in text if char not in set(string.punctuation))
text = re.sub(regex, " ", text)
text = " ".join(text.split())
tokens = text.split()
return tokens
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.35"
version = "0.2.36"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit 4e48c36

Please sign in to comment.