From 3a4933143ed22b70323682600c86fa67a83a725d Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Sat, 30 Nov 2024 08:57:03 -0800 Subject: [PATCH] allow special tokens when encoding text for token accounting --- .../openai_online_request_processor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 132ae01a..cef4f277 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -674,7 +674,7 @@ def num_tokens_consumed_from_request( num_tokens += 4 # every message follows {role/name}\n{content}\n for key, value in message.items(): try: - num_tokens += len(encoding.encode(str(value))) + num_tokens += len(encoding.encode(str(value), disallowed_special=())) except TypeError: logger.warning( f"Failed to encode value {value} with tiktoken to count tokens. Instead assuming a token for every 4 characters." @@ -688,11 +688,13 @@ def num_tokens_consumed_from_request( else: prompt = api_specific_request_json["prompt"] if isinstance(prompt, str): # single prompt - prompt_tokens = len(encoding.encode(prompt)) + prompt_tokens = len(encoding.encode(prompt, disallowed_special=())) num_tokens = prompt_tokens + completion_tokens return num_tokens elif isinstance(prompt, list): # multiple prompts - prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) + prompt_tokens = sum( + [len(encoding.encode(p, disallowed_special=())) for p in prompt] + ) num_tokens = prompt_tokens + completion_tokens * len(prompt) return num_tokens else: @@ -703,10 +705,10 @@ def num_tokens_consumed_from_request( elif api_endpoint == "embeddings": input = api_specific_request_json["input"] if isinstance(input, str): # single input - num_tokens = len(encoding.encode(input)) + num_tokens = len(encoding.encode(input, disallowed_special=())) return num_tokens elif isinstance(input, list): # multiple inputs - num_tokens = sum([len(encoding.encode(i)) for i in input]) + num_tokens = sum([len(encoding.encode(i, disallowed_special=())) for i in input]) return num_tokens else: raise TypeError(