Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple chat templates per model #134

Merged
93 changes: 76 additions & 17 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import Hub
import Foundation
import Jinja

enum TokenizerError : Error {
enum TokenizerError: Error {
case missingConfig
case missingTokenizerClassInConfig
case unsupportedTokenizer(String)
case missingVocab
case malformedVocab

case chatTemplate(String)
case tooLong(String)
}

Expand Down Expand Up @@ -94,6 +94,13 @@ struct TokenizerModel {
}
}

public enum ChatTemplateArgument {
/// A Jinja template to use for the conversion. Normally it is not necessary to provide a template, since it will be read from the tokenizer config file.
DePasqualeOrg marked this conversation as resolved.
Show resolved Hide resolved
case literal(String)
/// For models whose tokenizer config file includes multiple chat templates, the template can be specified by name. Normally this is not necessary.
case name(String)
}

public protocol Tokenizer {
func tokenize(text: String) -> [String]

Expand All @@ -117,15 +124,19 @@ public protocol Tokenizer {
var eosTokenId: Int? { get }
var unknownToken: String? { get }
var unknownTokenId: Int? { get }

func applyChatTemplate(messages: [[String: String]]) throws -> [Int]


func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this solution. But it's a breaking change, do you know if the current method is used by mlx or others?

Perhaps we could add an overload applyChatTemplate(messages: [[String: String]], chatTemplateName: String) that just builds a ChatTemplateArgument (literal) and forwards the call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I searched the mlx-swift and mlx-swift-examples repos and didn't find any instances of applyChatTemplate there. I have a draft PR to use the chat templates once this PR and the Jinja PR land.

I think you mean the additional overload should have the argument chatTemplate: String. Sure, I can add that.


func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config file. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?
maxLength: Int?,
tools: [[String: Any]]?
) throws -> [Int]
}

Expand Down Expand Up @@ -176,8 +187,6 @@ public class PreTrainedTokenizer: Tokenizer {
private let tokenizerConfig: Config

private let cleanUpTokenizationSpaces: Bool

private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

required public init(tokenizerConfig: Config, tokenizerData: Config) throws {
var addedTokens: [String : Int] = [:]
Expand Down Expand Up @@ -222,7 +231,7 @@ public class PreTrainedTokenizer: Tokenizer {
self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder)
self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true
self.tokenizerConfig = tokenizerConfig

model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
}

Expand Down Expand Up @@ -316,22 +325,72 @@ public class PreTrainedTokenizer: Tokenizer {
public func convertIdToToken(_ id: Int) -> String? {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, maxLength: nil)
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
}

public func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
Copy link
Contributor Author

@DePasqualeOrg DePasqualeOrg Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addGenerationPrompt is false by default in this method (which mirrors the Python implementation), but it is currently set to true in the overload method. Should we make the behavior consistent in all applyChatTemplate methods? Should it be true or false by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be true for the overloads as I expect that to be the most common use of the API, but I'm not fully sure.

truncation: Bool = false,
maxLength: Int?
maxLength: Int? = nil,
/// A list of tools (callable functions) that will be accessible to the model. If the template does not
/// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
/// giving the name, description and argument types for the tool. See the
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
/// for more information.
DePasqualeOrg marked this conversation as resolved.
Show resolved Hide resolved
/// Note: tool calling is not supported yet, it will be available in a future update.
tools: [[String: Any]]? = nil
) throws -> [Int] {
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
var selectedChatTemplate: String?
if let chatTemplate, case .literal(let template) = chatTemplate {
// Use chat template from argument
selectedChatTemplate = template
} else if let valueFromConfig = tokenizerConfig.chatTemplate {
if let arrayValue = valueFromConfig.arrayValue {
// If the config specifies a list of chat templates, convert them to a dictionary
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: arrayValue.compactMap { item in
guard let name = item.name?.stringValue, let template = item.template?.stringValue else {
return nil
}
return (name, template)
})
if let chatTemplate, case .name(let name) = chatTemplate {
// Select chat template from config by name
if let matchingDictEntry = templateDict[name] {
selectedChatTemplate = matchingDictEntry
} else {
throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config file")
}
} else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
// Use tool use chat template from config
selectedChatTemplate = toolUseTemplate
} else if let defaultChatTemplate = templateDict["default"] {
// Use default chat template from config
selectedChatTemplate = defaultChatTemplate
}
} else if let stringValue = valueFromConfig.stringValue {
// Use chat template from config
selectedChatTemplate = stringValue
}
}

guard let selectedChatTemplate else {
throw TokenizerError.chatTemplate("No chat template was specified")
}

let template = try Template(selectedChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
DePasqualeOrg marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Add `tools` entry when support is added in Jinja
// "tools": tools
]

// TODO: maybe keep NSString here
Expand Down Expand Up @@ -397,15 +456,15 @@ extension AutoTokenizer {

return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

public static func from(
modelFolder: URL,
hubApi: HubApi = .shared
) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi)
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}
}
Expand Down
61 changes: 61 additions & 0 deletions Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//
// ChatTemplateTests.swift
// swift-transformers
//
// Created by Anthony DePasquale on 2/10/24.
//

import XCTest
import Tokenizers

class ChatTemplateTests: XCTestCase {
let messages = [[
"role": "user",
"content": "Describe the Swift programming language.",
]]

func testTemplateFromConfig() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
let encoded = try tokenizer.applyChatTemplate(messages: messages)
let encodedTarget = [32010, 4002, 29581, 278, 14156, 8720, 4086, 29889, 32007, 32001]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<|user|>Describe the Swift programming language.<|end|><|assistant|>"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testDefaultTemplateFromArrayInConfig() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
let encoded = try tokenizer.applyChatTemplate(messages: messages)
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testTemplateFromArgument() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
// Purposely not using the correct template for this model to verify that the template from the config is not being used
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .literal(mistral7BDefaultTemplate))
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testNamedTemplateFromArgument() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
// Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use`
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .name("default"))
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

// TODO: Add tests for tool use template
}