Skip to content
Merged
32 changes: 28 additions & 4 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ public protocol Tokenizer {
chatTemplate: String?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?
maxLength: Int?,
tools: [[String: Any]]?
) throws -> [Int]
}

Expand Down Expand Up @@ -323,12 +324,35 @@ public class PreTrainedTokenizer: Tokenizer {

public func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
chatTemplate: String? = nil,
addGenerationPrompt: Bool = false,
Copy link
Contributor Author

@DePasqualeOrg DePasqualeOrg Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addGenerationPrompt is false by default in this method (which mirrors the Python implementation), but it is currently set to true in the overload method. Should we make the behavior consistent in all applyChatTemplate methods? Should it be true or false by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be true for the overloads as I expect that to be the most common use of the API, but I'm not fully sure.

truncation: Bool = false,
maxLength: Int?
maxLength: Int? = nil,
tools: [[String: Any]]? = nil
) throws -> [Int] {
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
var chatTemplateFromConfig: String?
if let chatTemplateValue = tokenizerConfig.chatTemplate {
if let chatTemplateStringValue = chatTemplateValue.stringValue {
chatTemplateFromConfig = chatTemplateStringValue
} else if let chatTemplateArrayValue = chatTemplateValue.arrayValue {
// If a list of chat templates is specified, convert them to a dict
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: chatTemplateArrayValue.compactMap { template in
guard let name = template[dynamicMember: "name"]?.stringValue,
let templateString = template[dynamicMember: "template"]?.stringValue else {
return nil
}
return (name, templateString)
})
// Choose the appropriate template
if let tools = tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
chatTemplateFromConfig = toolUseTemplate
} else {
chatTemplateFromConfig = templateDict["default"]
}
}
}

let template = try Template(chatTemplate ?? chatTemplateFromConfig ?? defaultChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
Expand Down