Skip to content

Commit

Permalink
Add overload with chatTemplate argument of type String
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg committed Oct 3, 2024
1 parent 62ccb41 commit c0355dd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
17 changes: 13 additions & 4 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ struct TokenizerModel {
}

public enum ChatTemplateArgument {
/// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config file.
/// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config.
case literal(String)
/// For models whose tokenizer config file includes multiple chat templates, the template can be specified by name. Normally this is not necessary.
/// For models whose tokenizer config includes multiple chat templates, the template can be specified by name. Normally this is not necessary.
case name(String)
}

Expand Down Expand Up @@ -125,13 +125,18 @@ public protocol Tokenizer {
var unknownToken: String? { get }
var unknownTokenId: Int? { get }

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]

/// The chat template is provided as a string literal or specified by name
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]

/// The chat template is provided as a string literal
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]

func applyChatTemplate(
messages: [[String: String]],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config file. Normally this is not necessary.
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
Expand Down Expand Up @@ -334,6 +339,10 @@ public class PreTrainedTokenizer: Tokenizer {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
}

public func applyChatTemplate(
messages: [[String: String]],
chatTemplate: ChatTemplateArgument? = nil,
Expand Down Expand Up @@ -366,7 +375,7 @@ public class PreTrainedTokenizer: Tokenizer {
if let matchingDictEntry = templateDict[name] {
selectedChatTemplate = matchingDictEntry
} else {
throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config file")
throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config")
}
} else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
// Use tool use chat template from config
Expand Down
14 changes: 13 additions & 1 deletion Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ChatTemplateTests: XCTestCase {
XCTAssertEqual(decoded, decodedTarget)
}

func testTemplateFromArgument() async throws {
func testTemplateFromArgumentWithEnum() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
// Purposely not using the correct template for this model to verify that the template from the config is not being used
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
Expand All @@ -46,6 +46,18 @@ class ChatTemplateTests: XCTestCase {
XCTAssertEqual(decoded, decodedTarget)
}

func testTemplateFromArgumentWithString() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
// Purposely not using the correct template for this model to verify that the template from the config is not being used
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate)
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testNamedTemplateFromArgument() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
// Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use`
Expand Down

0 comments on commit c0355dd

Please sign in to comment.