diff --git a/Examples/Mistral7B/README.md b/Examples/Mistral7B/README.md
new file mode 100644
index 00000000..30d64b14
--- /dev/null
+++ b/Examples/Mistral7B/README.md
@@ -0,0 +1,27 @@
+### Export Mistral 7B Instruct v0.3
+
+```shell
+✗ uv run export.py
+
+Loading checkpoint shards: 100%|███████████████████████████| 3/3 [00:12<00:00, 4.11s/it]
+Converting PyTorch Frontend ==> MIL Ops: 100%|███| 5575/5575 [00:02<00:00, 2440.66 ops/s]
+Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 7.12 passes/s]
+Running MIL default pipeline: 100%|█████████████████| 79/79 [02:36<00:00, 1.98s/ passes]
+Running MIL backend_mlprogram pipeline: 100%|███████| 12/12 [00:00<00:00, 22.90 passes/s]
+Running compression: 100%|███████████████████████████| 296/296 [03:04<00:00, 1.60 ops/s]
+...
+```
+
+### Generate Text
+
+```shell
+✗ swift run transformers-cli "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage
+
+Best recommendations for a place to visit in Paris in August 2024:
+
+1. Palace of Versailles: This iconic palace is a must-visit. It's a short train ride from Paris and offers a glimpse into the opulence of the French monarchy.
+
+2. Eiffel Tower: No trip to Paris is complete without a visit to the Eiffel Tower. You can take an elevator ride to the top for a stunning view of the city.
+
+3. Louvre Museum: Home to thousands of works of art, including the Mona Lisa and the Winged Victory of Samothrace, the Louvre is a cultural treasure.
+```
diff --git a/Examples/Mistral7B/export.py b/Examples/Mistral7B/export.py
new file mode 100644
index 00000000..4509a937
--- /dev/null
+++ b/Examples/Mistral7B/export.py
@@ -0,0 +1,230 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "coremltools",
+# "numpy",
+# "sentencepiece",
+# "torch",
+# "tqdm",
+# "transformers",
+# ]
+# ///
+import logging
+import os
+import warnings
+from typing import List, Optional, Tuple
+
+import coremltools as ct
+import numpy as np
+import torch
+from transformers.cache_utils import Cache
+from transformers.models.mistral.modeling_mistral import (
+ MISTRAL_ATTENTION_CLASSES,
+ MistralAttention,
+ MistralConfig,
+ MistralForCausalLM,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+
+warnings.filterwarnings("ignore")
+logging.getLogger("coremltools").setLevel(logging.ERROR)
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
+MODEL_ID: str = "mistralai/Mistral-7B-Instruct-v0.3"
+METADATA_TOKENIZER: str = "co.huggingface.exporters.name"
+
+
+class SliceUpdateKeyValueCache(Cache):
+ def __init__(
+ self,
+ shape: Tuple[int, ...],
+ device="cpu",
+ dtype=torch.float32,
+ ) -> None:
+ """KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim)."""
+ super().__init__()
+ self.past_seen_tokens: int = 0
+ self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
+ self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
+
+ def update(
+ self,
+ k_state: torch.Tensor,
+ v_state: torch.Tensor,
+ layer_idx: int,
+ slice_indices: torch.LongTensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]).
+ Return slice of key/value cache tensors from [0, slice_indices[1]).
+ """
+ if len(slice_indices) != 2:
+ raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.")
+ begin, end = slice_indices
+ self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state
+ self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state
+ k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
+ v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
+ return k_cache, v_cache
+
+ def get_seq_length(self, _: int | None = 0) -> int:
+ """Get the sequence length of the cache."""
+ return self.past_seen_tokens
+
+
+class SliceUpdateMistralAttention(MistralAttention):
+ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
+ super().__init__(config=config, layer_idx=layer_idx)
+
+ @torch.no_grad()
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor | None, ...]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(
+ 1, 2
+ )
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ # Slice update key/value cache
+ end_step = attention_mask.shape[-1]
+ key_states, value_states = past_key_value.update(
+ key_states,
+ value_states,
+ self.layer_idx,
+ slice_indices=(end_step - q_len, end_step),
+ )
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, None, None
+
+
+class StatefulMistralForCausalLM(torch.nn.Module):
+ def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None:
+ super().__init__()
+
+ # Custom attention implementation for stateful slice update key/value cache, override
+ # "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation
+ MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention
+ self.model = MistralForCausalLM.from_pretrained(model_path)
+
+ # Register KV cache buffers to be recognized as Core ML states
+ config: MistralConfig = self.model.config
+ self.kv_cache_shape: Tuple[int, ...] = (
+ config.num_hidden_layers,
+ batch_size,
+ config.num_key_value_heads,
+ max_context_size,
+ config.hidden_size // config.num_attention_heads,
+ )
+ self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
+ self.register_buffer("keyCache", self.kv_cache.k_cache)
+ self.register_buffer("valueCache", self.kv_cache.v_cache)
+
+ @torch.no_grad()
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ causal_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ # Compute past seen tokens used for updating key/value cache slices
+ self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1]
+ return self.model(
+ input_ids,
+ attention_mask=causal_mask,
+ past_key_values=self.kv_cache,
+ use_cache=True,
+ ).logits
+
+
+def export() -> None:
+ # Construct model from transformers and trace to TorchScript
+ max_context_size: int = 2048
+ torch_model = StatefulMistralForCausalLM(MODEL_ID, max_context_size=max_context_size)
+ torch_model.eval()
+ input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32)
+ causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32)
+ traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask])
+ kv_cache_shape = torch_model.kv_cache_shape
+ del torch_model
+
+ # Convert traced TorchScript to Core ML format
+ query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
+ end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
+ inputs: List[ct.TensorType] = [
+ ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"),
+ ct.TensorType(
+ shape=(1, 1, query_length, end_step_dim),
+ dtype=np.float16,
+ name="causalMask",
+ ),
+ ]
+ outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]
+ states: List[ct.StateType] = [
+ ct.StateType(
+ wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16),
+ name="keyCache",
+ ),
+ ct.StateType(
+ wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16),
+ name="valueCache",
+ ),
+ ]
+
+ # Convert model with FP16 precision
+ mlmodel_fp16: ct.MLModel = ct.convert(
+ traced_model,
+ inputs=inputs,
+ outputs=outputs,
+ states=states,
+ minimum_deployment_target=ct.target.iOS18,
+ skip_model_load=True,
+ )
+ del traced_model
+
+ # Block-wise quantize model weights to int4
+ op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
+ mode="linear_symmetric",
+ dtype="int4",
+ granularity="per_block",
+ block_size=32,
+ )
+ config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
+ mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config)
+ mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID})
+ del mlmodel_fp16
+ mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage")
+
+
+if __name__ == "__main__":
+ export()
diff --git a/Examples/Mistral7B/generate.py b/Examples/Mistral7B/generate.py
new file mode 100644
index 00000000..0c90acd0
--- /dev/null
+++ b/Examples/Mistral7B/generate.py
@@ -0,0 +1,99 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "coremltools",
+# "numpy",
+# "sentencepiece",
+# "torch",
+# "tqdm",
+# "transformers",
+# ]
+# ///
+import argparse
+from typing import Dict, Generator, List, Tuple
+
+import numpy as np
+from coremltools.models import MLModel
+from transformers import AutoTokenizer
+
+from export import METADATA_TOKENIZER
+
+
+def load(model_path: str) -> Tuple[MLModel, AutoTokenizer]:
+ """Load a Core ML model and corresponding tokenizer."""
+ model: MLModel = MLModel(model_path)
+ description = model.get_spec().description
+ if METADATA_TOKENIZER not in description.metadata.userDefined:
+ raise ValueError("Model metadata does not contain tokenizer path.")
+ tokenizer_path: str = description.metadata.userDefined[METADATA_TOKENIZER]
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ return model, tokenizer
+
+
+def get_next_token(model: MLModel, prompt_tokens: np.ndarray) -> Generator[int, None, None]:
+ """Generate a sequence of tokens with naive greedy decoding."""
+
+ def sample(logits: np.ndarray) -> int:
+ """Perform greedy decoding on the logits array to get the next token."""
+ return int(np.argmax(logits[0][-1], axis=-1))
+
+ def inference(model: MLModel, input_ids: np.ndarray, num_past_tokens: int) -> np.ndarray:
+ """Perform inference with the given model and input data."""
+ causal_mask: np.ndarray = np.triu(
+ np.full(
+ (1, 1, input_ids.shape[-1], num_past_tokens + input_ids.shape[-1]),
+ fill_value=-np.inf if num_past_tokens == 0 else 0,
+ ),
+ k=1,
+ ).astype(np.float16)
+ outputs: Dict[str, np.ndarray] = model.predict(
+ data={"inputIds": input_ids, "causalMask": causal_mask},
+ state=kv_cache_state,
+ )
+ return outputs["logits"]
+
+ kv_cache_state = model.make_state()
+ logits: np.ndarray = inference(model, input_ids=prompt_tokens, num_past_tokens=0)
+ token: int = sample(logits=logits)
+ num_past_tokens: int = prompt_tokens.shape[-1]
+
+ while True:
+ yield token
+ logits: np.ndarray = inference(
+ model,
+ input_ids=np.array([[token]], dtype=np.int32),
+ num_past_tokens=num_past_tokens,
+ )
+ token: int = sample(logits=logits)
+ num_past_tokens += 1
+
+
+def generate(
+ model: MLModel,
+ prompt: str,
+ tokenizer: AutoTokenizer,
+ max_new_tokens: int,
+) -> str:
+ prompt_tokens: np.ndarray = tokenizer(prompt, return_tensors="np").input_ids
+ extend_tokens: List[int] = []
+ for i, token in enumerate(get_next_token(model, prompt_tokens=prompt_tokens.astype(np.int32))):
+ if token == tokenizer.eos_token_id or i == max_new_tokens:
+ break
+ extend_tokens.append(token)
+ return tokenizer.decode(prompt_tokens[0].tolist() + extend_tokens)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("model_path", type=str)
+ parser.add_argument("--prompt", type=str, default="Hello")
+ parser.add_argument("--max_new_tokens", type=int, default=128)
+ args = parser.parse_args()
+ model, tokenizer = load(args.model_path)
+ extend_text: str = generate(
+ model,
+ prompt=args.prompt,
+ tokenizer=tokenizer,
+ max_new_tokens=args.max_new_tokens,
+ )
+ print(extend_text)
diff --git a/Examples/Mistral7B/requirements.txt b/Examples/Mistral7B/requirements.txt
new file mode 100644
index 00000000..3c954fb2
--- /dev/null
+++ b/Examples/Mistral7B/requirements.txt
@@ -0,0 +1,6 @@
+coremltools
+numpy
+torch
+tqdm
+transformers
+sentencepiece
diff --git a/Examples/transformers-cli/Package.swift b/Examples/transformers-cli/Package.swift
new file mode 100644
index 00000000..f360afb1
--- /dev/null
+++ b/Examples/transformers-cli/Package.swift
@@ -0,0 +1,24 @@
+// swift-tools-version: 6.2
+// The swift-tools-version declares the minimum version of Swift required to build this package.
+
+import PackageDescription
+
+let package = Package(
+ name: "transformers-cli",
+ platforms: [.iOS(.v18), .macOS(.v15)],
+ dependencies: [
+ .package(path: "../.."),
+ // If you copy this manifest as a template, use the following line instead
+ //.package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"),
+ .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"),
+ ],
+ targets: [
+ .executableTarget(
+ name: "transformers-cli",
+ dependencies: [
+ .product(name: "Transformers", package: "swift-transformers"),
+ .product(name: "ArgumentParser", package: "swift-argument-parser"),
+ ]
+ )
+ ]
+)
diff --git a/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift
new file mode 100644
index 00000000..cb24e3ca
--- /dev/null
+++ b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift
@@ -0,0 +1,141 @@
+import ArgumentParser
+import CoreML
+import Foundation
+import Generation
+import Models
+
+@available(macOS 15.0, iOS 18.0, *)
+@main
+struct TransformersCLI: AsyncParsableCommand {
+ static let configuration = CommandConfiguration(
+ abstract: "Run text generation on a Core ML language model",
+ version: "0.0.1"
+ )
+
+ @Argument(help: "Input text")
+ var prompt: String
+
+ @Argument(help: "Path to Core ML mlpackage model")
+ var modelPath: String = "./model.mlpackage"
+
+ @Option(help: "Maximum amount of tokens the model should generate")
+ var maxLength: Int = 100
+
+ @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
+ var computeUnits: ComputeUnits = .cpuAndGPU
+
+ @Option(
+ help: """
+ When enabled, two generation passes are ran, one to 'warm up' and another to collect \
+ benchmark metrics.
+ """)
+ var warmup: Bool = false
+
+ func generate(
+ model: LanguageModel,
+ config: GenerationConfig,
+ prompt: String,
+ printOutput: Bool = true
+ ) async throws {
+ var tokensReceived = 0
+ var previousIndex: String.Index? = nil
+ var startTime = Date()
+ var promptProcessingTime: Double = 0 // seconds
+ try await model.generate(config: config, prompt: prompt) { inProgressGeneration in
+ if previousIndex == nil { // Prompt pre-filling
+ promptProcessingTime = Date().timeIntervalSince(startTime)
+ // Reset start time to more accurately compute the average tps.
+ startTime = Date()
+ } else { // Extend
+ // Only start counting tokens once the prompt has been processed.
+ tokensReceived += 1
+ }
+ let response = formatResponse(inProgressGeneration)
+ if printOutput {
+ print(response[(previousIndex ?? response.startIndex)...], terminator: "")
+ fflush(stdout)
+ }
+ previousIndex = response.endIndex
+ }
+ // Current time - start time + elapsed time to process the prompt
+ let endTime = Date()
+ let completionTime = endTime.timeIntervalSince(startTime) + promptProcessingTime
+ let tps = Double(tokensReceived) / endTime.timeIntervalSince(startTime)
+ if printOutput {
+ print("")
+ print(
+ """
+ \(tps.formatted("%.2f")) tokens/s, \
+ prompt pre-filling time: \(promptProcessingTime.formatted("%.2f"))s, \
+ total time: \(completionTime.formatted("%.2f"))s
+ """)
+ }
+ }
+
+ func compile(at url: URL) throws -> URL {
+ #if os(watchOS)
+ fatalError("Model compilation is not supported on watchOS")
+ #else
+ if url.pathExtension == "mlmodelc" { return url }
+ print("Compiling model \(url)")
+ return try MLModel.compileModel(at: url)
+ #endif
+ }
+
+ func run() async throws {
+ let url = URL(filePath: modelPath)
+ let compiledURL = try compile(at: url)
+ print("Loading model \(compiledURL)")
+ let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits)
+
+ // Using greedy generation for now
+ var config = model.defaultGenerationConfig
+ config.doSample = false
+ config.maxNewTokens = maxLength
+
+ // Given the size of the out-of-model computation, dispatch all
+ // tensor operations to the CPU.
+
+ if warmup {
+ print("Warming up...")
+ try await withMLTensorComputePolicy(.cpuOnly) {
+ try await generate(model: model, config: config, prompt: prompt, printOutput: false)
+ }
+ }
+
+ print("Generating")
+ try await withMLTensorComputePolicy(.cpuOnly) {
+ try await generate(model: model, config: config, prompt: prompt)
+ }
+ }
+}
+
+@available(macOS 15.0, iOS 18.0, *)
+enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
+ case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine
+ var asMLComputeUnits: MLComputeUnits {
+ switch self {
+ case .all: return .all
+ case .cpuAndGPU: return .cpuAndGPU
+ case .cpuOnly: return .cpuOnly
+ case .cpuAndNeuralEngine: return .cpuAndNeuralEngine
+ }
+ }
+}
+
+/// Returns a cleaned and formatted version of the response.
+///
+/// - Parameter respone: The response to clean and format.
+/// - Returns: A 'user friendly' representation of the generated response.
+private func formatResponse(_ response: String) -> String {
+ response
+ .replacingOccurrences(of: "\\n", with: "\n")
+ .replacingOccurrences(of: "", with: "")
+ .replacingOccurrences(of: "", with: "")
+}
+
+extension Double {
+ func formatted(_ format: String) -> String {
+ return String(format: "\(format)", self)
+ }
+}
diff --git a/Package.swift b/Package.swift
index fb69f9f3..dc6c3f5a 100644
--- a/Package.swift
+++ b/Package.swift
@@ -24,7 +24,6 @@ let package = Package(
.target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings),
.target(name: "Models", dependencies: ["Tokenizers", "Generation"]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]),
- .testTarget(name: "GenerationTests", dependencies: ["Generation"]),
.testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings),
.testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]),
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]),
diff --git a/Sources/Generation/CoreML+Extensions.swift b/Sources/Generation/CoreML+Extensions.swift
deleted file mode 100644
index 1c3c8f7e..00000000
--- a/Sources/Generation/CoreML+Extensions.swift
+++ /dev/null
@@ -1,302 +0,0 @@
-//
-// CoreML+Extensions.swift
-// CoreMLBert
-//
-// Created by Julien Chaumond on 27/06/2019.
-// Copyright © 2019 Hugging Face. All rights reserved.
-//
-
-#if canImport(CoreML)
-import CoreML
-import Foundation
-
-extension MLMultiArray {
- /// Creates an MLMultiArray from an array of integers.
- ///
- /// All values are stored in the last dimension of the MLMultiArray, with leading
- /// dimensions set to 1. For example, with dims=2, the shape becomes [1, arr.count].
- ///
- /// - Parameters:
- /// - arr: Array of integers to convert
- /// - dims: Number of dimensions for the resulting MLMultiArray
- /// - Returns: MLMultiArray containing the integer values
- static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray {
- var shape = Array(repeating: 1, count: dims)
- shape[shape.count - 1] = arr.count
- // Examples:
- // dims=1 : [arr.count]
- // dims=2 : [1, arr.count]
- //
- let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .int32)
- let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer))
- for (i, item) in arr.enumerated() {
- ptr[i] = Int32(item)
- }
- return o
- }
-
- /// Creates an MLMultiArray from an array of doubles.
- ///
- /// All values are stored in the last dimension of the MLMultiArray, with leading
- /// dimensions set to 1. For example, with dims=2, the shape becomes [1, arr.count].
- ///
- /// - Parameters:
- /// - arr: Array of doubles to convert
- /// - dims: Number of dimensions for the resulting MLMultiArray
- /// - Returns: MLMultiArray containing the double values
- static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray {
- var shape = Array(repeating: 1, count: dims)
- shape[shape.count - 1] = arr.count
- // Examples:
- // dims=1 : [arr.count]
- // dims=2 : [1, arr.count]
- //
- let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .float64)
- let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer))
- for (i, item) in arr.enumerated() {
- ptr[i] = Double(item)
- }
- return o
- }
-
- /// Converts an MLMultiArray to a flat array of integers.
- ///
- /// Concatenates all dimensions into a single one-dimensional array by reading
- /// the MLMultiArray data in memory order.
- ///
- /// - Parameter o: MLMultiArray to convert
- /// - Returns: Flat array of integer values
- static func toIntArray(_ o: MLMultiArray) -> [Int] {
- var arr = Array(repeating: 0, count: o.count)
- let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer))
- for i in 0.. [Int] { Self.toIntArray(self) }
-
- /// Converts an MLMultiArray to a flat array of doubles.
- ///
- /// Concatenates all dimensions into a single one-dimensional array by reading
- /// the MLMultiArray data in memory order.
- ///
- /// - Parameter o: MLMultiArray to convert
- /// - Returns: Flat array of double values
- static func toDoubleArray(_ o: MLMultiArray) -> [Double] {
- var arr: [Double] = Array(repeating: 0, count: o.count)
- let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer))
- for i in 0.. [Double] { Self.toDoubleArray(self) }
-
- /// Creates a test MLMultiArray with sequentially indexed values.
- ///
- /// Useful for debugging and unit tests. Values are assigned sequentially
- /// starting from 0, following the memory layout of the specified shape.
- ///
- /// Example output for shape [2, 3, 4]:
- /// ```
- /// [[[ 0, 1, 2, 3 ],
- /// [ 4, 5, 6, 7 ],
- /// [ 8, 9, 10, 11 ]],
- /// [[ 12, 13, 14, 15 ],
- /// [ 16, 17, 18, 19 ],
- /// [ 20, 21, 22, 23 ]]]
- /// ```
- ///
- /// - Parameter shape: Desired shape of the test tensor
- /// - Returns: MLMultiArray with sequential values for testing
- static func testTensor(shape: [Int]) -> MLMultiArray {
- let arr = try! MLMultiArray(shape: shape as [NSNumber], dataType: .double)
- let ptr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer))
- for i in 0.. MLMultiArray {
- assert(
- indexing.count == o.shape.count
- )
- assert(
- indexing.filter { $0 == Indexing.slice }.count == 1
- )
- var selectDims: [Int: Int] = [:]
- for (i, idx) in indexing.enumerated() {
- if case let .select(select) = idx {
- selectDims[i] = select
- }
- }
- return slice(
- o,
- sliceDim: indexing.firstIndex { $0 == Indexing.slice }!,
- selectDims: selectDims
- )
- }
-
- /// Slice an array according to a list, according to `sliceDim` (which dimension to slice on)
- /// and a dictionary of `dim` to `index`.
- ///
- /// You must select all other dimensions than the slice dimension (cf. the assert).
- static func slice(_ o: MLMultiArray, sliceDim: Int, selectDims: [Int: Int]) -> MLMultiArray {
- assert(
- selectDims.count + 1 == o.shape.count
- )
- var shape: [NSNumber] = Array(repeating: 1, count: o.shape.count)
- shape[sliceDim] = o.shape[sliceDim]
- // print("About to slice ndarray of shape \(o.shape) into ndarray of shape \(shape)")
- let arr = try! MLMultiArray(shape: shape, dataType: .double)
-
- // let srcPtr = UnsafeMutablePointer(OpaquePointer(o.dataPointer))
- // TODO: use srcPtr instead of array subscripting.
- let dstPtr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer))
- for i in 0.. String {
- func indent(_ x: Int) -> String {
- String(repeating: " ", count: x)
- }
-
- // This function is called recursively for every dimension.
- // Add an entry for this dimension to the end of the array.
- var indices = indices + [0]
-
- let d = indices.count - 1 // the current dimension
- let N = shape[d].intValue // how many elements in this dimension
- var s = "["
- if indices.count < shape.count { // not last dimension yet?
- for i in 0.. {
- /// Efficiently extracts float values from the shaped array.
- ///
- /// Uses optimized memory copying when possible (stride=1), falling back to
- /// slower scalar access for non-contiguous arrays.
- ///
- /// - Returns: Array of Float values from the shaped array
- var floats: [Float] {
- guard strides.first == 1, strides.count == 1 else {
- // For some reason this path is slow.
- // If strides is not 1, we can write a Metal kernel to copy the values properly.
- return scalars
- }
-
- // Fast path: memcpy
- let mlArray = MLMultiArray(self)
- return mlArray.floats ?? scalars
- }
-}
-
-extension MLShapedArraySlice {
- /// Efficiently extracts float values from the shaped array slice.
- ///
- /// Uses optimized memory copying when possible (stride=1), falling back to
- /// slower scalar access for non-contiguous slices.
- ///
- /// - Returns: Array of Float values from the shaped array slice
- var floats: [Float] {
- guard strides.first == 1, strides.count == 1 else {
- // For some reason this path is slow.
- // If strides is not 1, we can write a Metal kernel to copy the values properly.
- return scalars
- }
-
- // Fast path: memcpy
- let mlArray = MLMultiArray(self)
- return mlArray.floats ?? scalars
- }
-}
-
-extension MLMultiArray {
- /// Efficiently extracts float values from the MLMultiArray if it contains float32 data.
- ///
- /// Uses fast memory copying to extract all float values as a contiguous array.
- /// Returns nil if the array doesn't contain float32 data.
- ///
- /// - Returns: Array of Float values, or nil if not float32 type
- var floats: [Float]? {
- guard dataType == .float32 else { return nil }
-
- var result: [Float] = Array(repeating: 0, count: count)
- return withUnsafeBytes { ptr in
- guard let source = ptr.baseAddress else { return nil }
- result.withUnsafeMutableBytes { resultPtr in
- let dest = resultPtr.baseAddress!
- memcpy(dest, source, self.count * MemoryLayout.stride)
- }
- return result
- }
- }
-}
-#endif // canImport(CoreML)
diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift
new file mode 100644
index 00000000..fafbd1b4
--- /dev/null
+++ b/Sources/Generation/Decoders.swift
@@ -0,0 +1,28 @@
+#if canImport(CoreML)
+import CoreML
+
+// MARK: Greedy Decoding
+
+@available(macOS 15.0, iOS 18.0, *)
+func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor {
+ scores.argmax(alongAxis: -1).reshaped(to: [1, 1])
+}
+
+// MARK: Top-K Sampling
+
+@available(macOS 15.0, iOS 18.0, *)
+func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor {
+ let temperatureAdjustedScores = scores / temperature
+ let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK)
+ let topKProbs = topKScores.softmax(alongAxis: -1)
+ let rnd = topKProbs.sum() * Float.random(in: 0..<1)
+ var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1)
+ accumTopKProbs += (accumTopKProbs .< rnd) * 100.0
+ let topKIndex = accumTopKProbs.argsort()[..., 0]
+ let nextTokenTensor = topKIndices.gathering(
+ atIndices: topKIndex,
+ alongAxis: topKIndices.rank - 1
+ )
+ return nextTokenTensor.reshaped(to: [1, 1])
+}
+#endif // canImport(CoreML)
diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift
index e56af9cc..0cdfd375 100644
--- a/Sources/Generation/Generation.swift
+++ b/Sources/Generation/Generation.swift
@@ -8,6 +8,7 @@
#if canImport(CoreML)
import CoreML
+import CoreML
import Tokenizers
/// Supported text generation modes.
@@ -37,7 +38,8 @@ public typealias GenerationOutput = [Int]
/// - Parameter tokens: Input token sequence
/// - Parameter config: Generation configuration
/// - Returns: Logits array for next token prediction
-public typealias NextTokenModel = (InputTokens, GenerationConfig) -> any MLShapedArrayProtocol
+@available(macOS 15.0, iOS 18.0, *)
+public typealias NextTokenModel = (MLTensor, GenerationConfig) async -> MLTensor
/// Callback for receiving generated tokens during streaming.
public typealias PredictionTokensCallback = (GenerationOutput) -> Void
@@ -46,17 +48,8 @@ public typealias PredictionTokensCallback = (GenerationOutput) -> Void
public typealias PredictionStringCallback = (String) -> Void
/// Protocol for text generation implementations.
+@available(macOS 15.0, iOS 18.0, *)
public protocol Generation {
- /// Performs greedy search generation.
- ///
- /// - Parameters:
- /// - config: Generation configuration
- /// - tokens: Input token sequence
- /// - model: Model for next token prediction
- /// - callback: Optional callback for streaming tokens
- /// - Returns: Generated token sequence
- func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput
-
/// Generates text from a prompt string.
///
/// - Parameters:
@@ -69,89 +62,84 @@ public protocol Generation {
func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String
}
-public extension Generation {
- func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
- // Iterate until we find the eos token or reach the max length
- // TODO: additional stopping criteria
- var outputTokens = tokens
- while outputTokens.count < config.maxLength {
- let logits = model(outputTokens, config)
- let (nextToken, _) = Math.argmax(logits)
- if nextToken == config.eosTokenId { break }
- outputTokens.append(nextToken)
- callback?(outputTokens)
+@available(macOS 15.0, iOS 18.0, *)
+extension Generation {
+ public func generate(
+ config: GenerationConfig,
+ tokens: InputTokens,
+ model: NextTokenModel,
+ callback: PredictionTokensCallback? = nil
+ ) async -> GenerationOutput {
+ let tokens = tokens.map { Int32($0) }
+ var outputTokens = MLTensor(tokens).expandingShape(at: 0)
+ while outputTokens.shape[1] < config.maxLength {
+ let nextTokenScores = await model(outputTokens, config)
+ let nextToken =
+ switch config.generationMode {
+ case .greedy:
+ selectNextTokenUsingGreedyDecoding(from: nextTokenScores)
+ case .sample:
+ selectNextTokenUsingTopKSampling(
+ from: nextTokenScores,
+ temperature: config.temperature,
+ topK: config.topK
+ )
+ default:
+ fatalError("Generation mode \(config.generationMode) not implemented yet")
+ }
+
+ if let nextTokenId = await tensorToGenerationOutput(nextToken).first, nextTokenId == config.eosTokenId {
+ break
+ }
+
+ outputTokens = MLTensor(concatenating: [outputTokens, nextToken], alongAxis: -1)
+ if let callback {
+ let outputTokenIDs = await tensorToGenerationOutput(outputTokens)
+ callback(outputTokenIDs)
+ }
}
- return outputTokens
+ return await tensorToGenerationOutput(outputTokens)
}
- /// Performs sampling-based text generation with configurable logits warping.
- ///
- /// Uses various logits warpers (temperature, top-k, top-p, repetition penalty) to modify
- /// token probabilities before sampling, enabling diverse and controllable text generation.
+ private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput {
+ await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) }
+ }
+}
+
+@available(macOS 15.0, iOS 18.0, *)
+public extension Generation {
+ /// Performs greedy or sampling-based text generation based on generation configuration.
///
/// - Parameters:
/// - config: Generation configuration with sampling parameters
- /// - tokens: Input token sequence
+ /// - prompt: Input string
/// - model: Model for next token prediction
+ /// - tokenizer: Tokenizer to convert prompt to input tokens
/// - callback: Optional callback for streaming tokens
/// - Returns: Generated token sequence
///
/// - Note: Based on https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552
- func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
- // Iterate until we find the eos token or reach the max length
- // TODO: additional stopping criteria
- var outputTokens = tokens
- let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config))
- while outputTokens.count < config.maxLength {
- let outputs = model(outputTokens, config)
- // `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
- let logits = (outputs as? MLShapedArraySlice)?.floats ?? outputs.scalars as! [Float]
- let (indexes, processedLogits) = logitsProcessor(logits)
- let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits))
- if nextToken == config.eosTokenId { break }
- outputTokens.append(nextToken)
- callback?(outputTokens)
- }
- return outputTokens
- }
-
- func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback? = nil) async -> String {
+ func generate(
+ config: GenerationConfig,
+ prompt: String,
+ model: NextTokenModel,
+ tokenizer: Tokenizer,
+ callback: PredictionStringCallback? = nil
+ ) async -> String {
let tokens = tokenizer.encode(text: prompt)
var generationConfig = config
generationConfig.maxLength = config.maxNewTokens + tokens.count
-
- let output: GenerationOutput
- switch generationConfig.generationMode {
- case .greedy:
- output = await greedySearch(config: generationConfig, tokens: tokens, model: model) { tokens in
- callback?(tokenizer.decode(tokens: tokens))
- }
- case .sample:
- output = await sample(config: generationConfig, tokens: tokens, model: model) { tokens in
- callback?(tokenizer.decode(tokens: tokens))
- }
- default:
- fatalError("Generation mode \(generationConfig.generationMode) not implemented yet")
+ generationConfig.eosTokenId = tokenizer.eosTokenId
+ generationConfig.bosTokenId = tokenizer.bosTokenId
+ let output = await generate(
+ config: generationConfig,
+ tokens: tokens,
+ model: model
+ ) { tokens in
+ callback?(tokenizer.decode(tokens: tokens))
}
return tokenizer.decode(tokens: output)
}
-
- private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] {
- var logitsWarpers = [any LogitsWarper]()
- if config.temperature > 0, config.temperature != 1 {
- logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature)))
- }
- if config.topK > 0 {
- logitsWarpers.append(TopKLogitsWarper(k: config.topK))
- }
- if config.topP < 1.0 {
- logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP)))
- }
- if config.repetitionPenalty != 1.0 {
- logitsWarpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty))
- }
- return logitsWarpers
- }
}
#endif // canImport(CoreML)
diff --git a/Sources/Generation/GenerationConfig.swift b/Sources/Generation/GenerationConfig.swift
index 978f404a..e7389fbb 100644
--- a/Sources/Generation/GenerationConfig.swift
+++ b/Sources/Generation/GenerationConfig.swift
@@ -33,7 +33,7 @@ public struct GenerationConfig {
public var penaltyAlpha: Double?
/// Temperature for sampling (higher values increase randomness).
- public var temperature = 1.0
+ public var temperature: Float = 1.0
/// Number of top tokens to consider for top-k sampling.
public var topK = 50
@@ -66,14 +66,25 @@ public struct GenerationConfig {
/// - topK: Top-k sampling parameter
/// - topP: Top-p sampling parameter
/// - repetitionPenalty: Repetition penalty factor
- public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) {
+ public init(
+ maxLength: Int = 20,
+ maxNewTokens: Int,
+ doSample: Bool = false,
+ numBeams: Int = 1,
+ numBeamGroups: Int = 1,
+ penaltyAlpha: Double? = nil,
+ temperature: Double = 1.0,
+ topK: Int = 50,
+ topP: Double = 1.0,
+ repetitionPenalty: Double = 1.0
+ ) {
self.maxLength = maxLength
self.maxNewTokens = maxNewTokens
self.doSample = doSample
self.numBeams = numBeams
self.numBeamGroups = numBeamGroups
self.penaltyAlpha = penaltyAlpha
- self.temperature = temperature
+ self.temperature = Float(temperature)
self.topK = topK
self.topP = topP
self.repetitionPenalty = repetitionPenalty
diff --git a/Sources/Generation/LogitsWarper/LogitsProcessor.swift b/Sources/Generation/LogitsWarper/LogitsProcessor.swift
deleted file mode 100644
index 8f16477c..00000000
--- a/Sources/Generation/LogitsWarper/LogitsProcessor.swift
+++ /dev/null
@@ -1,33 +0,0 @@
-import Foundation
-
-/// Processes logits by applying a sequence of logits warpers.
-///
-/// Coordinates the application of multiple logits warpers in sequence,
-/// allowing for complex probability transformations during text generation.
-public struct LogitsProcessor {
- /// Array of logits warpers to apply in sequence.
- public var logitsWarpers: [any LogitsWarper]
-
- /// Creates a new logits processor.
- ///
- /// - Parameter logitsWarpers: Array of warpers to apply in sequence
- public init(logitsWarpers: [any LogitsWarper]) {
- self.logitsWarpers = logitsWarpers
- }
-
- /// Processes logits by applying all warpers in sequence.
- ///
- /// Each warper is applied to the output of the previous warper, allowing
- /// for complex chaining of probability transformations.
- ///
- /// - Parameter arr: Input logits array
- /// - Returns: Tuple of processed (indices, logits)
- public func callAsFunction(_ arr: [Float]) -> (indices: [Int], logits: [Float]) {
- var indices = Array(arr.indices)
- var logits = arr
- for warper in logitsWarpers {
- (indices, logits) = warper(indices, logits)
- }
- return (indices: indices, logits: logits)
- }
-}
diff --git a/Sources/Generation/LogitsWarper/LogitsWarper.swift b/Sources/Generation/LogitsWarper/LogitsWarper.swift
deleted file mode 100644
index 576820c0..00000000
--- a/Sources/Generation/LogitsWarper/LogitsWarper.swift
+++ /dev/null
@@ -1,35 +0,0 @@
-import Foundation
-
-/// Protocol for logits warpers that transform token probabilities during generation.
-///
-/// Logits warpers modify the probability distribution over tokens before sampling,
-/// enabling techniques like temperature scaling, top-k/top-p filtering, and repetition penalties.
-public protocol LogitsWarper {
- /// Warps the logits and corresponding indices.
- ///
- /// - Parameters:
- /// - indices: Array of token indices corresponding to the logits
- /// - logits: Array of logit values to transform
- /// - Returns: Tuple of transformed (indices, logits)
- func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float])
-
- /// Convenience method that calls the warp function.
- ///
- /// - Parameters:
- /// - indices: Array of token indices
- /// - logits: Array of logit values
- /// - Returns: Tuple of transformed (indices, logits)
- func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float])
-}
-
-public extension LogitsWarper {
- /// Default implementation of callAsFunction that delegates to warp.
- ///
- /// - Parameters:
- /// - indices: Array of token indices
- /// - logits: Array of logit values
- /// - Returns: Tuple of transformed (indices, logits)
- func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) {
- warp(indices: indices, logits: logits)
- }
-}
diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift
deleted file mode 100644
index 7e89d67e..00000000
--- a/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift
+++ /dev/null
@@ -1,42 +0,0 @@
-import Foundation
-
-/// Logits warper that prevents repetition of previous tokens through a penalty.
-///
-/// Applies a penalty to tokens that have already been generated, reducing their
-/// probability of being selected again. The penalty is applied differently based
-/// on the sign of the logit value to maintain numerical stability.
-///
-/// - Note: Based on https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L294
-public struct RepetitionPenaltyWarper: LogitsWarper {
- /// Penalty factor applied to repeated tokens.
- public var penalty: Float
-
- /// Creates a repetition penalty warper.
- ///
- /// - Parameter penalty: Penalty factor (>1.0 discourages repetition, <1.0 encourages it)
- public init(penalty: Double) {
- self.penalty = Float(penalty)
- }
-
- /// Applies repetition penalty to the logits.
- ///
- /// For positive logits, divides by penalty. For negative logits, multiplies by penalty.
- /// This asymmetric approach maintains numerical stability while effectively penalizing repetition.
- ///
- /// - Parameters:
- /// - indices: Token indices to apply penalty to
- /// - logits: Current logits values
- /// - Returns: Tuple of (indices, penalized logits)
- public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
- var logits = logits
- for index in indices {
- if logits[index] < 0 {
- logits[index] *= penalty
- } else {
- logits[index] /= penalty
- }
- }
-
- return (indices, logits)
- }
-}
diff --git a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift
deleted file mode 100644
index bf985770..00000000
--- a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift
+++ /dev/null
@@ -1,32 +0,0 @@
-import Foundation
-
-/// Logits warper that applies temperature scaling to control generation randomness.
-///
-/// Temperature scaling modifies the "sharpness" of the probability distribution:
-/// - Temperature > 1.0: Makes distribution more uniform (more random)
-/// - Temperature < 1.0: Makes distribution more peaked (less random)
-/// - Temperature = 1.0: No change
-public struct TemperatureLogitsWarper: LogitsWarper {
- /// Temperature scaling factor.
- public var temperature: Float
-
- /// Creates a temperature logits warper.
- ///
- /// - Parameter temperature: Scaling factor (higher values increase randomness)
- public init(temperature: Float) {
- self.temperature = temperature
- }
-
- /// Applies temperature scaling to the logits.
- ///
- /// Divides each logit by the temperature value, which affects the final
- /// probability distribution after softmax is applied.
- ///
- /// - Parameters:
- /// - indices: Token indices (unchanged)
- /// - logits: Current logits values
- /// - Returns: Tuple of (indices, temperature-scaled logits)
- public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
- (indices: indices, logits: logits.map { $0 / temperature })
- }
-}
diff --git a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift
deleted file mode 100644
index 0665bbe0..00000000
--- a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift
+++ /dev/null
@@ -1,74 +0,0 @@
-#if canImport(Accelerate)
-import Accelerate
-import Foundation
-
-/// Logits warper that implements top-k filtering for sampling.
-///
-/// Selects the k most probable tokens and sets all other token probabilities
-/// to zero, effectively limiting the sampling space to the top k candidates.
-/// This helps balance diversity and quality in generated text.
-public struct TopKLogitsWarper: LogitsWarper {
- /// Number of top tokens to keep.
- public var k: Int
-
- /// Creates a top-k logits warper.
- ///
- /// - Parameter k: Number of top tokens to retain (others are filtered out)
- public init(k: Int) {
- self.k = k
- }
-
- /// Applies top-k filtering to the logits.
- ///
- /// Uses Accelerate framework's optimized top-k algorithm to efficiently
- /// select the k highest-valued logits and their corresponding indices.
- ///
- /// - Parameters:
- /// - indices: Current token indices
- /// - logits: Current logits values
- /// - Returns: Tuple of (top-k indices, top-k logits)
- public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
- guard !logits.isEmpty else {
- return (indices: [], logits: [])
- }
- let k = min(k, logits.count)
- let arrDescriptor = BNNSNDArrayDescriptor.allocate(
- initializingFrom: logits,
- shape: .vector(logits.count)
- )
- defer {
- arrDescriptor.deallocate()
- }
- let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(
- scalarType: Int32.self,
- shape: .vector(k)
- )
- defer {
- bestIndices.deallocate()
- }
- let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(
- scalarType: Float.self,
- shape: .vector(k)
- )
- defer {
- bestValues.deallocate()
- }
- try! Accelerate.BNNS.applyTopK(
- k: k,
- input: arrDescriptor,
- bestValues: bestValues,
- bestIndices: bestIndices,
- axis: 0,
- batchSize: 1,
- filterParameters: nil
- )
- let topkLogits = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in
- Array(UnsafeBufferPointer(start: ptr, count: k))
- }
- let topkIndices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
- Array(UnsafeBufferPointer(start: ptr, count: k))
- }
- return (indices: topkIndices.map { indices[Int($0)] }, logits: topkLogits)
- }
-}
-#endif // canImport(Accelerate)
diff --git a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift
deleted file mode 100644
index 114660bd..00000000
--- a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift
+++ /dev/null
@@ -1,54 +0,0 @@
-import Foundation
-
-/// Logits warper that implements nucleus (top-p) sampling.
-///
-/// Selects the smallest set of tokens whose cumulative probability exceeds
-/// the threshold p, providing dynamic vocabulary selection based on the
-/// probability distribution rather than a fixed number of tokens.
-///
-/// - Note: Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
-public struct TopPLogitsWarper: LogitsWarper {
- /// Cumulative probability threshold.
- public var p: Float
-
- /// Creates a top-p (nucleus) logits warper.
- ///
- /// - Parameter p: Cumulative probability threshold (0.0 to 1.0)
- public init(p: Float) {
- self.p = p
- }
-
- /// Applies top-p (nucleus) filtering to the logits.
- ///
- /// Computes softmax probabilities, sorts by probability, and selects tokens
- /// until their cumulative probability exceeds the threshold p.
- ///
- /// - Parameters:
- /// - indices: Current token indices
- /// - logits: Current logits values
- /// - Returns: Tuple of (filtered indices, filtered logits)
- public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
- guard !logits.isEmpty else {
- return (indices: [], logits: [])
- }
-
- let arrSoftmax = Math.softmax(logits)
- var indexLogitProb = [(index: Int, logit: Float, prob: Float)]()
- indexLogitProb.reserveCapacity(logits.count)
- for (index, data) in zip(logits, arrSoftmax).enumerated() {
- indexLogitProb.append((index: index, logit: data.0, prob: data.1))
- }
- indexLogitProb.sort { $0.prob > $1.prob }
-
- let cumsum = Math.cumsum(indexLogitProb.map(\.prob))
- var sliceIndex = cumsum.count - 1
- for (index, element) in cumsum.enumerated() where element > p {
- sliceIndex = index
- break
- }
-
- let toppIndices = indexLogitProb[0...sliceIndex].map { indices[$0.index] }
- let toppLogits = indexLogitProb[0...sliceIndex].map(\.logit)
- return (indices: toppIndices, logits: toppLogits)
- }
-}
diff --git a/Sources/Generation/Math.swift b/Sources/Generation/Math.swift
deleted file mode 100644
index aadbc3ea..00000000
--- a/Sources/Generation/Math.swift
+++ /dev/null
@@ -1,214 +0,0 @@
-//
-// Math.swift
-// CoreMLBert
-//
-// Created by Julien Chaumond on 27/06/2019.
-// Copyright © 2019 Hugging Face. All rights reserved.
-//
-
-#if canImport(CoreML) && canImport(Accelerate)
-import Accelerate
-import CoreML
-import Foundation
-
-/// Mathematical utilities for text generation and tensor operations.
-///
-/// Provides optimized implementations of common mathematical operations
-/// used in text generation, including argmax, softmax, sampling, and cumulative sum.
-///
-/// - Note: From M.I. Hollemans - https://github.com/hollance/CoreMLHelpers
-public enum Math {
- /**
- Returns the index and value of the largest element in the array.
-
- - Parameters:
- - ptr: Pointer to the first element in memory.
- - count: How many elements to look at.
- - stride: The distance between two elements in memory.
- */
- public static func argmax(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) {
- var maxValue: Float = 0
- var maxIndex: vDSP_Length = 0
- vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count))
- return (Int(maxIndex), maxValue)
- }
-
- /**
- Returns the index and value of the largest element in the array.
- - Parameters:
- - ptr: Pointer to the first element in memory.
- - count: How many elements to look at.
- - stride: The distance between two elements in memory.
- */
- public static func argmax(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Double) {
- var maxValue: Double = 0
- var maxIndex: vDSP_Length = 0
- vDSP_maxviD(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count))
- return (Int(maxIndex), maxValue)
- }
-
- /// Returns the index and value of the largest element in a Float array.
- ///
- /// - Parameters:
- /// - ptr: Pointer to the first element in memory
- /// - count: How many elements to look at
- /// - stride: The distance between two elements in memory
- /// - Returns: Tuple of (index, value) of the maximum element
- public static func argmax32(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) {
- var maxValue: Float = 0
- var maxIndex: vDSP_Length = 0
- vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count))
- return (Int(maxIndex), maxValue)
- }
-
- /// Returns the index and value of the largest element in an MLMultiArray of doubles.
- ///
- /// - Parameter multiArray: Input MLMultiArray with double precision values
- /// - Returns: Tuple of (index, value) of the maximum element
- public static func argmax(_ multiArray: MLMultiArray) -> (Int, Double) {
- assert(multiArray.dataType == .double)
- let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer))
- return Math.argmax(ptr, count: multiArray.count)
- }
-
- /// Returns the index and value of the largest element in an MLMultiArray of floats.
- ///
- /// - Parameter multiArray: Input MLMultiArray with single precision values
- /// - Returns: Tuple of (index, value) of the maximum element
- public static func argmax32(_ multiArray: MLMultiArray) -> (Int, Float) {
- assert(multiArray.dataType == .float32)
- let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer))
- return Math.argmax32(ptr, count: multiArray.count)
- }
-
- /// Returns the cumulative sum of the array.
- ///
- /// Computes the cumulative sum where each element is the sum of all previous elements
- /// plus the current element.
- ///
- /// - Parameter arr: Input array of Float values
- /// - Returns: Array of cumulative sums
- public static func cumsum(_ arr: [Float]) -> [Float] {
- guard !arr.isEmpty else {
- return []
- }
- let arrCount = vDSP_Length(arr.count)
- var weight: Float = 1.0
- var result: [Float] = Array(repeating: 0.0, count: arr.count)
- var firstItem = arr[0]
- vDSP_vrsum(arr, 1, &weight, &result, 1, arrCount)
- vDSP_vsadd(result, 1, &firstItem, &result, 1, arrCount)
- return result
- }
-
- /// Performs multinomial sampling from probability distributions.
- ///
- /// Selects an index based on probability weights, commonly used for token sampling
- /// in text generation after applying logits warpers.
- ///
- /// - Parameters:
- /// - indexes: Array of indices to sample from
- /// - probs: Probability weights for each index
- /// - Returns: Selected index based on probability distribution
- public static func sample(indexes: [Int], probs: [Float]) -> Int {
- let i = randomNumber(probabilities: probs)
- return indexes[i]
- }
-
- /// Computes the softmax function over an array.
- ///
- /// Converts logits into a probability distribution by applying exponential normalization.
- /// Uses numerical stability techniques by shifting values to prevent overflow.
- ///
- /// The implementation follows this algorithm:
- /// 1. Subtract maximum value for numerical stability
- /// 2. Apply exponential function to all elements
- /// 3. Normalize by dividing by the sum of all exponentials
- ///
- /// - Parameter x: Input logits array
- /// - Returns: Probability distribution (sums to 1.0)
- ///
- /// - Note: Based on code from https://github.com/nikolaypavlov/MLPNeuralNet/
- public static func softmax(_ x: [Float]) -> [Float] {
- var x = x
- let len = vDSP_Length(x.count)
-
- // Find the maximum value in the input array.
- var max: Float = 0
- vDSP_maxv(x, 1, &max, len)
-
- // Subtract the maximum from all the elements in the array.
- // Now the highest value in the array is 0.
- max = -max
- vDSP_vsadd(x, 1, &max, &x, 1, len)
-
- // Exponentiate all the elements in the array.
- var count = Int32(x.count)
- vvexpf(&x, x, &count)
-
- // Compute the sum of all exponentiated values.
- var sum: Float = 0
- vDSP_sve(x, 1, &sum, len)
-
- // Divide each element by the sum. This normalizes the array contents
- // so that they all add up to 1.
- vDSP_vsdiv(x, 1, &sum, &x, 1, len)
-
- return x
- }
-
- /// Generates a random index based on probability weights.
- ///
- /// Uses the roulette wheel selection algorithm to choose an index where
- /// the probability of selection is proportional to the weight at that index.
- ///
- /// - Parameter probabilities: Array of probability weights (need not sum to 1.0)
- /// - Returns: Selected index based on probability distribution
- ///
- /// - Note: From https://stackoverflow.com/questions/30309556/generate-random-numbers-with-a-given-distribution
- public static func randomNumber(probabilities: [Float]) -> Int {
- // Sum of all probabilities (so that we don't have to require that the sum is 1.0):
- let sum = probabilities.reduce(0, +)
- // Random number in the range 0.0 <= rnd < sum :
- let rnd = sum * Float(arc4random_uniform(UInt32.max)) / Float(UInt32.max)
- // Find the first interval of accumulated probabilities into which `rnd` falls:
- var accum: Float = 0.0
- for (i, p) in probabilities.enumerated() {
- accum += p
- if rnd < accum {
- return i
- }
- }
- // This point might be reached due to floating point inaccuracies:
- return probabilities.count - 1
- }
-}
-
-/// MLShapedArray extensions for Math operations.
-public extension Math {
- /// Returns the index and value of the largest element in an MLShapedArray of floats.
- ///
- /// - Parameter shapedArray: Input MLShapedArray containing Float values
- /// - Returns: Tuple of (index, value) of the maximum element
- static func argmax(_ shapedArray: MLShapedArray) -> (Int, Float) {
- shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in
- assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices")
- return Math.argmax32(ptr.baseAddress!, count: shapedArray.count, stride: strides.first!)
- }
- }
-
- /// Returns the index and value of the largest element in a generic MLShapedArray.
- ///
- /// - Parameter shapedArray: Input shaped array conforming to MLShapedArrayProtocol
- /// - Returns: Tuple of (index, value) of the maximum element as Float
- ///
- /// - Note: Currently assumes Float data type
- static func argmax(_ shapedArray: some MLShapedArrayProtocol) -> (Int, Float) {
- shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in
- assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices")
- let floatsPtr = ptr.baseAddress as! UnsafePointer
- return Math.argmax32(floatsPtr, count: shapedArray.count, stride: strides.first!)
- }
- }
-}
-#endif // canImport(CoreML) && canImport(Accelerate)
diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift
index ff6d7440..a7129c58 100644
--- a/Sources/Hub/Hub.swift
+++ b/Sources/Hub/Hub.swift
@@ -27,6 +27,8 @@ public extension Hub {
case httpStatusCode(Int)
/// Failed to parse server response or configuration data.
case parse
+ /// Expected json response could not be parsed as json.
+ case jsonSerialization(fileURL: URL, message: String)
/// An unexpected error occurred during operation.
case unexpectedError
/// A download operation failed with the specified error message.
@@ -52,6 +54,8 @@ public extension Hub {
String(localized: "HTTP error with status code: \(code)")
case .parse:
String(localized: "Failed to parse server response.")
+ case .jsonSerialization(_, let message):
+ message
case .unexpectedError:
String(localized: "An unexpected error occurred.")
case let .downloadError(message):
diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift
index 9211f770..e89dab0a 100644
--- a/Sources/Hub/HubApi.swift
+++ b/Sources/Hub/HubApi.swift
@@ -314,7 +314,9 @@ public extension HubApi {
/// `fileURL` is a complete local file path for the given model
func configuration(fileURL: URL) throws -> Config {
let data = try Data(contentsOf: fileURL)
- let parsed = try JSONSerialization.bomPreservingJsonObject(with: data)
+ guard let parsed = try? JSONSerialization.bomPreservingJsonObject(with: data) else {
+ throw Hub.HubClientError.jsonSerialization(fileURL: fileURL, message: "JSON Serialization failed for \(fileURL). Please verify that you have set the HF_TOKEN environment variable.")
+ }
guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse }
return Config(dictionary)
}
diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift
index 8cd23831..59203c6a 100644
--- a/Sources/Models/LanguageModel.swift
+++ b/Sources/Models/LanguageModel.swift
@@ -12,6 +12,7 @@ import Generation
import Hub
import Tokenizers
+@available(macOS 15.0, iOS 18.0, *)
/// A high-level interface for language model inference using CoreML.
///
/// `LanguageModel` provides a convenient way to load and interact with pre-trained
@@ -27,15 +28,6 @@ public class LanguageModel {
/// The maximum context length supported by the model.
public let maxContextLength: Int
- let input_ids = "input_ids"
- let attention_mask = "attention_mask"
-
- struct Configurations {
- var modelConfig: Config
- var tokenizerConfig: Config?
- var tokenizerData: Config
- }
-
private var configuration: LanguageModelConfigurationFromHub?
private var _tokenizer: Tokenizer?
@@ -45,36 +37,104 @@ public class LanguageModel {
/// - Important: Triggers a fatal error if the model doesn't have the expected input shape information
public required init(model: MLModel) {
self.model = model
+ (minContextLength, maxContextLength) = Self.contextRange(from: model)
+ configuration = LanguageModelConfigurationFromHub(modelName: modelName)
+ }
+
+ public func resetState() async {}
+
+ public func predictNextTokenScores(
+ _ tokens: MLTensor,
+ config: GenerationConfig
+ ) async -> MLTensor {
+ assert(tokens.rank == 2) // [batch, current sequence length]
+ let tokenCount = tokens.shape[1]
+ let padLength = maxContextLength - tokenCount
+ let padding = MLTensor(repeating: Int32(config.padTokenId ?? 0), shape: [1, padLength])
+ let inputIDs = MLTensor(concatenating: [tokens, padding], alongAxis: -1)
+ var inputDictionary = [inputIdsName: inputIDs]
+ if isRequiringAttentionMask {
+ let mask = [Int32](repeating: 1, count: tokenCount) + [Int32](repeating: 0, count: padLength)
+ let attentionMask = MLTensor(shape: inputIDs.shape, scalars: mask)
+ inputDictionary[Keys.attentionMask] = attentionMask
+ }
+ let outputs = try! await model.prediction(from: inputDictionary)
+
+ assert(outputs.keys.contains(Keys.logits))
+
+ let scores = outputs[Keys.logits]!
+ assert(scores.rank == 3)
+ let tokenIndex = tokenCount - 1
+ let nextTokenScores = scores[nil, tokenIndex, nil].expandingShape(at: 0)
+ assert(nextTokenScores.rank == 3)
+ assert(nextTokenScores.shape[0] == 1 && nextTokenScores.shape[1] == 1)
+ return nextTokenScores
+ }
+}
+
+@available(macOS 15.0, iOS 18.0, *)
+private extension LanguageModel {
+ static func contextRange(from model: MLModel) -> (min: Int, max: Int) {
+ contextRange(from: model, inputKey: Keys.inputIds)
+ }
- // We assume inputs named "input_ids" with shape (1, seq_length)
- // Perhaps we should convert to vectors of shape (seq_length) and use sequenceConstraint instead of shapeConstraint
- let inputDescription = model.modelDescription.inputDescriptionsByName["input_ids"]
+ static func contextRange(from model: MLModel, inputKey: String) -> (min: Int, max: Int) {
+ let inputDescription = model.modelDescription.inputDescriptionsByName[inputKey]
guard let shapeConstraint = inputDescription?.multiArrayConstraint?.shapeConstraint else {
fatalError("Cannot obtain shape information")
}
+ var minContextLength = 128
+ var maxContextLength = 128
+
switch shapeConstraint.type {
case .enumerated:
- // TODO: support a set of fixed shapes (keeping the first one here)
minContextLength = shapeConstraint.enumeratedShapes[0][1].intValue
maxContextLength = minContextLength
case .range:
- let range = inputDescription?.multiArrayConstraint?.shapeConstraint.sizeRangeForDimension[1] as? NSRange
- minContextLength = range?.location ?? 1
- maxContextLength = range?.length ?? 128
+ if let sizeRangeForDimension = inputDescription?.multiArrayConstraint?.shapeConstraint.sizeRangeForDimension {
+ let lastAxis = sizeRangeForDimension.count - 1
+ let range = sizeRangeForDimension[lastAxis] as? NSRange
+ minContextLength = range?.location ?? 1
+ maxContextLength = range?.length ?? 128
+ }
case .unspecified:
- minContextLength = 128
- maxContextLength = 128
+ break
@unknown default:
- minContextLength = 128
- maxContextLength = 128
+ break
}
- configuration = LanguageModelConfigurationFromHub(modelName: modelName)
+ return (minContextLength, maxContextLength)
}
}
+@available(macOS 15.0, iOS 18.0, *)
+extension LanguageModel {
+ struct Configurations {
+ var modelConfig: Config
+ var tokenizerConfig: Config?
+ var tokenizerData: Config
+ }
+}
+
+@available(macOS 15.0, iOS 18.0, *)
+extension LanguageModel {
+ enum Keys {
+ // Input keys
+ static let inputIds = "inputIds"
+ static let attentionMask = "attentionMask"
+ static let causalMask = "causalMask"
+ static let keyCache = "keyCache"
+ static let valueCache = "valueCache"
+ // Output keys
+ static let logits = "logits"
+ static let presentKeys = "presentKeys"
+ static let presentValues = "presentValues"
+ }
+}
+
+@available(macOS 15.0, iOS 18.0, *)
public extension LanguageModel {
/// Loads a compiled CoreML model from disk.
///
@@ -87,16 +147,44 @@ public extension LanguageModel {
let config = MLModelConfiguration()
config.computeUnits = computeUnits
let model = try MLModel(contentsOf: url, configuration: config)
- return LanguageModel(model: model)
+ return switch kvCacheAvailability(for: model) {
+ case .statefulKVCache: LanguageModelWithStatefulKVCache(model: model)
+ default: LanguageModel(model: model)
+ }
+ }
+}
+
+@available(macOS 15.0, iOS 18.0, *)
+extension LanguageModel {
+ enum KVCacheAvailability {
+ /// Language models that support KV cache via state. Implementation details for handling state
+ /// encapsulated within the Core ML framework.
+ ///
+ /// Input: State
+ /// Output: N/A
+ case statefulKVCache
}
}
+@available(macOS 15.0, iOS 18.0, *)
public extension LanguageModel {
+ /// Metadata fields associated to the Core ML model.
+ var metadata: [MLModelMetadataKey: Any] {
+ model.modelDescription.metadata
+ }
+
+ /// A description of a model containing input, output, and state feature descriptions.
+ ///
+ /// Returns a MLModelDescription instance.
+ var modelDescription: MLModelDescription {
+ model.modelDescription
+ }
+
/// A human-readable description of the model.
///
/// Returns the model's description from its metadata, or the display name if no description is available.
var description: String {
- if let description = model.modelDescription.metadata[MLModelMetadataKey.description] as? String,
+ if let description = metadata[MLModelMetadataKey.description] as? String,
!description.isEmpty
{
return description
@@ -109,19 +197,21 @@ public extension LanguageModel {
/// Returns the model identifier from Hugging Face Hub metadata if available,
/// otherwise falls back to the model's display name.
var modelName: String {
- if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String],
+ if let userFields = metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String],
let name = userFields["co.huggingface.exporters.name"]
{
return name
}
// This is usually the basename of the file, that's our best bet if no metadata exists
- guard let modelName = model.configuration.modelDisplayName else { fatalError("Models must have a name that identifies them") }
+ guard let modelName = model.configuration.modelDisplayName else {
+ fatalError("Models must have a name that identifies them")
+ }
return modelName
}
/// The feature description for the input_ids input.
var inputIdsDescription: MLFeatureDescription {
- model.modelDescription.inputDescriptionsByName[input_ids]!
+ modelDescription.inputDescriptionsByName[Keys.inputIds]!
}
/// The name of the input_ids feature in the model.
@@ -131,51 +221,82 @@ public extension LanguageModel {
/// The expected shape of the input_ids tensor.
var inputIdsShape: [Int] {
- inputIdsDescription.multiArrayConstraint!.shape.map { $0.intValue }
+ inputIdsDescription.multiArrayConstraint!.shape.map(\.intValue)
}
/// Whether the model requires attention mask inputs.
- var requiresAttention: Bool {
- model.modelDescription.inputDescriptionsByName[attention_mask] != nil
+ var isRequiringAttentionMask: Bool {
+ modelDescription.inputDescriptionsByName[Keys.attentionMask] != nil
+ }
+
+ /// Whether the model requires a causal attention mask.
+ var isRequiringCausalMask: Bool {
+ modelDescription.inputDescriptionsByName[Keys.causalMask] != nil
}
- /// Predicts the next token scores for the given input tokens.
+ /// Determines the type of KV Cache available for the model, if any.
///
/// - Parameters:
- /// - tokens: The input token sequence
- /// - config: The generation configuration containing model parameters
- /// - Returns: A shaped array containing the logits for the next token prediction
- func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol {
- // TODO: exceptions
-
- // Maybe pad or truncate
- let maxTokens = min(tokens.count, maxContextLength)
- let padLength = maxTokens >= minContextLength ? 0 : minContextLength - maxTokens
- let inputTokens = Array(tokens[0..(scalars: inputTokens.map { Int32($0) }, shape: inputIdsShape)
- var inputDictionary = [inputIdsName: MLFeatureValue(shapedArray: inputIds)]
- if requiresAttention {
- let mask = Array(repeating: 1, count: maxTokens) + Array(repeating: 0, count: padLength)
- let attentionMask = MLShapedArray(scalars: mask.map { Int32($0) }, shape: inputIdsShape)
- inputDictionary[attention_mask] = MLFeatureValue(shapedArray: attentionMask)
+ /// - model: The Core ML model
+ /// - Returns: The type of KV Cache available.
+ fileprivate static func kvCacheAvailability(for model: MLModel) -> KVCacheAvailability? {
+ func isStatefulKVCacheAvailable(for model: MLModel) -> Bool {
+ let kCacheState = model.modelDescription.stateDescriptionsByName[Keys.keyCache] != nil
+ let vCacheState = model.modelDescription.stateDescriptionsByName[Keys.valueCache] != nil
+ guard Set([kCacheState, vCacheState]).count == 1 else {
+ fatalError("Invalid model configuration, expecting KV cache for states")
+ }
+ return kCacheState && kCacheState
}
- let input = try! MLDictionaryFeatureProvider(dictionary: inputDictionary)
- let output = try! model.prediction(from: input)
+ func isDynamicallyShaped(_ description: MLFeatureDescription) -> Bool {
+ guard let multiArrayConstraint = description.multiArrayConstraint else {
+ return false
+ }
+ return switch multiArrayConstraint.shapeConstraint.type {
+ case .unspecified: true
+ case .enumerated: multiArrayConstraint.shapeConstraint.enumeratedShapes.count > 1
+ case .range: true
+ default: false
+ }
+ }
- // TODO: maybe try to support models with "token_scores" too (after the softmax)
- assert(output.featureNames.first! == "logits")
+ if isStatefulKVCacheAvailable(for: model) {
+ return .statefulKVCache
+ }
+ let kCacheInput = model.modelDescription.inputDescriptionsByName[Keys.keyCache] != nil
+ let vCacheInput = model.modelDescription.inputDescriptionsByName[Keys.valueCache] != nil
+ let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentKeys] != nil
+ let vCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentValues] != nil
- let scores = output.featureValue(for: output.featureNames.first!)!.shapedArrayValue(of: Float.self)!
- let nextTokenScores = scores[0, maxTokens - 1]
- return nextTokenScores
+ guard Set([kCacheInput, vCacheInput, kCacheOutput, vCacheOutput]).count == 1 else {
+ fatalError("Invalid model configuration, expecting KV cache for inputs and outputs")
+ }
+ guard kCacheInput else {
+ return nil
+ }
+ // Check if cache is dynamic or not.
+ let kCacheConstraint = model.modelDescription.inputDescriptionsByName[Keys.keyCache]!
+ if isDynamicallyShaped(kCacheConstraint) {
+ fatalError(
+ """
+ KV Cache using IO is currently not supported, please file a feature request on \
+ https://github.com/huggingface/swift-transformers
+ """)
+ } else {
+ fatalError(
+ """
+ KV Cache using IO is currently not supported, please file a feature request on \
+ https://github.com/huggingface/swift-transformers
+ """)
+ }
}
}
// MARK: - Configuration Properties
/// Asynchronous properties that are downloaded from the Hugging Face Hub configuration.
+@available(macOS 15.0, iOS 18.0, *)
public extension LanguageModel {
/// The model configuration dictionary.
///
@@ -281,13 +402,14 @@ public extension LanguageModel {
// MARK: - TextGenerationModel Conformance
+@available(macOS 15.0, iOS 18.0, *)
extension LanguageModel: TextGenerationModel {
/// The default generation configuration for this model.
///
/// Provides sensible defaults based on the model type, with model-specific
/// optimizations for known architectures like GPT models.
public var defaultGenerationConfig: GenerationConfig {
- var config = GenerationConfig(maxNewTokens: 30)
+ var config = GenerationConfig(maxNewTokens: 2048)
switch modelName.lowercased() {
case let x where x.contains("gpt"):
config.doSample = true
@@ -298,6 +420,88 @@ extension LanguageModel: TextGenerationModel {
}
}
+/// Language model implementation with stateful KV Cache.
+///
+/// Maintains a KV Cache as sequence generation progresses,
+/// using stateful Core ML buffers to minimize latency.
+@available(macOS 15.0, iOS 18.0, *)
+public class LanguageModelWithStatefulKVCache: LanguageModel {
+ private enum Mode {
+ case prefilling
+ case extending
+ }
+ private var mode: Mode = .prefilling
+
+ var state: MLState?
+
+ public required init(model: MLModel) {
+ super.init(model: model)
+ // To support pre-filling and extend, the input must support
+ // flexible shapes.
+ guard maxContextLength - minContextLength > 1 else {
+ fatalError("Expecting ranged query sequence length.")
+ }
+ }
+
+ public override func resetState() async {
+ state = model.makeState()
+ mode = .prefilling
+ }
+
+ public override func predictNextTokenScores(
+ _ tokens: MLTensor,
+ config _: GenerationConfig
+ ) async -> MLTensor {
+ assert(tokens.rank == 2) // [batch, current sequence length]
+ let tokenCount = tokens.shape[1]
+ guard let state else {
+ fatalError(
+ """
+ Encountered uninitialized `state`. Ensure `resetState` is called prior to calling \
+ `predictNextTokenScores`.
+ """)
+ }
+ let inputIds =
+ switch mode {
+ case .prefilling: tokens // Pass in all takens if pre-filling prompt
+ case .extending: tokens[nil, -1].expandingShape(at: 0) // otherwise just the last token
+ }
+ mode = .extending
+
+ var inputDictionary = [
+ Keys.inputIds: inputIds
+ ]
+ if isRequiringAttentionMask {
+ #if !((os(macOS) || (macCatalyst)) && arch(x86_64))
+ // TODO: Infer scalar type from cache or model I/O descriptors
+ let attentionMask = MLTensor(zeros: [1, 1, 1, tokenCount + 1], scalarType: Float16.self)
+ inputDictionary[Keys.attentionMask] = attentionMask
+ #else
+ fatalError()
+ #endif
+ }
+ if isRequiringCausalMask {
+ #if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64))
+ // TODO: Infer scalar type from cache or model I/O descriptors
+ let causalMask = MLTensor(zeros: [1, 1, 1, tokenCount + 1], scalarType: Float16.self)
+ inputDictionary[Keys.causalMask] = causalMask
+ #else
+ fatalError()
+ #endif
+ }
+ let outputs = try! await model.prediction(from: inputDictionary, using: state)
+
+ assert(outputs.keys.contains(Keys.logits))
+ let scores = outputs[Keys.logits]!
+ assert(scores.rank == 3)
+ let tokenIndex = inputIds.shape[1] - 1
+ let nextTokenScores = scores[nil, tokenIndex, nil].expandingShape(at: 0)
+ assert(nextTokenScores.rank == 3)
+ assert(nextTokenScores.shape[0] == 1 && nextTokenScores.shape[1] == 1)
+ return nextTokenScores
+ }
+}
+
/// Errors that can occur during tokenizer operations in language models.
public enum TokenizerError: LocalizedError {
/// The tokenizer configuration file could not be found.
diff --git a/Sources/Models/LanguageModelTypes.swift b/Sources/Models/LanguageModelTypes.swift
index 3ba9d0f1..0e7e1909 100644
--- a/Sources/Models/LanguageModelTypes.swift
+++ b/Sources/Models/LanguageModelTypes.swift
@@ -15,6 +15,7 @@ import Tokenizers
///
/// This protocol establishes the fundamental requirements for any language model
/// that can perform next-token prediction and text generation tasks.
+@available(macOS 15.0, iOS 18.0, *)
public protocol LanguageModelProtocol {
/// The name or path of the model.
///
@@ -30,6 +31,11 @@ public protocol LanguageModelProtocol {
/// The underlying CoreML model used for inference.
var model: MLModel { get }
+ /// Resets the state of the language model.
+ ///
+ /// Call `resetState()` for each new sequence generated.
+ func resetState() async
+
/// Creates a new language model instance from a CoreML model.
///
/// - Parameter model: The CoreML model to wrap
@@ -38,11 +44,14 @@ public protocol LanguageModelProtocol {
/// Predicts the next token scores for the given input tokens.
///
/// - Parameters:
- /// - tokens: The input token sequence
- /// - config: The generation configuration containing model parameters
- /// - Returns: A shaped array containing the logits for the next token prediction
- func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol
+ /// - input: The input sequence tensor.
+ /// - config: The generation configuration containing model parameters.
+ /// - Returns: MLTensor with the raw scores of the next token.
+ func predictNextTokenScores(_ input: MLTensor, config: GenerationConfig) async -> MLTensor
+}
+@available(macOS 15.0, iOS 18.0, *)
+public extension LanguageModelProtocol {
/// Function call syntax for next token prediction.
///
/// This provides a more convenient syntax for calling `predictNextTokenScores`.
@@ -51,13 +60,8 @@ public protocol LanguageModelProtocol {
/// - tokens: The input token sequence
/// - config: The generation configuration containing model parameters
/// - Returns: A shaped array containing the logits for the next token prediction
- func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol
-}
-
-public extension LanguageModelProtocol {
- /// Default implementation of function call syntax that delegates to `predictNextTokenScores`.
- func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol {
- predictNextTokenScores(tokens, config: config)
+ func callAsFunction(_ input: MLTensor, config: GenerationConfig) async -> MLTensor {
+ await predictNextTokenScores(input, config: config)
}
}
@@ -65,6 +69,7 @@ public extension LanguageModelProtocol {
///
/// This protocol extends `LanguageModelProtocol` and `Generation` to provide
/// high-level text generation functionality with configurable parameters.
+@available(macOS 15.0, iOS 18.0, *)
public protocol TextGenerationModel: Generation, LanguageModelProtocol {
/// The default generation configuration for this model.
///
@@ -80,9 +85,14 @@ public protocol TextGenerationModel: Generation, LanguageModelProtocol {
/// - callback: Optional callback to receive intermediate generation results
/// - Returns: The generated text as a string
/// - Throws: An error if text generation fails
- func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback?) async throws -> String
+ func generate(
+ config: GenerationConfig,
+ prompt: String,
+ callback: PredictionStringCallback?
+ ) async throws -> String
}
+@available(macOS 15.0, iOS 18.0, *)
public extension TextGenerationModel {
/// Default implementation of text generation that uses the underlying generation framework.
///
@@ -94,7 +104,17 @@ public extension TextGenerationModel {
/// - Throws: An error if text generation fails
@discardableResult
func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async throws -> String {
- try await generate(config: config, prompt: prompt, model: callAsFunction, tokenizer: tokenizer, callback: callback)
+ // Prepare the language model for a new sequence.
+ await resetState()
+
+ // Run inference.
+ return try await generate(
+ config: config,
+ prompt: prompt,
+ model: callAsFunction,
+ tokenizer: tokenizer,
+ callback: callback
+ )
}
}
#endif // canImport(CoreML)
diff --git a/Tests/GenerationTests/LogitsWarperTests.swift b/Tests/GenerationTests/LogitsWarperTests.swift
deleted file mode 100644
index 5d97652a..00000000
--- a/Tests/GenerationTests/LogitsWarperTests.swift
+++ /dev/null
@@ -1,160 +0,0 @@
-//
-// LogitsWarperTests.swift
-//
-// Created by Jan Krukowski on 09/12/2023.
-//
-
-#if canImport(CoreML)
-import CoreML
-@testable import Generation
-import Testing
-
-@Suite("Logits Warper Tests")
-struct LogitsWarperTests {
- private let accuracy: Float = 0.00001
-
- @Test("Temperature logits warper functionality")
- func temperatureLogitsWarper() {
- let result1 = TemperatureLogitsWarper(temperature: 0.0)([], [])
- #expect(result1.indices.isEmpty)
- #expect(result1.logits.isEmpty)
-
- let result2 = TemperatureLogitsWarper(temperature: 1.0)([], [])
- #expect(result2.indices.isEmpty)
- #expect(result2.logits.isEmpty)
-
- let result3 = TemperatureLogitsWarper(temperature: 1.0)([0, 1], [2.0, 1.0])
- #expect(result3.indices == [0, 1])
- #expect(isClose(result3.logits, [2.0, 1.0], accuracy: accuracy))
-
- let result4 = TemperatureLogitsWarper(temperature: 2.0)([0, 1], [2.0, 1.0])
- #expect(result4.indices == [0, 1])
- #expect(isClose(result4.logits, [1.0, 0.5], accuracy: accuracy))
-
- let result5 = TemperatureLogitsWarper(temperature: 0.5)([0, 1], [2.0, 1.0])
- #expect(result5.indices == [0, 1])
- #expect(isClose(result5.logits, [4.0, 2.0], accuracy: accuracy))
-
- let result6 = TemperatureLogitsWarper(temperature: 0.5)([200, 100], [2.0, 1.0])
- #expect(result6.indices == [200, 100])
- #expect(isClose(result6.logits, [4.0, 2.0], accuracy: accuracy))
- }
-
- @Test("Top-K logits warper functionality")
- func topKLogitsWarper() {
- let result1 = TopKLogitsWarper(k: 0)([], [])
- #expect(result1.indices.isEmpty)
- #expect(result1.logits.isEmpty)
-
- let result2 = TopKLogitsWarper(k: 3)([], [])
- #expect(result2.indices.isEmpty)
- #expect(result2.logits.isEmpty)
-
- let result3 = TopKLogitsWarper(k: 3)([0, 1], [2.0, 1.0])
- #expect(result3.indices == [0, 1])
- #expect(isClose(result3.logits, [2.0, 1.0], accuracy: accuracy))
-
- let result4 = TopKLogitsWarper(k: 3)([0, 1, 2], [2.0, 1.0, 3.0])
- #expect(result4.indices == [2, 0, 1])
- #expect(isClose(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy))
-
- let result5 = TopKLogitsWarper(k: 4)([0, 1, 2, 3, 4, 5], [2.0, 1.0, 3.0, -1.0, 123.0, 0.0])
- #expect(result5.indices == [4, 2, 0, 1])
- #expect(isClose(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy))
-
- let result6 = TopKLogitsWarper(k: 3)([10, 1, 52], [2.0, 1.0, 3.0])
- #expect(result6.indices == [52, 10, 1])
- #expect(isClose(result6.logits, [3.0, 2.0, 1.0], accuracy: accuracy))
- }
-
- @Test("Top-P logits warper functionality")
- func topPLogitsWarper() {
- let result1 = TopPLogitsWarper(p: 0.99)([], [])
- #expect(result1.indices.isEmpty)
- #expect(result1.logits.isEmpty)
-
- let logits = (0..<10).map { Float($0) }
- let indices = Array(logits.indices)
- let result2 = TopPLogitsWarper(p: 0.99)(indices, logits)
- #expect(result2.indices == [9, 8, 7, 6, 5])
- #expect(isClose(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy))
-
- let result3 = TopPLogitsWarper(p: 0.95)(indices, logits)
- #expect(result3.indices == [9, 8, 7])
- #expect(isClose(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy))
-
- let result4 = TopPLogitsWarper(p: 0.6321493)(indices, logits)
- #expect(result4.indices == [9, 8])
- #expect(isClose(result4.logits, [9.0, 8.0], accuracy: accuracy))
-
- let result5 = TopPLogitsWarper(p: 0.95)([3, 1, 8], [0, 1, 2])
- #expect(result5.indices == [8, 1, 3])
- #expect(isClose(result5.logits, [2, 1, 0], accuracy: accuracy))
- }
-
- @Test("Repetition penalty warper functionality")
- func repetitionPenaltyWarper() {
- let indices = Array(0..<10)
- let logits = indices.map { Float($0) }
-
- let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits)
- #expect(result1.indices == indices)
- #expect(isClose(result1.logits, logits, accuracy: accuracy))
-
- let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits)
- #expect(result2.indices == indices)
- let logits2 = indices.map { Float($0) / 3.75 }
- #expect(isClose(result2.logits, logits2, accuracy: accuracy))
-
- let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119])
- #expect(result3.indices == [0, 1, 2])
- #expect(isClose(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4))
-
- let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158])
- #expect(result4.indices == [2, 3, 4])
- #expect(isClose(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4))
-
- let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966])
- #expect(result5.indices == [0, 1, 2])
- #expect(isClose(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4))
-
- let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755])
- #expect(result6.indices == [3, 1, 2])
- #expect(isClose(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4))
- }
-
- @Test("Logits processor functionality")
- func logitsProcessor() {
- let processor1 = LogitsProcessor(logitsWarpers: [])
- let result1 = processor1([])
- #expect(result1.indices.isEmpty)
- #expect(result1.logits.isEmpty)
-
- let processor2 = LogitsProcessor(logitsWarpers: [])
- let result2 = processor2([2.0, 1.0])
- #expect(result2.indices == [0, 1])
- #expect(isClose(result2.logits, [2.0, 1.0], accuracy: accuracy))
-
- let processor3 = LogitsProcessor(
- logitsWarpers: [TopKLogitsWarper(k: 3)]
- )
- let result3 = processor3([2.0, 1.0, 3.0, -5.0])
- #expect(result3.indices == [2, 0, 1])
- #expect(isClose(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy))
-
- let processor4 = LogitsProcessor(
- logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)]
- )
- let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5])
- #expect(result4.indices == [5])
- #expect(isClose(result4.logits, [12.5], accuracy: accuracy))
-
- let processor5 = LogitsProcessor(
- logitsWarpers: [TopKLogitsWarper(k: 4), TopPLogitsWarper(p: 0.99)]
- )
- let result5 = processor5([2.0, 1.0, 3.0, -5.0, -3.0, 4.5])
- #expect(result5.indices == [5, 2, 0, 1])
- #expect(isClose(result5.logits, [4.5, 3.0, 2.0, 1.0], accuracy: accuracy))
- }
-}
-#endif // canImport(CoreML)
diff --git a/Tests/GenerationTests/MathTests.swift b/Tests/GenerationTests/MathTests.swift
deleted file mode 100644
index 0e255ecd..00000000
--- a/Tests/GenerationTests/MathTests.swift
+++ /dev/null
@@ -1,59 +0,0 @@
-//
-// MathTests.swift
-//
-// Created by Jan Krukowski on 25/11/2023.
-//
-
-#if canImport(CoreML)
-import CoreML
-@testable import Generation
-import Testing
-
-@Suite("Math Tests")
-struct MathTests {
- private let accuracy: Float = 0.00001
-
- @Test("Cumulative sum functionality")
- func cumsum() {
- #expect(Math.cumsum([]).isEmpty)
- #expect(Math.cumsum([1]) == [1])
- #expect(Math.cumsum([1, 2, 3, 4]) == [1, 3, 6, 10])
- }
-
- @Test("Argmax functionality")
- func argmax() throws {
- let result1 = Math.argmax([3.0, 4.0, 1.0, 2.0] as [Float], count: 4)
- #expect(result1.0 == 1)
- #expect(result1.1 == 4.0)
-
- let result2 = Math.argmax32([3.0, 4.0, 1.0, 2.0], count: 4)
- #expect(result2.0 == 1)
- #expect(result2.1 == 4.0)
-
- let result3 = Math.argmax([3.0, 4.0, 1.0, 2.0] as [Double], count: 4)
- #expect(result3.0 == 1)
- #expect(result3.1 == 4.0)
-
- let result4 = try Math.argmax32(MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Float]))
- #expect(result4.0 == 1)
- #expect(result4.1 == 4.0)
-
- let result5 = try Math.argmax(MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Double]))
- #expect(result5.0 == 1)
- #expect(result5.1 == 4.0)
-
- let result6 = Math.argmax(MLShapedArray(scalars: [3.0, 4.0, 1.0, 2.0] as [Float], shape: [4]))
- #expect(result6.0 == 1)
- #expect(result6.1 == 4.0)
- }
-
- @Test("Softmax functionality")
- func softmax() {
- #expect(Math.softmax([]) == [])
-
- let result1 = Math.softmax([3.0, 4.0, 1.0, 2.0])
- #expect(isClose(result1, [0.23688284, 0.6439143, 0.032058604, 0.08714432], accuracy: accuracy))
- #expect(abs(result1.reduce(0, +) - 1.0) < accuracy)
- }
-}
-#endif // canImport(CoreML)
diff --git a/Tests/GenerationTests/TestUtils.swift b/Tests/GenerationTests/TestUtils.swift
deleted file mode 100644
index 4f9f2a59..00000000
--- a/Tests/GenerationTests/TestUtils.swift
+++ /dev/null
@@ -1,7 +0,0 @@
-import Foundation
-
-/// Check if two floating-point arrays are equal within a given accuracy
-func isClose(_ lhs: [T], _ rhs: [T], accuracy: T) -> Bool {
- guard lhs.count == rhs.count else { return false }
- return zip(lhs, rhs).allSatisfy { abs($0.0 - $0.1) <= accuracy }
-}
diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift
index aecfa778..552fbf3e 100644
--- a/Tests/HubTests/DownloaderTests.swift
+++ b/Tests/HubTests/DownloaderTests.swift
@@ -181,11 +181,11 @@ final class DownloaderTests: XCTestCase {
) { error in
XCTAssertEqual((error as NSError).code, 516)
}
- XCTAssertEqual(try String(contentsOf: destination), "existing")
+ XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), "existing")
XCTAssertNoThrow(
try FileManager.default.moveDownloadedFile(from: source2, to: destination, percentEncoded: false)
)
- XCTAssertEqual(try String(contentsOf: destination), "v2")
+ XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), "v2")
}
}
diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift
index 559306ec..e8c4031e 100644
--- a/Tests/HubTests/HubApiTests.swift
+++ b/Tests/HubTests/HubApiTests.swift
@@ -571,7 +571,7 @@ class SnapshotDownloadTests: XCTestCase {
// If that's the case we just update the metadata and keep the local file.
XCTAssertEqual(originalTimestamp, thirdDownloadTimestamp)
- let metadataString = try String(contentsOfFile: metadataFile.path)
+ let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8)
// Updated metadata file needs to have the correct commit hash, etag and timestamp.
// This is being updated because the local etag (SHA256 checksum) matches the remote etag
@@ -602,7 +602,7 @@ class SnapshotDownloadTests: XCTestCase {
)
let metadataFile = metadataDestination.appendingPathComponent("llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata")
- let metadataString = try String(contentsOfFile: metadataFile.path)
+ let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8)
let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nfc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107"
XCTAssertTrue(metadataString.contains(expected))
@@ -632,7 +632,7 @@ class SnapshotDownloadTests: XCTestCase {
)
let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata")
- let metadataString = try String(contentsOfFile: metadataFile.path)
+ let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8)
let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4"
XCTAssertTrue(metadataString.contains(expected))
@@ -662,7 +662,7 @@ class SnapshotDownloadTests: XCTestCase {
)
let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata")
- let metadataString = try String(contentsOfFile: metadataFile.path)
+ let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8)
let metadataArr = metadataString.components(separatedBy: .newlines)
let commitHash = metadataArr[0]
@@ -720,7 +720,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertTrue(originalTimestamp == secondDownloadTimestamp)
XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path))
- let metadataString = try String(contentsOfFile: metadataFile.path)
+ let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8)
let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4"
XCTAssertTrue(metadataString.contains(expected))
@@ -771,7 +771,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertTrue(originalTimestamp == secondDownloadTimestamp)
XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path))
- let metadataString = try String(contentsOfFile: metadataFile.path)
+ let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8)
let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4"
XCTAssertTrue(metadataString.contains(expected))
@@ -822,7 +822,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertTrue(originalTimestamp != secondDownloadTimestamp)
XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path))
- let metadataString = try String(contentsOfFile: metadataFile.path)
+ let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8)
let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nd6ceb92ce9e3c83ab146dc8e92a93517ac1cc66f"
XCTAssertTrue(metadataString.contains(expected))
@@ -1009,7 +1009,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertEqual(lastProgress?.completedUnitCount, 1)
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
- let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path)
+ let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path, encoding: .utf8)
let expected = """
{
@@ -1048,7 +1048,7 @@ class SnapshotDownloadTests: XCTestCase {
XCTAssertEqual(lastProgress?.completedUnitCount, 1)
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
- let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path)
+ let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path, encoding: .utf8)
let expected = """
X
diff --git a/Tests/TokenizersTests/BertTokenizerTests.swift b/Tests/TokenizersTests/BertTokenizerTests.swift
index e6c269a5..33493b43 100644
--- a/Tests/TokenizersTests/BertTokenizerTests.swift
+++ b/Tests/TokenizersTests/BertTokenizerTests.swift
@@ -87,7 +87,7 @@ private enum Squad {
private let bertTokenizer: BertTokenizer = {
let vocab = {
let url = Bundle.module.url(forResource: "bert-vocab", withExtension: "txt")!
- let vocabTxt = try! String(contentsOf: url)
+ let vocabTxt = try! String(contentsOf: url, encoding: .utf8)
let tokens = vocabTxt.split(separator: "\n").map { String($0) }
var vocab: [String: Int] = [:]
for (i, token) in tokens.enumerated() {