-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multiple chat templates per model #134
Changes from 4 commits
48f9167
ac91113
852ea26
b916247
cb47ba4
19a5da7
de740f8
62ccb41
c0355dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,13 @@ import Hub | |
import Foundation | ||
import Jinja | ||
|
||
enum TokenizerError : Error { | ||
enum TokenizerError: Error { | ||
case missingConfig | ||
case missingTokenizerClassInConfig | ||
case unsupportedTokenizer(String) | ||
case missingVocab | ||
case malformedVocab | ||
case noChatTemplateSpecified | ||
|
||
case tooLong(String) | ||
} | ||
|
@@ -125,7 +126,8 @@ public protocol Tokenizer { | |
chatTemplate: String?, | ||
addGenerationPrompt: Bool, | ||
truncation: Bool, | ||
maxLength: Int? | ||
maxLength: Int?, | ||
tools: [[String: Any]]? | ||
) throws -> [Int] | ||
} | ||
|
||
|
@@ -176,8 +178,6 @@ public class PreTrainedTokenizer: Tokenizer { | |
private let tokenizerConfig: Config | ||
|
||
private let cleanUpTokenizationSpaces: Bool | ||
|
||
private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" | ||
|
||
required public init(tokenizerConfig: Config, tokenizerData: Config) throws { | ||
var addedTokens: [String : Int] = [:] | ||
|
@@ -323,12 +323,54 @@ public class PreTrainedTokenizer: Tokenizer { | |
|
||
public func applyChatTemplate( | ||
messages: [[String: String]], | ||
chatTemplate: String?, | ||
/// A Jinja template or the name of a template to use for this conversion. | ||
/// It is usually not necessary to pass anything to this argument, | ||
/// as the model's template will be used by default. | ||
chatTemplate: String? = nil, | ||
addGenerationPrompt: Bool = false, | ||
truncation: Bool = false, | ||
maxLength: Int? | ||
maxLength: Int? = nil, | ||
/// A list of tools (callable functions) that will be accessible to the model. If the template does not | ||
/// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, | ||
/// giving the name, description and argument types for the tool. See the | ||
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) | ||
/// for more information. | ||
DePasqualeOrg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tools: [[String: Any]]? = nil | ||
) throws -> [Int] { | ||
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate) | ||
var selectedChatTemplate: String? | ||
if let valueFromConfig = tokenizerConfig.chatTemplate { | ||
if let arrayValue = valueFromConfig.arrayValue { | ||
// If the config specifies a list of chat templates, convert them to a dictionary | ||
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: arrayValue.compactMap { item in | ||
guard let name = item.name?.stringValue, let template = item.template?.stringValue else { | ||
return nil | ||
} | ||
return (name, template) | ||
}) | ||
if let chatTemplateArgument = chatTemplate, let matchingDictEntry = templateDict[chatTemplateArgument] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So if you pass an actual template as an argument to a tokenizer that has multiple ones ( |
||
// Use chat template from config that matches the name specified in the `chatTemplate` argument | ||
selectedChatTemplate = matchingDictEntry | ||
} else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] { | ||
// Use tool use chat template from config | ||
selectedChatTemplate = toolUseTemplate | ||
} else if let defaultChatTemplate = templateDict["default"] { | ||
// Use default chat template from config | ||
selectedChatTemplate = defaultChatTemplate | ||
} | ||
} else if let chatTemplateArgument = chatTemplate { | ||
// Use chat template from argument | ||
selectedChatTemplate = chatTemplateArgument | ||
} else if let stringValue = valueFromConfig.stringValue { | ||
// Use chat template from config | ||
selectedChatTemplate = stringValue | ||
} | ||
} | ||
|
||
guard let selectedChatTemplate else { | ||
throw TokenizerError.noChatTemplateSpecified | ||
} | ||
|
||
let template = try Template(selectedChatTemplate) | ||
var context: [String: Any] = [ | ||
"messages": messages, | ||
"add_generation_prompt": addGenerationPrompt | ||
DePasqualeOrg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// | ||
// ChatTemplateTests.swift | ||
// swift-transformers | ||
// | ||
// Created by Anthony DePasquale on 2/10/24. | ||
// | ||
|
||
import XCTest | ||
import Tokenizers | ||
|
||
class ChatTemplateTests: XCTestCase { | ||
let messages = [[ | ||
"role": "user", | ||
"content": "Describe the Swift programming language.", | ||
]] | ||
|
||
func testTemplateFromConfig() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages) | ||
let encodedTarget = [32010, 4002, 29581, 278, 14156, 8720, 4086, 29889, 32007, 32001] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<|user|>Describe the Swift programming language.<|end|><|assistant|>" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
func testDefaultTemplateFromArrayInConfig() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit") | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages) | ||
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
func testTemplateFromArgument() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") | ||
// Purposely not using the correct template for this model to verify that the template from the config is not being used | ||
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate, addGenerationPrompt: false, truncation: false, maxLength: nil, tools: nil) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit unfortunate we have to add all the arguments in this version of the call, perhaps we could add a new one that supports |
||
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
func testNamedTemplateFromArgument() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit") | ||
// Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use` | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: "default", addGenerationPrompt: false, truncation: false, maxLength: nil, tools: nil) | ||
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
// TODO: Add tests for tool use template | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addGenerationPrompt
is false by default in this method (which mirrors the Python implementation), but it is currently set to true in the overload method. Should we make the behavior consistent in allapplyChatTemplate
methods? Should it be true or false by default?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be
true
for the overloads as I expect that to be the most common use of the API, but I'm not fully sure.