Skip to content

Commit

Permalink
Use chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg committed Oct 1, 2024
1 parent 0907d7f commit 24223b2
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 96 deletions.
11 changes: 5 additions & 6 deletions Applications/LLMEval/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class LLMEvaluator {

/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
/// more devices.
let modelConfiguration = ModelConfiguration.phi3_5_4bit
// let modelConfiguration = ModelConfiguration.phi3_5_4bit
let modelConfiguration = ModelConfiguration.mistral7B4bit

/// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.6)
Expand Down Expand Up @@ -212,11 +213,9 @@ class LLMEvaluator {
do {
let modelContainer = try await load()

// augment the prompt as needed
let prompt = modelConfiguration.prepare(prompt: prompt)

let promptTokens = await modelContainer.perform { _, tokenizer in
tokenizer.encode(text: prompt)
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// each time you generate you will get something new
Expand Down
7 changes: 3 additions & 4 deletions Applications/LoRATrainingExample/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,9 @@ class LoRAEvaluator {

let modelContainer = try await loadModel()

// prepare the prompt
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = await modelContainer.perform { _, tokenizer in
tokenizer.encode(text: preparedPrompt)
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// evaluate
Expand Down
7 changes: 4 additions & 3 deletions Libraries/LLM/LLMModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import Tokenizers

/// Container for models that guarantees single threaded access.
///
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
/// the model and/or tokenizer:
///
/// ```swift
/// let promptTokens = await modelContainer.perform { _, tokenizer in
/// tokenizer.encode(text: prompt)
/// let messages = [["role": "user", "content": prompt]]
/// let promptTokens = try await modelContainer.perform { _, tokenizer in
/// try tokenizer.applyChatTemplate(messages: messages)
/// }
/// ```
///
Expand Down
78 changes: 13 additions & 65 deletions Libraries/LLM/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ public struct ModelConfiguration: Sendable {
/// Additional tokens to use for end of string
public let extraEOSTokens: Set<String>

/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format
private let preparePrompt: (@Sendable (String) -> String)?

public init(
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
Expand All @@ -55,25 +50,18 @@ public struct ModelConfiguration: Sendable {
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}

public init(
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: (@Sendable (String) -> String)? = nil
extraEOSTokens: Set<String> = []
) {
self.id = .directory(directory)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}

public func prepare(prompt: String) -> String {
preparePrompt?(prompt) ?? prompt
}

public func modelDirectory(hub: HubApi = HubApi()) -> URL {
Expand Down Expand Up @@ -116,40 +104,26 @@ extension ModelConfiguration {
public static let smolLM_135M_4bit = ModelConfiguration(
id: "mlx-community/SmolLM-135M-Instruct-4bit",
defaultPrompt: "Tell me about the history of Spain."
) {
prompt in
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
}
)

public static let mistralNeMo4bit = ModelConfiguration(
id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit",
defaultPrompt: "Explain quaternions."
) { prompt in
"<s>[INST] \(prompt) [/INST] "
}
)

public static let mistral7B4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
defaultPrompt: "Describe the Swift language."
) { prompt in
"<s>[INST] \(prompt) [/INST] "
}
)

public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>

"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
}
)

public static let phi4bit = ModelConfiguration(
id: "mlx-community/phi-2-hf-4bit-mlx",

// https://www.promptingguide.ai/models/phi-2
defaultPrompt: "Why is the sky blue?"
)
Expand All @@ -158,76 +132,50 @@ extension ModelConfiguration {
id: "mlx-community/Phi-3.5-mini-instruct-4bit",
defaultPrompt: "What is the gravity on Mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}
)

public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "what is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
}
)

public static let gemma_2_9b_it_4bit = ModelConfiguration(
id: "mlx-community/gemma-2-9b-it-4bit",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "What is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
}
)

public static let gemma_2_2b_it_4bit = ModelConfiguration(
id: "mlx-community/gemma-2-2b-it-4bit",
overrideTokenizer: "PreTrainedTokenizer",

// https://www.promptingguide.ai/models/gemma
defaultPrompt: "What is the difference between lettuce and cabbage?"

) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
)

public static let qwen205b4bit = ModelConfiguration(
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "why is the sky blue?"
) { prompt in
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
}
)

public static let openelm270m4bit = ModelConfiguration(
id: "mlx-community/OpenELM-270M-Instruct",

// https://huggingface.co/apple/OpenELM
defaultPrompt: "Once upon a time there was"
) { prompt in
"\(prompt)"
}
)

public static let llama3_1_8B_4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

public static let llama3_8B_4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
)

private enum BootstrapState: Sendable {
case idle
Expand Down
19 changes: 4 additions & 15 deletions Tools/llm-tool/LLMTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,6 @@ struct GenerateArguments: ParsableArguments, Sendable {
}
}

func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
String, [Int]
) {
MLXRandom.seed(seed)

let prompt = try resolvePrompt(configuration: configuration)
let preparedPrompt = configuration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)

return (prompt, promptTokens)
}

func generate(
promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil
Expand Down Expand Up @@ -221,9 +209,10 @@ struct EvaluateCommand: AsyncParsableCommand {
print("Model loaded -> \(modelConfiguration.id)")
}

let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

if !generate.quiet {
Expand Down
7 changes: 4 additions & 3 deletions Tools/llm-tool/LoraCommands.swift
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,10 @@ struct LoRAEvalCommand: AsyncParsableCommand {

memory.start()

let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

if !generate.quiet {
Expand Down

0 comments on commit 24223b2

Please sign in to comment.