Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use chat template #135

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions Applications/LLMEval/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ class LLMEvaluator {

/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
/// more devices.
let modelConfiguration = ModelConfiguration.phi3_5_4bit

// let modelConfiguration = ModelConfiguration.phi3_5_4bit
// let modelConfiguration = ModelConfiguration.mistral7B4bit
let modelConfiguration = ModelConfiguration.llama3_2_3B_4bit

/// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.6)
Expand Down Expand Up @@ -217,11 +220,9 @@ class LLMEvaluator {
do {
let modelContainer = try await load()

// augment the prompt as needed
let prompt = modelConfiguration.prepare(prompt: prompt)

let promptTokens = await modelContainer.perform { _, tokenizer in
tokenizer.encode(text: prompt)
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// each time you generate you will get something new
Expand Down
7 changes: 3 additions & 4 deletions Applications/LoRATrainingExample/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,9 @@ class LoRAEvaluator {

let modelContainer = try await loadModel()

// prepare the prompt
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = await modelContainer.perform { _, tokenizer in
tokenizer.encode(text: preparedPrompt)
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// evaluate
Expand Down
7 changes: 4 additions & 3 deletions Libraries/LLM/LLMModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import Tokenizers

/// Container for models that guarantees single threaded access.
///
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
/// the model and/or tokenizer:
///
/// ```swift
/// let promptTokens = await modelContainer.perform { _, tokenizer in
/// tokenizer.encode(text: prompt)
/// let messages = [["role": "user", "content": prompt]]
/// let promptTokens = try await modelContainer.perform { _, tokenizer in
/// try tokenizer.applyChatTemplate(messages: messages)
/// }
/// ```
///
Expand Down
88 changes: 15 additions & 73 deletions Libraries/LLM/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ public struct ModelConfiguration: Sendable {
/// Additional tokens to use for end of string
public let extraEOSTokens: Set<String>

/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format
private let preparePrompt: (@Sendable (String) -> String)?

public init(
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
Expand All @@ -55,25 +50,18 @@ public struct ModelConfiguration: Sendable {
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}

public init(
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: (@Sendable (String) -> String)? = nil
extraEOSTokens: Set<String> = []
) {
self.id = .directory(directory)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}

public func prepare(prompt: String) -> String {
preparePrompt?(prompt) ?? prompt
}

public func modelDirectory(hub: HubApi = HubApi()) -> URL {
Expand Down Expand Up @@ -116,40 +104,26 @@ extension ModelConfiguration {
public static let smolLM_135M_4bit = ModelConfiguration(
id: "mlx-community/SmolLM-135M-Instruct-4bit",
defaultPrompt: "Tell me about the history of Spain."
) {
prompt in
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
}
)

public static let mistralNeMo4bit = ModelConfiguration(
id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit",
defaultPrompt: "Explain quaternions."
) { prompt in
"<s>[INST] \(prompt) [/INST] "
}
)

public static let mistral7B4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
defaultPrompt: "Describe the Swift language."
) { prompt in
"<s>[INST] \(prompt) [/INST] "
}
)

public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>

"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
}
)

public static let phi4bit = ModelConfiguration(
id: "mlx-community/phi-2-hf-4bit-mlx",

// https://www.promptingguide.ai/models/phi-2
defaultPrompt: "Why is the sky blue?"
)
Expand All @@ -158,92 +132,60 @@ extension ModelConfiguration {
id: "mlx-community/Phi-3.5-mini-instruct-4bit",
defaultPrompt: "What is the gravity on Mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}
)

public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "what is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
}
)

public static let gemma_2_9b_it_4bit = ModelConfiguration(
id: "mlx-community/gemma-2-9b-it-4bit",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "What is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
}
)

public static let gemma_2_2b_it_4bit = ModelConfiguration(
id: "mlx-community/gemma-2-2b-it-4bit",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "What is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
)

public static let qwen205b4bit = ModelConfiguration(
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "why is the sky blue?"
) { prompt in
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
}
)

public static let openelm270m4bit = ModelConfiguration(
id: "mlx-community/OpenELM-270M-Instruct",

// https://huggingface.co/apple/OpenELM
defaultPrompt: "Once upon a time there was"
) { prompt in
"\(prompt)"
}
)

public static let llama3_1_8B_4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

public static let llama3_8B_4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

public static let llama3_2_1B_4bit = ModelConfiguration(
id: "mlx-community/Llama-3.2-1B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

public static let llama3_2_3B_4bit = ModelConfiguration(
id: "mlx-community/Llama-3.2-3B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

private enum BootstrapState: Sendable {
case idle
Expand Down
19 changes: 4 additions & 15 deletions Tools/llm-tool/LLMTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,6 @@ struct GenerateArguments: ParsableArguments, Sendable {
}
}

func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
String, [Int]
) {
MLXRandom.seed(seed)

let prompt = try resolvePrompt(configuration: configuration)
let preparedPrompt = configuration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)

return (prompt, promptTokens)
}

func generate(
promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil
Expand Down Expand Up @@ -221,9 +209,10 @@ struct EvaluateCommand: AsyncParsableCommand {
print("Model loaded -> \(modelConfiguration.id)")
}

let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

if !generate.quiet {
Expand Down
7 changes: 4 additions & 3 deletions Tools/llm-tool/LoraCommands.swift
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,10 @@ struct LoRAEvalCommand: AsyncParsableCommand {

memory.start()

let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

if !generate.quiet {
Expand Down
4 changes: 2 additions & 2 deletions mlx-swift-examples.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -3555,7 +3555,7 @@
repositoryURL = "https://github.com/huggingface/swift-transformers";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.1.12;
minimumVersion = 0.1.13;
};
};
C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */ = {
Expand All @@ -3571,7 +3571,7 @@
repositoryURL = "https://github.com/ml-explore/mlx-swift";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.16.1;
minimumVersion = 0.18.0;
};
};
/* End XCRemoteSwiftPackageReference section */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/maiqingqiang/Jinja",
"state" : {
"revision" : "5b0703d19a8901b76948753e5c5e3ca77043d33f",
"version" : "1.0.0"
"revision" : "b435eb62b0d3d5f34167ec70a128355486981712",
"version" : "1.0.5"
}
},
{
"identity" : "mlx-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"revision" : "86ad75ab1ee96cd70325732b37cd830f87d7e43f",
"version" : "0.16.1"
"revision" : "78a7cfe6701d6e9c88e9d4a0d1f7990af84b2146",
"version" : "0.18.0"
}
},
{
Expand All @@ -51,17 +51,17 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-argument-parser.git",
"state" : {
"revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b",
"version" : "1.4.0"
"revision" : "41982a3656a71c768319979febd796c6fd111d5c",
"version" : "1.5.0"
}
},
{
"identity" : "swift-markdown-ui",
"kind" : "remoteSourceControl",
"location" : "https://github.com/gonzalezreal/swift-markdown-ui",
"state" : {
"revision" : "9a8119b37e09a770367eeb26e05267c75d854053",
"version" : "2.3.1"
"revision" : "55441810c0f678c78ed7e2ebd46dde89228e02fc",
"version" : "2.4.0"
}
},
{
Expand All @@ -78,8 +78,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers",
"state" : {
"revision" : "0f2306713d48a75b862026ebb291926793773f52",
"version" : "0.1.12"
"revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0",
"version" : "0.1.13"
}
}
],
Expand Down