Skip to content

Commit

Permalink
Use chat template (#135)
Browse files Browse the repository at this point in the history
* Use chat template

* Update packages
  • Loading branch information
DePasqualeOrg authored Oct 22, 2024
1 parent 4e5977d commit 5a7a1a4
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 116 deletions.
13 changes: 7 additions & 6 deletions Applications/LLMEval/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ class LLMEvaluator {

/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
/// more devices.
let modelConfiguration = ModelConfiguration.phi3_5_4bit

// let modelConfiguration = ModelConfiguration.phi3_5_4bit
// let modelConfiguration = ModelConfiguration.mistral7B4bit
let modelConfiguration = ModelConfiguration.llama3_2_3B_4bit

/// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.6)
Expand Down Expand Up @@ -217,11 +220,9 @@ class LLMEvaluator {
do {
let modelContainer = try await load()

// augment the prompt as needed
let prompt = modelConfiguration.prepare(prompt: prompt)

let promptTokens = await modelContainer.perform { _, tokenizer in
tokenizer.encode(text: prompt)
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// each time you generate you will get something new
Expand Down
7 changes: 3 additions & 4 deletions Applications/LoRATrainingExample/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,9 @@ class LoRAEvaluator {

let modelContainer = try await loadModel()

// prepare the prompt
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = await modelContainer.perform { _, tokenizer in
tokenizer.encode(text: preparedPrompt)
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// evaluate
Expand Down
7 changes: 4 additions & 3 deletions Libraries/LLM/LLMModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import Tokenizers

/// Container for models that guarantees single threaded access.
///
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
/// the model and/or tokenizer:
///
/// ```swift
/// let promptTokens = await modelContainer.perform { _, tokenizer in
/// tokenizer.encode(text: prompt)
/// let messages = [["role": "user", "content": prompt]]
/// let promptTokens = try await modelContainer.perform { _, tokenizer in
/// try tokenizer.applyChatTemplate(messages: messages)
/// }
/// ```
///
Expand Down
88 changes: 15 additions & 73 deletions Libraries/LLM/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ public struct ModelConfiguration: Sendable {
/// Additional tokens to use for end of string
public let extraEOSTokens: Set<String>

/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format
private let preparePrompt: (@Sendable (String) -> String)?

public init(
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
Expand All @@ -55,25 +50,18 @@ public struct ModelConfiguration: Sendable {
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}

public init(
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: (@Sendable (String) -> String)? = nil
extraEOSTokens: Set<String> = []
) {
self.id = .directory(directory)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}

public func prepare(prompt: String) -> String {
preparePrompt?(prompt) ?? prompt
}

public func modelDirectory(hub: HubApi = HubApi()) -> URL {
Expand Down Expand Up @@ -116,40 +104,26 @@ extension ModelConfiguration {
public static let smolLM_135M_4bit = ModelConfiguration(
id: "mlx-community/SmolLM-135M-Instruct-4bit",
defaultPrompt: "Tell me about the history of Spain."
) {
prompt in
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
}
)

public static let mistralNeMo4bit = ModelConfiguration(
id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit",
defaultPrompt: "Explain quaternions."
) { prompt in
"<s>[INST] \(prompt) [/INST] "
}
)

public static let mistral7B4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
defaultPrompt: "Describe the Swift language."
) { prompt in
"<s>[INST] \(prompt) [/INST] "
}
)

public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>

"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
}
)

public static let phi4bit = ModelConfiguration(
id: "mlx-community/phi-2-hf-4bit-mlx",

// https://www.promptingguide.ai/models/phi-2
defaultPrompt: "Why is the sky blue?"
)
Expand All @@ -158,92 +132,60 @@ extension ModelConfiguration {
id: "mlx-community/Phi-3.5-mini-instruct-4bit",
defaultPrompt: "What is the gravity on Mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}
)

public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "what is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
}
)

public static let gemma_2_9b_it_4bit = ModelConfiguration(
id: "mlx-community/gemma-2-9b-it-4bit",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "What is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
}
)

public static let gemma_2_2b_it_4bit = ModelConfiguration(
id: "mlx-community/gemma-2-2b-it-4bit",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "What is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
)

public static let qwen205b4bit = ModelConfiguration(
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "why is the sky blue?"
) { prompt in
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
}
)

public static let openelm270m4bit = ModelConfiguration(
id: "mlx-community/OpenELM-270M-Instruct",

// https://huggingface.co/apple/OpenELM
defaultPrompt: "Once upon a time there was"
) { prompt in
"\(prompt)"
}
)

