Skip to content

Commit

Permalink
Add reverse templating for the tools list as well
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Sep 16, 2024
1 parent 8d8c051 commit 0bda395
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions tests/utils/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import os
import tempfile
import unittest
from typing import List, Optional, Tuple, Union
from pathlib import Path
from typing import List, Optional, Tuple, Union

from transformers import AutoTokenizer
from transformers.testing_utils import require_jinja
Expand Down Expand Up @@ -778,17 +778,26 @@ def test_chat_template_dict_saving(self):
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)


class InverseChatTemplateTest(unittest.TestCase):
def _get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/tiny-gpt2-with-mistral-tool-template")
tokenizer.inverse_template = r"""
{%- set tools = finditer("\[AVAILABLE_TOOLS\] (.*?)\[\/AVAILABLE_TOOLS\]", chat, flags=16) %}
{%- set user_messages = finditer('(?:\[INST\] )(.+?)\[\/INST\]', chat, flags=16, add_tag="user") %}
{%- set asst_messages = finditer('(?:\[\/INST\]|\[\/TOOL_RESULTS\]) (.+?)<\/s>', chat, flags=16, add_tag="assistant") %}
{%- set available_tools = finditer('\[AVAILABLE_TOOLS\] (.*?)\[\/AVAILABLE_TOOLS\]', chat, flags=16, add_tag="available_tools") %}
{%- set tool_calls = finditer('\[TOOL_CALLS\] (.+?\])<\/s>', chat, flags=16, add_tag="tool_calls") %}
{%- set tool_results = finditer('\[TOOL_RESULTS\] (.+?)\[\/TOOL_RESULTS\]', chat, flags=16, add_tag="tool") %}
{%- set combined = sort_by_group_start(user_messages + asst_messages + tool_calls + tool_results, group_idx=1) %}
{{- '{"messages": [' }}
{{- '{' }}
{%- if tools | length > 0 %}
{%- set tools = json_loads(tools[0].group[1]) %}
{{- '"tools": ' }}
{{- tools | tojson }}
{{- ', ' }}
{%- endif %}
{{- '"messages": [' }}
{%- for match in combined %}
{%- if match.tag == 'assistant' or match.tag == 'user' %}
{%- set message_dict = dict(role=match.tag, content=match.group[1]) %}
Expand Down Expand Up @@ -835,19 +844,49 @@ def test_simple_chat_inversion(self):
]
chat_str = tokenizer.apply_chat_template(chat, tokenize=False)
inverted_chat = tokenizer.apply_inverse_template(chat_str)
self.assertEqual(chat, inverted_chat['messages'])
self.assertEqual(chat, inverted_chat["messages"])

def test_chat_inversion_with_tool_calls(self):
tokenizer = self._get_tokenizer()
chat = [
{"role": "user", "content": "user message"},
{"role": "assistant", "tool_calls": [{"type": "function", "id": "9Ae3bDc2F", "function": {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}}]},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"id": "9Ae3bDc2F",
"function": {
"name": "get_current_temperature",
"arguments": {"location": "Paris, France", "unit": "celsius"},
},
}
],
},
{"role": "tool", "content": "22.0", "tool_call_id": "9Ae3bDc2F"},
{"role": "assistant", "content": "assistant message"},
]
chat_str = tokenizer.apply_chat_template(chat, tokenize=False)
inverted_chat = tokenizer.apply_inverse_template(chat_str)
self.assertEqual(chat, inverted_chat['messages'])
self.assertEqual(chat, inverted_chat["messages"])

def test_tool_extraction(self):
# TODO Not done yet!
tokenizer = self._get_tokenizer()
chat = [
{"role": "user", "content": "user message"},
]

def tool_fn(location: str, unit: str):
"""
Get the current temperature
Args:
location: The location to get the temperature from
unit: The unit to return the temperature in
"""
return 22.0

chat_str = tokenizer.apply_chat_template(chat, tools=[tool_fn], tokenize=False)
inverted_chat = tokenizer.apply_inverse_template(chat_str)
self.assertEqual(inverted_chat["messages"], chat)
self.assertEqual(inverted_chat["tools"], [get_json_schema(tool_fn)])

0 comments on commit 0bda395

Please sign in to comment.