Skip to content

Commit

Permalink
fix: fixed initial_prompt_tokens calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Nov 8, 2024
1 parent b727ff1 commit c68499f
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 13 deletions.
8 changes: 6 additions & 2 deletions aidial_adapter_openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@


def plain_text_truncate_prompt(
messages: List[dict], max_prompt_tokens: int, tokenizer: PlainTextTokenizer
request: dict,
messages: List[dict],
max_prompt_tokens: int,
tokenizer: PlainTextTokenizer,
) -> Tuple[List[dict], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.tokenize_request_message,
is_system_message=lambda message: message["role"] == "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST,
initial_prompt_tokens=tokenizer.tokenize_request(request, []),
)


Expand Down Expand Up @@ -58,6 +61,7 @@ async def gpt_chat_completion(

request["messages"], discarded_messages, estimated_prompt_tokens = (
plain_text_truncate_prompt(
request=request,
messages=cast(List[dict], request["messages"]),
max_prompt_tokens=max_prompt_tokens,
tokenizer=tokenizer,
Expand Down
6 changes: 3 additions & 3 deletions aidial_adapter_openai/gpt4_multi_modal/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ async def predict_non_stream(


def multi_modal_truncate_prompt(
request: dict,
messages: List[MultiModalMessage],
max_prompt_tokens: int,
initial_prompt_tokens: int,
tokenizer: MultiModalTokenizer,
) -> Tuple[List[MultiModalMessage], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
Expand All @@ -130,7 +130,7 @@ def multi_modal_truncate_prompt(
is_system_message=lambda message: message.raw_message["role"]
== "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=initial_prompt_tokens,
initial_prompt_tokens=tokenizer.tokenize_request(request, []),
)


Expand Down Expand Up @@ -218,9 +218,9 @@ async def chat_completion(
if max_prompt_tokens is not None:
multi_modal_messages, discarded_messages, estimated_prompt_tokens = (
multi_modal_truncate_prompt(
request=request,
messages=multi_modal_messages,
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.TOKENS_PER_REQUEST,
tokenizer=tokenizer,
)
)
Expand Down
4 changes: 1 addition & 3 deletions aidial_adapter_openai/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def _tokens_per_request_message_name(self) -> int:
return 1

def tokenize_request(
self,
original_request: dict,
messages: List[MessageType],
self, original_request: dict, messages: List[MessageType]
) -> int:
tokens = self.TOKENS_PER_REQUEST

Expand Down
4 changes: 2 additions & 2 deletions tests/test_discard_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_discarded_messages_without_error(
):
tokenizer = PlainTextTokenizer(model="gpt-4")
truncated_messages, discarded_messages, _used_tokens = (
plain_text_truncate_prompt(messages, max_prompt_tokens, tokenizer)
plain_text_truncate_prompt({}, messages, max_prompt_tokens, tokenizer)
)
assert (truncated_messages, discarded_messages) == response

Expand All @@ -157,5 +157,5 @@ def test_discarded_messages_with_error(
tokenizer = PlainTextTokenizer(model="gpt-4")

with pytest.raises(DialException) as e_info:
plain_text_truncate_prompt(messages, max_prompt_tokens, tokenizer)
plain_text_truncate_prompt({}, messages, max_prompt_tokens, tokenizer)
assert e_info.value.message == error_message
6 changes: 3 additions & 3 deletions tests/test_multimodal_truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_multimodal_truncate_with_system_and_last_user_error():
),
]
with pytest.raises(TruncatePromptSystemAndLastUserError):
multi_modal_truncate_prompt(transformations, 15, 0, tokenizer)
multi_modal_truncate_prompt({}, transformations, 15, tokenizer)


def test_multimodal_truncate_with_system_error():
Expand All @@ -57,7 +57,7 @@ def test_multimodal_truncate_with_system_error():
),
]
with pytest.raises(TruncatePromptSystemError):
multi_modal_truncate_prompt(transformations, 9, 3, tokenizer)
multi_modal_truncate_prompt({}, transformations, 9, tokenizer)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -194,9 +194,9 @@ def test_multimodal_truncate(
):
truncated, actual_discarded_messages, actual_used_tokens = (
multi_modal_truncate_prompt(
{},
transformations,
max_prompt_tokens,
initial_prompt_tokens=3,
tokenizer=tokenizer,
)
)
Expand Down

0 comments on commit c68499f

Please sign in to comment.