diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index e026d71..0ac0de6 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -125,7 +125,8 @@ public protocol Tokenizer { chatTemplate: String?, addGenerationPrompt: Bool, truncation: Bool, - maxLength: Int? + maxLength: Int?, + tools: [[String: Any]]? ) throws -> [Int] } @@ -323,12 +324,35 @@ public class PreTrainedTokenizer: Tokenizer { public func applyChatTemplate( messages: [[String: String]], - chatTemplate: String?, + chatTemplate: String? = nil, addGenerationPrompt: Bool = false, truncation: Bool = false, - maxLength: Int? + maxLength: Int? = nil, + tools: [[String: Any]]? = nil ) throws -> [Int] { - let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate) + var chatTemplateFromConfig: String? + if let chatTemplateValue = tokenizerConfig.chatTemplate { + if let chatTemplateStringValue = chatTemplateValue.stringValue { + chatTemplateFromConfig = chatTemplateStringValue + } else if let chatTemplateArrayValue = chatTemplateValue.arrayValue { + // If a list of chat templates is specified, convert them to a dict + let templateDict = Dictionary(uniqueKeysWithValues: chatTemplateArrayValue.compactMap { template in + guard let name = template[dynamicMember: "name"]?.stringValue, + let templateString = template[dynamicMember: "template"]?.stringValue else { + return nil + } + return (name, templateString) + }) + // Choose the appropriate template + if let tools = tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] { + chatTemplateFromConfig = toolUseTemplate + } else { + chatTemplateFromConfig = templateDict["default"] + } + } + } + + let template = try Template(chatTemplate ?? chatTemplateFromConfig ?? defaultChatTemplate) var context: [String: Any] = [ "messages": messages, "add_generation_prompt": addGenerationPrompt