diff --git a/Examples/Mistral7B/README.md b/Examples/Mistral7B/README.md new file mode 100644 index 00000000..30d64b14 --- /dev/null +++ b/Examples/Mistral7B/README.md @@ -0,0 +1,27 @@ +### Export Mistral 7B Instruct v0.3 + +```shell +✗ uv run export.py + +Loading checkpoint shards: 100%|███████████████████████████| 3/3 [00:12<00:00, 4.11s/it] +Converting PyTorch Frontend ==> MIL Ops: 100%|███| 5575/5575 [00:02<00:00, 2440.66 ops/s] +Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 7.12 passes/s] +Running MIL default pipeline: 100%|█████████████████| 79/79 [02:36<00:00, 1.98s/ passes] +Running MIL backend_mlprogram pipeline: 100%|███████| 12/12 [00:00<00:00, 22.90 passes/s] +Running compression: 100%|███████████████████████████| 296/296 [03:04<00:00, 1.60 ops/s] +... +``` + +### Generate Text + +```shell +✗ swift run transformers-cli "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage + +Best recommendations for a place to visit in Paris in August 2024: + +1. Palace of Versailles: This iconic palace is a must-visit. It's a short train ride from Paris and offers a glimpse into the opulence of the French monarchy. + +2. Eiffel Tower: No trip to Paris is complete without a visit to the Eiffel Tower. You can take an elevator ride to the top for a stunning view of the city. + +3. Louvre Museum: Home to thousands of works of art, including the Mona Lisa and the Winged Victory of Samothrace, the Louvre is a cultural treasure. +``` diff --git a/Examples/Mistral7B/export.py b/Examples/Mistral7B/export.py new file mode 100644 index 00000000..4509a937 --- /dev/null +++ b/Examples/Mistral7B/export.py @@ -0,0 +1,230 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "coremltools", +# "numpy", +# "sentencepiece", +# "torch", +# "tqdm", +# "transformers", +# ] +# /// +import logging +import os +import warnings +from typing import List, Optional, Tuple + +import coremltools as ct +import numpy as np +import torch +from transformers.cache_utils import Cache +from transformers.models.mistral.modeling_mistral import ( + MISTRAL_ATTENTION_CLASSES, + MistralAttention, + MistralConfig, + MistralForCausalLM, + apply_rotary_pos_emb, + repeat_kv, +) + +warnings.filterwarnings("ignore") +logging.getLogger("coremltools").setLevel(logging.ERROR) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 +MODEL_ID: str = "mistralai/Mistral-7B-Instruct-v0.3" +METADATA_TOKENIZER: str = "co.huggingface.exporters.name" + + +class SliceUpdateKeyValueCache(Cache): + def __init__( + self, + shape: Tuple[int, ...], + device="cpu", + dtype=torch.float32, + ) -> None: + """KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim).""" + super().__init__() + self.past_seen_tokens: int = 0 + self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device) + self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device) + + def update( + self, + k_state: torch.Tensor, + v_state: torch.Tensor, + layer_idx: int, + slice_indices: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]). + Return slice of key/value cache tensors from [0, slice_indices[1]). + """ + if len(slice_indices) != 2: + raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.") + begin, end = slice_indices + self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state + self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state + k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :] + v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :] + return k_cache, v_cache + + def get_seq_length(self, _: int | None = 0) -> int: + """Get the sequence length of the cache.""" + return self.past_seen_tokens + + +class SliceUpdateMistralAttention(MistralAttention): + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__(config=config, layer_idx=layer_idx) + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + **kwargs, + ) -> Tuple[torch.Tensor | None, ...]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Slice update key/value cache + end_step = attention_mask.shape[-1] + key_states, value_states = past_key_value.update( + key_states, + value_states, + self.layer_idx, + slice_indices=(end_step - q_len, end_step), + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, None + + +class StatefulMistralForCausalLM(torch.nn.Module): + def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None: + super().__init__() + + # Custom attention implementation for stateful slice update key/value cache, override + # "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation + MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention + self.model = MistralForCausalLM.from_pretrained(model_path) + + # Register KV cache buffers to be recognized as Core ML states + config: MistralConfig = self.model.config + self.kv_cache_shape: Tuple[int, ...] = ( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + max_context_size, + config.hidden_size // config.num_attention_heads, + ) + self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape) + self.register_buffer("keyCache", self.kv_cache.k_cache) + self.register_buffer("valueCache", self.kv_cache.v_cache) + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + causal_mask: torch.Tensor, + ) -> torch.Tensor: + # Compute past seen tokens used for updating key/value cache slices + self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1] + return self.model( + input_ids, + attention_mask=causal_mask, + past_key_values=self.kv_cache, + use_cache=True, + ).logits + + +def export() -> None: + # Construct model from transformers and trace to TorchScript + max_context_size: int = 2048 + torch_model = StatefulMistralForCausalLM(MODEL_ID, max_context_size=max_context_size) + torch_model.eval() + input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32) + causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32) + traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask]) + kv_cache_shape = torch_model.kv_cache_shape + del torch_model + + # Convert traced TorchScript to Core ML format + query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) + end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) + inputs: List[ct.TensorType] = [ + ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"), + ct.TensorType( + shape=(1, 1, query_length, end_step_dim), + dtype=np.float16, + name="causalMask", + ), + ] + outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")] + states: List[ct.StateType] = [ + ct.StateType( + wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16), + name="keyCache", + ), + ct.StateType( + wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16), + name="valueCache", + ), + ] + + # Convert model with FP16 precision + mlmodel_fp16: ct.MLModel = ct.convert( + traced_model, + inputs=inputs, + outputs=outputs, + states=states, + minimum_deployment_target=ct.target.iOS18, + skip_model_load=True, + ) + del traced_model + + # Block-wise quantize model weights to int4 + op_config = ct.optimize.coreml.OpLinearQuantizerConfig( + mode="linear_symmetric", + dtype="int4", + granularity="per_block", + block_size=32, + ) + config = ct.optimize.coreml.OptimizationConfig(global_config=op_config) + mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config) + mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID}) + del mlmodel_fp16 + mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage") + + +if __name__ == "__main__": + export() diff --git a/Examples/Mistral7B/generate.py b/Examples/Mistral7B/generate.py new file mode 100644 index 00000000..0c90acd0 --- /dev/null +++ b/Examples/Mistral7B/generate.py @@ -0,0 +1,99 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "coremltools", +# "numpy", +# "sentencepiece", +# "torch", +# "tqdm", +# "transformers", +# ] +# /// +import argparse +from typing import Dict, Generator, List, Tuple + +import numpy as np +from coremltools.models import MLModel +from transformers import AutoTokenizer + +from export import METADATA_TOKENIZER + + +def load(model_path: str) -> Tuple[MLModel, AutoTokenizer]: + """Load a Core ML model and corresponding tokenizer.""" + model: MLModel = MLModel(model_path) + description = model.get_spec().description + if METADATA_TOKENIZER not in description.metadata.userDefined: + raise ValueError("Model metadata does not contain tokenizer path.") + tokenizer_path: str = description.metadata.userDefined[METADATA_TOKENIZER] + tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + return model, tokenizer + + +def get_next_token(model: MLModel, prompt_tokens: np.ndarray) -> Generator[int, None, None]: + """Generate a sequence of tokens with naive greedy decoding.""" + + def sample(logits: np.ndarray) -> int: + """Perform greedy decoding on the logits array to get the next token.""" + return int(np.argmax(logits[0][-1], axis=-1)) + + def inference(model: MLModel, input_ids: np.ndarray, num_past_tokens: int) -> np.ndarray: + """Perform inference with the given model and input data.""" + causal_mask: np.ndarray = np.triu( + np.full( + (1, 1, input_ids.shape[-1], num_past_tokens + input_ids.shape[-1]), + fill_value=-np.inf if num_past_tokens == 0 else 0, + ), + k=1, + ).astype(np.float16) + outputs: Dict[str, np.ndarray] = model.predict( + data={"inputIds": input_ids, "causalMask": causal_mask}, + state=kv_cache_state, + ) + return outputs["logits"] + + kv_cache_state = model.make_state() + logits: np.ndarray = inference(model, input_ids=prompt_tokens, num_past_tokens=0) + token: int = sample(logits=logits) + num_past_tokens: int = prompt_tokens.shape[-1] + + while True: + yield token + logits: np.ndarray = inference( + model, + input_ids=np.array([[token]], dtype=np.int32), + num_past_tokens=num_past_tokens, + ) + token: int = sample(logits=logits) + num_past_tokens += 1 + + +def generate( + model: MLModel, + prompt: str, + tokenizer: AutoTokenizer, + max_new_tokens: int, +) -> str: + prompt_tokens: np.ndarray = tokenizer(prompt, return_tensors="np").input_ids + extend_tokens: List[int] = [] + for i, token in enumerate(get_next_token(model, prompt_tokens=prompt_tokens.astype(np.int32))): + if token == tokenizer.eos_token_id or i == max_new_tokens: + break + extend_tokens.append(token) + return tokenizer.decode(prompt_tokens[0].tolist() + extend_tokens) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_path", type=str) + parser.add_argument("--prompt", type=str, default="Hello") + parser.add_argument("--max_new_tokens", type=int, default=128) + args = parser.parse_args() + model, tokenizer = load(args.model_path) + extend_text: str = generate( + model, + prompt=args.prompt, + tokenizer=tokenizer, + max_new_tokens=args.max_new_tokens, + ) + print(extend_text) diff --git a/Examples/Mistral7B/requirements.txt b/Examples/Mistral7B/requirements.txt new file mode 100644 index 00000000..3c954fb2 --- /dev/null +++ b/Examples/Mistral7B/requirements.txt @@ -0,0 +1,6 @@ +coremltools +numpy +torch +tqdm +transformers +sentencepiece diff --git a/Examples/transformers-cli/Package.swift b/Examples/transformers-cli/Package.swift new file mode 100644 index 00000000..f360afb1 --- /dev/null +++ b/Examples/transformers-cli/Package.swift @@ -0,0 +1,24 @@ +// swift-tools-version: 6.2 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "transformers-cli", + platforms: [.iOS(.v18), .macOS(.v15)], + dependencies: [ + .package(path: "../.."), + // If you copy this manifest as a template, use the following line instead + //.package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"), + ], + targets: [ + .executableTarget( + name: "transformers-cli", + dependencies: [ + .product(name: "Transformers", package: "swift-transformers"), + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ) + ] +) diff --git a/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift new file mode 100644 index 00000000..cb24e3ca --- /dev/null +++ b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift @@ -0,0 +1,141 @@ +import ArgumentParser +import CoreML +import Foundation +import Generation +import Models + +@available(macOS 15.0, iOS 18.0, *) +@main +struct TransformersCLI: AsyncParsableCommand { + static let configuration = CommandConfiguration( + abstract: "Run text generation on a Core ML language model", + version: "0.0.1" + ) + + @Argument(help: "Input text") + var prompt: String + + @Argument(help: "Path to Core ML mlpackage model") + var modelPath: String = "./model.mlpackage" + + @Option(help: "Maximum amount of tokens the model should generate") + var maxLength: Int = 100 + + @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") + var computeUnits: ComputeUnits = .cpuAndGPU + + @Option( + help: """ + When enabled, two generation passes are ran, one to 'warm up' and another to collect \ + benchmark metrics. + """) + var warmup: Bool = false + + func generate( + model: LanguageModel, + config: GenerationConfig, + prompt: String, + printOutput: Bool = true + ) async throws { + var tokensReceived = 0 + var previousIndex: String.Index? = nil + var startTime = Date() + var promptProcessingTime: Double = 0 // seconds + try await model.generate(config: config, prompt: prompt) { inProgressGeneration in + if previousIndex == nil { // Prompt pre-filling + promptProcessingTime = Date().timeIntervalSince(startTime) + // Reset start time to more accurately compute the average tps. + startTime = Date() + } else { // Extend + // Only start counting tokens once the prompt has been processed. + tokensReceived += 1 + } + let response = formatResponse(inProgressGeneration) + if printOutput { + print(response[(previousIndex ?? response.startIndex)...], terminator: "") + fflush(stdout) + } + previousIndex = response.endIndex + } + // Current time - start time + elapsed time to process the prompt + let endTime = Date() + let completionTime = endTime.timeIntervalSince(startTime) + promptProcessingTime + let tps = Double(tokensReceived) / endTime.timeIntervalSince(startTime) + if printOutput { + print("") + print( + """ + \(tps.formatted("%.2f")) tokens/s, \ + prompt pre-filling time: \(promptProcessingTime.formatted("%.2f"))s, \ + total time: \(completionTime.formatted("%.2f"))s + """) + } + } + + func compile(at url: URL) throws -> URL { + #if os(watchOS) + fatalError("Model compilation is not supported on watchOS") + #else + if url.pathExtension == "mlmodelc" { return url } + print("Compiling model \(url)") + return try MLModel.compileModel(at: url) + #endif + } + + func run() async throws { + let url = URL(filePath: modelPath) + let compiledURL = try compile(at: url) + print("Loading model \(compiledURL)") + let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits) + + // Using greedy generation for now + var config = model.defaultGenerationConfig + config.doSample = false + config.maxNewTokens = maxLength + + // Given the size of the out-of-model computation, dispatch all + // tensor operations to the CPU. + + if warmup { + print("Warming up...") + try await withMLTensorComputePolicy(.cpuOnly) { + try await generate(model: model, config: config, prompt: prompt, printOutput: false) + } + } + + print("Generating") + try await withMLTensorComputePolicy(.cpuOnly) { + try await generate(model: model, config: config, prompt: prompt) + } + } +} + +@available(macOS 15.0, iOS 18.0, *) +enum ComputeUnits: String, ExpressibleByArgument, CaseIterable { + case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine + var asMLComputeUnits: MLComputeUnits { + switch self { + case .all: return .all + case .cpuAndGPU: return .cpuAndGPU + case .cpuOnly: return .cpuOnly + case .cpuAndNeuralEngine: return .cpuAndNeuralEngine + } + } +} + +/// Returns a cleaned and formatted version of the response. +/// +/// - Parameter respone: The response to clean and format. +/// - Returns: A 'user friendly' representation of the generated response. +private func formatResponse(_ response: String) -> String { + response + .replacingOccurrences(of: "\\n", with: "\n") + .replacingOccurrences(of: "", with: "") + .replacingOccurrences(of: "", with: "") +} + +extension Double { + func formatted(_ format: String) -> String { + return String(format: "\(format)", self) + } +} diff --git a/Package.swift b/Package.swift index fb69f9f3..dc6c3f5a 100644 --- a/Package.swift +++ b/Package.swift @@ -24,7 +24,6 @@ let package = Package( .target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings), .target(name: "Models", dependencies: ["Tokenizers", "Generation"]), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]), - .testTarget(name: "GenerationTests", dependencies: ["Generation"]), .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings), .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]), diff --git a/Sources/Generation/CoreML+Extensions.swift b/Sources/Generation/CoreML+Extensions.swift deleted file mode 100644 index 1c3c8f7e..00000000 --- a/Sources/Generation/CoreML+Extensions.swift +++ /dev/null @@ -1,302 +0,0 @@ -// -// CoreML+Extensions.swift -// CoreMLBert -// -// Created by Julien Chaumond on 27/06/2019. -// Copyright © 2019 Hugging Face. All rights reserved. -// - -#if canImport(CoreML) -import CoreML -import Foundation - -extension MLMultiArray { - /// Creates an MLMultiArray from an array of integers. - /// - /// All values are stored in the last dimension of the MLMultiArray, with leading - /// dimensions set to 1. For example, with dims=2, the shape becomes [1, arr.count]. - /// - /// - Parameters: - /// - arr: Array of integers to convert - /// - dims: Number of dimensions for the resulting MLMultiArray - /// - Returns: MLMultiArray containing the integer values - static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { - var shape = Array(repeating: 1, count: dims) - shape[shape.count - 1] = arr.count - // Examples: - // dims=1 : [arr.count] - // dims=2 : [1, arr.count] - // - let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .int32) - let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for (i, item) in arr.enumerated() { - ptr[i] = Int32(item) - } - return o - } - - /// Creates an MLMultiArray from an array of doubles. - /// - /// All values are stored in the last dimension of the MLMultiArray, with leading - /// dimensions set to 1. For example, with dims=2, the shape becomes [1, arr.count]. - /// - /// - Parameters: - /// - arr: Array of doubles to convert - /// - dims: Number of dimensions for the resulting MLMultiArray - /// - Returns: MLMultiArray containing the double values - static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray { - var shape = Array(repeating: 1, count: dims) - shape[shape.count - 1] = arr.count - // Examples: - // dims=1 : [arr.count] - // dims=2 : [1, arr.count] - // - let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .float64) - let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for (i, item) in arr.enumerated() { - ptr[i] = Double(item) - } - return o - } - - /// Converts an MLMultiArray to a flat array of integers. - /// - /// Concatenates all dimensions into a single one-dimensional array by reading - /// the MLMultiArray data in memory order. - /// - /// - Parameter o: MLMultiArray to convert - /// - Returns: Flat array of integer values - static func toIntArray(_ o: MLMultiArray) -> [Int] { - var arr = Array(repeating: 0, count: o.count) - let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for i in 0.. [Int] { Self.toIntArray(self) } - - /// Converts an MLMultiArray to a flat array of doubles. - /// - /// Concatenates all dimensions into a single one-dimensional array by reading - /// the MLMultiArray data in memory order. - /// - /// - Parameter o: MLMultiArray to convert - /// - Returns: Flat array of double values - static func toDoubleArray(_ o: MLMultiArray) -> [Double] { - var arr: [Double] = Array(repeating: 0, count: o.count) - let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for i in 0.. [Double] { Self.toDoubleArray(self) } - - /// Creates a test MLMultiArray with sequentially indexed values. - /// - /// Useful for debugging and unit tests. Values are assigned sequentially - /// starting from 0, following the memory layout of the specified shape. - /// - /// Example output for shape [2, 3, 4]: - /// ``` - /// [[[ 0, 1, 2, 3 ], - /// [ 4, 5, 6, 7 ], - /// [ 8, 9, 10, 11 ]], - /// [[ 12, 13, 14, 15 ], - /// [ 16, 17, 18, 19 ], - /// [ 20, 21, 22, 23 ]]] - /// ``` - /// - /// - Parameter shape: Desired shape of the test tensor - /// - Returns: MLMultiArray with sequential values for testing - static func testTensor(shape: [Int]) -> MLMultiArray { - let arr = try! MLMultiArray(shape: shape as [NSNumber], dataType: .double) - let ptr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) - for i in 0.. MLMultiArray { - assert( - indexing.count == o.shape.count - ) - assert( - indexing.filter { $0 == Indexing.slice }.count == 1 - ) - var selectDims: [Int: Int] = [:] - for (i, idx) in indexing.enumerated() { - if case let .select(select) = idx { - selectDims[i] = select - } - } - return slice( - o, - sliceDim: indexing.firstIndex { $0 == Indexing.slice }!, - selectDims: selectDims - ) - } - - /// Slice an array according to a list, according to `sliceDim` (which dimension to slice on) - /// and a dictionary of `dim` to `index`. - /// - /// You must select all other dimensions than the slice dimension (cf. the assert). - static func slice(_ o: MLMultiArray, sliceDim: Int, selectDims: [Int: Int]) -> MLMultiArray { - assert( - selectDims.count + 1 == o.shape.count - ) - var shape: [NSNumber] = Array(repeating: 1, count: o.shape.count) - shape[sliceDim] = o.shape[sliceDim] - // print("About to slice ndarray of shape \(o.shape) into ndarray of shape \(shape)") - let arr = try! MLMultiArray(shape: shape, dataType: .double) - - // let srcPtr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - // TODO: use srcPtr instead of array subscripting. - let dstPtr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) - for i in 0.. String { - func indent(_ x: Int) -> String { - String(repeating: " ", count: x) - } - - // This function is called recursively for every dimension. - // Add an entry for this dimension to the end of the array. - var indices = indices + [0] - - let d = indices.count - 1 // the current dimension - let N = shape[d].intValue // how many elements in this dimension - var s = "[" - if indices.count < shape.count { // not last dimension yet? - for i in 0.. { - /// Efficiently extracts float values from the shaped array. - /// - /// Uses optimized memory copying when possible (stride=1), falling back to - /// slower scalar access for non-contiguous arrays. - /// - /// - Returns: Array of Float values from the shaped array - var floats: [Float] { - guard strides.first == 1, strides.count == 1 else { - // For some reason this path is slow. - // If strides is not 1, we can write a Metal kernel to copy the values properly. - return scalars - } - - // Fast path: memcpy - let mlArray = MLMultiArray(self) - return mlArray.floats ?? scalars - } -} - -extension MLShapedArraySlice { - /// Efficiently extracts float values from the shaped array slice. - /// - /// Uses optimized memory copying when possible (stride=1), falling back to - /// slower scalar access for non-contiguous slices. - /// - /// - Returns: Array of Float values from the shaped array slice - var floats: [Float] { - guard strides.first == 1, strides.count == 1 else { - // For some reason this path is slow. - // If strides is not 1, we can write a Metal kernel to copy the values properly. - return scalars - } - - // Fast path: memcpy - let mlArray = MLMultiArray(self) - return mlArray.floats ?? scalars - } -} - -extension MLMultiArray { - /// Efficiently extracts float values from the MLMultiArray if it contains float32 data. - /// - /// Uses fast memory copying to extract all float values as a contiguous array. - /// Returns nil if the array doesn't contain float32 data. - /// - /// - Returns: Array of Float values, or nil if not float32 type - var floats: [Float]? { - guard dataType == .float32 else { return nil } - - var result: [Float] = Array(repeating: 0, count: count) - return withUnsafeBytes { ptr in - guard let source = ptr.baseAddress else { return nil } - result.withUnsafeMutableBytes { resultPtr in - let dest = resultPtr.baseAddress! - memcpy(dest, source, self.count * MemoryLayout.stride) - } - return result - } - } -} -#endif // canImport(CoreML) diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift new file mode 100644 index 00000000..fafbd1b4 --- /dev/null +++ b/Sources/Generation/Decoders.swift @@ -0,0 +1,28 @@ +#if canImport(CoreML) +import CoreML + +// MARK: Greedy Decoding + +@available(macOS 15.0, iOS 18.0, *) +func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor { + scores.argmax(alongAxis: -1).reshaped(to: [1, 1]) +} + +// MARK: Top-K Sampling + +@available(macOS 15.0, iOS 18.0, *) +func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor { + let temperatureAdjustedScores = scores / temperature + let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK) + let topKProbs = topKScores.softmax(alongAxis: -1) + let rnd = topKProbs.sum() * Float.random(in: 0..<1) + var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1) + accumTopKProbs += (accumTopKProbs .< rnd) * 100.0 + let topKIndex = accumTopKProbs.argsort()[..., 0] + let nextTokenTensor = topKIndices.gathering( + atIndices: topKIndex, + alongAxis: topKIndices.rank - 1 + ) + return nextTokenTensor.reshaped(to: [1, 1]) +} +#endif // canImport(CoreML) diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index e56af9cc..0cdfd375 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -8,6 +8,7 @@ #if canImport(CoreML) import CoreML +import CoreML import Tokenizers /// Supported text generation modes. @@ -37,7 +38,8 @@ public typealias GenerationOutput = [Int] /// - Parameter tokens: Input token sequence /// - Parameter config: Generation configuration /// - Returns: Logits array for next token prediction -public typealias NextTokenModel = (InputTokens, GenerationConfig) -> any MLShapedArrayProtocol +@available(macOS 15.0, iOS 18.0, *) +public typealias NextTokenModel = (MLTensor, GenerationConfig) async -> MLTensor /// Callback for receiving generated tokens during streaming. public typealias PredictionTokensCallback = (GenerationOutput) -> Void @@ -46,17 +48,8 @@ public typealias PredictionTokensCallback = (GenerationOutput) -> Void public typealias PredictionStringCallback = (String) -> Void /// Protocol for text generation implementations. +@available(macOS 15.0, iOS 18.0, *) public protocol Generation { - /// Performs greedy search generation. - /// - /// - Parameters: - /// - config: Generation configuration - /// - tokens: Input token sequence - /// - model: Model for next token prediction - /// - callback: Optional callback for streaming tokens - /// - Returns: Generated token sequence - func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput - /// Generates text from a prompt string. /// /// - Parameters: @@ -69,89 +62,84 @@ public protocol Generation { func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String } -public extension Generation { - func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput { - // Iterate until we find the eos token or reach the max length - // TODO: additional stopping criteria - var outputTokens = tokens - while outputTokens.count < config.maxLength { - let logits = model(outputTokens, config) - let (nextToken, _) = Math.argmax(logits) - if nextToken == config.eosTokenId { break } - outputTokens.append(nextToken) - callback?(outputTokens) +@available(macOS 15.0, iOS 18.0, *) +extension Generation { + public func generate( + config: GenerationConfig, + tokens: InputTokens, + model: NextTokenModel, + callback: PredictionTokensCallback? = nil + ) async -> GenerationOutput { + let tokens = tokens.map { Int32($0) } + var outputTokens = MLTensor(tokens).expandingShape(at: 0) + while outputTokens.shape[1] < config.maxLength { + let nextTokenScores = await model(outputTokens, config) + let nextToken = + switch config.generationMode { + case .greedy: + selectNextTokenUsingGreedyDecoding(from: nextTokenScores) + case .sample: + selectNextTokenUsingTopKSampling( + from: nextTokenScores, + temperature: config.temperature, + topK: config.topK + ) + default: + fatalError("Generation mode \(config.generationMode) not implemented yet") + } + + if let nextTokenId = await tensorToGenerationOutput(nextToken).first, nextTokenId == config.eosTokenId { + break + } + + outputTokens = MLTensor(concatenating: [outputTokens, nextToken], alongAxis: -1) + if let callback { + let outputTokenIDs = await tensorToGenerationOutput(outputTokens) + callback(outputTokenIDs) + } } - return outputTokens + return await tensorToGenerationOutput(outputTokens) } - /// Performs sampling-based text generation with configurable logits warping. - /// - /// Uses various logits warpers (temperature, top-k, top-p, repetition penalty) to modify - /// token probabilities before sampling, enabling diverse and controllable text generation. + private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput { + await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) } + } +} + +@available(macOS 15.0, iOS 18.0, *) +public extension Generation { + /// Performs greedy or sampling-based text generation based on generation configuration. /// /// - Parameters: /// - config: Generation configuration with sampling parameters - /// - tokens: Input token sequence + /// - prompt: Input string /// - model: Model for next token prediction + /// - tokenizer: Tokenizer to convert prompt to input tokens /// - callback: Optional callback for streaming tokens /// - Returns: Generated token sequence /// /// - Note: Based on https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552 - func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput { - // Iterate until we find the eos token or reach the max length - // TODO: additional stopping criteria - var outputTokens = tokens - let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config)) - while outputTokens.count < config.maxLength { - let outputs = model(outputTokens, config) - // `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case - let logits = (outputs as? MLShapedArraySlice)?.floats ?? outputs.scalars as! [Float] - let (indexes, processedLogits) = logitsProcessor(logits) - let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits)) - if nextToken == config.eosTokenId { break } - outputTokens.append(nextToken) - callback?(outputTokens) - } - return outputTokens - } - - func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback? = nil) async -> String { + func generate( + config: GenerationConfig, + prompt: String, + model: NextTokenModel, + tokenizer: Tokenizer, + callback: PredictionStringCallback? = nil + ) async -> String { let tokens = tokenizer.encode(text: prompt) var generationConfig = config generationConfig.maxLength = config.maxNewTokens + tokens.count - - let output: GenerationOutput - switch generationConfig.generationMode { - case .greedy: - output = await greedySearch(config: generationConfig, tokens: tokens, model: model) { tokens in - callback?(tokenizer.decode(tokens: tokens)) - } - case .sample: - output = await sample(config: generationConfig, tokens: tokens, model: model) { tokens in - callback?(tokenizer.decode(tokens: tokens)) - } - default: - fatalError("Generation mode \(generationConfig.generationMode) not implemented yet") + generationConfig.eosTokenId = tokenizer.eosTokenId + generationConfig.bosTokenId = tokenizer.bosTokenId + let output = await generate( + config: generationConfig, + tokens: tokens, + model: model + ) { tokens in + callback?(tokenizer.decode(tokens: tokens)) } return tokenizer.decode(tokens: output) } - - private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] { - var logitsWarpers = [any LogitsWarper]() - if config.temperature > 0, config.temperature != 1 { - logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature))) - } - if config.topK > 0 { - logitsWarpers.append(TopKLogitsWarper(k: config.topK)) - } - if config.topP < 1.0 { - logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP))) - } - if config.repetitionPenalty != 1.0 { - logitsWarpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty)) - } - return logitsWarpers - } } #endif // canImport(CoreML) diff --git a/Sources/Generation/GenerationConfig.swift b/Sources/Generation/GenerationConfig.swift index 978f404a..e7389fbb 100644 --- a/Sources/Generation/GenerationConfig.swift +++ b/Sources/Generation/GenerationConfig.swift @@ -33,7 +33,7 @@ public struct GenerationConfig { public var penaltyAlpha: Double? /// Temperature for sampling (higher values increase randomness). - public var temperature = 1.0 + public var temperature: Float = 1.0 /// Number of top tokens to consider for top-k sampling. public var topK = 50 @@ -66,14 +66,25 @@ public struct GenerationConfig { /// - topK: Top-k sampling parameter /// - topP: Top-p sampling parameter /// - repetitionPenalty: Repetition penalty factor - public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) { + public init( + maxLength: Int = 20, + maxNewTokens: Int, + doSample: Bool = false, + numBeams: Int = 1, + numBeamGroups: Int = 1, + penaltyAlpha: Double? = nil, + temperature: Double = 1.0, + topK: Int = 50, + topP: Double = 1.0, + repetitionPenalty: Double = 1.0 + ) { self.maxLength = maxLength self.maxNewTokens = maxNewTokens self.doSample = doSample self.numBeams = numBeams self.numBeamGroups = numBeamGroups self.penaltyAlpha = penaltyAlpha - self.temperature = temperature + self.temperature = Float(temperature) self.topK = topK self.topP = topP self.repetitionPenalty = repetitionPenalty diff --git a/Sources/Generation/LogitsWarper/LogitsProcessor.swift b/Sources/Generation/LogitsWarper/LogitsProcessor.swift deleted file mode 100644 index 8f16477c..00000000 --- a/Sources/Generation/LogitsWarper/LogitsProcessor.swift +++ /dev/null @@ -1,33 +0,0 @@ -import Foundation - -/// Processes logits by applying a sequence of logits warpers. -/// -/// Coordinates the application of multiple logits warpers in sequence, -/// allowing for complex probability transformations during text generation. -public struct LogitsProcessor { - /// Array of logits warpers to apply in sequence. - public var logitsWarpers: [any LogitsWarper] - - /// Creates a new logits processor. - /// - /// - Parameter logitsWarpers: Array of warpers to apply in sequence - public init(logitsWarpers: [any LogitsWarper]) { - self.logitsWarpers = logitsWarpers - } - - /// Processes logits by applying all warpers in sequence. - /// - /// Each warper is applied to the output of the previous warper, allowing - /// for complex chaining of probability transformations. - /// - /// - Parameter arr: Input logits array - /// - Returns: Tuple of processed (indices, logits) - public func callAsFunction(_ arr: [Float]) -> (indices: [Int], logits: [Float]) { - var indices = Array(arr.indices) - var logits = arr - for warper in logitsWarpers { - (indices, logits) = warper(indices, logits) - } - return (indices: indices, logits: logits) - } -} diff --git a/Sources/Generation/LogitsWarper/LogitsWarper.swift b/Sources/Generation/LogitsWarper/LogitsWarper.swift deleted file mode 100644 index 576820c0..00000000 --- a/Sources/Generation/LogitsWarper/LogitsWarper.swift +++ /dev/null @@ -1,35 +0,0 @@ -import Foundation - -/// Protocol for logits warpers that transform token probabilities during generation. -/// -/// Logits warpers modify the probability distribution over tokens before sampling, -/// enabling techniques like temperature scaling, top-k/top-p filtering, and repetition penalties. -public protocol LogitsWarper { - /// Warps the logits and corresponding indices. - /// - /// - Parameters: - /// - indices: Array of token indices corresponding to the logits - /// - logits: Array of logit values to transform - /// - Returns: Tuple of transformed (indices, logits) - func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) - - /// Convenience method that calls the warp function. - /// - /// - Parameters: - /// - indices: Array of token indices - /// - logits: Array of logit values - /// - Returns: Tuple of transformed (indices, logits) - func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) -} - -public extension LogitsWarper { - /// Default implementation of callAsFunction that delegates to warp. - /// - /// - Parameters: - /// - indices: Array of token indices - /// - logits: Array of logit values - /// - Returns: Tuple of transformed (indices, logits) - func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) { - warp(indices: indices, logits: logits) - } -} diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift deleted file mode 100644 index 7e89d67e..00000000 --- a/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift +++ /dev/null @@ -1,42 +0,0 @@ -import Foundation - -/// Logits warper that prevents repetition of previous tokens through a penalty. -/// -/// Applies a penalty to tokens that have already been generated, reducing their -/// probability of being selected again. The penalty is applied differently based -/// on the sign of the logit value to maintain numerical stability. -/// -/// - Note: Based on https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L294 -public struct RepetitionPenaltyWarper: LogitsWarper { - /// Penalty factor applied to repeated tokens. - public var penalty: Float - - /// Creates a repetition penalty warper. - /// - /// - Parameter penalty: Penalty factor (>1.0 discourages repetition, <1.0 encourages it) - public init(penalty: Double) { - self.penalty = Float(penalty) - } - - /// Applies repetition penalty to the logits. - /// - /// For positive logits, divides by penalty. For negative logits, multiplies by penalty. - /// This asymmetric approach maintains numerical stability while effectively penalizing repetition. - /// - /// - Parameters: - /// - indices: Token indices to apply penalty to - /// - logits: Current logits values - /// - Returns: Tuple of (indices, penalized logits) - public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - var logits = logits - for index in indices { - if logits[index] < 0 { - logits[index] *= penalty - } else { - logits[index] /= penalty - } - } - - return (indices, logits) - } -} diff --git a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift deleted file mode 100644 index bf985770..00000000 --- a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift +++ /dev/null @@ -1,32 +0,0 @@ -import Foundation - -/// Logits warper that applies temperature scaling to control generation randomness. -/// -/// Temperature scaling modifies the "sharpness" of the probability distribution: -/// - Temperature > 1.0: Makes distribution more uniform (more random) -/// - Temperature < 1.0: Makes distribution more peaked (less random) -/// - Temperature = 1.0: No change -public struct TemperatureLogitsWarper: LogitsWarper { - /// Temperature scaling factor. - public var temperature: Float - - /// Creates a temperature logits warper. - /// - /// - Parameter temperature: Scaling factor (higher values increase randomness) - public init(temperature: Float) { - self.temperature = temperature - } - - /// Applies temperature scaling to the logits. - /// - /// Divides each logit by the temperature value, which affects the final - /// probability distribution after softmax is applied. - /// - /// - Parameters: - /// - indices: Token indices (unchanged) - /// - logits: Current logits values - /// - Returns: Tuple of (indices, temperature-scaled logits) - public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - (indices: indices, logits: logits.map { $0 / temperature }) - } -} diff --git a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift deleted file mode 100644 index 0665bbe0..00000000 --- a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift +++ /dev/null @@ -1,74 +0,0 @@ -#if canImport(Accelerate) -import Accelerate -import Foundation - -/// Logits warper that implements top-k filtering for sampling. -/// -/// Selects the k most probable tokens and sets all other token probabilities -/// to zero, effectively limiting the sampling space to the top k candidates. -/// This helps balance diversity and quality in generated text. -public struct TopKLogitsWarper: LogitsWarper { - /// Number of top tokens to keep. - public var k: Int - - /// Creates a top-k logits warper. - /// - /// - Parameter k: Number of top tokens to retain (others are filtered out) - public init(k: Int) { - self.k = k - } - - /// Applies top-k filtering to the logits. - /// - /// Uses Accelerate framework's optimized top-k algorithm to efficiently - /// select the k highest-valued logits and their corresponding indices. - /// - /// - Parameters: - /// - indices: Current token indices - /// - logits: Current logits values - /// - Returns: Tuple of (top-k indices, top-k logits) - public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - guard !logits.isEmpty else { - return (indices: [], logits: []) - } - let k = min(k, logits.count) - let arrDescriptor = BNNSNDArrayDescriptor.allocate( - initializingFrom: logits, - shape: .vector(logits.count) - ) - defer { - arrDescriptor.deallocate() - } - let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized( - scalarType: Int32.self, - shape: .vector(k) - ) - defer { - bestIndices.deallocate() - } - let bestValues = BNNSNDArrayDescriptor.allocateUninitialized( - scalarType: Float.self, - shape: .vector(k) - ) - defer { - bestValues.deallocate() - } - try! Accelerate.BNNS.applyTopK( - k: k, - input: arrDescriptor, - bestValues: bestValues, - bestIndices: bestIndices, - axis: 0, - batchSize: 1, - filterParameters: nil - ) - let topkLogits = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in - Array(UnsafeBufferPointer(start: ptr, count: k)) - } - let topkIndices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in - Array(UnsafeBufferPointer(start: ptr, count: k)) - } - return (indices: topkIndices.map { indices[Int($0)] }, logits: topkLogits) - } -} -#endif // canImport(Accelerate) diff --git a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift deleted file mode 100644 index 114660bd..00000000 --- a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift +++ /dev/null @@ -1,54 +0,0 @@ -import Foundation - -/// Logits warper that implements nucleus (top-p) sampling. -/// -/// Selects the smallest set of tokens whose cumulative probability exceeds -/// the threshold p, providing dynamic vocabulary selection based on the -/// probability distribution rather than a fixed number of tokens. -/// -/// - Note: Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 -public struct TopPLogitsWarper: LogitsWarper { - /// Cumulative probability threshold. - public var p: Float - - /// Creates a top-p (nucleus) logits warper. - /// - /// - Parameter p: Cumulative probability threshold (0.0 to 1.0) - public init(p: Float) { - self.p = p - } - - /// Applies top-p (nucleus) filtering to the logits. - /// - /// Computes softmax probabilities, sorts by probability, and selects tokens - /// until their cumulative probability exceeds the threshold p. - /// - /// - Parameters: - /// - indices: Current token indices - /// - logits: Current logits values - /// - Returns: Tuple of (filtered indices, filtered logits) - public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - guard !logits.isEmpty else { - return (indices: [], logits: []) - } - - let arrSoftmax = Math.softmax(logits) - var indexLogitProb = [(index: Int, logit: Float, prob: Float)]() - indexLogitProb.reserveCapacity(logits.count) - for (index, data) in zip(logits, arrSoftmax).enumerated() { - indexLogitProb.append((index: index, logit: data.0, prob: data.1)) - } - indexLogitProb.sort { $0.prob > $1.prob } - - let cumsum = Math.cumsum(indexLogitProb.map(\.prob)) - var sliceIndex = cumsum.count - 1 - for (index, element) in cumsum.enumerated() where element > p { - sliceIndex = index - break - } - - let toppIndices = indexLogitProb[0...sliceIndex].map { indices[$0.index] } - let toppLogits = indexLogitProb[0...sliceIndex].map(\.logit) - return (indices: toppIndices, logits: toppLogits) - } -} diff --git a/Sources/Generation/Math.swift b/Sources/Generation/Math.swift deleted file mode 100644 index aadbc3ea..00000000 --- a/Sources/Generation/Math.swift +++ /dev/null @@ -1,214 +0,0 @@ -// -// Math.swift -// CoreMLBert -// -// Created by Julien Chaumond on 27/06/2019. -// Copyright © 2019 Hugging Face. All rights reserved. -// - -#if canImport(CoreML) && canImport(Accelerate) -import Accelerate -import CoreML -import Foundation - -/// Mathematical utilities for text generation and tensor operations. -/// -/// Provides optimized implementations of common mathematical operations -/// used in text generation, including argmax, softmax, sampling, and cumulative sum. -/// -/// - Note: From M.I. Hollemans - https://github.com/hollance/CoreMLHelpers -public enum Math { - /** - Returns the index and value of the largest element in the array. - - - Parameters: - - ptr: Pointer to the first element in memory. - - count: How many elements to look at. - - stride: The distance between two elements in memory. - */ - public static func argmax(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) { - var maxValue: Float = 0 - var maxIndex: vDSP_Length = 0 - vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) - return (Int(maxIndex), maxValue) - } - - /** - Returns the index and value of the largest element in the array. - - Parameters: - - ptr: Pointer to the first element in memory. - - count: How many elements to look at. - - stride: The distance between two elements in memory. - */ - public static func argmax(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Double) { - var maxValue: Double = 0 - var maxIndex: vDSP_Length = 0 - vDSP_maxviD(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) - return (Int(maxIndex), maxValue) - } - - /// Returns the index and value of the largest element in a Float array. - /// - /// - Parameters: - /// - ptr: Pointer to the first element in memory - /// - count: How many elements to look at - /// - stride: The distance between two elements in memory - /// - Returns: Tuple of (index, value) of the maximum element - public static func argmax32(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) { - var maxValue: Float = 0 - var maxIndex: vDSP_Length = 0 - vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) - return (Int(maxIndex), maxValue) - } - - /// Returns the index and value of the largest element in an MLMultiArray of doubles. - /// - /// - Parameter multiArray: Input MLMultiArray with double precision values - /// - Returns: Tuple of (index, value) of the maximum element - public static func argmax(_ multiArray: MLMultiArray) -> (Int, Double) { - assert(multiArray.dataType == .double) - let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer)) - return Math.argmax(ptr, count: multiArray.count) - } - - /// Returns the index and value of the largest element in an MLMultiArray of floats. - /// - /// - Parameter multiArray: Input MLMultiArray with single precision values - /// - Returns: Tuple of (index, value) of the maximum element - public static func argmax32(_ multiArray: MLMultiArray) -> (Int, Float) { - assert(multiArray.dataType == .float32) - let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer)) - return Math.argmax32(ptr, count: multiArray.count) - } - - /// Returns the cumulative sum of the array. - /// - /// Computes the cumulative sum where each element is the sum of all previous elements - /// plus the current element. - /// - /// - Parameter arr: Input array of Float values - /// - Returns: Array of cumulative sums - public static func cumsum(_ arr: [Float]) -> [Float] { - guard !arr.isEmpty else { - return [] - } - let arrCount = vDSP_Length(arr.count) - var weight: Float = 1.0 - var result: [Float] = Array(repeating: 0.0, count: arr.count) - var firstItem = arr[0] - vDSP_vrsum(arr, 1, &weight, &result, 1, arrCount) - vDSP_vsadd(result, 1, &firstItem, &result, 1, arrCount) - return result - } - - /// Performs multinomial sampling from probability distributions. - /// - /// Selects an index based on probability weights, commonly used for token sampling - /// in text generation after applying logits warpers. - /// - /// - Parameters: - /// - indexes: Array of indices to sample from - /// - probs: Probability weights for each index - /// - Returns: Selected index based on probability distribution - public static func sample(indexes: [Int], probs: [Float]) -> Int { - let i = randomNumber(probabilities: probs) - return indexes[i] - } - - /// Computes the softmax function over an array. - /// - /// Converts logits into a probability distribution by applying exponential normalization. - /// Uses numerical stability techniques by shifting values to prevent overflow. - /// - /// The implementation follows this algorithm: - /// 1. Subtract maximum value for numerical stability - /// 2. Apply exponential function to all elements - /// 3. Normalize by dividing by the sum of all exponentials - /// - /// - Parameter x: Input logits array - /// - Returns: Probability distribution (sums to 1.0) - /// - /// - Note: Based on code from https://github.com/nikolaypavlov/MLPNeuralNet/ - public static func softmax(_ x: [Float]) -> [Float] { - var x = x - let len = vDSP_Length(x.count) - - // Find the maximum value in the input array. - var max: Float = 0 - vDSP_maxv(x, 1, &max, len) - - // Subtract the maximum from all the elements in the array. - // Now the highest value in the array is 0. - max = -max - vDSP_vsadd(x, 1, &max, &x, 1, len) - - // Exponentiate all the elements in the array. - var count = Int32(x.count) - vvexpf(&x, x, &count) - - // Compute the sum of all exponentiated values. - var sum: Float = 0 - vDSP_sve(x, 1, &sum, len) - - // Divide each element by the sum. This normalizes the array contents - // so that they all add up to 1. - vDSP_vsdiv(x, 1, &sum, &x, 1, len) - - return x - } - - /// Generates a random index based on probability weights. - /// - /// Uses the roulette wheel selection algorithm to choose an index where - /// the probability of selection is proportional to the weight at that index. - /// - /// - Parameter probabilities: Array of probability weights (need not sum to 1.0) - /// - Returns: Selected index based on probability distribution - /// - /// - Note: From https://stackoverflow.com/questions/30309556/generate-random-numbers-with-a-given-distribution - public static func randomNumber(probabilities: [Float]) -> Int { - // Sum of all probabilities (so that we don't have to require that the sum is 1.0): - let sum = probabilities.reduce(0, +) - // Random number in the range 0.0 <= rnd < sum : - let rnd = sum * Float(arc4random_uniform(UInt32.max)) / Float(UInt32.max) - // Find the first interval of accumulated probabilities into which `rnd` falls: - var accum: Float = 0.0 - for (i, p) in probabilities.enumerated() { - accum += p - if rnd < accum { - return i - } - } - // This point might be reached due to floating point inaccuracies: - return probabilities.count - 1 - } -} - -/// MLShapedArray extensions for Math operations. -public extension Math { - /// Returns the index and value of the largest element in an MLShapedArray of floats. - /// - /// - Parameter shapedArray: Input MLShapedArray containing Float values - /// - Returns: Tuple of (index, value) of the maximum element - static func argmax(_ shapedArray: MLShapedArray) -> (Int, Float) { - shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in - assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices") - return Math.argmax32(ptr.baseAddress!, count: shapedArray.count, stride: strides.first!) - } - } - - /// Returns the index and value of the largest element in a generic MLShapedArray. - /// - /// - Parameter shapedArray: Input shaped array conforming to MLShapedArrayProtocol - /// - Returns: Tuple of (index, value) of the maximum element as Float - /// - /// - Note: Currently assumes Float data type - static func argmax(_ shapedArray: some MLShapedArrayProtocol) -> (Int, Float) { - shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in - assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices") - let floatsPtr = ptr.baseAddress as! UnsafePointer - return Math.argmax32(floatsPtr, count: shapedArray.count, stride: strides.first!) - } - } -} -#endif // canImport(CoreML) && canImport(Accelerate) diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index ff6d7440..a7129c58 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -27,6 +27,8 @@ public extension Hub { case httpStatusCode(Int) /// Failed to parse server response or configuration data. case parse + /// Expected json response could not be parsed as json. + case jsonSerialization(fileURL: URL, message: String) /// An unexpected error occurred during operation. case unexpectedError /// A download operation failed with the specified error message. @@ -52,6 +54,8 @@ public extension Hub { String(localized: "HTTP error with status code: \(code)") case .parse: String(localized: "Failed to parse server response.") + case .jsonSerialization(_, let message): + message case .unexpectedError: String(localized: "An unexpected error occurred.") case let .downloadError(message): diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 9211f770..e89dab0a 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -314,7 +314,9 @@ public extension HubApi { /// `fileURL` is a complete local file path for the given model func configuration(fileURL: URL) throws -> Config { let data = try Data(contentsOf: fileURL) - let parsed = try JSONSerialization.bomPreservingJsonObject(with: data) + guard let parsed = try? JSONSerialization.bomPreservingJsonObject(with: data) else { + throw Hub.HubClientError.jsonSerialization(fileURL: fileURL, message: "JSON Serialization failed for \(fileURL). Please verify that you have set the HF_TOKEN environment variable.") + } guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse } return Config(dictionary) } diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 8cd23831..59203c6a 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -12,6 +12,7 @@ import Generation import Hub import Tokenizers +@available(macOS 15.0, iOS 18.0, *) /// A high-level interface for language model inference using CoreML. /// /// `LanguageModel` provides a convenient way to load and interact with pre-trained @@ -27,15 +28,6 @@ public class LanguageModel { /// The maximum context length supported by the model. public let maxContextLength: Int - let input_ids = "input_ids" - let attention_mask = "attention_mask" - - struct Configurations { - var modelConfig: Config - var tokenizerConfig: Config? - var tokenizerData: Config - } - private var configuration: LanguageModelConfigurationFromHub? private var _tokenizer: Tokenizer? @@ -45,36 +37,104 @@ public class LanguageModel { /// - Important: Triggers a fatal error if the model doesn't have the expected input shape information public required init(model: MLModel) { self.model = model + (minContextLength, maxContextLength) = Self.contextRange(from: model) + configuration = LanguageModelConfigurationFromHub(modelName: modelName) + } + + public func resetState() async {} + + public func predictNextTokenScores( + _ tokens: MLTensor, + config: GenerationConfig + ) async -> MLTensor { + assert(tokens.rank == 2) // [batch, current sequence length] + let tokenCount = tokens.shape[1] + let padLength = maxContextLength - tokenCount + let padding = MLTensor(repeating: Int32(config.padTokenId ?? 0), shape: [1, padLength]) + let inputIDs = MLTensor(concatenating: [tokens, padding], alongAxis: -1) + var inputDictionary = [inputIdsName: inputIDs] + if isRequiringAttentionMask { + let mask = [Int32](repeating: 1, count: tokenCount) + [Int32](repeating: 0, count: padLength) + let attentionMask = MLTensor(shape: inputIDs.shape, scalars: mask) + inputDictionary[Keys.attentionMask] = attentionMask + } + let outputs = try! await model.prediction(from: inputDictionary) + + assert(outputs.keys.contains(Keys.logits)) + + let scores = outputs[Keys.logits]! + assert(scores.rank == 3) + let tokenIndex = tokenCount - 1 + let nextTokenScores = scores[nil, tokenIndex, nil].expandingShape(at: 0) + assert(nextTokenScores.rank == 3) + assert(nextTokenScores.shape[0] == 1 && nextTokenScores.shape[1] == 1) + return nextTokenScores + } +} + +@available(macOS 15.0, iOS 18.0, *) +private extension LanguageModel { + static func contextRange(from model: MLModel) -> (min: Int, max: Int) { + contextRange(from: model, inputKey: Keys.inputIds) + } - // We assume inputs named "input_ids" with shape (1, seq_length) - // Perhaps we should convert to vectors of shape (seq_length) and use sequenceConstraint instead of shapeConstraint - let inputDescription = model.modelDescription.inputDescriptionsByName["input_ids"] + static func contextRange(from model: MLModel, inputKey: String) -> (min: Int, max: Int) { + let inputDescription = model.modelDescription.inputDescriptionsByName[inputKey] guard let shapeConstraint = inputDescription?.multiArrayConstraint?.shapeConstraint else { fatalError("Cannot obtain shape information") } + var minContextLength = 128 + var maxContextLength = 128 + switch shapeConstraint.type { case .enumerated: - // TODO: support a set of fixed shapes (keeping the first one here) minContextLength = shapeConstraint.enumeratedShapes[0][1].intValue maxContextLength = minContextLength case .range: - let range = inputDescription?.multiArrayConstraint?.shapeConstraint.sizeRangeForDimension[1] as? NSRange - minContextLength = range?.location ?? 1 - maxContextLength = range?.length ?? 128 + if let sizeRangeForDimension = inputDescription?.multiArrayConstraint?.shapeConstraint.sizeRangeForDimension { + let lastAxis = sizeRangeForDimension.count - 1 + let range = sizeRangeForDimension[lastAxis] as? NSRange + minContextLength = range?.location ?? 1 + maxContextLength = range?.length ?? 128 + } case .unspecified: - minContextLength = 128 - maxContextLength = 128 + break @unknown default: - minContextLength = 128 - maxContextLength = 128 + break } - configuration = LanguageModelConfigurationFromHub(modelName: modelName) + return (minContextLength, maxContextLength) } } +@available(macOS 15.0, iOS 18.0, *) +extension LanguageModel { + struct Configurations { + var modelConfig: Config + var tokenizerConfig: Config? + var tokenizerData: Config + } +} + +@available(macOS 15.0, iOS 18.0, *) +extension LanguageModel { + enum Keys { + // Input keys + static let inputIds = "inputIds" + static let attentionMask = "attentionMask" + static let causalMask = "causalMask" + static let keyCache = "keyCache" + static let valueCache = "valueCache" + // Output keys + static let logits = "logits" + static let presentKeys = "presentKeys" + static let presentValues = "presentValues" + } +} + +@available(macOS 15.0, iOS 18.0, *) public extension LanguageModel { /// Loads a compiled CoreML model from disk. /// @@ -87,16 +147,44 @@ public extension LanguageModel { let config = MLModelConfiguration() config.computeUnits = computeUnits let model = try MLModel(contentsOf: url, configuration: config) - return LanguageModel(model: model) + return switch kvCacheAvailability(for: model) { + case .statefulKVCache: LanguageModelWithStatefulKVCache(model: model) + default: LanguageModel(model: model) + } + } +} + +@available(macOS 15.0, iOS 18.0, *) +extension LanguageModel { + enum KVCacheAvailability { + /// Language models that support KV cache via state. Implementation details for handling state + /// encapsulated within the Core ML framework. + /// + /// Input: State + /// Output: N/A + case statefulKVCache } } +@available(macOS 15.0, iOS 18.0, *) public extension LanguageModel { + /// Metadata fields associated to the Core ML model. + var metadata: [MLModelMetadataKey: Any] { + model.modelDescription.metadata + } + + /// A description of a model containing input, output, and state feature descriptions. + /// + /// Returns a MLModelDescription instance. + var modelDescription: MLModelDescription { + model.modelDescription + } + /// A human-readable description of the model. /// /// Returns the model's description from its metadata, or the display name if no description is available. var description: String { - if let description = model.modelDescription.metadata[MLModelMetadataKey.description] as? String, + if let description = metadata[MLModelMetadataKey.description] as? String, !description.isEmpty { return description @@ -109,19 +197,21 @@ public extension LanguageModel { /// Returns the model identifier from Hugging Face Hub metadata if available, /// otherwise falls back to the model's display name. var modelName: String { - if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String], + if let userFields = metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String], let name = userFields["co.huggingface.exporters.name"] { return name } // This is usually the basename of the file, that's our best bet if no metadata exists - guard let modelName = model.configuration.modelDisplayName else { fatalError("Models must have a name that identifies them") } + guard let modelName = model.configuration.modelDisplayName else { + fatalError("Models must have a name that identifies them") + } return modelName } /// The feature description for the input_ids input. var inputIdsDescription: MLFeatureDescription { - model.modelDescription.inputDescriptionsByName[input_ids]! + modelDescription.inputDescriptionsByName[Keys.inputIds]! } /// The name of the input_ids feature in the model. @@ -131,51 +221,82 @@ public extension LanguageModel { /// The expected shape of the input_ids tensor. var inputIdsShape: [Int] { - inputIdsDescription.multiArrayConstraint!.shape.map { $0.intValue } + inputIdsDescription.multiArrayConstraint!.shape.map(\.intValue) } /// Whether the model requires attention mask inputs. - var requiresAttention: Bool { - model.modelDescription.inputDescriptionsByName[attention_mask] != nil + var isRequiringAttentionMask: Bool { + modelDescription.inputDescriptionsByName[Keys.attentionMask] != nil + } + + /// Whether the model requires a causal attention mask. + var isRequiringCausalMask: Bool { + modelDescription.inputDescriptionsByName[Keys.causalMask] != nil } - /// Predicts the next token scores for the given input tokens. + /// Determines the type of KV Cache available for the model, if any. /// /// - Parameters: - /// - tokens: The input token sequence - /// - config: The generation configuration containing model parameters - /// - Returns: A shaped array containing the logits for the next token prediction - func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { - // TODO: exceptions - - // Maybe pad or truncate - let maxTokens = min(tokens.count, maxContextLength) - let padLength = maxTokens >= minContextLength ? 0 : minContextLength - maxTokens - let inputTokens = Array(tokens[0..(scalars: inputTokens.map { Int32($0) }, shape: inputIdsShape) - var inputDictionary = [inputIdsName: MLFeatureValue(shapedArray: inputIds)] - if requiresAttention { - let mask = Array(repeating: 1, count: maxTokens) + Array(repeating: 0, count: padLength) - let attentionMask = MLShapedArray(scalars: mask.map { Int32($0) }, shape: inputIdsShape) - inputDictionary[attention_mask] = MLFeatureValue(shapedArray: attentionMask) + /// - model: The Core ML model + /// - Returns: The type of KV Cache available. + fileprivate static func kvCacheAvailability(for model: MLModel) -> KVCacheAvailability? { + func isStatefulKVCacheAvailable(for model: MLModel) -> Bool { + let kCacheState = model.modelDescription.stateDescriptionsByName[Keys.keyCache] != nil + let vCacheState = model.modelDescription.stateDescriptionsByName[Keys.valueCache] != nil + guard Set([kCacheState, vCacheState]).count == 1 else { + fatalError("Invalid model configuration, expecting KV cache for states") + } + return kCacheState && kCacheState } - let input = try! MLDictionaryFeatureProvider(dictionary: inputDictionary) - let output = try! model.prediction(from: input) + func isDynamicallyShaped(_ description: MLFeatureDescription) -> Bool { + guard let multiArrayConstraint = description.multiArrayConstraint else { + return false + } + return switch multiArrayConstraint.shapeConstraint.type { + case .unspecified: true + case .enumerated: multiArrayConstraint.shapeConstraint.enumeratedShapes.count > 1 + case .range: true + default: false + } + } - // TODO: maybe try to support models with "token_scores" too (after the softmax) - assert(output.featureNames.first! == "logits") + if isStatefulKVCacheAvailable(for: model) { + return .statefulKVCache + } + let kCacheInput = model.modelDescription.inputDescriptionsByName[Keys.keyCache] != nil + let vCacheInput = model.modelDescription.inputDescriptionsByName[Keys.valueCache] != nil + let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentKeys] != nil + let vCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentValues] != nil - let scores = output.featureValue(for: output.featureNames.first!)!.shapedArrayValue(of: Float.self)! - let nextTokenScores = scores[0, maxTokens - 1] - return nextTokenScores + guard Set([kCacheInput, vCacheInput, kCacheOutput, vCacheOutput]).count == 1 else { + fatalError("Invalid model configuration, expecting KV cache for inputs and outputs") + } + guard kCacheInput else { + return nil + } + // Check if cache is dynamic or not. + let kCacheConstraint = model.modelDescription.inputDescriptionsByName[Keys.keyCache]! + if isDynamicallyShaped(kCacheConstraint) { + fatalError( + """ + KV Cache using IO is currently not supported, please file a feature request on \ + https://github.com/huggingface/swift-transformers + """) + } else { + fatalError( + """ + KV Cache using IO is currently not supported, please file a feature request on \ + https://github.com/huggingface/swift-transformers + """) + } } } // MARK: - Configuration Properties /// Asynchronous properties that are downloaded from the Hugging Face Hub configuration. +@available(macOS 15.0, iOS 18.0, *) public extension LanguageModel { /// The model configuration dictionary. /// @@ -281,13 +402,14 @@ public extension LanguageModel { // MARK: - TextGenerationModel Conformance +@available(macOS 15.0, iOS 18.0, *) extension LanguageModel: TextGenerationModel { /// The default generation configuration for this model. /// /// Provides sensible defaults based on the model type, with model-specific /// optimizations for known architectures like GPT models. public var defaultGenerationConfig: GenerationConfig { - var config = GenerationConfig(maxNewTokens: 30) + var config = GenerationConfig(maxNewTokens: 2048) switch modelName.lowercased() { case let x where x.contains("gpt"): config.doSample = true @@ -298,6 +420,88 @@ extension LanguageModel: TextGenerationModel { } } +/// Language model implementation with stateful KV Cache. +/// +/// Maintains a KV Cache as sequence generation progresses, +/// using stateful Core ML buffers to minimize latency. +@available(macOS 15.0, iOS 18.0, *) +public class LanguageModelWithStatefulKVCache: LanguageModel { + private enum Mode { + case prefilling + case extending + } + private var mode: Mode = .prefilling + + var state: MLState? + + public required init(model: MLModel) { + super.init(model: model) + // To support pre-filling and extend, the input must support + // flexible shapes. + guard maxContextLength - minContextLength > 1 else { + fatalError("Expecting ranged query sequence length.") + } + } + + public override func resetState() async { + state = model.makeState() + mode = .prefilling + } + + public override func predictNextTokenScores( + _ tokens: MLTensor, + config _: GenerationConfig + ) async -> MLTensor { + assert(tokens.rank == 2) // [batch, current sequence length] + let tokenCount = tokens.shape[1] + guard let state else { + fatalError( + """ + Encountered uninitialized `state`. Ensure `resetState` is called prior to calling \ + `predictNextTokenScores`. + """) + } + let inputIds = + switch mode { + case .prefilling: tokens // Pass in all takens if pre-filling prompt + case .extending: tokens[nil, -1].expandingShape(at: 0) // otherwise just the last token + } + mode = .extending + + var inputDictionary = [ + Keys.inputIds: inputIds + ] + if isRequiringAttentionMask { + #if !((os(macOS) || (macCatalyst)) && arch(x86_64)) + // TODO: Infer scalar type from cache or model I/O descriptors + let attentionMask = MLTensor(zeros: [1, 1, 1, tokenCount + 1], scalarType: Float16.self) + inputDictionary[Keys.attentionMask] = attentionMask + #else + fatalError() + #endif + } + if isRequiringCausalMask { + #if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) + // TODO: Infer scalar type from cache or model I/O descriptors + let causalMask = MLTensor(zeros: [1, 1, 1, tokenCount + 1], scalarType: Float16.self) + inputDictionary[Keys.causalMask] = causalMask + #else + fatalError() + #endif + } + let outputs = try! await model.prediction(from: inputDictionary, using: state) + + assert(outputs.keys.contains(Keys.logits)) + let scores = outputs[Keys.logits]! + assert(scores.rank == 3) + let tokenIndex = inputIds.shape[1] - 1 + let nextTokenScores = scores[nil, tokenIndex, nil].expandingShape(at: 0) + assert(nextTokenScores.rank == 3) + assert(nextTokenScores.shape[0] == 1 && nextTokenScores.shape[1] == 1) + return nextTokenScores + } +} + /// Errors that can occur during tokenizer operations in language models. public enum TokenizerError: LocalizedError { /// The tokenizer configuration file could not be found. diff --git a/Sources/Models/LanguageModelTypes.swift b/Sources/Models/LanguageModelTypes.swift index 3ba9d0f1..0e7e1909 100644 --- a/Sources/Models/LanguageModelTypes.swift +++ b/Sources/Models/LanguageModelTypes.swift @@ -15,6 +15,7 @@ import Tokenizers /// /// This protocol establishes the fundamental requirements for any language model /// that can perform next-token prediction and text generation tasks. +@available(macOS 15.0, iOS 18.0, *) public protocol LanguageModelProtocol { /// The name or path of the model. /// @@ -30,6 +31,11 @@ public protocol LanguageModelProtocol { /// The underlying CoreML model used for inference. var model: MLModel { get } + /// Resets the state of the language model. + /// + /// Call `resetState()` for each new sequence generated. + func resetState() async + /// Creates a new language model instance from a CoreML model. /// /// - Parameter model: The CoreML model to wrap @@ -38,11 +44,14 @@ public protocol LanguageModelProtocol { /// Predicts the next token scores for the given input tokens. /// /// - Parameters: - /// - tokens: The input token sequence - /// - config: The generation configuration containing model parameters - /// - Returns: A shaped array containing the logits for the next token prediction - func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol + /// - input: The input sequence tensor. + /// - config: The generation configuration containing model parameters. + /// - Returns: MLTensor with the raw scores of the next token. + func predictNextTokenScores(_ input: MLTensor, config: GenerationConfig) async -> MLTensor +} +@available(macOS 15.0, iOS 18.0, *) +public extension LanguageModelProtocol { /// Function call syntax for next token prediction. /// /// This provides a more convenient syntax for calling `predictNextTokenScores`. @@ -51,13 +60,8 @@ public protocol LanguageModelProtocol { /// - tokens: The input token sequence /// - config: The generation configuration containing model parameters /// - Returns: A shaped array containing the logits for the next token prediction - func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol -} - -public extension LanguageModelProtocol { - /// Default implementation of function call syntax that delegates to `predictNextTokenScores`. - func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { - predictNextTokenScores(tokens, config: config) + func callAsFunction(_ input: MLTensor, config: GenerationConfig) async -> MLTensor { + await predictNextTokenScores(input, config: config) } } @@ -65,6 +69,7 @@ public extension LanguageModelProtocol { /// /// This protocol extends `LanguageModelProtocol` and `Generation` to provide /// high-level text generation functionality with configurable parameters. +@available(macOS 15.0, iOS 18.0, *) public protocol TextGenerationModel: Generation, LanguageModelProtocol { /// The default generation configuration for this model. /// @@ -80,9 +85,14 @@ public protocol TextGenerationModel: Generation, LanguageModelProtocol { /// - callback: Optional callback to receive intermediate generation results /// - Returns: The generated text as a string /// - Throws: An error if text generation fails - func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback?) async throws -> String + func generate( + config: GenerationConfig, + prompt: String, + callback: PredictionStringCallback? + ) async throws -> String } +@available(macOS 15.0, iOS 18.0, *) public extension TextGenerationModel { /// Default implementation of text generation that uses the underlying generation framework. /// @@ -94,7 +104,17 @@ public extension TextGenerationModel { /// - Throws: An error if text generation fails @discardableResult func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async throws -> String { - try await generate(config: config, prompt: prompt, model: callAsFunction, tokenizer: tokenizer, callback: callback) + // Prepare the language model for a new sequence. + await resetState() + + // Run inference. + return try await generate( + config: config, + prompt: prompt, + model: callAsFunction, + tokenizer: tokenizer, + callback: callback + ) } } #endif // canImport(CoreML) diff --git a/Tests/GenerationTests/LogitsWarperTests.swift b/Tests/GenerationTests/LogitsWarperTests.swift deleted file mode 100644 index 5d97652a..00000000 --- a/Tests/GenerationTests/LogitsWarperTests.swift +++ /dev/null @@ -1,160 +0,0 @@ -// -// LogitsWarperTests.swift -// -// Created by Jan Krukowski on 09/12/2023. -// - -#if canImport(CoreML) -import CoreML -@testable import Generation -import Testing - -@Suite("Logits Warper Tests") -struct LogitsWarperTests { - private let accuracy: Float = 0.00001 - - @Test("Temperature logits warper functionality") - func temperatureLogitsWarper() { - let result1 = TemperatureLogitsWarper(temperature: 0.0)([], []) - #expect(result1.indices.isEmpty) - #expect(result1.logits.isEmpty) - - let result2 = TemperatureLogitsWarper(temperature: 1.0)([], []) - #expect(result2.indices.isEmpty) - #expect(result2.logits.isEmpty) - - let result3 = TemperatureLogitsWarper(temperature: 1.0)([0, 1], [2.0, 1.0]) - #expect(result3.indices == [0, 1]) - #expect(isClose(result3.logits, [2.0, 1.0], accuracy: accuracy)) - - let result4 = TemperatureLogitsWarper(temperature: 2.0)([0, 1], [2.0, 1.0]) - #expect(result4.indices == [0, 1]) - #expect(isClose(result4.logits, [1.0, 0.5], accuracy: accuracy)) - - let result5 = TemperatureLogitsWarper(temperature: 0.5)([0, 1], [2.0, 1.0]) - #expect(result5.indices == [0, 1]) - #expect(isClose(result5.logits, [4.0, 2.0], accuracy: accuracy)) - - let result6 = TemperatureLogitsWarper(temperature: 0.5)([200, 100], [2.0, 1.0]) - #expect(result6.indices == [200, 100]) - #expect(isClose(result6.logits, [4.0, 2.0], accuracy: accuracy)) - } - - @Test("Top-K logits warper functionality") - func topKLogitsWarper() { - let result1 = TopKLogitsWarper(k: 0)([], []) - #expect(result1.indices.isEmpty) - #expect(result1.logits.isEmpty) - - let result2 = TopKLogitsWarper(k: 3)([], []) - #expect(result2.indices.isEmpty) - #expect(result2.logits.isEmpty) - - let result3 = TopKLogitsWarper(k: 3)([0, 1], [2.0, 1.0]) - #expect(result3.indices == [0, 1]) - #expect(isClose(result3.logits, [2.0, 1.0], accuracy: accuracy)) - - let result4 = TopKLogitsWarper(k: 3)([0, 1, 2], [2.0, 1.0, 3.0]) - #expect(result4.indices == [2, 0, 1]) - #expect(isClose(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) - - let result5 = TopKLogitsWarper(k: 4)([0, 1, 2, 3, 4, 5], [2.0, 1.0, 3.0, -1.0, 123.0, 0.0]) - #expect(result5.indices == [4, 2, 0, 1]) - #expect(isClose(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy)) - - let result6 = TopKLogitsWarper(k: 3)([10, 1, 52], [2.0, 1.0, 3.0]) - #expect(result6.indices == [52, 10, 1]) - #expect(isClose(result6.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) - } - - @Test("Top-P logits warper functionality") - func topPLogitsWarper() { - let result1 = TopPLogitsWarper(p: 0.99)([], []) - #expect(result1.indices.isEmpty) - #expect(result1.logits.isEmpty) - - let logits = (0..<10).map { Float($0) } - let indices = Array(logits.indices) - let result2 = TopPLogitsWarper(p: 0.99)(indices, logits) - #expect(result2.indices == [9, 8, 7, 6, 5]) - #expect(isClose(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy)) - - let result3 = TopPLogitsWarper(p: 0.95)(indices, logits) - #expect(result3.indices == [9, 8, 7]) - #expect(isClose(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy)) - - let result4 = TopPLogitsWarper(p: 0.6321493)(indices, logits) - #expect(result4.indices == [9, 8]) - #expect(isClose(result4.logits, [9.0, 8.0], accuracy: accuracy)) - - let result5 = TopPLogitsWarper(p: 0.95)([3, 1, 8], [0, 1, 2]) - #expect(result5.indices == [8, 1, 3]) - #expect(isClose(result5.logits, [2, 1, 0], accuracy: accuracy)) - } - - @Test("Repetition penalty warper functionality") - func repetitionPenaltyWarper() { - let indices = Array(0..<10) - let logits = indices.map { Float($0) } - - let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits) - #expect(result1.indices == indices) - #expect(isClose(result1.logits, logits, accuracy: accuracy)) - - let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits) - #expect(result2.indices == indices) - let logits2 = indices.map { Float($0) / 3.75 } - #expect(isClose(result2.logits, logits2, accuracy: accuracy)) - - let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119]) - #expect(result3.indices == [0, 1, 2]) - #expect(isClose(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4)) - - let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158]) - #expect(result4.indices == [2, 3, 4]) - #expect(isClose(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4)) - - let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966]) - #expect(result5.indices == [0, 1, 2]) - #expect(isClose(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4)) - - let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755]) - #expect(result6.indices == [3, 1, 2]) - #expect(isClose(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4)) - } - - @Test("Logits processor functionality") - func logitsProcessor() { - let processor1 = LogitsProcessor(logitsWarpers: []) - let result1 = processor1([]) - #expect(result1.indices.isEmpty) - #expect(result1.logits.isEmpty) - - let processor2 = LogitsProcessor(logitsWarpers: []) - let result2 = processor2([2.0, 1.0]) - #expect(result2.indices == [0, 1]) - #expect(isClose(result2.logits, [2.0, 1.0], accuracy: accuracy)) - - let processor3 = LogitsProcessor( - logitsWarpers: [TopKLogitsWarper(k: 3)] - ) - let result3 = processor3([2.0, 1.0, 3.0, -5.0]) - #expect(result3.indices == [2, 0, 1]) - #expect(isClose(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) - - let processor4 = LogitsProcessor( - logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)] - ) - let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5]) - #expect(result4.indices == [5]) - #expect(isClose(result4.logits, [12.5], accuracy: accuracy)) - - let processor5 = LogitsProcessor( - logitsWarpers: [TopKLogitsWarper(k: 4), TopPLogitsWarper(p: 0.99)] - ) - let result5 = processor5([2.0, 1.0, 3.0, -5.0, -3.0, 4.5]) - #expect(result5.indices == [5, 2, 0, 1]) - #expect(isClose(result5.logits, [4.5, 3.0, 2.0, 1.0], accuracy: accuracy)) - } -} -#endif // canImport(CoreML) diff --git a/Tests/GenerationTests/MathTests.swift b/Tests/GenerationTests/MathTests.swift deleted file mode 100644 index 0e255ecd..00000000 --- a/Tests/GenerationTests/MathTests.swift +++ /dev/null @@ -1,59 +0,0 @@ -// -// MathTests.swift -// -// Created by Jan Krukowski on 25/11/2023. -// - -#if canImport(CoreML) -import CoreML -@testable import Generation -import Testing - -@Suite("Math Tests") -struct MathTests { - private let accuracy: Float = 0.00001 - - @Test("Cumulative sum functionality") - func cumsum() { - #expect(Math.cumsum([]).isEmpty) - #expect(Math.cumsum([1]) == [1]) - #expect(Math.cumsum([1, 2, 3, 4]) == [1, 3, 6, 10]) - } - - @Test("Argmax functionality") - func argmax() throws { - let result1 = Math.argmax([3.0, 4.0, 1.0, 2.0] as [Float], count: 4) - #expect(result1.0 == 1) - #expect(result1.1 == 4.0) - - let result2 = Math.argmax32([3.0, 4.0, 1.0, 2.0], count: 4) - #expect(result2.0 == 1) - #expect(result2.1 == 4.0) - - let result3 = Math.argmax([3.0, 4.0, 1.0, 2.0] as [Double], count: 4) - #expect(result3.0 == 1) - #expect(result3.1 == 4.0) - - let result4 = try Math.argmax32(MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Float])) - #expect(result4.0 == 1) - #expect(result4.1 == 4.0) - - let result5 = try Math.argmax(MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Double])) - #expect(result5.0 == 1) - #expect(result5.1 == 4.0) - - let result6 = Math.argmax(MLShapedArray(scalars: [3.0, 4.0, 1.0, 2.0] as [Float], shape: [4])) - #expect(result6.0 == 1) - #expect(result6.1 == 4.0) - } - - @Test("Softmax functionality") - func softmax() { - #expect(Math.softmax([]) == []) - - let result1 = Math.softmax([3.0, 4.0, 1.0, 2.0]) - #expect(isClose(result1, [0.23688284, 0.6439143, 0.032058604, 0.08714432], accuracy: accuracy)) - #expect(abs(result1.reduce(0, +) - 1.0) < accuracy) - } -} -#endif // canImport(CoreML) diff --git a/Tests/GenerationTests/TestUtils.swift b/Tests/GenerationTests/TestUtils.swift deleted file mode 100644 index 4f9f2a59..00000000 --- a/Tests/GenerationTests/TestUtils.swift +++ /dev/null @@ -1,7 +0,0 @@ -import Foundation - -/// Check if two floating-point arrays are equal within a given accuracy -func isClose(_ lhs: [T], _ rhs: [T], accuracy: T) -> Bool { - guard lhs.count == rhs.count else { return false } - return zip(lhs, rhs).allSatisfy { abs($0.0 - $0.1) <= accuracy } -} diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift index aecfa778..552fbf3e 100644 --- a/Tests/HubTests/DownloaderTests.swift +++ b/Tests/HubTests/DownloaderTests.swift @@ -181,11 +181,11 @@ final class DownloaderTests: XCTestCase { ) { error in XCTAssertEqual((error as NSError).code, 516) } - XCTAssertEqual(try String(contentsOf: destination), "existing") + XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), "existing") XCTAssertNoThrow( try FileManager.default.moveDownloadedFile(from: source2, to: destination, percentEncoded: false) ) - XCTAssertEqual(try String(contentsOf: destination), "v2") + XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), "v2") } } diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 559306ec..e8c4031e 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -571,7 +571,7 @@ class SnapshotDownloadTests: XCTestCase { // If that's the case we just update the metadata and keep the local file. XCTAssertEqual(originalTimestamp, thirdDownloadTimestamp) - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) // Updated metadata file needs to have the correct commit hash, etag and timestamp. // This is being updated because the local etag (SHA256 checksum) matches the remote etag @@ -602,7 +602,7 @@ class SnapshotDownloadTests: XCTestCase { ) let metadataFile = metadataDestination.appendingPathComponent("llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata") - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nfc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107" XCTAssertTrue(metadataString.contains(expected)) @@ -632,7 +632,7 @@ class SnapshotDownloadTests: XCTestCase { ) let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" XCTAssertTrue(metadataString.contains(expected)) @@ -662,7 +662,7 @@ class SnapshotDownloadTests: XCTestCase { ) let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let metadataArr = metadataString.components(separatedBy: .newlines) let commitHash = metadataArr[0] @@ -720,7 +720,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" XCTAssertTrue(metadataString.contains(expected)) @@ -771,7 +771,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" XCTAssertTrue(metadataString.contains(expected)) @@ -822,7 +822,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nd6ceb92ce9e3c83ab146dc8e92a93517ac1cc66f" XCTAssertTrue(metadataString.contains(expected)) @@ -1009,7 +1009,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) + let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path, encoding: .utf8) let expected = """ { @@ -1048,7 +1048,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) + let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path, encoding: .utf8) let expected = """ X diff --git a/Tests/TokenizersTests/BertTokenizerTests.swift b/Tests/TokenizersTests/BertTokenizerTests.swift index e6c269a5..33493b43 100644 --- a/Tests/TokenizersTests/BertTokenizerTests.swift +++ b/Tests/TokenizersTests/BertTokenizerTests.swift @@ -87,7 +87,7 @@ private enum Squad { private let bertTokenizer: BertTokenizer = { let vocab = { let url = Bundle.module.url(forResource: "bert-vocab", withExtension: "txt")! - let vocabTxt = try! String(contentsOf: url) + let vocabTxt = try! String(contentsOf: url, encoding: .utf8) let tokens = vocabTxt.split(separator: "\n").map { String($0) } var vocab: [String: Int] = [:] for (i, token) in tokens.enumerated() {