From c0355ddaba33e8f7f4f77d7ede9b631470105137 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Thu, 3 Oct 2024 15:43:39 +0200 Subject: [PATCH] Add overload with `chatTemplate` argument of type `String` --- Sources/Tokenizers/Tokenizer.swift | 17 +++++++++++++---- Tests/TokenizersTests/ChatTemplateTests.swift | 14 +++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index 34b72150..9c8c3816 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -95,9 +95,9 @@ struct TokenizerModel { } public enum ChatTemplateArgument { - /// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config file. + /// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config. case literal(String) - /// For models whose tokenizer config file includes multiple chat templates, the template can be specified by name. Normally this is not necessary. + /// For models whose tokenizer config includes multiple chat templates, the template can be specified by name. Normally this is not necessary. case name(String) } @@ -125,13 +125,18 @@ public protocol Tokenizer { var unknownToken: String? { get } var unknownTokenId: Int? { get } + /// The appropriate chat template is selected from the tokenizer config func applyChatTemplate(messages: [[String: String]]) throws -> [Int] + /// The chat template is provided as a string literal or specified by name func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] + /// The chat template is provided as a string literal + func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] + func applyChatTemplate( messages: [[String: String]], - /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config file. Normally this is not necessary. + /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary. chatTemplate: ChatTemplateArgument?, addGenerationPrompt: Bool, truncation: Bool, @@ -334,6 +339,10 @@ public class PreTrainedTokenizer: Tokenizer { try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true) } + public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] { + try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true) + } + public func applyChatTemplate( messages: [[String: String]], chatTemplate: ChatTemplateArgument? = nil, @@ -366,7 +375,7 @@ public class PreTrainedTokenizer: Tokenizer { if let matchingDictEntry = templateDict[name] { selectedChatTemplate = matchingDictEntry } else { - throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config file") + throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config") } } else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] { // Use tool use chat template from config diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift index d8c2306d..3ee7aa15 100644 --- a/Tests/TokenizersTests/ChatTemplateTests.swift +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -34,7 +34,7 @@ class ChatTemplateTests: XCTestCase { XCTAssertEqual(decoded, decodedTarget) } - func testTemplateFromArgument() async throws { + func testTemplateFromArgumentWithEnum() async throws { let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") // Purposely not using the correct template for this model to verify that the template from the config is not being used let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" @@ -46,6 +46,18 @@ class ChatTemplateTests: XCTestCase { XCTAssertEqual(decoded, decodedTarget) } + func testTemplateFromArgumentWithString() async throws { + let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct") + // Purposely not using the correct template for this model to verify that the template from the config is not being used + let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate) + let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962] + let decoded = tokenizer.decode(tokens: encoded) + let decodedTarget = " [INST] Describe the Swift programming language. [/INST]" + XCTAssertEqual(encoded, encodedTarget) + XCTAssertEqual(decoded, decodedTarget) + } + func testNamedTemplateFromArgument() async throws { let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit") // Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use`