From b0a58f75823ed6236b8ed2190270530fbe93a509 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Mon, 18 Sep 2023 11:29:25 -0400 Subject: [PATCH] chore: add warning to function calling with chat_round --- kani/kani.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/kani/kani.py b/kani/kani.py index 60c0198..0e2e846 100644 --- a/kani/kani.py +++ b/kani/kani.py @@ -1,6 +1,7 @@ import asyncio import inspect import logging +import warnings import weakref from typing import AsyncIterable, Callable @@ -124,12 +125,21 @@ async def chat_round(self, query: str, **kwargs) -> ChatMessage: :param kwargs: Additional arguments to pass to the model engine (e.g. hyperparameters). :returns: The model's reply. """ + # warn if the user has functions defined and has not explicitly silenced them in this call + if self.functions and "include_functions" not in kwargs: + warnings.warn( + f"You have defined functions in the body of {type(self).__name__} but chat_round() will not call" + " functions. Use full_round() instead.\nIf this is intentional, use chat_round(...," + " include_functions=False) to silence this warning." + ) + kwargs = {**kwargs, "include_functions": False} + # do the chat round async with self.lock: # add the user's chat input to the state await self.add_to_history(ChatMessage.user(query.strip())) # and get a completion - completion = await self.get_model_completion(include_functions=False, **kwargs) + completion = await self.get_model_completion(**kwargs) message = completion.message await self.add_to_history(message) return message