public static let llama3_1_8B_4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

public static let llama3_8B_4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

public static let llama3_2_1B_4bit = ModelConfiguration(
id: "mlx-community/Llama-3.2-1B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

public static let llama3_2_3B_4bit = ModelConfiguration(
id: "mlx-community/Llama-3.2-3B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

private enum BootstrapState: Sendable {
case idle
Expand Down
19 changes: 4 additions & 15 deletions Tools/llm-tool/LLMTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,6 @@ struct GenerateArguments: ParsableArguments, Sendable {
}
}

func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
String, [Int]
) {
MLXRandom.seed(seed)

let prompt = try resolvePrompt(configuration: configuration)
let preparedPrompt = configuration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)

return (prompt, promptTokens)
}

func generate(
promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil
Expand Down Expand Up @@ -221,9 +209,10 @@ struct EvaluateCommand: AsyncParsableCommand {
print("Model loaded -> \(modelConfiguration.id)")
}

let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

if !generate.quiet {
Expand Down
7 changes: 4 additions & 3 deletions Tools/llm-tool/LoraCommands.swift
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,10 @@ struct LoRAEvalCommand: AsyncParsableCommand {

memory.start()

let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

if !generate.quiet {
Expand Down
4 changes: 2 additions & 2 deletions mlx-swift-examples.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -3555,7 +3555,7 @@
repositoryURL = "https://github.com/huggingface/swift-transformers";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.1.12;
minimumVersion = 0.1.13;
};
};
C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */ = {
Expand All @@ -3571,7 +3571,7 @@
repositoryURL = "https://github.com/ml-explore/mlx-swift";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.16.1;
minimumVersion = 0.18.0;
};
};
/* End XCRemoteSwiftPackageReference section */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/maiqingqiang/Jinja",
"state" : {
"revision" : "5b0703d19a8901b76948753e5c5e3ca77043d33f",
"version" : "1.0.0"
"revision" : "b435eb62b0d3d5f34167ec70a128355486981712",
"version" : "1.0.5"
}
},
{
"identity" : "mlx-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"revision" : "86ad75ab1ee96cd70325732b37cd830f87d7e43f",
"version" : "0.16.1"
"revision" : "78a7cfe6701d6e9c88e9d4a0d1f7990af84b2146",
"version" : "0.18.0"
}
},
{
Expand All @@ -51,17 +51,17 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-argument-parser.git",
"state" : {
"revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b",
"version" : "1.4.0"
"revision" : "41982a3656a71c768319979febd796c6fd111d5c",
"version" : "1.5.0"
}
},
{
"identity" : "swift-markdown-ui",
"kind" : "remoteSourceControl",
"location" : "https://github.com/gonzalezreal/swift-markdown-ui",
"state" : {
"revision" : "9a8119b37e09a770367eeb26e05267c75d854053",
"version" : "2.3.1"
"revision" : "55441810c0f678c78ed7e2ebd46dde89228e02fc",
"version" : "2.4.0"
}
},
{
Expand All @@ -78,8 +78,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers",
"state" : {
"revision" : "0f2306713d48a75b862026ebb291926793773f52",
"version" : "0.1.12"
"revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0",
"version" : "0.1.13"
}
}
],
Expand Down

0 comments on commit 5a7a1a4

Please sign in to comment.