Skip to content

Commit 3a49331

Browse files
committed
allow special tokens when encoding text for token accounting
1 parent 14bd655 commit 3a49331

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/bespokelabs/curator/request_processor/openai_online_request_processor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def num_tokens_consumed_from_request(
674674
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
675675
for key, value in message.items():
676676
try:
677-
num_tokens += len(encoding.encode(str(value)))
677+
num_tokens += len(encoding.encode(str(value), disallowed_special=()))
678678
except TypeError:
679679
logger.warning(
680680
f"Failed to encode value {value} with tiktoken to count tokens. Instead assuming a token for every 4 characters."
@@ -688,11 +688,13 @@ def num_tokens_consumed_from_request(
688688
else:
689689
prompt = api_specific_request_json["prompt"]
690690
if isinstance(prompt, str): # single prompt
691-
prompt_tokens = len(encoding.encode(prompt))
691+
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
692692
num_tokens = prompt_tokens + completion_tokens
693693
return num_tokens
694694
elif isinstance(prompt, list): # multiple prompts
695-
prompt_tokens = sum([len(encoding.encode(p)) for p in prompt])
695+
prompt_tokens = sum(
696+
[len(encoding.encode(p, disallowed_special=())) for p in prompt]
697+
)
696698
num_tokens = prompt_tokens + completion_tokens * len(prompt)
697699
return num_tokens
698700
else:
@@ -703,10 +705,10 @@ def num_tokens_consumed_from_request(
703705
elif api_endpoint == "embeddings":
704706
input = api_specific_request_json["input"]
705707
if isinstance(input, str): # single input
706-
num_tokens = len(encoding.encode(input))
708+
num_tokens = len(encoding.encode(input, disallowed_special=()))
707709
return num_tokens
708710
elif isinstance(input, list): # multiple inputs
709-
num_tokens = sum([len(encoding.encode(i)) for i in input])
711+
num_tokens = sum([len(encoding.encode(i, disallowed_special=())) for i in input])
710712
return num_tokens
711713
else:
712714
raise TypeError(

0 commit comments

Comments
 (0)