Skip to content

Commit

Permalink
Improve chat template parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg committed Oct 2, 2024
1 parent 71963c3 commit 48f9167
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ public protocol Tokenizer {
chatTemplate: String?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?
maxLength: Int?,
tools: [[String: Any]]?
) throws -> [Int]
}

Expand Down Expand Up @@ -323,12 +324,35 @@ public class PreTrainedTokenizer: Tokenizer {

public func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
chatTemplate: String? = nil,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
maxLength: Int?
maxLength: Int? = nil,
tools: [[String: Any]]? = nil
) throws -> [Int] {
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
var chatTemplateFromConfig: String?
if let chatTemplateValue = tokenizerConfig.chatTemplate {
if let chatTemplateStringValue = chatTemplateValue.stringValue {
chatTemplateFromConfig = chatTemplateStringValue
} else if let chatTemplateArrayValue = chatTemplateValue.arrayValue {
// If a list of chat templates is specified, convert them to a dict
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: chatTemplateArrayValue.compactMap { template in
guard let name = template[dynamicMember: "name"]?.stringValue,
let templateString = template[dynamicMember: "template"]?.stringValue else {
return nil
}
return (name, templateString)
})
// Choose the appropriate template
if let tools = tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
chatTemplateFromConfig = toolUseTemplate
} else {
chatTemplateFromConfig = templateDict["default"]
}
}
}

let template = try Template(chatTemplate ?? chatTemplateFromConfig ?? defaultChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
Expand Down

0 comments on commit 48f9167

Please sign in to comment.