diff --git a/Applications/LLMEval/ContentView.swift b/Applications/LLMEval/ContentView.swift index d6f3ba5..2789aa3 100644 --- a/Applications/LLMEval/ContentView.swift +++ b/Applications/LLMEval/ContentView.swift @@ -159,7 +159,10 @@ class LLMEvaluator { /// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on /// more devices. - let modelConfiguration = ModelConfiguration.phi3_5_4bit + + // let modelConfiguration = ModelConfiguration.phi3_5_4bit + // let modelConfiguration = ModelConfiguration.mistral7B4bit + let modelConfiguration = ModelConfiguration.llama3_2_3B_4bit /// parameters controlling the output let generateParameters = GenerateParameters(temperature: 0.6) @@ -217,11 +220,9 @@ class LLMEvaluator { do { let modelContainer = try await load() - // augment the prompt as needed - let prompt = modelConfiguration.prepare(prompt: prompt) - - let promptTokens = await modelContainer.perform { _, tokenizer in - tokenizer.encode(text: prompt) + let messages = [["role": "user", "content": prompt]] + let promptTokens = try await modelContainer.perform { _, tokenizer in + try tokenizer.applyChatTemplate(messages: messages) } // each time you generate you will get something new diff --git a/Applications/LoRATrainingExample/ContentView.swift b/Applications/LoRATrainingExample/ContentView.swift index 1f89186..7813372 100644 --- a/Applications/LoRATrainingExample/ContentView.swift +++ b/Applications/LoRATrainingExample/ContentView.swift @@ -269,10 +269,9 @@ class LoRAEvaluator { let modelContainer = try await loadModel() - // prepare the prompt - let preparedPrompt = modelConfiguration.prepare(prompt: prompt) - let promptTokens = await modelContainer.perform { _, tokenizer in - tokenizer.encode(text: preparedPrompt) + let messages = [["role": "user", "content": prompt]] + let promptTokens = try await modelContainer.perform { _, tokenizer in + try tokenizer.applyChatTemplate(messages: messages) } // evaluate diff --git a/Libraries/LLM/LLMModel.swift b/Libraries/LLM/LLMModel.swift index 83da0a7..5999fb5 100644 --- a/Libraries/LLM/LLMModel.swift +++ b/Libraries/LLM/LLMModel.swift @@ -8,12 +8,13 @@ import Tokenizers /// Container for models that guarantees single threaded access. /// -/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access +/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access /// the model and/or tokenizer: /// /// ```swift -/// let promptTokens = await modelContainer.perform { _, tokenizer in -/// tokenizer.encode(text: prompt) +/// let messages = [["role": "user", "content": prompt]] +/// let promptTokens = try await modelContainer.perform { _, tokenizer in +/// try tokenizer.applyChatTemplate(messages: messages) /// } /// ``` /// diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index bb5e8c3..1946f3e 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -39,11 +39,6 @@ public struct ModelConfiguration: Sendable { /// Additional tokens to use for end of string public let extraEOSTokens: Set - /// custom preparation logic for the prompt. custom tokenizers provide more capability, but this - /// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt - /// format - private let preparePrompt: (@Sendable (String) -> String)? - public init( id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil, defaultPrompt: String = "hello", @@ -55,25 +50,18 @@ public struct ModelConfiguration: Sendable { self.overrideTokenizer = overrideTokenizer self.defaultPrompt = defaultPrompt self.extraEOSTokens = extraEOSTokens - self.preparePrompt = preparePrompt } public init( directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil, defaultPrompt: String = "hello", - extraEOSTokens: Set = [], - preparePrompt: (@Sendable (String) -> String)? = nil + extraEOSTokens: Set = [] ) { self.id = .directory(directory) self.tokenizerId = tokenizerId self.overrideTokenizer = overrideTokenizer self.defaultPrompt = defaultPrompt self.extraEOSTokens = extraEOSTokens - self.preparePrompt = preparePrompt - } - - public func prepare(prompt: String) -> String { - preparePrompt?(prompt) ?? prompt } public func modelDirectory(hub: HubApi = HubApi()) -> URL { @@ -116,40 +104,26 @@ extension ModelConfiguration { public static let smolLM_135M_4bit = ModelConfiguration( id: "mlx-community/SmolLM-135M-Instruct-4bit", defaultPrompt: "Tell me about the history of Spain." - ) { - prompt in - "<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n" - } + ) public static let mistralNeMo4bit = ModelConfiguration( id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit", defaultPrompt: "Explain quaternions." - ) { prompt in - "[INST] \(prompt) [/INST] " - } + ) public static let mistral7B4bit = ModelConfiguration( id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit", defaultPrompt: "Describe the Swift language." - ) { prompt in - "[INST] \(prompt) [/INST] " - } + ) public static let codeLlama13b4bit = ModelConfiguration( id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", overrideTokenizer: "PreTrainedTokenizer", defaultPrompt: "func sortArray(_ array: [Int]) -> String { }" - ) { prompt in - // given the prompt: func sortArray(_ array: [Int]) -> String { } - // the python code produces this (via its custom tokenizer): - //
 func sortArray(_ array: [Int]) -> String {   } 
-
-        "
 " + prompt.replacingOccurrences(of: "", with: "") + " "
-    }
+    )
 
     public static let phi4bit = ModelConfiguration(
         id: "mlx-community/phi-2-hf-4bit-mlx",
-
         // https://www.promptingguide.ai/models/phi-2
         defaultPrompt: "Why is the sky blue?"
     )
