diff --git a/examples/quick_start/openai_example_func_call.py b/examples/quick_start/openai_example_func_call.py new file mode 100644 index 0000000000..1a6f0f08a7 --- /dev/null +++ b/examples/quick_start/openai_example_func_call.py @@ -0,0 +1,100 @@ +""" +Usage: +export OPENAI_API_KEY=sk-****** +python3 openai_example_func_call.py +""" + +import sglang as sgl +import json + + +def get_current_weather(location: str, unit: str = "fahrenheit"): + """Get the current weather in a given location""" + if "tokyo" in location.lower(): + return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit}) + elif "san francisco" in location.lower(): + return json.dumps( + {"location": "San Francisco", "temperature": "72", "unit": unit} + ) + elif "paris" in location.lower(): + return json.dumps({"location": "Paris", "temperature": "22", "unit": unit}) + else: + return json.dumps({"location": location, "temperature": "unknown"}) + + +def get_n_day_weather_forecast(location: str, num_days: int, unit: str = "fahrenheit"): + """Get an N-day weather forecast in a given location""" + if "tokyo" in location.lower(): + return json.dumps( + { + "location": "Tokyo", + "num_days": num_days, + "forecast": "all sunny", + "unit": unit, + } + ) + elif "san francisco" in location.lower(): + return json.dumps( + { + "location": "San Francisco", + "num_days": num_days, + "forecast": "all foggy", + "unit": unit, + } + ) + elif "paris" in location.lower(): + return json.dumps( + { + "location": "Paris", + "num_days": num_days, + "forecast": "all rainy", + "unit": unit, + } + ) + else: + return json.dumps({"location": location, "forecast": "unknown"}) + + +@sgl.function +def question(s, question1, question2, tools=[]): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question1) + s += sgl.assistant( + sgl.gen( + "answer_1", + max_tokens=256, + tools=tools, + tool_choice="auto", + ) + ) + s += sgl.user(question2) + s += sgl.assistant( + sgl.gen( + "answer_2", + max_tokens=256, + tools=tools, + tool_choice="auto", + ) + ) + + +def single(): + state = question.run( + question1="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?", + question2="What's the weather like in San Francisco, Tokyo, Paris, and Beijing in the next 5 days?", + tools=[get_current_weather, get_n_day_weather_forecast], + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + print("\n-- answer_2 --\n", state["answer_2"]) + + +if __name__ == "__main__": + sgl.set_default_backend(sgl.OpenAI("gpt-4o")) + + # Run a single request + print("\n========== single ==========\n") + single() diff --git a/python/sglang/api.py b/python/sglang/api.py index c32943963b..3e52977574 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -74,6 +74,8 @@ def gen( dtype: Optional[type] = None, choices: Optional[List[str]] = None, regex: Optional[str] = None, + tools: Optional[List[str]] = None, + tool_choice: Optional[str] = "auto", ): """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md""" @@ -103,6 +105,8 @@ def gen( return_text_in_logprobs, dtype, regex, + tools, + tool_choice, ) diff --git a/python/sglang/lang/backend/openai.py b/python/sglang/lang/backend/openai.py index 06701cb371..19186e0af5 100644 --- a/python/sglang/lang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -1,4 +1,6 @@ import dataclasses +import inspect +import json import logging import time import warnings @@ -42,6 +44,27 @@ def create_logit_bias_int(tokenizer): "gpt-3.5-turbo-instruct", ] +PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES = [ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", +] + +FUNC_CALL_ENABLED_MODEL_NAMES = PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES + [ + "gpt-4", + "gpt-4-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0613", +] + @dataclasses.dataclass class TokenUsage: @@ -140,6 +163,7 @@ def generate( self, s: StreamExecutor, sampling_params: SglSamplingParams, + function_call_messages: List = [], spec_var_name: str = None, ): if sampling_params.dtype is None: @@ -151,7 +175,7 @@ def generate( "For OpenAI chat models, sgl.gen must be right after sgl.assistant. " "Example of adding api speculative execution: @function(num_api_spec_tokens=128)." ) - prompt = s.messages_ + prompt = s.messages_ + function_call_messages else: return self._prepare_spec_execution( sampling_params, s.num_api_spec_tokens, spec_var_name @@ -229,6 +253,101 @@ def spec_pattern_match(self, comp): return False return True + def build_function_call_messages( + self, + s: StreamExecutor, + tools: List[str], + tool_choice: str, + ): + assert ( + s.num_api_spec_tokens is None + ), "function calling is not supported with api speculative execution" + assert ( + self.model_name in FUNC_CALL_ENABLED_MODEL_NAMES + ), "function calling is not supported with the provided model" + + def convert_param_type(type): + if type == "int" or type == "integer": + return "integer" + if type == "str" or type == "string": + return "string" + return type + + def function_to_json_schema(func): + signature = inspect.signature(func) + parameters = signature.parameters + func_schema = { + "type": "function", + "function": { + "name": func.__name__, + "parameters": { + "type": "object", + "properties": { + param.name: { + "type": convert_param_type( + str(param.annotation) + .replace("", "") + ) + } + for param in parameters.values() + }, + }, + }, + } + return func_schema + + def build_tool_choice_param(): + if tool_choice in ["auto", "required", "none"]: + return tool_choice + else: + assert tool_choice in [ + tool.__name__ for tool in tools + ], "could not find a candidate function that matches the provided tool choice" + return {"type": "function", "function": {"name": tool_choice}} + + tools_to_use = [] + if tools: + tools_to_use = [ + function_to_json_schema(tool_to_use) for tool_to_use in tools + ] + if tool_choice: + tool_choice = build_tool_choice_param() + + response = self.client.chat.completions.create( + model=self.model_name, + messages=s.messages_, + tools=tools_to_use, + tool_choice=tool_choice, + **self.spec_kwargs, + ) + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + # Check if the model wanted to call a function + ret_messages = [] + if tool_calls: + # Call the function + # Note: the JSON response may not always be valid; be sure to handle errors + available_functions = {} + for tool in tools: + available_functions[tool.__name__] = tool + ret_messages.append(response_message) + # Send the info for each function call and function response to the model + for tool_call in tool_calls: + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call(**function_args) + ret_messages.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": str(function_response), + } + ) + return ret_messages + def role_end_generate( self, s: StreamExecutor, @@ -266,6 +385,7 @@ def generate_stream( self, s: StreamExecutor, sampling_params: SglSamplingParams, + function_call_messages: List = [], ): if sampling_params.dtype is None: if self.is_chat_model: @@ -274,7 +394,7 @@ def generate_stream( "This use case is not supported. " "For OpenAI chat models, sgl.gen must be right after sgl.assistant" ) - prompt = s.messages_ + prompt = s.messages_ + function_call_messages else: prompt = s.text_ diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 573b9970ba..af3633236d 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -202,6 +202,9 @@ def __init__( self.cur_role = None self.cur_role_begin_pos = None + # For function calling + self.function_calls_messages = [] # The messages in the OpenAI API format + # For vision self.images_ = [] self.cur_images = [] @@ -487,6 +490,17 @@ def find_stop(): return comp, meta_info def _execute_gen(self, expr: SglGen): + if expr.tools: + # Previous function calls are not remembered, users are expected to + # provide all candidate functions in the current generate call + self.function_calls_messages = self.backend.build_function_call_messages( + self, expr.tools, expr.tool_choice + ) + self._execute_gen_helper(expr) + else: + self._execute_gen_helper(expr) + + def _execute_gen_helper(self, expr: SglGen): sampling_params = self._resolve_sampling_params(expr.sampling_params) name = expr.name @@ -495,6 +509,7 @@ def _execute_gen(self, expr: SglGen): comp, meta_info = self.backend.generate( self, sampling_params=sampling_params, + function_call_messages=self.function_calls_messages, ) else: if self.backend.is_chat_model: @@ -521,7 +536,9 @@ def _execute_gen(self, expr: SglGen): self.num_api_spec_tokens is None ), "stream is not supported with api speculative execution" generator = self.backend.generate_stream( - self, sampling_params=sampling_params + self, + sampling_params=sampling_params, + function_call_messages=self.function_calls_messages, ) self.variables[name] = "" diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index e5d5e837aa..e79494094a 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -409,6 +409,8 @@ def __init__( return_text_in_logprobs: Optional[bool] = None, dtype: Optional[type] = None, regex: Optional[str] = None, + tools: Optional[List[str]] = None, + tool_choice: Optional[str] = "auto", ): """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md""" super().__init__() @@ -429,6 +431,8 @@ def __init__( dtype=dtype, regex=regex, ) + self.tools = tools + self.tool_choice = tool_choice def __repr__(self): return f"Gen('{self.name}')"