Skip to content

Commit

Permalink
refactor!: draft token counting rewrite
Browse files Browse the repository at this point in the history
I think these changes are too drastic, so I'll probably discard them. Saving to this branch for archival purposes
  • Loading branch information
zhudotexe committed Sep 5, 2024
1 parent 8d6562f commit 3cf8ae4
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 30 deletions.
Empty file added breaking-changes.md
Empty file.
90 changes: 61 additions & 29 deletions kani/engines/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import abc
import inspect
import warnings
from collections.abc import AsyncIterable
from typing import Awaitable

from kani.ai_function import AIFunction
from kani.exceptions import KaniException
from kani.models import ChatMessage


Expand Down Expand Up @@ -60,8 +63,21 @@ class BaseEngine(abc.ABC):
"""The maximum context size supported by this engine's LM."""

@abc.abstractmethod
def message_len(self, message: ChatMessage) -> int:
"""Return the length, in tokens, of the given chat message."""
def prompt_len(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **kwargs
) -> int | Awaitable[int]:
"""
Returns the number of tokens used by the given prompt (i.e., list of messages and functions), or a best estimate
if the exact count is unavailable.
This method MAY be asynchronous. Use :meth:`.Kani.prompt_token_len` for a higher-level interface that handles
asynchrony.
:param messages: The messages in the prompt.
:param functions: The functions included in the prompt.
:param kwargs: Any additional parameters to pass to the underlying token counting implementation
(engine-specific).
"""
raise NotImplementedError

@abc.abstractmethod
Expand All @@ -79,26 +95,6 @@ async def predict(
raise NotImplementedError

# ==== optional interface ====
token_reserve: int = 0
"""Optional: The number of tokens to reserve for internal engine mechanisms (e.g. if an engine has to set up the
model's reply with a delimiting token).
Default: 0
"""

def function_token_reserve(self, functions: list[AIFunction]) -> int:
"""Optional: How many tokens are required to build a prompt to expose the given functions to the model.
Default: If this is not implemented and the user passes in functions, log a warning that the engine does not
support function calling.
"""
if functions:
warnings.warn(
f"The {type(self).__name__} engine is conversational only and does not support function calling.\n"
"Developers: If this warning is incorrect, please implement `function_token_reserve()`."
)
return 0

async def stream(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> AsyncIterable[str | BaseCompletion]:
Expand Down Expand Up @@ -132,6 +128,49 @@ async def close(self):
"""Optional: Clean up any resources the engine might need."""
pass

# ==== deprecated: old-style token counting ====
def message_len(self, message: ChatMessage) -> int:
"""
Returns the estimated number of tokens used by a single given message.
.. note::
The token count returned by this may not exactly reflect the actual token count (e.g., due to prompt
formatting or not having access to the tokenizer). It should, however, be a safe overestimate to use as
an upper bound.
"""
if inspect.iscoroutinefunction(self.prompt_len):
raise KaniException(
"This engine's token counting method is asynchronous only. Please use `await"
" Kani.prompt_token_len([message])` instead."
)
return self.prompt_len(([message]))

token_reserve: int = 0
"""
Optional: The number of tokens to reserve for internal engine mechanisms (e.g. if an engine has to set up the
model's reply with a delimiting token). Default: 0
.. deprecated:: 1.1.0
Use :meth:`prompt_len` instead.
"""

def function_token_reserve(self, functions: list[AIFunction]) -> int:
"""
Optional: How many tokens are required to build a prompt to expose the given functions to the model.
Default: If this is not implemented and the user passes in functions, log a warning that the engine does not
support function calling.
.. deprecated:: 1.1.0
Use :meth:`prompt_len` instead.
"""
if functions:
warnings.warn(
f"The {type(self).__name__} engine is conversational only and does not support function calling.\n"
"Developers: If this warning is incorrect, please implement `function_token_reserve()`."
)
return 0

# ==== internal ====
__ignored_repr_attrs__ = ("token_cache",)

Expand Down Expand Up @@ -163,12 +202,8 @@ def __init__(self, engine: BaseEngine, *args, **kwargs):

# passthrough attrs
self.max_context_size = engine.max_context_size
self.token_reserve = engine.token_reserve

# passthrough methods
def message_len(self, message: ChatMessage) -> int:
return self.engine.message_len(message)

async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> BaseCompletion:
Expand All @@ -180,9 +215,6 @@ async def stream(
async for elem in self.engine.stream(messages, functions, **hyperparams):
yield elem

def function_token_reserve(self, functions: list[AIFunction]) -> int:
return self.engine.function_token_reserve(functions)

async def close(self):
return await self.engine.close()

Expand Down
19 changes: 18 additions & 1 deletion kani/kani.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,25 @@ def always_len(self) -> int:
+ self.desired_response_tokens
)

async def prompt_token_len(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **kwargs):
"""
Returns the number of tokens used by the given prompt (i.e., list of messages and functions).
In general, this is preferred over :meth:`message_token_len`.
"""
if inspect.iscoroutinefunction(self.engine.prompt_len):
return await self.engine.prompt_len(messages, functions, **kwargs)
return self.engine.prompt_len(messages, functions, **kwargs)

def message_token_len(self, message: ChatMessage):
"""Returns the number of tokens used by a given message."""
"""
Returns the estimated number of tokens used by a single given message.
.. note::
The token count returned by this may not exactly reflect the actual token count (e.g., due to prompt
formatting or not having access to the tokenizer). It should, however, be a safe overestimate to use as
an upper bound.
"""
return self.engine.message_len(message)

async def get_model_completion(self, include_functions: bool = True, **kwargs) -> BaseCompletion:
Expand Down

0 comments on commit 3cf8ae4

Please sign in to comment.