@@ -158,92 +132,60 @@ extension ModelConfiguration {
         id: "mlx-community/Phi-3.5-mini-instruct-4bit",
         defaultPrompt: "What is the gravity on Mars and the moon?",
         extraEOSTokens: ["<|end|>"]
-    ) {
-        prompt in
-        "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
-    }
+    )
 
     public static let gemma2bQuantized = ModelConfiguration(
         id: "mlx-community/quantized-gemma-2b-it",
         overrideTokenizer: "PreTrainedTokenizer",
-
         // https://www.promptingguide.ai/models/gemma
         defaultPrompt: "what is the difference between lettuce and cabbage?"
-
-    ) { prompt in
-        "user\n\(prompt)\nmodel\n"
-    }
+    )
 
     public static let gemma_2_9b_it_4bit = ModelConfiguration(
         id: "mlx-community/gemma-2-9b-it-4bit",
         overrideTokenizer: "PreTrainedTokenizer",
-
         // https://www.promptingguide.ai/models/gemma
         defaultPrompt: "What is the difference between lettuce and cabbage?"
-
-    ) { prompt in
-        "user\n\(prompt)\nmodel\n"
-    }
+    )
 
     public static let gemma_2_2b_it_4bit = ModelConfiguration(
         id: "mlx-community/gemma-2-2b-it-4bit",
         overrideTokenizer: "PreTrainedTokenizer",
-
         // https://www.promptingguide.ai/models/gemma
         defaultPrompt: "What is the difference between lettuce and cabbage?"
-
-    ) { prompt in
-        "user \(prompt)model"
-    }
+    )
 
     public static let qwen205b4bit = ModelConfiguration(
         id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
         overrideTokenizer: "PreTrainedTokenizer",
         defaultPrompt: "why is the sky blue?"
-    ) { prompt in
-        "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
-    }
+    )
 
     public static let openelm270m4bit = ModelConfiguration(
         id: "mlx-community/OpenELM-270M-Instruct",
-
         // https://huggingface.co/apple/OpenELM
         defaultPrompt: "Once upon a time there was"
-    ) { prompt in
-        "\(prompt)"
-    }
+    )
 
     public static let llama3_1_8B_4bit = ModelConfiguration(
         id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
         defaultPrompt: "What is the difference between a fruit and a vegetable?"
-    ) {
-        prompt in
-        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
-    }
+    )
 
     public static let llama3_8B_4bit = ModelConfiguration(
         id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
         defaultPrompt: "What is the difference between a fruit and a vegetable?"
-    ) {
-        prompt in
-        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
-    }
+    )
 
     public static let llama3_2_1B_4bit = ModelConfiguration(
         id: "mlx-community/Llama-3.2-1B-Instruct-4bit",
         defaultPrompt: "What is the difference between a fruit and a vegetable?"
-    ) {
-        prompt in
-        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
-    }
+    )
 
     public static let llama3_2_3B_4bit = ModelConfiguration(
         id: "mlx-community/Llama-3.2-3B-Instruct-4bit",
         defaultPrompt: "What is the difference between a fruit and a vegetable?"
-    ) {
-        prompt in
-        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
-    }
+    )
 
     private enum BootstrapState: Sendable {
         case idle
diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift
index 84f27c0..3706f2e 100644
--- a/Tools/llm-tool/LLMTool.swift
+++ b/Tools/llm-tool/LLMTool.swift
@@ -84,18 +84,6 @@ struct GenerateArguments: ParsableArguments, Sendable {
         }
     }
 
