-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multiple chat templates per model #134
Changes from all commits
48f9167
ac91113
852ea26
b916247
cb47ba4
19a5da7
de740f8
62ccb41
c0355dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,13 @@ import Hub | |
import Foundation | ||
import Jinja | ||
|
||
enum TokenizerError : Error { | ||
enum TokenizerError: Error { | ||
case missingConfig | ||
case missingTokenizerClassInConfig | ||
case unsupportedTokenizer(String) | ||
case missingVocab | ||
case malformedVocab | ||
|
||
case chatTemplate(String) | ||
case tooLong(String) | ||
} | ||
|
||
|
@@ -94,6 +94,13 @@ struct TokenizerModel { | |
} | ||
} | ||
|
||
public enum ChatTemplateArgument { | ||
/// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config. | ||
case literal(String) | ||
/// For models whose tokenizer config includes multiple chat templates, the template can be specified by name. Normally this is not necessary. | ||
case name(String) | ||
} | ||
|
||
public protocol Tokenizer { | ||
func tokenize(text: String) -> [String] | ||
|
||
|
@@ -117,15 +124,24 @@ public protocol Tokenizer { | |
var eosTokenId: Int? { get } | ||
var unknownToken: String? { get } | ||
var unknownTokenId: Int? { get } | ||
|
||
|
||
/// The appropriate chat template is selected from the tokenizer config | ||
func applyChatTemplate(messages: [[String: String]]) throws -> [Int] | ||
|
||
|
||
/// The chat template is provided as a string literal or specified by name | ||
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] | ||
|
||
/// The chat template is provided as a string literal | ||
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] | ||
|
||
func applyChatTemplate( | ||
messages: [[String: String]], | ||
chatTemplate: String?, | ||
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. | ||
chatTemplate: ChatTemplateArgument?, | ||
addGenerationPrompt: Bool, | ||
truncation: Bool, | ||
maxLength: Int? | ||
maxLength: Int?, | ||
tools: [[String: Any]]? | ||
) throws -> [Int] | ||
} | ||
|
||
|
@@ -176,8 +192,6 @@ public class PreTrainedTokenizer: Tokenizer { | |
private let tokenizerConfig: Config | ||
|
||
private let cleanUpTokenizationSpaces: Bool | ||
|
||
private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" | ||
|
||
required public init(tokenizerConfig: Config, tokenizerData: Config) throws { | ||
var addedTokens: [String : Int] = [:] | ||
|
@@ -222,7 +236,7 @@ public class PreTrainedTokenizer: Tokenizer { | |
self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder) | ||
self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true | ||
self.tokenizerConfig = tokenizerConfig | ||
|
||
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) | ||
} | ||
|
||
|
@@ -316,22 +330,76 @@ public class PreTrainedTokenizer: Tokenizer { | |
public func convertIdToToken(_ id: Int) -> String? { | ||
model.convertIdToToken(id) | ||
} | ||
|
||
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] { | ||
try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, maxLength: nil) | ||
try applyChatTemplate(messages: messages, addGenerationPrompt: true) | ||
} | ||
|
||
public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] { | ||
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true) | ||
} | ||
|
||
|
||
public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] { | ||
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true) | ||
} | ||
|
||
public func applyChatTemplate( | ||
messages: [[String: String]], | ||
chatTemplate: String?, | ||
chatTemplate: ChatTemplateArgument? = nil, | ||
addGenerationPrompt: Bool = false, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be |
||
truncation: Bool = false, | ||
maxLength: Int? | ||
maxLength: Int? = nil, | ||
/// A list of tools (callable functions) that will be accessible to the model. If the template does not | ||
/// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, | ||
/// giving the name, description and argument types for the tool. See the | ||
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) | ||
/// for more information. | ||
DePasqualeOrg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// Note: tool calling is not supported yet, it will be available in a future update. | ||
tools: [[String: Any]]? = nil | ||
) throws -> [Int] { | ||
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate) | ||
var selectedChatTemplate: String? | ||
if let chatTemplate, case .literal(let template) = chatTemplate { | ||
// Use chat template from argument | ||
selectedChatTemplate = template | ||
} else if let valueFromConfig = tokenizerConfig.chatTemplate { | ||
if let arrayValue = valueFromConfig.arrayValue { | ||
// If the config specifies a list of chat templates, convert them to a dictionary | ||
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: arrayValue.compactMap { item in | ||
guard let name = item.name?.stringValue, let template = item.template?.stringValue else { | ||
return nil | ||
} | ||
return (name, template) | ||
}) | ||
if let chatTemplate, case .name(let name) = chatTemplate { | ||
// Select chat template from config by name | ||
if let matchingDictEntry = templateDict[name] { | ||
selectedChatTemplate = matchingDictEntry | ||
} else { | ||
throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config") | ||
} | ||
} else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] { | ||
// Use tool use chat template from config | ||
selectedChatTemplate = toolUseTemplate | ||
} else if let defaultChatTemplate = templateDict["default"] { | ||
// Use default chat template from config | ||
selectedChatTemplate = defaultChatTemplate | ||
} | ||
} else if let stringValue = valueFromConfig.stringValue { | ||
// Use chat template from config | ||
selectedChatTemplate = stringValue | ||
} | ||
} | ||
|
||
guard let selectedChatTemplate else { | ||
throw TokenizerError.chatTemplate("No chat template was specified") | ||
} | ||
|
||
let template = try Template(selectedChatTemplate) | ||
var context: [String: Any] = [ | ||
"messages": messages, | ||
"add_generation_prompt": addGenerationPrompt | ||
DePasqualeOrg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// TODO: Add `tools` entry when support is added in Jinja | ||
// "tools": tools | ||
] | ||
|
||
// TODO: maybe keep NSString here | ||
|
@@ -397,15 +465,15 @@ extension AutoTokenizer { | |
|
||
return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) | ||
} | ||
|
||
public static func from( | ||
modelFolder: URL, | ||
hubApi: HubApi = .shared | ||
) async throws -> Tokenizer { | ||
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi) | ||
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig } | ||
let tokenizerData = try await config.tokenizerData | ||
|
||
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// | ||
// ChatTemplateTests.swift | ||
// swift-transformers | ||
// | ||
// Created by Anthony DePasquale on 2/10/24. | ||
// | ||
|
||
import XCTest | ||
import Tokenizers | ||
|
||
class ChatTemplateTests: XCTestCase { | ||
let messages = [[ | ||
"role": "user", | ||
"content": "Describe the Swift programming language.", | ||
]] | ||
|
||
func testTemplateFromConfig() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages) | ||
let encodedTarget = [32010, 4002, 29581, 278, 14156, 8720, 4086, 29889, 32007, 32001] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<|user|>Describe the Swift programming language.<|end|><|assistant|>" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
func testDefaultTemplateFromArrayInConfig() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit") | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages) | ||
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
func testTemplateFromArgumentWithEnum() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") | ||
// Purposely not using the correct template for this model to verify that the template from the config is not being used | ||
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .literal(mistral7BDefaultTemplate)) | ||
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
func testTemplateFromArgumentWithString() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") | ||
// Purposely not using the correct template for this model to verify that the template from the config is not being used | ||
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate) | ||
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
func testNamedTemplateFromArgument() async throws { | ||
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit") | ||
// Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use` | ||
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .name("default")) | ||
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4] | ||
let decoded = tokenizer.decode(tokens: encoded) | ||
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]" | ||
XCTAssertEqual(encoded, encodedTarget) | ||
XCTAssertEqual(decoded, decodedTarget) | ||
} | ||
|
||
// TODO: Add tests for tool use template | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this solution. But it's a breaking change, do you know if the current method is used by mlx or others?
Perhaps we could add an overload
applyChatTemplate(messages: [[String: String]], chatTemplateName: String)
that just builds aChatTemplateArgument
(literal
) and forwards the call.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I searched the mlx-swift and mlx-swift-examples repos and didn't find any instances of
applyChatTemplate
there. I have a draft PR to use the chat templates once this PR and the Jinja PR land.I think you mean the additional overload should have the argument
chatTemplate: String
. Sure, I can add that.