diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index e026d71..9c8c381 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -9,13 +9,13 @@ import Hub import Foundation import Jinja -enum TokenizerError : Error { +enum TokenizerError: Error { case missingConfig case missingTokenizerClassInConfig case unsupportedTokenizer(String) case missingVocab case malformedVocab - + case chatTemplate(String) case tooLong(String) } @@ -94,6 +94,13 @@ struct TokenizerModel { } } +public enum ChatTemplateArgument { + /// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config. + case literal(String) + /// For models whose tokenizer config includes multiple chat templates, the template can be specified by name. Normally this is not necessary. + case name(String) +} + public protocol Tokenizer { func tokenize(text: String) -> [String] @@ -117,15 +124,24 @@ public protocol Tokenizer { var eosTokenId: Int? { get } var unknownToken: String? { get } var unknownTokenId: Int? { get } - + + /// The appropriate chat template is selected from the tokenizer config func applyChatTemplate(messages: [[String: String]]) throws -> [Int] - + + /// The chat template is provided as a string literal or specified by name + func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] + + /// The chat template is provided as a string literal + func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] + func applyChatTemplate( messages: [[String: String]], - chatTemplate: String?, + /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. + chatTemplate: ChatTemplateArgument?, addGenerationPrompt: Bool, truncation: Bool, - maxLength: Int? + maxLength: Int?, + tools: [[String: Any]]? ) throws -> [Int] } @@ -176,8 +192,6 @@ public class PreTrainedTokenizer: Tokenizer { private let tokenizerConfig: Config private let cleanUpTokenizationSpaces: Bool - - private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" required public init(tokenizerConfig: Config, tokenizerData: Config) throws { var addedTokens: [String : Int] = [:] @@ -222,7 +236,7 @@ public class PreTrainedTokenizer: Tokenizer { self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder) self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true self.tokenizerConfig = tokenizerConfig - + model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) } @@ -316,22 +330,76 @@ public class PreTrainedTokenizer: Tokenizer { public func convertIdToToken(_ id: Int) -> String? { model.convertIdToToken(id) } - + public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] { - try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, maxLength: nil) + try applyChatTemplate(messages: messages, addGenerationPrompt: true) + } + + public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] { + try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true) } - + + public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] { + try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true) + } + public func applyChatTemplate( messages: [[String: String]], - chatTemplate: String?, + chatTemplate: ChatTemplateArgument? = nil, addGenerationPrompt: Bool = false, truncation: Bool = false, - maxLength: Int? + maxLength: Int? = nil, + /// A list of tools (callable functions) that will be accessible to the model. If the template does not + /// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + /// giving the name, description and argument types for the tool. See the + /// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + /// for more information. + /// Note: tool calling is not supported yet, it will be available in a future update. + tools: [[String: Any]]? = nil ) throws -> [Int] { - let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate) + var selectedChatTemplate: String? + if let chatTemplate, case .literal(let template) = chatTemplate { + // Use chat template from argument + selectedChatTemplate = template + } else if let valueFromConfig = tokenizerConfig.chatTemplate { + if let arrayValue = valueFromConfig.arrayValue { + // If the config specifies a list of chat templates, convert them to a dictionary + let templateDict = Dictionary(uniqueKeysWithValues: arrayValue.compactMap { item in + guard let name = item.name?.stringValue, let template = item.template?.stringValue else { + return nil + } + return (name, template) + }) + if let chatTemplate, case .name(let name) = chatTemplate { + // Select chat template from config by name + if let matchingDictEntry = templateDict[name] { + selectedChatTemplate = matchingDictEntry + } else { + throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config") + } + } else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] { + // Use tool use chat template from config + selectedChatTemplate = toolUseTemplate + } else if let defaultChatTemplate = templateDict["default"] { + // Use default chat template from config + selectedChatTemplate = defaultChatTemplate + } + } else if let stringValue = valueFromConfig.stringValue { + // Use chat template from config + selectedChatTemplate = stringValue + } + } + + guard let selectedChatTemplate else { + throw TokenizerError.chatTemplate("No chat template was specified") + } + + let template = try Template(selectedChatTemplate) var context: [String: Any] = [ "messages": messages, "add_generation_prompt": addGenerationPrompt + // TODO: Add `tools` entry when support is added in Jinja + // "tools": tools ] // TODO: maybe keep NSString here @@ -397,7 +465,7 @@ extension AutoTokenizer { return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } - + public static func from( modelFolder: URL, hubApi: HubApi = .shared @@ -405,7 +473,7 @@ extension AutoTokenizer { let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi) guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig } let tokenizerData = try await config.tokenizerData - + return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } } diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift new file mode 100644 index 0000000..3ee7aa1 --- /dev/null +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -0,0 +1,73 @@ +// +// ChatTemplateTests.swift +// swift-transformers +// +// Created by Anthony DePasquale on 2/10/24. +// + +import XCTest +import Tokenizers + +class ChatTemplateTests: XCTestCase { + let messages = [[ + "role": "user", + "content": "Describe the Swift programming language.", + ]] + + func testTemplateFromConfig() async throws { + let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") + let encoded = try tokenizer.applyChatTemplate(messages: messages) + let encodedTarget = [32010, 4002, 29581, 278, 14156, 8720, 4086, 29889, 32007, 32001] + let decoded = tokenizer.decode(tokens: encoded) + let decodedTarget = "<|user|>Describe the Swift programming language.<|end|><|assistant|>" + XCTAssertEqual(encoded, encodedTarget) + XCTAssertEqual(decoded, decodedTarget) + } + + func testDefaultTemplateFromArrayInConfig() async throws { + let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit") + let encoded = try tokenizer.applyChatTemplate(messages: messages) + let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4] + let decoded = tokenizer.decode(tokens: encoded) + let decodedTarget = " [INST] Describe the Swift programming language. [/INST]" + XCTAssertEqual(encoded, encodedTarget) + XCTAssertEqual(decoded, decodedTarget) + } + + func testTemplateFromArgumentWithEnum() async throws { + let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") + // Purposely not using the correct template for this model to verify that the template from the config is not being used + let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .literal(mistral7BDefaultTemplate)) + let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962] + let decoded = tokenizer.decode(tokens: encoded) + let decodedTarget = " [INST] Describe the Swift programming language. [/INST]" + XCTAssertEqual(encoded, encodedTarget) + XCTAssertEqual(decoded, decodedTarget) + } + + func testTemplateFromArgumentWithString() async throws { + let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") + // Purposely not using the correct template for this model to verify that the template from the config is not being used + let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate) + let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962] + let decoded = tokenizer.decode(tokens: encoded) + let decodedTarget = " [INST] Describe the Swift programming language. [/INST]" + XCTAssertEqual(encoded, encodedTarget) + XCTAssertEqual(decoded, decodedTarget) + } + + func testNamedTemplateFromArgument() async throws { + let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit") + // Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use` + let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .name("default")) + let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4] + let decoded = tokenizer.decode(tokens: encoded) + let decodedTarget = " [INST] Describe the Swift programming language. [/INST]" + XCTAssertEqual(encoded, encodedTarget) + XCTAssertEqual(decoded, decodedTarget) + } + + // TODO: Add tests for tool use template +}