-    func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
-        String, [Int]
-    ) {
-        MLXRandom.seed(seed)
-
-        let prompt = try resolvePrompt(configuration: configuration)
-        let preparedPrompt = configuration.prepare(prompt: prompt)
-        let promptTokens = tokenizer.encode(text: preparedPrompt)
-
-        return (prompt, promptTokens)
-    }
-
     func generate(
         promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
         extraEOSTokens: Set? = nil
@@ -221,9 +209,10 @@ struct EvaluateCommand: AsyncParsableCommand {
             print("Model loaded -> \(modelConfiguration.id)")
         }
 
-        let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
-            try generate.tokenizePrompt(
-                configuration: modelConfiguration, tokenizer: tokenizer)
+        let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
+        let messages = [["role": "user", "content": prompt]]
+        let promptTokens = try await modelContainer.perform { _, tokenizer in
+            try tokenizer.applyChatTemplate(messages: messages)
         }
 
         if !generate.quiet {
diff --git a/Tools/llm-tool/LoraCommands.swift b/Tools/llm-tool/LoraCommands.swift
index 0861f1e..a5f668f 100644
--- a/Tools/llm-tool/LoraCommands.swift
+++ b/Tools/llm-tool/LoraCommands.swift
@@ -291,9 +291,10 @@ struct LoRAEvalCommand: AsyncParsableCommand {
 
         memory.start()
 
-        let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
-            try generate.tokenizePrompt(
-                configuration: modelConfiguration, tokenizer: tokenizer)
+        let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
+        let messages = [["role": "user", "content": prompt]]
+        let promptTokens = try await modelContainer.perform { _, tokenizer in
+            try tokenizer.applyChatTemplate(messages: messages)
         }
 
         if !generate.quiet {
diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj
index 7e5f974..1993f5c 100644
--- a/mlx-swift-examples.xcodeproj/project.pbxproj
+++ b/mlx-swift-examples.xcodeproj/project.pbxproj
@@ -3555,7 +3555,7 @@
 			repositoryURL = "https://github.com/huggingface/swift-transformers";
 			requirement = {
 				kind = upToNextMajorVersion;
-				minimumVersion = 0.1.12;
+				minimumVersion = 0.1.13;
 			};
 		};
 		C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */ = {
@@ -3571,7 +3571,7 @@
 			repositoryURL = "https://github.com/ml-explore/mlx-swift";
 			requirement = {
 				kind = upToNextMajorVersion;
-				minimumVersion = 0.16.1;
+				minimumVersion = 0.18.0;
 			};
 		};
 /* End XCRemoteSwiftPackageReference section */
diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
index 6c35c08..0361535 100644
--- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
+++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
@@ -15,8 +15,8 @@
       "kind" : "remoteSourceControl",
       "location" : "https://github.com/maiqingqiang/Jinja",
       "state" : {
-        "revision" : "5b0703d19a8901b76948753e5c5e3ca77043d33f",
-        "version" : "1.0.0"
+        "revision" : "b435eb62b0d3d5f34167ec70a128355486981712",
+        "version" : "1.0.5"
       }
     },
     {
@@ -24,8 +24,8 @@
       "kind" : "remoteSourceControl",
       "location" : "https://github.com/ml-explore/mlx-swift",
       "state" : {
-        "revision" : "86ad75ab1ee96cd70325732b37cd830f87d7e43f",
-        "version" : "0.16.1"
+        "revision" : "78a7cfe6701d6e9c88e9d4a0d1f7990af84b2146",
+        "version" : "0.18.0"
       }
     },
     {
@@ -51,8 +51,8 @@
       "kind" : "remoteSourceControl",
       "location" : "https://github.com/apple/swift-argument-parser.git",
       "state" : {
-        "revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b",
-        "version" : "1.4.0"
+        "revision" : "41982a3656a71c768319979febd796c6fd111d5c",
+        "version" : "1.5.0"
       }
     },
     {
@@ -60,8 +60,8 @@
       "kind" : "remoteSourceControl",
       "location" : "https://github.com/gonzalezreal/swift-markdown-ui",
       "state" : {
-        "revision" : "9a8119b37e09a770367eeb26e05267c75d854053",
-        "version" : "2.3.1"
+        "revision" : "55441810c0f678c78ed7e2ebd46dde89228e02fc",
+        "version" : "2.4.0"
       }
     },
     {
@@ -78,8 +78,8 @@
       "kind" : "remoteSourceControl",
       "location" : "https://github.com/huggingface/swift-transformers",
       "state" : {
-        "revision" : "0f2306713d48a75b862026ebb291926793773f52",
-        "version" : "0.1.12"
+        "revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0",
+        "version" : "0.1.13"
       }
     }
   ],