Skip to content

Commit 48f9167

Browse files
committed
Improve chat template parsing
1 parent 71963c3 commit 48f9167

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ public protocol Tokenizer {
125125
chatTemplate: String?,
126126
addGenerationPrompt: Bool,
127127
truncation: Bool,
128-
maxLength: Int?
128+
maxLength: Int?,
129+
tools: [[String: Any]]?
129130
) throws -> [Int]
130131
}
131132

@@ -323,12 +324,35 @@ public class PreTrainedTokenizer: Tokenizer {
323324

324325
public func applyChatTemplate(
325326
messages: [[String: String]],
326-
chatTemplate: String?,
327+
chatTemplate: String? = nil,
327328
addGenerationPrompt: Bool = false,
328329
truncation: Bool = false,
329-
maxLength: Int?
330+
maxLength: Int? = nil,
331+
tools: [[String: Any]]? = nil
330332
) throws -> [Int] {
331-
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
333+
var chatTemplateFromConfig: String?
334+
if let chatTemplateValue = tokenizerConfig.chatTemplate {
335+
if let chatTemplateStringValue = chatTemplateValue.stringValue {
336+
chatTemplateFromConfig = chatTemplateStringValue
337+
} else if let chatTemplateArrayValue = chatTemplateValue.arrayValue {
338+
// If a list of chat templates is specified, convert them to a dict
339+
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: chatTemplateArrayValue.compactMap { template in
340+
guard let name = template[dynamicMember: "name"]?.stringValue,
341+
let templateString = template[dynamicMember: "template"]?.stringValue else {
342+
return nil
343+
}
344+
return (name, templateString)
345+
})
346+
// Choose the appropriate template
347+
if let tools = tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
348+
chatTemplateFromConfig = toolUseTemplate
349+
} else {
350+
chatTemplateFromConfig = templateDict["default"]
351+
}
352+
}
353+
}
354+
355+
let template = try Template(chatTemplate ?? chatTemplateFromConfig ?? defaultChatTemplate)
332356
var context: [String: Any] = [
333357
"messages": messages,
334358
"add_generation_prompt": addGenerationPrompt

0 commit comments

Comments
 (0)