Skip to content

Commit

Permalink
add call_func api
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyun-Liang committed Jun 29, 2024
1 parent 36ac1cf commit fbe44c2
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 40 deletions.
50 changes: 50 additions & 0 deletions examples/quick_start/openai_example_func_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Usage:
export OPENAI_API_KEY=sk-******
python3 openai_example_chat.py
"""

import sglang as sgl
import json


def get_current_weather(location: str, unit: str = "fahrenheit"):
"""Get the current weather in a given location"""
if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
elif "san francisco" in location.lower():
return json.dumps(
{"location": "San Francisco", "temperature": "72", "unit": unit}
)
elif "paris" in location.lower():
return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
else:
return json.dumps({"location": location, "temperature": "unknown"})


@sgl.function
def multi_turn_question(s, question_1, functions=[]):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.func_call("func_call_1", tools=functions, tool_choice="auto")
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))


def single():
state = multi_turn_question.run(
question_1="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?",
functions=[get_current_weather],
)

for m in state.messages():
print(m["role"], ":", m["content"])

print("\n-- answer_1 --\n", state["answer_1"])


if __name__ == "__main__":
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo"))

# Run a single request
print("\n========== single ==========\n")
single()
2 changes: 2 additions & 0 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
assistant_end,
flush_cache,
function,
func_call,
gen,
gen_int,
gen_string,
Expand Down Expand Up @@ -58,4 +59,5 @@
"user_end",
"assistant_begin",
"assistant_end",
"func_call",
]
9 changes: 9 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SglRoleEnd,
SglSelect,
SglVideo,
SglFuncCall,
)


Expand Down Expand Up @@ -198,3 +199,11 @@ def assistant_begin():

def assistant_end():
return SglRoleEnd("assistant")


def func_call(
name: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_choice: Optional[str] = "auto",
):
return SglFuncCall(name, tools, tool_choice)
59 changes: 25 additions & 34 deletions python/sglang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def generate(
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
for function_call in s.function_calls:
prompt.append(function_call)
else:
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
Expand Down Expand Up @@ -237,8 +239,8 @@ def function_calling(
tools: List[str],
tool_choice: str,
):
# chat model vs. non chat model
# stream vs non stream
assert self.is_chat_model, "function calling only supported on chat model"
# TODO: special handling for chat model vs. non chat model, stream vs non stream
if self.model_name not in [
"gpt-4o",
"gpt-4o-2024-05-13",
Expand Down Expand Up @@ -290,63 +292,52 @@ def function_to_json_schema(func):
return func_schema

tools_to_use = []
if self.tools:
if tools:
tools_to_use = [
function_to_json_schema(tool_to_use) for tool_to_use in self.tools
function_to_json_schema(tool_to_use) for tool_to_use in tools
]
tool_choice = "auto"
if self.tool_choice:
tool_choice = (
self.tool_choice
if self.tool_choice in ["auto", "required", "none"]
else {"type": "function", "function": {"name": self.tool_choice}}
cur_tool_choice = "auto"
if tool_choice:
cur_tool_choice = (
tool_choice
if tool_choice in ["auto", "required", "none"]
else {"type": "function", "function": {"name": tool_choice}}
)

# should we append "Never mention what tools you use." or provide a system prompt input argument
messages = s.text_
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
# TODO: "Never mention what tools you use." or provide a system prompt input argument
response = self.client.chat.completions.create(
model=self.model_name,
prompt=messages,
messages=s.messages_,
tools=tools_to_use,
tool_choice=tool_choice,
tool_choice=cur_tool_choice,
**self.spec_kwargs,
)
response_message = comp.choices[0].message
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
# Check if the model wanted to call a function
ret_messages = []
if tool_calls:
# Call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = ()
for tool_name in tools:
available_functions.append({tool_name: globals()[tool_name]})
messages.append(response_message)
available_functions = {}
for tool in tools:
available_functions[tool.__name__] = tool
ret_messages.append(response_message)
# Send the info for each function call and function response to the model
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(**function_args)
messages.append(
ret_messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
"content": str(function_response),
}
)
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
**self.spec_kwargs,
)
s.text_ += comp
return ret_messages

def role_end_generate(
self,
Expand Down
19 changes: 13 additions & 6 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ def __init__(
self.cur_role = None
self.cur_role_begin_pos = None

# For function calling
self.function_calls = [] # The messages in the OpenAI API format

# For vision
self.images_ = []
self.cur_images = []
Expand Down Expand Up @@ -491,7 +494,6 @@ def find_stop():
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
print("0-", self)

if not self.stream:
if self.num_api_spec_tokens is None:
Expand All @@ -514,14 +516,11 @@ def _execute_gen(self, expr: SglGen):
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)

print("1-", comp)
self.text_ += comp
print("2-", self.text_)

self.variables[name] = comp
self.meta_info[name] = meta_info
self.variable_event[name].set()
print("3-", self.variables[name], name, self.meta_info[name])
else:
assert (
self.num_api_spec_tokens is None
Expand Down Expand Up @@ -562,7 +561,10 @@ def _execute_select(self, expr: SglSelect):
self.text_ += decision

def _execute_func_call(self, expr: SglFuncCall):
self.backend.function_calling(self, expr.tools, expr.tool_choice)
# TODO: Should we clear the previous function call states for the next function call
self.function_calls = self.backend.function_calling(
self, expr.tools, expr.tool_choice
)

def _execute_variable(self, expr: SglVariable):
src_executor = expr.source_stream_executor
Expand Down Expand Up @@ -760,7 +762,12 @@ def text(self):
return self.stream_executor.text()

def messages(self):
return self.stream_executor.messages()
filtered_list = [
item
for item in self.stream_executor.messages()
if item not in self.stream_executor.function_calls
]
return filtered_list

def sync(self):
return self.stream_executor.sync()
Expand Down

0 comments on commit fbe44c2

Please sign in to comment.