Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple chat templates per model #134

Merged
56 changes: 49 additions & 7 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import Hub
import Foundation
import Jinja

enum TokenizerError : Error {
enum TokenizerError: Error {
case missingConfig
case missingTokenizerClassInConfig
case unsupportedTokenizer(String)
case missingVocab
case malformedVocab
case noChatTemplateSpecified

case tooLong(String)
}
Expand Down Expand Up @@ -125,7 +126,8 @@ public protocol Tokenizer {
chatTemplate: String?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?
maxLength: Int?,
tools: [[String: Any]]?
) throws -> [Int]
}

Expand Down Expand Up @@ -176,8 +178,6 @@ public class PreTrainedTokenizer: Tokenizer {
private let tokenizerConfig: Config

private let cleanUpTokenizationSpaces: Bool

private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

required public init(tokenizerConfig: Config, tokenizerData: Config) throws {
var addedTokens: [String : Int] = [:]
Expand Down Expand Up @@ -323,12 +323,54 @@ public class PreTrainedTokenizer: Tokenizer {

public func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
/// A Jinja template or the name of a template to use for this conversion.
/// It is usually not necessary to pass anything to this argument,
/// as the model's template will be used by default.
chatTemplate: String? = nil,
addGenerationPrompt: Bool = false,
Copy link
Contributor Author

@DePasqualeOrg DePasqualeOrg Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addGenerationPrompt is false by default in this method (which mirrors the Python implementation), but it is currently set to true in the overload method. Should we make the behavior consistent in all applyChatTemplate methods? Should it be true or false by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be true for the overloads as I expect that to be the most common use of the API, but I'm not fully sure.

truncation: Bool = false,
maxLength: Int?
maxLength: Int? = nil,
/// A list of tools (callable functions) that will be accessible to the model. If the template does not
/// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
/// giving the name, description and argument types for the tool. See the
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
/// for more information.
DePasqualeOrg marked this conversation as resolved.
Show resolved Hide resolved
tools: [[String: Any]]? = nil
) throws -> [Int] {
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
var selectedChatTemplate: String?
if let valueFromConfig = tokenizerConfig.chatTemplate {
if let arrayValue = valueFromConfig.arrayValue {
// If the config specifies a list of chat templates, convert them to a dictionary
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: arrayValue.compactMap { item in
guard let name = item.name?.stringValue, let template = item.template?.stringValue else {
return nil
}
return (name, template)
})
if let chatTemplateArgument = chatTemplate, let matchingDictEntry = templateDict[chatTemplateArgument] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if you pass an actual template as an argument to a tokenizer that has multiple ones (mlx-community/Mistral-7B-Instruct-v0.3-4bit, for example), it will be silently ignored and replaced with the default one, right? Perhaps we could fail if the tokenizer has several templates and the argument provided does not match any of them. What do you think?

// Use chat template from config that matches the name specified in the `chatTemplate` argument
selectedChatTemplate = matchingDictEntry
} else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
// Use tool use chat template from config
selectedChatTemplate = toolUseTemplate
} else if let defaultChatTemplate = templateDict["default"] {
// Use default chat template from config
selectedChatTemplate = defaultChatTemplate
}
} else if let chatTemplateArgument = chatTemplate {
// Use chat template from argument
selectedChatTemplate = chatTemplateArgument
} else if let stringValue = valueFromConfig.stringValue {
// Use chat template from config
selectedChatTemplate = stringValue
}
}

guard let selectedChatTemplate else {
throw TokenizerError.noChatTemplateSpecified
}

let template = try Template(selectedChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
DePasqualeOrg marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
61 changes: 61 additions & 0 deletions Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//
// ChatTemplateTests.swift
// swift-transformers
//
// Created by Anthony DePasquale on 2/10/24.
//

import XCTest
import Tokenizers

class ChatTemplateTests: XCTestCase {
let messages = [[
"role": "user",
"content": "Describe the Swift programming language.",
]]

func testTemplateFromConfig() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
let encoded = try tokenizer.applyChatTemplate(messages: messages)
let encodedTarget = [32010, 4002, 29581, 278, 14156, 8720, 4086, 29889, 32007, 32001]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<|user|>Describe the Swift programming language.<|end|><|assistant|>"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testDefaultTemplateFromArrayInConfig() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
let encoded = try tokenizer.applyChatTemplate(messages: messages)
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testTemplateFromArgument() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
// Purposely not using the correct template for this model to verify that the template from the config is not being used
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate, addGenerationPrompt: false, truncation: false, maxLength: nil, tools: nil)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unfortunate we have to add all the arguments in this version of the call, perhaps we could add a new one that supports messages and chatTemplate, if you think it's interesting to do so.

let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testNamedTemplateFromArgument() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
// Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use`
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: "default", addGenerationPrompt: false, truncation: false, maxLength: nil, tools: nil)
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

// TODO: Add tests for tool use template
}