diff --git a/Package.swift b/Package.swift index 55acf44e..ae688c74 100644 --- a/Package.swift +++ b/Package.swift @@ -15,15 +15,15 @@ let package = Package( .library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]) ], dependencies: [ - .package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.3.0")) + .package(url: "https://github.com/huggingface/swift-jinja.git", from: "2.0.0") ], targets: [ .target(name: "Generation", dependencies: ["Tokenizers"]), - .target(name: "Hub", resources: [.process("Resources")], swiftSettings: swiftSettings), + .target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings), .target(name: "Models", dependencies: ["Tokenizers", "Generation"]), - .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]), + .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]), .testTarget(name: "GenerationTests", dependencies: ["Generation"]), - .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")], swiftSettings: swiftSettings), + .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings), .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]), ] diff --git a/Sources/Hub/Config.swift b/Sources/Hub/Config.swift index 8348e5f6..4cf86dd1 100644 --- a/Sources/Hub/Config.swift +++ b/Sources/Hub/Config.swift @@ -5,6 +5,7 @@ // Created by Piotr Kowalczuk on 06.03.25. import Foundation +import Jinja // MARK: - Configuration files with dynamic lookup @@ -433,28 +434,28 @@ public struct Config: Hashable, Sendable, self.dictionary(or: or) } - public func toJinjaCompatible() -> Any? { + public func jinjaValue() -> Jinja.Value { switch self.value { case let .array(val): - return val.map { $0.toJinjaCompatible() } + return .array(val.map { $0.jinjaValue() }) case let .dictionary(val): - var result: [String: Any?] = [:] + var result: [String: Jinja.Value] = [:] for (key, config) in val { - result[key.string] = config.toJinjaCompatible() + result[key.string] = config.jinjaValue() } - return result + return .object(.init(uniqueKeysWithValues: result)) case let .boolean(val): - return val + return .boolean(val) case let .floating(val): - return val + return .double(Double(String(val)) ?? Double(val)) case let .integer(val): - return val + return .int(val) case let .string(val): - return val.string + return .string(val.string) case let .token(val): - return [String(val.0): val.1.string] as [String: String] + return [String(val.0): .string(val.1.string)] case .null: - return nil + return .null } } diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index 1f084c63..af0bc231 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -769,32 +769,34 @@ public class PreTrainedTokenizer: Tokenizer { } let template = try compiledTemplate(for: selectedChatTemplate) - var context: [String: Any] = [ - "messages": messages, - "add_generation_prompt": addGenerationPrompt, + var context: [String: Jinja.Value] = try [ + "messages": .array(messages.map { try Value(any: $0) }), + "add_generation_prompt": .boolean(addGenerationPrompt), ] if let tools { - context["tools"] = tools + context["tools"] = try .array(tools.map { try Value(any: $0) }) } if let additionalContext { // Additional keys and values to be added to the context provided to the prompt templating engine. // For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided. // The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message. for (key, value) in additionalContext { - context[key] = value + context[key] = try Value(any: value) } } for (key, value) in tokenizerConfig.dictionary(or: [:]) { if specialTokenAttributes.contains(key.string), !value.isNull() { if let stringValue = value.string() { - context[key.string] = stringValue + context[key.string] = .string(stringValue) } else if let dictionary = value.dictionary() { - context[key.string] = addedTokenAsString(Config(dictionary)) + if let addedTokenString = addedTokenAsString(Config(dictionary)) { + context[key.string] = .string(addedTokenString) + } } else if let array: [String] = value.get() { - context[key.string] = array + context[key.string] = .array(array.map { .string($0) }) } else { - context[key.string] = value + context[key.string] = try Value(any: value) } } } diff --git a/Tests/HubTests/ConfigTests.swift b/Tests/HubTests/ConfigTests.swift index 8a313381..a082bf2a 100644 --- a/Tests/HubTests/ConfigTests.swift +++ b/Tests/HubTests/ConfigTests.swift @@ -435,7 +435,7 @@ struct ConfigTests { """ let got = try Template(template).render([ - "config": cfg.toJinjaCompatible() + "config": cfg.jinjaValue() ]) #expect(got == exp) diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift index c9245af1..fec1ec1c 100644 --- a/Tests/TokenizersTests/ChatTemplateTests.swift +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -257,6 +257,7 @@ struct ChatTemplateTests { What is the weather in Paris today?<|im_end|> <|im_start|>assistant + """ #expect(