Skip to content

Commit

Permalink
update function call code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyun-Liang committed Jun 30, 2024
1 parent 071cedf commit edc30d2
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 53 deletions.
15 changes: 8 additions & 7 deletions examples/quick_start/openai_example_func_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,18 @@ def get_current_weather(location: str, unit: str = "fahrenheit"):


@sgl.function
def multi_turn_question(s, question_1, functions=[]):
def question(s, question, tools=[]):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.func_call("func_call_1", tools=functions, tool_choice="auto")
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
s += sgl.user(question)
s += sgl.assistant(
sgl.gen("answer_1", max_tokens=256, tools=tools, tool_choice="auto")
)


def single():
state = multi_turn_question.run(
question_1="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?",
functions=[get_current_weather],
state = question.run(
question="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?",
tools=[get_current_weather],
)

for m in state.messages():
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
assistant_end,
flush_cache,
function,
func_call,
gen,
gen_int,
gen_string,
Expand Down Expand Up @@ -59,5 +58,4 @@
"user_end",
"assistant_begin",
"assistant_end",
"func_call",
]
13 changes: 4 additions & 9 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
SglRoleEnd,
SglSelect,
SglVideo,
SglFuncCall,
)


Expand Down Expand Up @@ -133,6 +132,8 @@ def gen_string(
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
tools: Optional[List[str]] = None,
tool_choice: Optional[str] = "auto",
):
return SglGen(
name,
Expand All @@ -146,6 +147,8 @@ def gen_string(
ignore_eos,
str,
None,
tools,
tool_choice,
)


Expand Down Expand Up @@ -199,11 +202,3 @@ def assistant_begin():

def assistant_end():
return SglRoleEnd("assistant")


def func_call(
name: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_choice: Optional[str] = "auto",
):
return SglFuncCall(name, tools, tool_choice)
23 changes: 11 additions & 12 deletions python/sglang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ def generate(
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
# Open AI model requires function call information to be sent to the model
# along with the prompt.
for function_call in s.function_calls:
prompt.append(function_call)
else:
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
Expand Down Expand Up @@ -241,7 +237,6 @@ def function_calling(
tools: List[str],
tool_choice: str,
):
assert self.is_chat_model, "function calling only supported on chat model"
# TODO: special handling for chat model vs. non chat model, stream vs non stream
if self.model_name not in [
"gpt-4o",
Expand Down Expand Up @@ -293,25 +288,29 @@ def function_to_json_schema(func):
}
return func_schema

def build_tool_choice_param():
if tool_choice in ["auto", "required", "none"]:
return tool_choice
else:
assert (
tool_choice in tools
), "could not find a candidate function that matches the provided tool choice"
return {"type": "function", "function": {"name": tool_choice}}

tools_to_use = []
if tools:
tools_to_use = [
function_to_json_schema(tool_to_use) for tool_to_use in tools
]
cur_tool_choice = "auto"
if tool_choice:
cur_tool_choice = (
tool_choice
if tool_choice in ["auto", "required", "none"]
else {"type": "function", "function": {"name": tool_choice}}
)
tool_choice = build_tool_choice_param()

# TODO: "Never mention what tools you use." or provide a system prompt input argument
response = self.client.chat.completions.create(
model=self.model_name,
messages=s.messages_,
tools=tools_to_use,
tool_choice=cur_tool_choice,
tool_choice=tool_choice,
**self.spec_kwargs,
)
response_message = response.choices[0].message
Expand Down
24 changes: 14 additions & 10 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
SglFunction,
SglGen,
SglImage,
SglFuncCall,
SglRoleBegin,
SglRoleEnd,
SglSelect,
Expand Down Expand Up @@ -371,8 +370,6 @@ def _execute(self, other):
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglFuncCall):
self._execute_func_call(other)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
Expand Down Expand Up @@ -491,10 +488,23 @@ def find_stop():

return comp, meta_info

def _build_and_append_func_call_messages(self, expr: SglGen):
# TODO: Should we clear the previous function call states for the next function call
if self.backend.is_chat_model:
self.function_calls = self.backend.function_calling(
self, expr.tools, expr.tool_choice
)
for function_call in self.function_calls:
self.messages_.append(function_call)
# TODO: handle text appending

def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name

if expr.tools:
self._build_and_append_func_call_messages(expr)

if not self.stream:
if self.num_api_spec_tokens is None:
comp, meta_info = self.backend.generate(
Expand Down Expand Up @@ -560,12 +570,6 @@ def _execute_select(self, expr: SglSelect):
self.variable_event[name].set()
self.text_ += decision

def _execute_func_call(self, expr: SglFuncCall):
# TODO: Should we clear the previous function call states for the next function call
self.function_calls = self.backend.function_calling(
self, expr.tools, expr.tool_choice
)

def _execute_variable(self, expr: SglVariable):
src_executor = expr.source_stream_executor
value = src_executor.get_var(expr.name)
Expand Down Expand Up @@ -762,7 +766,7 @@ def text(self):
return self.stream_executor.text()

def messages(self):
# We do not want to expose tool use information to users in the final response,
# We do not want to expose tool use information to users in the final response,
# so removing the auxillary information from final messages.
filtered_list = [
item
Expand Down
17 changes: 4 additions & 13 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ def __init__(
ignore_eos,
dtype,
regex,
tools,
tool_choice,
):
super().__init__()
self.name = name
Expand All @@ -381,6 +383,8 @@ def __init__(
dtype=dtype,
regex=regex,
)
self.tools = tools
self.tool_choice = tool_choice

def __repr__(self):
return f"Gen('{self.name}')"
Expand Down Expand Up @@ -424,19 +428,6 @@ def __repr__(self):
return f"Select({self.name}, choices={self.choices})"


class SglFuncCall(SglExpr):
def __init__(self, name, tools, tool_choice):
super().__init__()
self.name = name
self.tools = tools
self.tool_choice = tool_choice

def __repr__(self):
return (
f"FuncCall({self.name}, tools={self.tools}, tool_choice={self.tool_choice})"
)


class SglFork(SglExpr):
def __init__(self, number, position_ids_offset=None):
super().__init__()
Expand Down

0 comments on commit edc30d2

Please sign in to comment.