diff --git a/src/promptflow-core/promptflow/core/_flow.py b/src/promptflow-core/promptflow/core/_flow.py index 8b1e4bed3cb..c4027fc2656 100644 --- a/src/promptflow-core/promptflow/core/_flow.py +++ b/src/promptflow-core/promptflow/core/_flow.py @@ -25,6 +25,7 @@ handle_openai_error, handle_openai_error_async, num_tokens_from_messages, + num_tokens_for_tools, prepare_open_ai_request_params, resolve_references, send_request_to_llm, @@ -484,11 +485,13 @@ def render(self, *args, **kwargs): # For chat mode, the message generated is list type. Convert to string type and return to user. return str(prompt) - def estimate_token_count(self, *args, **kwargs): + def estimate_token_count(self, model: Union[str, None] = None, *args, **kwargs): """Estimate the token count. LLM will reject the request when prompt token + response token is greater than the maximum number of tokens supported by the model. It is used to estimate the number of total tokens in this round of chat. + :param model: optional OpenAI model to use to determine tokenizer. + Use it when the Azure OpenaI deployment name is not a valid OpenAI model name. :param args: positional arguments are not supported. :param kwargs: prompty inputs with key word arguments. :return: Estimate total token count @@ -508,9 +511,13 @@ def estimate_token_count(self, *args, **kwargs): raise UserErrorException("Max_token needs to be integer.") elif response_max_token <= 1: raise UserErrorException(f"{response_max_token} is less than the minimum of max_tokens.") - total_token = num_tokens_from_messages(prompt, self._model._model, working_dir=self.path.parent) + ( + + total_token = num_tokens_from_messages(prompt, model or self._model._model, working_dir=self.path.parent) + ( response_max_token or 0 ) + + if self._model.parameters.get("tools", None): + total_token += num_tokens_for_tools(self._model.parameters["tools"], model or self._model._model) return total_token diff --git a/src/promptflow-core/promptflow/core/_prompty_utils.py b/src/promptflow-core/promptflow/core/_prompty_utils.py index 47ef499ce1e..9d069b8b809 100644 --- a/src/promptflow-core/promptflow/core/_prompty_utils.py +++ b/src/promptflow-core/promptflow/core/_prompty_utils.py @@ -326,6 +326,12 @@ def num_tokens_from_messages(messages, model, working_dir): for key, value in message.items(): if isinstance(value, str): num_tokens += len(encoding.encode(value)) + if key == "tool_calls" and isinstance(value, list): + for tool in value: + num_tokens += len(encoding.encode(tool.get("name"))) + num_tokens += tokens_per_name + if tool.get("arguments", None): + num_tokens += len(encoding.encode(tool["arguments"])) elif isinstance(value, list): for item in value: value_type = item.get("type", "text") @@ -349,6 +355,91 @@ def num_tokens_from_messages(messages, model, working_dir): return num_tokens +def num_tokens_for_tools(functions: list[dict], model: str) -> int: + """Calculate the number of tokens used by functions for a given model. + + Args: + functions (list[dict[str, Any]]): A list of function specifications. + model (str): The model name to use for token encoding. + + Returns: + int: The number of tokens used by function definition. + """ + # Initialize function settings to 0 + func_init = 0 + prop_init = 0 + prop_key = 0 + enum_init = 0 + enum_item = 0 + func_end = 0 + + if model in ["gpt-4o", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4o-2024-08-06"]: + # Set function settings for the above models + func_init = 7 + prop_init = 3 + prop_key = 3 + enum_init = -3 + enum_item = 3 + func_end = 12 + elif model in [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + ]: + # Set function settings for the above models + func_init = 10 + prop_init = 3 + prop_key = 3 + enum_init = -3 + enum_item = 3 + func_end = 12 + else: + msg = f"num_tokens_for_tools() is not implemented for model {model}." + raise NotImplementedError(msg) + + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using o200k_base encoding.") + encoding = tiktoken.get_encoding("o200k_base") + + func_token_count = 0 + if len(functions) > 0: + for f in functions: + func_token_count += func_init # Add tokens for start of each function + function = f["function"] + f_name = function["name"] + f_desc = function["description"] + if f_desc.endswith("."): + f_desc = f_desc[:-1] + line = f_name + ":" + f_desc + func_token_count += len(encoding.encode(line)) # Add tokens for set name and description + if len(function["parameters"]["properties"]) > 0: + func_token_count += prop_init # Add tokens for start of each property + for key in list(function["parameters"]["properties"].keys()): + func_token_count += prop_key # Add tokens for each set property + p_name = key + p_type = function["parameters"]["properties"][key]["type"] + p_desc = function["parameters"]["properties"][key]["description"] + if "enum" in function["parameters"]["properties"][key]: + func_token_count += enum_init # Add tokens if property has enum list + for item in function["parameters"]["properties"][key]["enum"]: + func_token_count += enum_item + func_token_count += len(encoding.encode(item)) + if p_desc.endswith("."): + p_desc = p_desc[:-1] + line = f"{p_name}:{p_type}:{p_desc}" + func_token_count += len(encoding.encode(line)) + func_token_count += func_end + + return func_token_count + + def _get_image_obj(image_str, working_dir): mime_pattern_with_content = MIME_PATTERN.pattern[:-1] + r":\s*(.*)$" match = re.match(mime_pattern_with_content, image_str)