Skip to content

Commit

Permalink
PR 2634 CI - Fix the tool_choice format for named choice by adapting …
Browse files Browse the repository at this point in the history
…OpenAIs scheme (#2645)

* add OpenAI like tool_choice for named choice

* add tests

* fix: run linter and bump api docs

* fix: consolidate changes and remove old tool type

* feat: improve, simplify and rename tool choice struct add required support and refactor

* fix: simplify tool choice logic, improve tests, openapi and rust docs

* fix: refactor away prepare_chat_input and improve tool grammar apply control flow

* feat: update docs and add tool choice configuration section

* fix: simplify naming, tool choice default and improve test

* fix: adjust tool choice none logic, add test and small refactors

* fix: add missing snapshot file

* fix: adjust tool choice type in test

* fix: adjust default when json tool choice is

* fix: remove trailing space lint after rebase

* fix: remove mostly mocked unit test

---------

Co-authored-by: Linus Bierhoff <[email protected]>
  • Loading branch information
drbh and linusbierhoff authored Nov 19, 2024
1 parent 2007a94 commit 5489406
Show file tree
Hide file tree
Showing 9 changed files with 442 additions and 242 deletions.
19 changes: 9 additions & 10 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,7 @@
"$ref": "#/components/schemas/ToolChoice"
}
],
"default": "auto",
"nullable": true
},
"tool_prompt": {
Expand Down Expand Up @@ -2294,14 +2295,6 @@
}
},
"ToolChoice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolType"
}
],
"nullable": true
},
"ToolType": {
"oneOf": [
{
"type": "string",
Expand All @@ -2317,6 +2310,13 @@
"none"
]
},
{
"type": "string",
"description": "Means the model must call one or more tools.",
"enum": [
"required"
]
},
{
"type": "object",
"required": [
Expand All @@ -2329,8 +2329,7 @@
}
}
],
"description": "Controls which (if any) tool is called by the model.",
"example": "auto"
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
},
"Url": {
"type": "object",
Expand Down
60 changes: 58 additions & 2 deletions docs/source/basic_tutorials/using_guidance.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,6 @@ print(chat.choices[0].message.tool_calls)

TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.

However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.

```python
from openai import OpenAI

Expand Down Expand Up @@ -362,3 +360,61 @@ print(called)
# },
# }
```

### Tool Choice Configuration

When configuring how the model interacts with tools during a chat completion, there are several options for determining if or how a tool should be called. These options are controlled by the `tool_choice` parameter, which specifies the behavior of the model in relation to tool usage. The following modes are supported:

1. **`auto`**:

- The model decides whether to call a tool or generate a response message based on the user's input.
- If tools are provided, this is the default mode.
- Example usage:
```python
tool_choice="auto"
```

2. **`none`**:

- The model will never call any tools and will only generate a response message.
- If no tools are provided, this is the default mode.
- Example usage:
```python
tool_choice="none"
```

3. **`required`**:

- The model must call one or more tools and will not generate a response message on its own.
- Example usage:
```python
tool_choice="required"
```

4. **Specific Tool Call by Function Name**:
- You can force the model to call a specific tool either by specifying the tool function directly or by using an object definition.
- Two ways to do this:
1. Provide the function name as a string:
```python
tool_choice="get_current_weather"
```
2. Use the function object format:
```python
tool_choice={
"type": "function",
"function": {
"name": "get_current_weather"
}
}
```

These options allow flexibility when integrating tools with the chat completions endpoint. You can configure the model to either rely on tools automatically or force it to follow a predefined behavior, based on the needs of the task at hand.

---

| **Tool Choice Option** | **Description** | **When to Use** |
| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- |
| `auto` | The model decides whether to call a tool or generate a message. This is the default if tools are provided. | Use when you want the model to decide when a tool is necessary. |
| `none` | The model generates a message without calling any tools. This is the default if no tools are provided. | Use when you do not want the model to call any tools. |
| `required` | The model must call one or more tools and will not generate a message on its own. | Use when a tool call is mandatory, and you do not want a regular message generated. |
| Specific Tool Call (`name` or object) | Force the model to call a specific tool either by specifying its name (`tool_choice="get_current_weather"`) or using an object. | Use when you want to restrict the model to calling a particular tool for the response. |
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"choices": [
{
"delta": {
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1729084854,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " deep",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1729262528,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1729084850,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
143 changes: 142 additions & 1 deletion integration-tests/models/test_tools_llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
import requests
import json


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -174,7 +176,7 @@ async def test_flash_llama_grammar_tools_choice(
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
},
}
]
Expand Down Expand Up @@ -327,3 +329,142 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
)
assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="required",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)

count = 0
tool_calls_generated = ""
last_response = None
async for response in responses:
count += 1
assert response.choices[0].delta.content is None
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response

assert count == 29
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>'
)
assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="none",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)

count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None

assert count == 100
print(content_generated)
assert (
content_generated
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
)
assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
flash_llama_grammar_tools, response_snapshot
):
# using `requests` to send the request until the client library supports tool_choice as a function object
responses = requests.post(
f"{flash_llama_grammar_tools.base_url}/v1/chat/completions",
headers=flash_llama_grammar_tools.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
"tools": tools,
"tool_choice": {
"type": "function",
"function": {"name": "get_n_day_weather_forecast"},
},
"seed": 24,
"max_tokens": 100,
"stream": True,
},
stream=True,
)
# iterate over the response in chunks
count = 0
tool_calls_generated = ""
last_response = None
for chunk in responses.iter_content(chunk_size=1024):
if chunk:
count += 1
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
for line in lines:
if line == "[DONE]":
break
response = json.loads(line)
tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
"function"
]["arguments"]
last_response = response

assert count == 39
assert (
tool_calls_generated
== '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>'
)
assert last_response == response_snapshot
Loading

0 comments on commit 5489406

Please sign in to comment.