Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

estimate token use before sending openai completions #1112

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import logging
import re
import tiktoken
from typing import List, Union

import openai
Expand Down Expand Up @@ -223,6 +224,32 @@ def _call_model(
if hasattr(self, arg) and arg not in self.suppressed_params:
create_args[arg] = getattr(self, arg)

# basic token boundary validation to ensure requests are not rejected for exceeding target context length
generation_max_tokens = create_args.get("max_tokens", None)
if generation_max_tokens is not None:
# count tokens in prompt and ensure max_tokens requested is <= context_len allowed
if (
hasattr(self, "context_len")
and self.context_len is not None
and generation_max_tokens > self.context_len
):
logging.warning(
f"Requested max_tokens {generation_max_tokens} exceeds context length {self.context_len}, reducing requested maximum"
)
generation_max_tokens = self.context_len
prompt_tokens = 0
try:
encoding = tiktoken.encoding_for_model(self.name)
prompt_tokens = len(encoding.encode(prompt))
except KeyError as e:
prompt_tokens = len(prompt.split()) # extra naive fallback
generation_max_tokens -= prompt_tokens
create_args["max_tokens"] = generation_max_tokens
if generation_max_tokens < 1: # allow at least a binary result token
raise garak.exception.GarakException(
"A response cannot be created within the available context length"
)

if self.generator == self.client.completions:
if not isinstance(prompt, str):
msg = (
Expand Down
45 changes: 44 additions & 1 deletion tests/generators/test_openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
# GENERATORS = [
# classname for (classname, active) in _plugins.enumerate_plugins("generators")
# ]
GENERATORS = ["generators.openai.OpenAIGenerator", "generators.nim.NVOpenAIChat", "generators.groq.GroqChat"]
GENERATORS = [
"generators.openai.OpenAIGenerator",
"generators.nim.NVOpenAIChat",
"generators.groq.GroqChat",
]

MODEL_NAME = "gpt-3.5-turbo-instruct"
ENV_VAR = os.path.abspath(
Expand Down Expand Up @@ -98,3 +102,42 @@ def test_openai_multiprocessing(openai_compat_mocks, classname):
with Pool(parallel_attempts) as attempt_pool:
for result in attempt_pool.imap_unordered(generate_in_subprocess, prompts):
assert result is not None


def test_validate_call_model_token_restrictions(openai_compat_mocks):
import lorem
import json
from garak.exception import GarakException

generator = build_test_instance(OpenAICompatible)
mock_url = getattr(generator, "uri", "https://api.openai.com/v1")
with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock:
mock_response = openai_compat_mocks["chat"]
respx_mock.post("chat/completions").mock(
return_value=httpx.Response(
mock_response["code"], json=mock_response["json"]
)
)
generator._call_model("test values")
resp_body = json.loads(respx_mock.routes[0].calls[0].request.content)
assert (
resp_body["max_tokens"] < generator.max_tokens
), "request max_tokens must account for prompt tokens"

test_large_context = ""
while len(test_large_context.split()) < generator.max_tokens:
test_large_context += "\n".join(lorem.paragraph())
large_context_len = len(test_large_context.split())
with pytest.raises(GarakException) as exc_info:
generator._call_model(test_large_context)
assert "cannot be created" in str(
exc_info.value
), "a prompt large then max_tokens must raise exception"

generator.context_len = large_context_len * 2
generator.max_tokens = generator.context_len - (large_context_len / 2)
generator._call_model("test values")
resp_body = json.loads(respx_mock.routes[0].calls[1].request.content)
assert (
resp_body["max_tokens"] < generator.context_len
), "request max_tokens must me less than model context length"