Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function calling for OpenAI backend #573

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions examples/quick_start/openai_example_func_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Usage:
export OPENAI_API_KEY=sk-******
python3 openai_example_func_call.py
"""

import sglang as sgl
import json


def get_current_weather(location: str, unit: str = "fahrenheit"):
"""Get the current weather in a given location"""
if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
Ying1123 marked this conversation as resolved.
Show resolved Hide resolved
elif "san francisco" in location.lower():
return json.dumps(
{"location": "San Francisco", "temperature": "72", "unit": unit}
)
elif "paris" in location.lower():
return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
else:
return json.dumps({"location": location, "temperature": "unknown"})


def get_n_day_weather_forecast(location: str, num_days: int, unit: str = "fahrenheit"):
"""Get an N-day weather forecast in a given location"""
if "tokyo" in location.lower():
return json.dumps(
{
"location": "Tokyo",
"num_days": num_days,
"forecast": "all sunny",
"unit": unit,
}
)
elif "san francisco" in location.lower():
return json.dumps(
{
"location": "San Francisco",
"num_days": num_days,
"forecast": "all foggy",
"unit": unit,
}
)
elif "paris" in location.lower():
return json.dumps(
{
"location": "Paris",
"num_days": num_days,
"forecast": "all rainy",
"unit": unit,
}
)
else:
return json.dumps({"location": location, "forecast": "unknown"})


@sgl.function
def question(s, question1, question2, tools=[]):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question1)
s += sgl.assistant(
sgl.gen(
"answer_1",
max_tokens=256,
tools=tools,
tool_choice="auto",
)
)
s += sgl.user(question2)
s += sgl.assistant(
sgl.gen(
"answer_2",
max_tokens=256,
tools=tools,
tool_choice="auto",
)
)


def single():
state = question.run(
question1="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?",
question2="What's the weather like in San Francisco, Tokyo, Paris, and Beijing in the next 5 days?",
tools=[get_current_weather, get_n_day_weather_forecast],
)

for m in state.messages():
print(m["role"], ":", m["content"])

print("\n-- answer_1 --\n", state["answer_1"])
print("\n-- answer_2 --\n", state["answer_2"])


if __name__ == "__main__":
sgl.set_default_backend(sgl.OpenAI("gpt-4o"))

# Run a single request
print("\n========== single ==========\n")
single()
4 changes: 4 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def gen(
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
regex: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_choice: Optional[str] = "auto",
):
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""

Expand Down Expand Up @@ -103,6 +105,8 @@ def gen(
return_text_in_logprobs,
dtype,
regex,
tools,
tool_choice,
)


Expand Down
124 changes: 122 additions & 2 deletions python/sglang/lang/backend/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import inspect
import json
import logging
import time
import warnings
Expand Down Expand Up @@ -42,6 +44,27 @@ def create_logit_bias_int(tokenizer):
"gpt-3.5-turbo-instruct",
]

PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES = [
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
]

FUNC_CALL_ENABLED_MODEL_NAMES = PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES + [
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0613",
]


@dataclasses.dataclass
class TokenUsage:
Expand Down Expand Up @@ -140,6 +163,7 @@ def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
function_call_messages: List = [],
spec_var_name: str = None,
):
if sampling_params.dtype is None:
Expand All @@ -151,7 +175,7 @@ def generate(
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
prompt = s.messages_ + function_call_messages
else:
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
Expand Down Expand Up @@ -229,6 +253,101 @@ def spec_pattern_match(self, comp):
return False
return True

def build_function_call_messages(
self,
s: StreamExecutor,
tools: List[str],
tool_choice: str,
):
assert (
s.num_api_spec_tokens is None
), "function calling is not supported with api speculative execution"
assert (
self.model_name in FUNC_CALL_ENABLED_MODEL_NAMES
), "function calling is not supported with the provided model"

def convert_param_type(type):
if type == "int" or type == "integer":
return "integer"
if type == "str" or type == "string":
return "string"
return type

def function_to_json_schema(func):
signature = inspect.signature(func)
parameters = signature.parameters
func_schema = {
"type": "function",
"function": {
"name": func.__name__,
"parameters": {
"type": "object",
"properties": {
param.name: {
"type": convert_param_type(
str(param.annotation)
.replace("<class '", "")
.replace("'>", "")
)
}
for param in parameters.values()
},
},
},
}
return func_schema

def build_tool_choice_param():
if tool_choice in ["auto", "required", "none"]:
return tool_choice
else:
assert tool_choice in [
tool.__name__ for tool in tools
], "could not find a candidate function that matches the provided tool choice"
return {"type": "function", "function": {"name": tool_choice}}

tools_to_use = []
if tools:
tools_to_use = [
function_to_json_schema(tool_to_use) for tool_to_use in tools
]
if tool_choice:
tool_choice = build_tool_choice_param()

response = self.client.chat.completions.create(
model=self.model_name,
messages=s.messages_,
tools=tools_to_use,
tool_choice=tool_choice,
**self.spec_kwargs,
)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
# Check if the model wanted to call a function
ret_messages = []
if tool_calls:
# Call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = {}
for tool in tools:
available_functions[tool.__name__] = tool
ret_messages.append(response_message)
# Send the info for each function call and function response to the model
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(**function_args)
ret_messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": str(function_response),
}
)
return ret_messages

def role_end_generate(
self,
s: StreamExecutor,
Expand Down Expand Up @@ -266,6 +385,7 @@ def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
function_call_messages: List = [],
):
if sampling_params.dtype is None:
if self.is_chat_model:
Expand All @@ -274,7 +394,7 @@ def generate_stream(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
prompt = s.messages_ + function_call_messages
else:
prompt = s.text_

Expand Down
19 changes: 18 additions & 1 deletion python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def __init__(
self.cur_role = None
self.cur_role_begin_pos = None

# For function calling
self.function_calls_messages = [] # The messages in the OpenAI API format

# For vision
self.images_ = []
self.cur_images = []
Expand Down Expand Up @@ -487,6 +490,17 @@ def find_stop():
return comp, meta_info

def _execute_gen(self, expr: SglGen):
if expr.tools:
# Previous function calls are not remembered, users are expected to
# provide all candidate functions in the current generate call
self.function_calls_messages = self.backend.build_function_call_messages(
self, expr.tools, expr.tool_choice
)
self._execute_gen_helper(expr)
else:
self._execute_gen_helper(expr)

def _execute_gen_helper(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name

Expand All @@ -495,6 +509,7 @@ def _execute_gen(self, expr: SglGen):
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
function_call_messages=self.function_calls_messages,
)
else:
if self.backend.is_chat_model:
Expand All @@ -521,7 +536,9 @@ def _execute_gen(self, expr: SglGen):
self.num_api_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
self,
sampling_params=sampling_params,
function_call_messages=self.function_calls_messages,
)

self.variables[name] = ""
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ def __init__(
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
regex: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_choice: Optional[str] = "auto",
):
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
super().__init__()
Expand All @@ -429,6 +431,8 @@ def __init__(
dtype=dtype,
regex=regex,
)
self.tools = tools
self.tool_choice = tool_choice

def __repr__(self):
return f"Gen('{self.name}')"
Expand Down
Loading