From e0b36379f96bcc33e5e7dad97e1f77b38cd541db Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 24 Sep 2025 08:46:33 +0200 Subject: [PATCH 01/16] feat: preview --- Examples/Mistral7B/README.md | 27 + Examples/Mistral7B/export.py | 215 ++++ Examples/Mistral7B/generate.py | 88 ++ Examples/Mistral7B/requirements.txt | 6 + Package.swift | 2 +- Sources/Generation/Decoders.swift | 24 + Sources/Generation/Generation.swift | 123 +- Sources/Generation/GenerationConfig.swift | 17 +- .../LogitsWarper/LogitsProcessor.swift | 18 - .../LogitsWarper/LogitsWarper.swift | 13 - .../TemperatureLogitsWarper.swift | 13 - .../LogitsWarper/TopKLogitsWarper.swift | 60 - .../LogitsWarper/TopPLogitsWarper.swift | 37 - Sources/Generation/MLMultiArray+Utils.swift | 200 ---- Sources/Generation/MLShapedArray+Utils.swift | 54 - Sources/Generation/Math.swift | 172 --- Sources/Generation/Random.swift | 1000 +++++++++++++++++ Sources/Models/LanguageModel.swift | 308 +++-- Sources/Models/LanguageModelTypes.swift | 43 +- Sources/TransformersCLI/Transformers.swift | 140 +++ 20 files changed, 1854 insertions(+), 706 deletions(-) create mode 100644 Examples/Mistral7B/README.md create mode 100644 Examples/Mistral7B/export.py create mode 100644 Examples/Mistral7B/generate.py create mode 100644 Examples/Mistral7B/requirements.txt create mode 100644 Sources/Generation/Decoders.swift delete mode 100644 Sources/Generation/LogitsWarper/LogitsProcessor.swift delete mode 100644 Sources/Generation/LogitsWarper/LogitsWarper.swift delete mode 100644 Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift delete mode 100644 Sources/Generation/LogitsWarper/TopKLogitsWarper.swift delete mode 100644 Sources/Generation/LogitsWarper/TopPLogitsWarper.swift delete mode 100644 Sources/Generation/MLMultiArray+Utils.swift delete mode 100644 Sources/Generation/MLShapedArray+Utils.swift delete mode 100644 Sources/Generation/Math.swift create mode 100644 Sources/Generation/Random.swift create mode 100644 Sources/TransformersCLI/Transformers.swift diff --git a/Examples/Mistral7B/README.md b/Examples/Mistral7B/README.md new file mode 100644 index 00000000..61123251 --- /dev/null +++ b/Examples/Mistral7B/README.md @@ -0,0 +1,27 @@ +### Export Mistral 7B Instruct v0.3 + +```shell +✗ python export.py + +Loading checkpoint shards: 100%|███████████████████████████| 3/3 [00:12<00:00, 4.11s/it] +Converting PyTorch Frontend ==> MIL Ops: 100%|███| 5575/5575 [00:02<00:00, 2440.66 ops/s] +Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 7.12 passes/s] +Running MIL default pipeline: 100%|█████████████████| 79/79 [02:36<00:00, 1.98s/ passes] +Running MIL backend_mlprogram pipeline: 100%|███████| 12/12 [00:00<00:00, 22.90 passes/s] +Running compression: 100%|███████████████████████████| 296/296 [03:04<00:00, 1.60 ops/s] +... +``` + +### Generate Text + +```shell +✗ swift run transformers "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage + +Best recommendations for a place to visit in Paris in August 2024: + +1. Palace of Versailles: This iconic palace is a must-visit. It's a short train ride from Paris and offers a glimpse into the opulence of the French monarchy. + +2. Eiffel Tower: No trip to Paris is complete without a visit to the Eiffel Tower. You can take an elevator ride to the top for a stunning view of the city. + +3. Louvre Museum: Home to thousands of works of art, including the Mona Lisa and the Winged Victory of Samothrace, the Louvre is a cultural treasure. +``` diff --git a/Examples/Mistral7B/export.py b/Examples/Mistral7B/export.py new file mode 100644 index 00000000..fdebfd13 --- /dev/null +++ b/Examples/Mistral7B/export.py @@ -0,0 +1,215 @@ +import logging +import os +import warnings +from typing import List, Optional, Tuple + +import coremltools as ct +import numpy as np +import torch +from transformers.cache_utils import Cache +from transformers.models.mistral.modeling_mistral import ( + MISTRAL_ATTENTION_CLASSES, + MistralAttention, + MistralConfig, + MistralForCausalLM, + apply_rotary_pos_emb, + repeat_kv, +) + +warnings.filterwarnings("ignore") +logging.getLogger("coremltools").setLevel(logging.ERROR) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 +MODEL_ID: str = "mistralai/Mistral-7B-Instruct-v0.3" +METADATA_TOKENIZER: str = "co.huggingface.exporters.name" + + +class SliceUpdateKeyValueCache(Cache): + def __init__( + self, + shape: Tuple[int, ...], + device="cpu", + dtype=torch.float32, + ) -> None: + """KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim).""" + super().__init__() + self.past_seen_tokens: int = 0 + self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device) + self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device) + + def update( + self, + k_state: torch.Tensor, + v_state: torch.Tensor, + layer_idx: int, + slice_indices: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]). + Return slice of key/value cache tensors from [0, slice_indices[1]). + """ + if len(slice_indices) != 2: + raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.") + begin, end = slice_indices + self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state + self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state + k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :] + v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :] + return k_cache, v_cache + + def get_seq_length(self, _: int | None = 0) -> int: + """Get the sequence length of the cache.""" + return self.past_seen_tokens + + +class SliceUpdateMistralAttention(MistralAttention): + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__(config=config, layer_idx=layer_idx) + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + **kwargs, + ) -> Tuple[torch.Tensor | None, ...]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Slice update key/value cache + end_step = attention_mask.shape[-1] + key_states, value_states = past_key_value.update( + key_states, + value_states, + self.layer_idx, + slice_indices=(end_step - q_len, end_step), + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, None + + +class StatefulMistralForCausalLM(torch.nn.Module): + def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None: + super().__init__() + + # Custom attention implementation for stateful slice update key/value cache, override + # "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation + MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention + self.model = MistralForCausalLM.from_pretrained(model_path) + + # Register KV cache buffers to be recognized as Core ML states + config: MistralConfig = self.model.config + self.kv_cache_shape: Tuple[int, ...] = ( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + max_context_size, + config.hidden_size // config.num_attention_heads, + ) + self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape) + self.register_buffer("keyCache", self.kv_cache.k_cache) + self.register_buffer("valueCache", self.kv_cache.v_cache) + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + causal_mask: torch.Tensor, + ) -> torch.Tensor: + # Compute past seen tokens used for updating key/value cache slices + self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1] + return self.model( + input_ids, + attention_mask=causal_mask, + past_key_values=self.kv_cache, + use_cache=True, + ).logits + + +def export() -> None: + # Construct model from transformers and trace to TorchScript + max_context_size: int = 2048 + torch_model = StatefulMistralForCausalLM(MODEL_ID, max_context_size=max_context_size) + torch_model.eval() + input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32) + causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32) + traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask]) + + # Convert traced TorchScript to Core ML format + query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) + end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) + inputs: List[ct.TensorType] = [ + ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"), + ct.TensorType( + shape=(1, 1, query_length, end_step_dim), + dtype=np.float16, + name="causalMask", + ), + ] + outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")] + states: List[ct.StateType] = [ + ct.StateType( + wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16), + name="keyCache", + ), + ct.StateType( + wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16), + name="valueCache", + ), + ] + + # Convert model with FP16 precision + mlmodel_fp16: ct.MLModel = ct.convert( + traced_model, + inputs=inputs, + outputs=outputs, + states=states, + minimum_deployment_target=ct.target.iOS18, + skip_model_load=True, + ) + + # Block-wise quantize model weights to int4 + op_config = ct.optimize.coreml.OpLinearQuantizerConfig( + mode="linear_symmetric", + dtype="int4", + granularity="per_block", + block_size=32, + ) + config = ct.optimize.coreml.OptimizationConfig(global_config=op_config) + mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config) + mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID}) + mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage") + + +if __name__ == "__main__": + export() diff --git a/Examples/Mistral7B/generate.py b/Examples/Mistral7B/generate.py new file mode 100644 index 00000000..9e373592 --- /dev/null +++ b/Examples/Mistral7B/generate.py @@ -0,0 +1,88 @@ +import argparse +from typing import Dict, Generator, List, Tuple + +import numpy as np +from coremltools.models import MLModel +from transformers import AutoTokenizer + +from export import METADATA_TOKENIZER + + +def load(model_path: str) -> Tuple[MLModel, AutoTokenizer]: + """Load a Core ML model and corresponding tokenizer.""" + model: MLModel = MLModel(model_path) + description = model.get_spec().description + if METADATA_TOKENIZER not in description.metadata.userDefined: + raise ValueError("Model metadata does not contain tokenizer path.") + tokenizer_path: str = description.metadata.userDefined[METADATA_TOKENIZER] + tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + return model, tokenizer + + +def get_next_token(model: MLModel, prompt_tokens: np.ndarray) -> Generator[int, None, None]: + """Generate a sequence of tokens with naive greedy decoding.""" + + def sample(logits: np.ndarray) -> int: + """Perform greedy decoding on the logits array to get the next token.""" + return int(np.argmax(logits[0][-1], axis=-1)) + + def inference(model: MLModel, input_ids: np.ndarray, num_past_tokens: int) -> np.ndarray: + """Perform inference with the given model and input data.""" + causal_mask: np.ndarray = np.triu( + np.full( + (1, 1, input_ids.shape[-1], num_past_tokens + input_ids.shape[-1]), + fill_value=-np.inf if num_past_tokens == 0 else 0, + ), + k=1, + ).astype(np.float16) + outputs: Dict[str, np.ndarray] = model.predict( + data={"inputIds": input_ids, "causalMask": causal_mask}, + state=kv_cache_state, + ) + return outputs["logits"] + + kv_cache_state = model.make_state() + logits: np.ndarray = inference(model, input_ids=prompt_tokens, num_past_tokens=0) + token: int = sample(logits=logits) + num_past_tokens: int = prompt_tokens.shape[-1] + + while True: + yield token + logits: np.ndarray = inference( + model, + input_ids=np.array([[token]], dtype=np.int32), + num_past_tokens=num_past_tokens, + ) + token: int = sample(logits=logits) + num_past_tokens += 1 + + +def generate( + model: MLModel, + prompt: str, + tokenizer: AutoTokenizer, + max_new_tokens: int, +) -> str: + prompt_tokens: np.ndarray = tokenizer(prompt, return_tensors="np").input_ids + extend_tokens: List[int] = [] + for i, token in enumerate(get_next_token(model, prompt_tokens=prompt_tokens.astype(np.int32))): + if token == tokenizer.eos_token_id or i == max_new_tokens: + break + extend_tokens.append(token) + return tokenizer.decode(prompt_tokens[0].tolist() + extend_tokens) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_path", type=str) + parser.add_argument("--prompt", type=str, default="Hello") + parser.add_argument("--max_new_tokens", type=int, default=128) + args = parser.parse_args() + model, tokenizer = load(args.model_path) + extend_text: str = generate( + model, + prompt=args.prompt, + tokenizer=tokenizer, + max_new_tokens=args.max_new_tokens, + ) + print(extend_text) diff --git a/Examples/Mistral7B/requirements.txt b/Examples/Mistral7B/requirements.txt new file mode 100644 index 00000000..f0f1fa68 --- /dev/null +++ b/Examples/Mistral7B/requirements.txt @@ -0,0 +1,6 @@ +coremltools==8.0b1 +numpy==1.26.4 +torch==2.3.1 +tqdm==4.66.4 +transformers==4.42.3 +sentencepiece==0.2.0 diff --git a/Package.swift b/Package.swift index 55acf44e..5f26e1f1 100644 --- a/Package.swift +++ b/Package.swift @@ -10,7 +10,7 @@ let swiftSettings: [SwiftSetting] = [ let package = Package( name: "swift-transformers", - platforms: [.iOS(.v16), .macOS(.v13)], + platforms: [.iOS("18.0"), .macOS("15.0")], products: [ .library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]) ], diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift new file mode 100644 index 00000000..173ff97e --- /dev/null +++ b/Sources/Generation/Decoders.swift @@ -0,0 +1,24 @@ +import CoreML + +// MARK: Greedy Decoding + +func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor { + scores.argmax(alongAxis: -1).reshaped(to: [1, 1]) +} + +// MARK: Top-K Sampling + +func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor { + let temperatureAdjustedScores = scores / temperature + let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK) + let topKProbs = topKScores.softmax(alongAxis: -1) + let rnd = topKProbs.sum() * Float.random(in: 0 ..< 1) + var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1) + accumTopKProbs += (accumTopKProbs .< rnd) * 100.0 + let topKIndex = accumTopKProbs.argsort()[..., 0] + let nextTokenTensor = topKIndices.gathering( + atIndices: topKIndex, + alongAxis: topKIndices.rank - 1 + ) + return nextTokenTensor.reshaped(to: [1, 1]) +} diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 43dfd435..f08fcb5e 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -8,6 +8,7 @@ #if canImport(CoreML) import CoreML +import CoreML import Tokenizers public enum GenerationMode { @@ -23,89 +24,83 @@ public typealias InputTokens = [Int] public typealias GenerationOutput = [Int] /// A callable (a model, usually), that predicts the next token after a given sequence -public typealias NextTokenModel = (InputTokens, GenerationConfig) -> any MLShapedArrayProtocol +@available(macOS 15.0, iOS 18.0, *) +public typealias NextTokenModel = (MLTensor, GenerationConfig) async -> MLTensor public typealias PredictionTokensCallback = (GenerationOutput) -> Void public typealias PredictionStringCallback = (String) -> Void -// TODO: callbacks (for streaming) +@available(macOS 15.0, iOS 18.0, *) public protocol Generation { - func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput - func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String } -public extension Generation { - func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput { - // Iterate until we find the eos token or reach the max length - // TODO: additional stopping criteria - var outputTokens = tokens - while outputTokens.count < config.maxLength { - let logits = model(outputTokens, config) - let (nextToken, _) = Math.argmax(logits) - if nextToken == config.eosTokenId { break } - outputTokens.append(nextToken) - callback?(outputTokens) +@available(macOS 15.0, iOS 18.0, *) +extension Generation { + public func generate( + config: GenerationConfig, + tokens: InputTokens, + model: NextTokenModel, + callback: PredictionTokensCallback? = nil + ) async -> GenerationOutput { + let tokens = tokens.map { Int32($0) } + var outputTokens = MLTensor(tokens).expandingShape(at: 0) + while outputTokens.shape[1] < config.maxLength { + let nextTokenScores = await model(outputTokens, config) + let nextToken = switch config.generationMode { + case .greedy: + selectNextTokenUsingGreedyDecoding(from: nextTokenScores) + case .sample: + selectNextTokenUsingTopKSampling( + from: nextTokenScores, + temperature: config.temperature, + topK: config.topK + ) + default: + fatalError("Generation mode \(config.generationMode) not implemented yet") + } + + if let nextTokenId = await tensorToGenerationOutput(nextToken).first, nextTokenId == config.eosTokenId { + break + } + + outputTokens = MLTensor(concatenating: [outputTokens, nextToken], alongAxis: -1) + if let callback { + let outputTokenIDs = await tensorToGenerationOutput(outputTokens) + callback(outputTokenIDs) + } } - return outputTokens + return await tensorToGenerationOutput(outputTokens) } - /// https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552 - func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput { - // Iterate until we find the eos token or reach the max length - // TODO: additional stopping criteria - var outputTokens = tokens - let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config)) - while outputTokens.count < config.maxLength { - let outputs = model(outputTokens, config) - // `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case - let logits = (outputs as? MLShapedArraySlice)?.floats ?? outputs.scalars as! [Float] - let (indexes, processedLogits) = logitsProcessor(logits) - let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits)) - if nextToken == config.eosTokenId { break } - outputTokens.append(nextToken) - callback?(outputTokens) - } - return outputTokens + private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput { + await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) } } +} - func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback? = nil) async -> String { +@available(macOS 15.0, iOS 18.0, *) +public extension Generation { + func generate( + config: GenerationConfig, + prompt: String, + model: NextTokenModel, + tokenizer: Tokenizer, + callback: PredictionStringCallback? = nil + ) async -> String { let tokens = tokenizer.encode(text: prompt) var generationConfig = config generationConfig.maxLength = config.maxNewTokens + tokens.count - - let output: GenerationOutput - switch generationConfig.generationMode { - case .greedy: - output = await greedySearch(config: generationConfig, tokens: tokens, model: model) { tokens in - callback?(tokenizer.decode(tokens: tokens)) - } - case .sample: - output = await sample(config: generationConfig, tokens: tokens, model: model) { tokens in - callback?(tokenizer.decode(tokens: tokens)) - } - default: - fatalError("Generation mode \(generationConfig.generationMode) not implemented yet") + generationConfig.eosTokenId = tokenizer.eosTokenId + generationConfig.bosTokenId = tokenizer.bosTokenId + let output = await generate( + config: generationConfig, + tokens: tokens, + model: model + ) { tokens in + callback?(tokenizer.decode(tokens: tokens)) } return tokenizer.decode(tokens: output) } - - private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] { - var logitsWarpers = [any LogitsWarper]() - if config.temperature > 0, config.temperature != 1 { - logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature))) - } - if config.topK > 0 { - logitsWarpers.append(TopKLogitsWarper(k: config.topK)) - } - if config.topP < 1.0 { - logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP))) - } - if config.repetitionPenalty != 1.0 { - logitsWarpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty)) - } - return logitsWarpers - } } #endif // canImport(CoreML) diff --git a/Sources/Generation/GenerationConfig.swift b/Sources/Generation/GenerationConfig.swift index 9223648e..2f2a6152 100644 --- a/Sources/Generation/GenerationConfig.swift +++ b/Sources/Generation/GenerationConfig.swift @@ -15,7 +15,7 @@ public struct GenerationConfig { public var numBeams = 1 public var numBeamGroups = 1 public var penaltyAlpha: Double? - public var temperature = 1.0 + public var temperature: Float = 1.0 public var topK = 50 public var topP = 1.0 public var repetitionPenalty = 1.0 @@ -24,14 +24,25 @@ public struct GenerationConfig { public var bosTokenId: Int? public var eosTokenId: Int? - public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) { + public init( + maxLength: Int = 20, + maxNewTokens: Int, + doSample: Bool = false, + numBeams: Int = 1, + numBeamGroups: Int = 1, + penaltyAlpha: Double? = nil, + temperature: Double = 1.0, + topK: Int = 50, + topP: Double = 1.0, + repetitionPenalty: Double = 1.0 + ) { self.maxLength = maxLength self.maxNewTokens = maxNewTokens self.doSample = doSample self.numBeams = numBeams self.numBeamGroups = numBeamGroups self.penaltyAlpha = penaltyAlpha - self.temperature = temperature + self.temperature = Float(temperature) self.topK = topK self.topP = topP self.repetitionPenalty = repetitionPenalty diff --git a/Sources/Generation/LogitsWarper/LogitsProcessor.swift b/Sources/Generation/LogitsWarper/LogitsProcessor.swift deleted file mode 100644 index 1c7f3a72..00000000 --- a/Sources/Generation/LogitsWarper/LogitsProcessor.swift +++ /dev/null @@ -1,18 +0,0 @@ -import Foundation - -public struct LogitsProcessor { - public var logitsWarpers: [any LogitsWarper] - - public init(logitsWarpers: [any LogitsWarper]) { - self.logitsWarpers = logitsWarpers - } - - public func callAsFunction(_ arr: [Float]) -> (indices: [Int], logits: [Float]) { - var indices = Array(arr.indices) - var logits = arr - for warper in logitsWarpers { - (indices, logits) = warper(indices, logits) - } - return (indices: indices, logits: logits) - } -} diff --git a/Sources/Generation/LogitsWarper/LogitsWarper.swift b/Sources/Generation/LogitsWarper/LogitsWarper.swift deleted file mode 100644 index 17fc64ef..00000000 --- a/Sources/Generation/LogitsWarper/LogitsWarper.swift +++ /dev/null @@ -1,13 +0,0 @@ -import Foundation - -/// Protocol for all logit warpers that can be applied during generation -public protocol LogitsWarper { - func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) - func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) -} - -public extension LogitsWarper { - func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) { - warp(indices: indices, logits: logits) - } -} diff --git a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift deleted file mode 100644 index 44e495b1..00000000 --- a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift +++ /dev/null @@ -1,13 +0,0 @@ -import Foundation - -public struct TemperatureLogitsWarper: LogitsWarper { - public var temperature: Float - - public init(temperature: Float) { - self.temperature = temperature - } - - public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - (indices: indices, logits: logits.map { $0 / temperature }) - } -} diff --git a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift deleted file mode 100644 index 18ee54da..00000000 --- a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift +++ /dev/null @@ -1,60 +0,0 @@ -#if canImport(Accelerate) -import Accelerate -import Foundation - -/// Top-K. -/// Select the k most-probable element indices from `arr` -/// and return both the indices (from the original array) -/// and their probabilities. -public struct TopKLogitsWarper: LogitsWarper { - public var k: Int - - public init(k: Int) { - self.k = k - } - - public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - guard !logits.isEmpty else { - return (indices: [], logits: []) - } - let k = min(k, logits.count) - let arrDescriptor = BNNSNDArrayDescriptor.allocate( - initializingFrom: logits, - shape: .vector(logits.count) - ) - defer { - arrDescriptor.deallocate() - } - let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized( - scalarType: Int32.self, - shape: .vector(k) - ) - defer { - bestIndices.deallocate() - } - let bestValues = BNNSNDArrayDescriptor.allocateUninitialized( - scalarType: Float.self, - shape: .vector(k) - ) - defer { - bestValues.deallocate() - } - try! Accelerate.BNNS.applyTopK( - k: k, - input: arrDescriptor, - bestValues: bestValues, - bestIndices: bestIndices, - axis: 0, - batchSize: 1, - filterParameters: nil - ) - let topkLogits = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in - Array(UnsafeBufferPointer(start: ptr, count: k)) - } - let topkIndices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in - Array(UnsafeBufferPointer(start: ptr, count: k)) - } - return (indices: topkIndices.map { indices[Int($0)] }, logits: topkLogits) - } -} -#endif // canImport(Accelerate) diff --git a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift deleted file mode 100644 index bc08952c..00000000 --- a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift +++ /dev/null @@ -1,37 +0,0 @@ -import Foundation - -/// Top-P. -/// Select the smallest set of elements whose cumulative probability exceeds the probability `p`. -/// Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 -public struct TopPLogitsWarper: LogitsWarper { - public var p: Float - - public init(p: Float) { - self.p = p - } - - public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - guard !logits.isEmpty else { - return (indices: [], logits: []) - } - - let arrSoftmax = Math.softmax(logits) - var indexLogitProb = [(index: Int, logit: Float, prob: Float)]() - indexLogitProb.reserveCapacity(logits.count) - for (index, data) in zip(logits, arrSoftmax).enumerated() { - indexLogitProb.append((index: index, logit: data.0, prob: data.1)) - } - indexLogitProb.sort { $0.prob > $1.prob } - - let cumsum = Math.cumsum(indexLogitProb.map(\.prob)) - var sliceIndex = cumsum.count - 1 - for (index, element) in cumsum.enumerated() where element > p { - sliceIndex = index - break - } - - let toppIndices = indexLogitProb[0...sliceIndex].map { indices[$0.index] } - let toppLogits = indexLogitProb[0...sliceIndex].map(\.logit) - return (indices: toppIndices, logits: toppLogits) - } -} diff --git a/Sources/Generation/MLMultiArray+Utils.swift b/Sources/Generation/MLMultiArray+Utils.swift deleted file mode 100644 index f3592233..00000000 --- a/Sources/Generation/MLMultiArray+Utils.swift +++ /dev/null @@ -1,200 +0,0 @@ -// -// MLMultiArray+Utils.swift -// CoreMLBert -// -// Created by Julien Chaumond on 27/06/2019. -// Copyright © 2019 Hugging Face. All rights reserved. -// - -#if canImport(CoreML) -import CoreML -import Foundation - -public extension MLMultiArray { - /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) - static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { - var shape = Array(repeating: 1, count: dims) - shape[shape.count - 1] = arr.count - // Examples: - // dims=1 : [arr.count] - // dims=2 : [1, arr.count] - // - let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .int32) - let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for (i, item) in arr.enumerated() { - ptr[i] = Int32(item) - } - return o - } - - /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) - static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray { - var shape = Array(repeating: 1, count: dims) - shape[shape.count - 1] = arr.count - // Examples: - // dims=1 : [arr.count] - // dims=2 : [1, arr.count] - // - let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .float64) - let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for (i, item) in arr.enumerated() { - ptr[i] = Double(item) - } - return o - } - - /// This will concatenate all dimensions into one one-dim array. - static func toIntArray(_ o: MLMultiArray) -> [Int] { - var arr = Array(repeating: 0, count: o.count) - let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for i in 0.. [Int] { Self.toIntArray(self) } - - /// This will concatenate all dimensions into one one-dim array. - static func toDoubleArray(_ o: MLMultiArray) -> [Double] { - var arr: [Double] = Array(repeating: 0, count: o.count) - let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - for i in 0.. [Double] { Self.toDoubleArray(self) } - - /// Helper to construct a sequentially-indexed multi array, - /// useful for debugging and unit tests - /// Example in 3 dimensions: - /// ``` - /// [[[ 0, 1, 2, 3 ], - /// [ 4, 5, 6, 7 ], - /// [ 8, 9, 10, 11 ]], - /// [[ 12, 13, 14, 15 ], - /// [ 16, 17, 18, 19 ], - /// [ 20, 21, 22, 23 ]]] - /// ``` - static func testTensor(shape: [Int]) -> MLMultiArray { - let arr = try! MLMultiArray(shape: shape as [NSNumber], dataType: .double) - let ptr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) - for i in 0.. MLMultiArray { - assert( - indexing.count == o.shape.count - ) - assert( - indexing.filter { $0 == Indexing.slice }.count == 1 - ) - var selectDims: [Int: Int] = [:] - for (i, idx) in indexing.enumerated() { - if case let .select(select) = idx { - selectDims[i] = select - } - } - return slice( - o, - sliceDim: indexing.firstIndex { $0 == Indexing.slice }!, - selectDims: selectDims - ) - } - - /// Slice an array according to a list, according to `sliceDim` (which dimension to slice on) - /// and a dictionary of `dim` to `index`. - /// - /// You must select all other dimensions than the slice dimension (cf. the assert). - static func slice(_ o: MLMultiArray, sliceDim: Int, selectDims: [Int: Int]) -> MLMultiArray { - assert( - selectDims.count + 1 == o.shape.count - ) - var shape: [NSNumber] = Array(repeating: 1, count: o.shape.count) - shape[sliceDim] = o.shape[sliceDim] - // print("About to slice ndarray of shape \(o.shape) into ndarray of shape \(shape)") - let arr = try! MLMultiArray(shape: shape, dataType: .double) - - // let srcPtr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) - // TODO: use srcPtr instead of array subscripting. - let dstPtr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) - for i in 0.. String { - func indent(_ x: Int) -> String { - String(repeating: " ", count: x) - } - - // This function is called recursively for every dimension. - // Add an entry for this dimension to the end of the array. - var indices = indices + [0] - - let d = indices.count - 1 // the current dimension - let N = shape[d].intValue // how many elements in this dimension - var s = "[" - if indices.count < shape.count { // not last dimension yet? - for i in 0.. { - var floats: [Float] { - guard strides.first == 1, strides.count == 1 else { - // For some reason this path is slow. - // If strides is not 1, we can write a Metal kernel to copy the values properly. - return scalars - } - - // Fast path: memcpy - let mlArray = MLMultiArray(self) - return mlArray.floats ?? scalars - } -} - -public extension MLShapedArraySlice { - var floats: [Float] { - guard strides.first == 1, strides.count == 1 else { - // For some reason this path is slow. - // If strides is not 1, we can write a Metal kernel to copy the values properly. - return scalars - } - - // Fast path: memcpy - let mlArray = MLMultiArray(self) - return mlArray.floats ?? scalars - } -} - -public extension MLMultiArray { - var floats: [Float]? { - guard dataType == .float32 else { return nil } - - var result: [Float] = Array(repeating: 0, count: count) - return withUnsafeBytes { ptr in - guard let source = ptr.baseAddress else { return nil } - result.withUnsafeMutableBytes { resultPtr in - let dest = resultPtr.baseAddress! - memcpy(dest, source, self.count * MemoryLayout.stride) - } - return result - } - } -} -#endif // canImport(CoreML) diff --git a/Sources/Generation/Math.swift b/Sources/Generation/Math.swift deleted file mode 100644 index 8c1b06b9..00000000 --- a/Sources/Generation/Math.swift +++ /dev/null @@ -1,172 +0,0 @@ -// -// Math.swift -// CoreMLBert -// -// Created by Julien Chaumond on 27/06/2019. -// Copyright © 2019 Hugging Face. All rights reserved. -// - -#if canImport(CoreML) && canImport(Accelerate) -import Accelerate -import CoreML -import Foundation - -/// -/// From M.I. Hollemans -/// -/// https://github.com/hollance/CoreMLHelpers -/// -public enum Math { - /** - Returns the index and value of the largest element in the array. - - - Parameters: - - ptr: Pointer to the first element in memory. - - count: How many elements to look at. - - stride: The distance between two elements in memory. - */ - public static func argmax(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) { - var maxValue: Float = 0 - var maxIndex: vDSP_Length = 0 - vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) - return (Int(maxIndex), maxValue) - } - - /** - Returns the index and value of the largest element in the array. - - Parameters: - - ptr: Pointer to the first element in memory. - - count: How many elements to look at. - - stride: The distance between two elements in memory. - */ - public static func argmax(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Double) { - var maxValue: Double = 0 - var maxIndex: vDSP_Length = 0 - vDSP_maxviD(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) - return (Int(maxIndex), maxValue) - } - - public static func argmax32(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) { - var maxValue: Float = 0 - var maxIndex: vDSP_Length = 0 - vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) - return (Int(maxIndex), maxValue) - } - - /// MLMultiArray helper. - /// Works in our specific use case. - public static func argmax(_ multiArray: MLMultiArray) -> (Int, Double) { - assert(multiArray.dataType == .double) - let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer)) - return Math.argmax(ptr, count: multiArray.count) - } - - /// MLMultiArray helper. - /// Works in our specific use case. - public static func argmax32(_ multiArray: MLMultiArray) -> (Int, Float) { - assert(multiArray.dataType == .float32) - let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer)) - return Math.argmax32(ptr, count: multiArray.count) - } - - /// Returns the cumulative sum of the array. - public static func cumsum(_ arr: [Float]) -> [Float] { - guard !arr.isEmpty else { - return [] - } - let arrCount = vDSP_Length(arr.count) - var weight: Float = 1.0 - var result: [Float] = Array(repeating: 0.0, count: arr.count) - var firstItem = arr[0] - vDSP_vrsum(arr, 1, &weight, &result, 1, arrCount) - vDSP_vsadd(result, 1, &firstItem, &result, 1, arrCount) - return result - } - - /// Multinomial sampling from an array of probs. Works well with topK - public static func sample(indexes: [Int], probs: [Float]) -> Int { - let i = randomNumber(probabilities: probs) - return indexes[i] - } - - /** - Computes the "softmax" function over an array. - Based on code from https://github.com/nikolaypavlov/MLPNeuralNet/ - This is what softmax looks like in "pseudocode" (actually using Python - and numpy): - x -= np.max(x) - exp_scores = np.exp(x) - softmax = exp_scores / np.sum(exp_scores) - First we shift the values of x so that the highest value in the array is 0. - This ensures numerical stability with the exponents, so they don't blow up. - */ - public static func softmax(_ x: [Float]) -> [Float] { - var x = x - let len = vDSP_Length(x.count) - - // Find the maximum value in the input array. - var max: Float = 0 - vDSP_maxv(x, 1, &max, len) - - // Subtract the maximum from all the elements in the array. - // Now the highest value in the array is 0. - max = -max - vDSP_vsadd(x, 1, &max, &x, 1, len) - - // Exponentiate all the elements in the array. - var count = Int32(x.count) - vvexpf(&x, x, &count) - - // Compute the sum of all exponentiated values. - var sum: Float = 0 - vDSP_sve(x, 1, &sum, len) - - // Divide each element by the sum. This normalizes the array contents - // so that they all add up to 1. - vDSP_vsdiv(x, 1, &sum, &x, 1, len) - - return x - } - - /// Multinomial sampling - /// - /// From https://stackoverflow.com/questions/30309556/generate-random-numbers-with-a-given-distribution - /// - public static func randomNumber(probabilities: [Float]) -> Int { - // Sum of all probabilities (so that we don't have to require that the sum is 1.0): - let sum = probabilities.reduce(0, +) - // Random number in the range 0.0 <= rnd < sum : - let rnd = sum * Float(arc4random_uniform(UInt32.max)) / Float(UInt32.max) - // Find the first interval of accumulated probabilities into which `rnd` falls: - var accum: Float = 0.0 - for (i, p) in probabilities.enumerated() { - accum += p - if rnd < accum { - return i - } - } - // This point might be reached due to floating point inaccuracies: - return probabilities.count - 1 - } -} - -// MLShapedArray versions - -public extension Math { - static func argmax(_ shapedArray: MLShapedArray) -> (Int, Float) { - shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in - assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices") - return Math.argmax32(ptr.baseAddress!, count: shapedArray.count, stride: strides.first!) - } - } - - // TODO: handle Double, etc. - static func argmax(_ shapedArray: some MLShapedArrayProtocol) -> (Int, Float) { - shapedArray.withUnsafeShapedBufferPointer { ptr, shape, strides in - assert(shape.count == 1, "Only supported for 1-dimensional arrays or slices") - let floatsPtr = ptr.baseAddress as! UnsafePointer - return Math.argmax32(floatsPtr, count: shapedArray.count, stride: strides.first!) - } - } -} -#endif // canImport(CoreML) && canImport(Accelerate) diff --git a/Sources/Generation/Random.swift b/Sources/Generation/Random.swift new file mode 100644 index 00000000..4f4ae693 --- /dev/null +++ b/Sources/Generation/Random.swift @@ -0,0 +1,1000 @@ +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import Darwin +#else +import Glibc +#endif + +/// Type-erased random number generator. +internal class AnyRandomNumberGenerator: RandomNumberGenerator { + private var rng: RandomNumberGenerator + + /// Creates a type-erased random number generator. + /// + /// - Parameters: + /// - rng: A random number generator. + init(_ rng: RandomNumberGenerator) { + self.rng = rng + } + + func next() -> UInt64 { + rng.next() + } +} + +extension AnyRandomNumberGenerator: ParallelRandomNumberGenerator { + func next(count: Int) -> [UInt64] { + if let rng = rng as? ParallelRandomNumberGenerator { + return rng.next(count: count) + } + return (0..(count: Int, upperBound: T) -> [T] { + if let rng = rng as? ParallelRandomNumberGenerator { + return rng.next(count: count, upperBound: upperBound) + } + return (0..(seed: T) +} + +extension SeedableRandomNumberGenerator { + init(seed: T) { + var newSeed: [UInt8] = [] + for i in 0..> (UInt8.bitWidth * i))) + } + self.init(seed: newSeed) + } +} + +extension RandomNumberGenerator { + mutating func next(count: Int) -> [UInt64] { + if let generator = self as? ParallelRandomNumberGenerator { + return generator.next(count: count) + } else { + return (0..( + count: Int, + upperBound: T + ) -> [T] { + if let generator = self as? ParallelRandomNumberGenerator { + return generator.next(count: count, upperBound: upperBound) + } else { + return (0.. 0, "Length of seed must be positive") + precondition(seed.count <= 256, "Length of seed must be at most 256") + var j: UInt8 = 0 + for i: UInt8 in 0...255 { + j &+= S(i) &+ seed[Int(i) % seed.count] + swapAt(i, j) + } + } + + // Produce the next random UInt64 from the stream, and advance the internal + // state. + public mutating func next() -> UInt64 { + var result: UInt64 = 0 + for _ in 0.. UInt8 { + return state[Int(index)] + } + + // Helper to swap elements of the state. + private mutating func swapAt(_ i: UInt8, _ j: UInt8) { + state.swapAt(Int(i), Int(j)) + } + + // Generates the next byte in the keystream. + private mutating func nextByte() -> UInt8 { + iPos &+= 1 + jPos &+= S(iPos) + swapAt(iPos, jPos) + return S(S(iPos) &+ S(jPos)) + } +} + +/// An implementation of `SeedableRandomNumberGenerator` using Threefry. +/// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +/// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +/// +/// This struct implements a 20-round Threefry2x32 PRNG. It must be seeded with +/// a 64-bit value. +/// +/// An individual generator is not thread-safe, but distinct generators do not +/// share state. The random data generated is of high-quality, but is not +/// suitable for cryptographic applications. +struct ThreefryRandomNumberGenerator: SeedableRandomNumberGenerator { + private let rot: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32) + = (13, 15, 26, 6, 17, 29, 16, 24) + + private func rotl32(value: UInt32, n: UInt32) -> UInt32 { + return (value << (n & 31)) | (value >> ((32 - n) & 31)) + } + + private var ctr: UInt64 = 0 + private let key: SIMD2 + + private func random(forCtr ctr: SIMD2, key: SIMD2) -> SIMD2 { + let skeinKsParity32: UInt32 = 0x1BD11BDA + + let ks0 = key.x + let ks1 = key.y + let ks2 = skeinKsParity32 ^ key.x ^ key.y + var X0 = ctr.x + var X1 = ctr.y + + // 20 rounds + // Key injection (r = 0) + X0 &+= ks0 + X1 &+= ks1 + // R1 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.0) + X1 ^= X0 + // R2 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.1) + X1 ^= X0 + // R3 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.2) + X1 ^= X0 + // R4 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.3) + X1 ^= X0 + // Key injection (r = 1) + X0 &+= ks1 + X1 &+= (ks2 + 1) + // R5 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.4) + X1 ^= X0 + // R6 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.5) + X1 ^= X0 + // R7 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.6) + X1 ^= X0 + // R8 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.7) + X1 ^= X0 + // Key injection (r = 2) + X0 &+= ks2 + X1 &+= (ks0 + 2) + // R9 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.0) + X1 ^= X0 + // R10 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.1) + X1 ^= X0 + // R11 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.2) + X1 ^= X0 + // R12 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.3) + X1 ^= X0 + // Key injection (r = 3) + X0 &+= ks0 + X1 &+= (ks1 + 3) + // R13 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.4) + X1 ^= X0 + // R14 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.5) + X1 ^= X0 + // R15 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.6) + X1 ^= X0 + // R16 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.7) + X1 ^= X0 + // Key injection (r = 4) + X0 &+= ks1 + X1 &+= (ks2 + 4) + // R17 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.0) + X1 ^= X0 + // R18 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.1) + X1 ^= X0 + // R19 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.2) + X1 ^= X0 + // R20 + X0 &+= X1 + X1 = rotl32(value: X1, n: rot.3) + X1 ^= X0 + // Key injection (r = 5) + X0 &+= ks2 + X1 &+= (ks0 + 5) + + return [X0, X1] + } + + internal init(uint64Seed seed: UInt64) { + key = seed.vector2 + } + + public init(seed: [UInt8]) { + precondition(seed.count > 0, "Length of seed must be positive") + precondition(seed.count <= 8, "Length of seed must be at most 8") + var combinedSeed: UInt64 = 0 + for (i, byte) in seed.enumerated() { + combinedSeed += UInt64(byte) << UInt64(8 * i) + } + self.init(uint64Seed: combinedSeed) + } + + public mutating func next() -> UInt64 { + defer { ctr += 1 } + return UInt64(highAndLow: random(forCtr: ctr.vector2, key: key)) + } +} + +/// An implementation of `SeedableRandomNumberGenerator` using Philox. +/// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +/// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +/// +/// This struct implements a 10-round Philox4x32 PRNG. It must be seeded with +/// a 64-bit value. +/// +/// An individual generator is not thread-safe, but distinct generators do not +/// share state. The random data generated is of high-quality, but is not +/// suitable for cryptographic applications. +struct PhiloxRandomNumberGenerator: SeedableRandomNumberGenerator { + @usableFromInline + var counter: UInt64 = 0 + @usableFromInline + let key: SIMD2 + + // Since we generate two 64-bit values at a time, we only need to run the + // generator every other invocation. + @usableFromInline + var useNextValue = false + @usableFromInline + var nextValue: UInt64 = 0 + + @inlinable + func bump(key: SIMD2) -> SIMD2 { + SIMD2(0x9E3779B9, 0xBB67AE85) &+ key + } + + @inlinable + func round(counter: SIMD4, key: SIMD2) -> SIMD4 { + let roundConstants = SIMD2(0xD2511F53, 0xCD9E8D57) + let products = roundConstants &* SIMD2(UInt64(counter[0]), UInt64(counter[2])) + + let hi = SIMD2(truncatingIfNeeded: products &>> 32) + let lo = SIMD2(truncatingIfNeeded: products & 0x0000_0000_FFFF_FFFF) + return [ + hi[1] ^ counter[1] ^ key[0], + lo[1], + hi[0] ^ counter[3] ^ key[1], + lo[0] + ] + } + + @inlinable + func random( + forCounter initialCounter: SIMD4, + key initialKey: SIMD2 + ) -> SIMD4 { + var counter = initialCounter + var key = initialKey + // 10 rounds + // R1 + counter = round(counter: counter, key: key) + // R2 + key = bump(key: key) + counter = round(counter: counter, key: key) + // R3 + key = bump(key: key) + counter = round(counter: counter, key: key) + // R4 + key = bump(key: key) + counter = round(counter: counter, key: key) + // R5 + key = bump(key: key) + counter = round(counter: counter, key: key) + // R6 + key = bump(key: key) + counter = round(counter: counter, key: key) + // R7 + key = bump(key: key) + counter = round(counter: counter, key: key) + // R8 + key = bump(key: key) + counter = round(counter: counter, key: key) + // R9 + key = bump(key: key) + counter = round(counter: counter, key: key) + // R10 + key = bump(key: key) + counter = round(counter: counter, key: key) + + return counter + } + + @inlinable + public init(uint64Seed seed: UInt64) { + key = seed.vector2 + } + + @inlinable + public init(seed: [UInt8]) { + precondition(seed.count > 0, "Length of seed must be positive") + precondition(seed.count <= 8, "Length of seed must be at most 8") + var combinedSeed: UInt64 = 0 + for (i, byte) in seed.enumerated() { + combinedSeed += UInt64(byte) << UInt64(8 * i) + } + self.init(uint64Seed: combinedSeed) + } + + @inlinable + public mutating func next() -> UInt64 { + if useNextValue { + useNextValue = false + return nextValue + } + let pair = random(forCounter: counter.vector4, key: key).reinterpretedUInt64Vector + useNextValue = true + nextValue = pair.y + counter += 1 + return pair.x + } +} + +/// Private helpers. +extension UInt64 { + @inlinable + var vector2: SIMD2 { + let msb = UInt32(truncatingIfNeeded: self >> 32) + let lsb = UInt32(truncatingIfNeeded: self & 0x0000_0000_FFFF_FFFF) + return [msb, lsb] + } + + @inlinable + var vector4: SIMD4 { + let msb = UInt32(truncatingIfNeeded: self >> 32) + let lsb = UInt32(truncatingIfNeeded: self) + return [0, 0, msb, lsb] + } + + @inlinable + init(highAndLow: SIMD2) { + self = (UInt64(highAndLow.x) << 32) + UInt64(highAndLow.y) + } +} + +extension SIMD4 where Scalar == UInt32 { + @inlinable + var reinterpretedUInt64Vector: SIMD2 { + let a = (UInt64(x) << 32) + UInt64(y) + let b = (UInt64(z) << 32) + UInt64(w) + return [a, b] + } +} + +// MARK: - Random distributions + +import Dispatch + +protocol RandomDistribution { + associatedtype Sample + func next(using generator: inout G) -> Sample + func next(_ count: Int, using generator: inout G) -> [Sample] +} + +extension RandomDistribution { + @_specialize( + where Self == UniformFloatingPointDistribution, + G == DefaultRandomNumberGeneratorForTensor) + public func next(_ count: Int, using generator: inout G) -> [Sample] { + return Array( + unsafeUninitializedCapacity: count + ) { buffer, initializedCount in + for i in 0..(using generator: inout G) -> Bool { + Bool.random(using: &generator) + } + + public func next(_ count: Int, using generator: inout G) -> [Bool] { + Array.random(count: count, using: &generator) + } +} + +struct UniformIntegerDistribution: RandomDistribution { + public let bounds: ClosedRange + + public init(bounds: ClosedRange = T.min...T.max) { + self.bounds = bounds + } + + public func next(using generator: inout G) -> T { + return T.random(in: bounds, using: &generator) + } + + public func next(_ count: Int, using generator: inout G) -> [T] { + Array.random(count: count, in: bounds, using: &generator) + } +} + +struct UniformFloatingPointDistribution: RandomDistribution + where T.RawSignificand: FixedWidthInteger +{ + public let bounds: ClosedRange + + public init(bounds: ClosedRange = 0...1) { + self.bounds = bounds + } + + @_specialize(where T == Float, G == DefaultRandomNumberGeneratorForTensor) + public func next(using generator: inout G) -> T { + return T.random(in: bounds, using: &generator) + } + + @_specialize(where T == Float, G == DefaultRandomNumberGeneratorForTensor) + public func next(_ count: Int, using generator: inout G) -> [T] { + Array.random(count: count, in: bounds, using: &generator) + } +} + +struct NormalDistribution: RandomDistribution + where T.RawSignificand: FixedWidthInteger +{ + public let mean: T + public let standardDeviation: T + @usableFromInline + let uniformDistribution = UniformFloatingPointDistribution() + + public init(mean: T = 0, standardDeviation: T = 1) { + self.mean = mean + self.standardDeviation = standardDeviation + } + + @_specialize(where T == Float) + @inlinable + func normalized(_ u1: T, _ u2: T) -> T { + let r = (-2 * T(log(Float(u1)))).squareRoot() + let theta = 2 * T.pi * u2 + let normal01 = r * T(cos(Float(theta))) + return mean + standardDeviation * normal01 + } + + @_specialize(where T == Float, G == DefaultRandomNumberGeneratorForTensor) + @inlinable + public func next(using generator: inout G) -> T { + // FIXME: Box-Muller can generate two values for only a little more than the + // cost of one. + normalized( + uniformDistribution.next(using: &generator), + uniformDistribution.next(using: &generator)) + } + + @_specialize(where T == Float, G == DefaultRandomNumberGeneratorForTensor) + @inlinable + public func next(_ count: Int, using generator: inout G) -> [T] { + let uniformNumbers = uniformDistribution.next(count * 2, using: &generator) + return Array(unsafeUninitializedCapacity: count) { buffer, initializedCount in + DispatchQueue.concurrentPerform(iterations: count) { i in + let offset = i * 2 + buffer[i] = normalized(uniformNumbers[offset], uniformNumbers[offset + 1]) + } + initializedCount = count + } + } +} + +struct TruncatedNormalDistribution: RandomDistribution + where T.RawSignificand: FixedWidthInteger +{ + public let mean: T + public let standardDeviation: T + private let normalDistribution = NormalDistribution(mean: 0, standardDeviation: 1) + + public init(mean: T = 0, standardDeviation: T = 1) { + self.mean = mean + self.standardDeviation = standardDeviation + } + + public func next(using generator: inout G) -> T { + // FIXME: Implement this. See + // https://github.com/tensorflow/tensorflow/blob/b1a6b315a63bb29b4593bfb98095da4397d8cd5a/tensorflow/compiler/tf2xla/lib/random.cc#L42. + fatalError("Unimplemented") + } + + public func next(_ count: Int, using generator: inout G) -> [T] { + // FIXME: Implement this. See + // https://github.com/tensorflow/tensorflow/blob/b1a6b315a63bb29b4593bfb98095da4397d8cd5a/tensorflow/compiler/tf2xla/lib/random.cc#L42. + fatalError("Unimplemented") + } +} + +struct BetaDistribution: RandomDistribution { + public let alpha: Float + public let beta: Float + private let uniformDistribution = UniformFloatingPointDistribution() + + public init(alpha: Float = 0, beta: Float = 1) { + self.alpha = alpha + self.beta = beta + } + + public func next(using generator: inout G) -> Float { + // Generate a sample using Cheng's sampling algorithm from: + // R. C. H. Cheng, "Generating beta variates with nonintegral shape + // parameters.". Communications of the ACM, 21, 317-322, 1978. + let a = min(alpha, beta) + let b = max(alpha, beta) + if a > 1 { + return BetaDistribution.chengsAlgorithmBB(alpha, a, b, using: &generator) + } else { + return BetaDistribution.chengsAlgorithmBC(alpha, b, a, using: &generator) + } + } + + /// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BB + /// algorithm, when both alpha and beta are greater than 1. + /// + /// - Parameters: + /// - alpha: First Beta distribution shape parameter. + /// - a: `min(alpha, beta)`. + /// - b: `max(alpha, beta)`. + /// - generator: Random number generator. + /// + /// - Returns: Sample obtained using Cheng's BB algorithm. + private static func chengsAlgorithmBB( + _ alpha0: Float, + _ a: Float, + _ b: Float, + using generator: inout G + ) -> Float { + let alpha = a + b + let beta = sqrt((alpha - 2) / (2 * a * b - alpha)) + let gamma = a + 1 / beta + + var r: Float = 0.0 + var w: Float = 0.0 + var t: Float = 0.0 + + repeat { + let u1 = Float.random(in: 0.0...1.0, using: &generator) + let u2 = Float.random(in: 0.0...1.0, using: &generator) + let v = beta * (log(u1) - log1p(-u1)) + r = gamma * v - 1.3862944 + let z = u1 * u1 * u2 + w = a * exp(v) + + let s = a + r - w + if s + 2.609438 >= 5 * z { + break + } + + t = log(z) + if s >= t { + break + } + } while r + alpha * (log(alpha) - log(b + w)) < t + + w = min(w, Float.greatestFiniteMagnitude) + return a == alpha0 ? w / (b + w) : b / (b + w) + } + + /// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BC + /// algorithm, when at least one of alpha and beta is less than 1. + /// + /// - Parameters: + /// - alpha: First Beta distribution shape parameter. + /// - a: `max(alpha, beta)`. + /// - b: `min(alpha, beta)`. + /// - generator: Random number generator. + /// + /// - Returns: Sample obtained using Cheng's BB algorithm. + private static func chengsAlgorithmBC( + _ alpha0: Float, + _ a: Float, + _ b: Float, + using generator: inout G + ) -> Float { + let alpha = a + b + let beta = 1 / b + let delta = 1 + a - b + let k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778) + let k2 = 0.25 + (0.5 + 0.25 / delta) * b + + var w: Float = 0.0 + + while true { + let u1 = Float.random(in: 0.0...1.0, using: &generator) + let u2 = Float.random(in: 0.0...1.0, using: &generator) + let y = u1 * u2 + let z = u1 * y + + if u1 < 0.5 { + if 0.25 * u2 + z - y >= k1 { + continue + } + } else { + if z <= 0.25 { + let v = beta * (log(u1) - log1p(-u1)) + w = a * exp(v) + break + } + if z >= k2 { + continue + } + } + + let v = beta * (log(u1) - log1p(-u1)) + w = a * exp(v) + if alpha * (log(alpha) - log(b + 1) + v) - 1.3862944 >= log(z) { + break + } + } + + w = min(w, Float.greatestFiniteMagnitude) + return a == alpha0 ? w / (b + w): b / (b + w) + } +} + +// MARK: Parallel Random Number Generators + +protocol ParallelRandomNumberGenerator: RandomNumberGenerator { + func next(count: Int) -> [UInt64] + func next(count: Int, upperBound: T) -> [T] +} + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +public struct SystemArc4RandomNumberGenerator: ParallelRandomNumberGenerator { + public mutating func next() -> UInt64 { + var result: UInt64 = 0 + arc4random_buf(&result, MemoryLayout.size) + return result + } + + public func next(count: Int) -> [UInt64] { + return Array(unsafeUninitializedCapacity: count) { buffer, size in + size = count + arc4random_buf( + UnsafeMutableRawPointer(buffer.baseAddress), + count * MemoryLayout.stride) + } + } + + public func next(count: Int, upperBound: T) -> [T] { + let rands = next(count: count) + return Array(unsafeUninitializedCapacity: count) { + bufferPointer, initializedCount in + DispatchQueue.concurrentPerform(iterations: count) { + bufferPointer[$0] = self.upperBound( + rands[$0], + to: upperBound) + } + initializedCount = count + } + } + + // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/Random.swift#L93 + private func upperBound( + _ val: UInt64, + to upperBound: T + ) -> T { + precondition(upperBound != 0, "upperBound cannot be zero.") + #if arch(i386) || arch(arm) || arch(arm64_32) // TODO(FIXME) SR-10912 + let tmp = (T.max % upperBound) + 1 + let range = tmp == upperBound ? 0 : tmp + var random = T(truncatingIfNeeded: val) + + while random < range { + withUnsafeMutablePointer(to: &random) { + arc4random_buf($0, MemoryLayout.size) + } + } + + return random % upperBound + #else + var random = T(truncatingIfNeeded: val) + var m = random.multipliedFullWidth(by: upperBound) + if m.low < upperBound { + let t = (0 &- upperBound) % upperBound + while m.low < t { + withUnsafeMutablePointer(to: &random) { + arc4random_buf($0, MemoryLayout.size) + } + m = random.multipliedFullWidth(by: upperBound) + } + } + return m.high + #endif + } +} +#endif + +// MARK: Random Array Generators + +extension Array where Element == Bool { + static func random( + count: Int, + using generator: inout RNG + ) -> Self { + let rands = generator.next(count: count) + return Self(unsafeUninitializedCapacity: count) { bufferPointer, initializedCount in + DispatchQueue.concurrentPerform(iterations: count) { + bufferPointer[$0] = (rands[$0] >> 17) & 1 == 0 + } + initializedCount = count + } + } +} + +extension Array where Element: BinaryFloatingPoint, Element.RawSignificand: FixedWidthInteger { + // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/FloatingPoint.swift#L2052 + static func random( + count: Int, + in range: Range, + using generator: inout RNG + ) -> Self { + let delta = range.upperBound - range.lowerBound + precondition(delta.isFinite, "There is no uniform distribution on an infinite range") + func randArray(_ count: Int) -> [UInt64] { + if Element.RawSignificand.bitWidth == Element.significandBitCount + 1 { + return generator.next(count: count) + } else { + let significandCount = Element.significandBitCount + 1 + let maxSignificand: Element.RawSignificand = 1 << significandCount + return generator.next(count: count).map { $0 & UInt64(maxSignificand - 1) } + } + } + return Self(unsafeUninitializedCapacity: count) { + resultBufferPointer, initializedCount in + var indicesNeedingResampling = [Int](resultBufferPointer.indices) + while !indicesNeedingResampling.isEmpty { + let rands = randArray(indicesNeedingResampling.count) + indicesNeedingResampling.withUnsafeMutableBufferPointer { + indicesNeedingResamplingBufferPtr in + DispatchQueue.concurrentPerform( + iterations: indicesNeedingResamplingBufferPtr.count + ) { i in + let rand = rands[i] + let unitRandom = Element(rand) * (Element.ulpOfOne / 2) + let randFloat = delta * unitRandom + range.lowerBound + if randFloat != range.upperBound { + let index = indicesNeedingResamplingBufferPtr[i] + resultBufferPointer[index] = randFloat + indicesNeedingResamplingBufferPtr[i] = -1 + } + } + } + indicesNeedingResampling.removeAll(where: { $0 == -1 }) + } + initializedCount = count + } + } + + // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/FloatingPoint.swift#L2152 + static func random( + count: Int, + in range: ClosedRange, + using generator: inout RNG + ) -> Self { + let delta = range.upperBound - range.lowerBound + precondition(delta.isFinite, "There is no uniform distribution on an infinite range") + func randArrays(_ count: Int) -> (rand: [UInt64], tmp: [UInt64]?) { + let rands: [UInt64] + var tmp: [UInt64]? = nil + if Element.RawSignificand.bitWidth == Element.significandBitCount + 1 { + rands = generator.next(count: count) + tmp = generator.next(count: count) + } else { + let significandCount = Element.significandBitCount + 1 + let maxSignificand: Element.RawSignificand = 1 << significandCount + rands = generator.next(count: count, upperBound: UInt64(maxSignificand + 1)) + } + return (rands, tmp) + } + + return Self(unsafeUninitializedCapacity: count) { + resultBufferPointer, initializedCount in + let indicesNeedingResampling = [Int](resultBufferPointer.indices) + let (rands, tmp) = randArrays(indicesNeedingResampling.count) + DispatchQueue.concurrentPerform( + iterations: indicesNeedingResampling.count + ) { i in + if Element.RawSignificand.bitWidth == Element.significandBitCount + 1 { + guard let tmp = tmp else { + fatalError("Expected the 'tmp' array to be initialized.") + } + if rands[i] == Element.RawSignificand.max && (tmp[i] & 1) == 1 { + let index = indicesNeedingResampling[i] + resultBufferPointer[index] = range.upperBound + return + } + } else { + let significandCount = Element.significandBitCount + 1 + let maxSignificand: Element.RawSignificand = 1 << significandCount + if rands[i] == maxSignificand { + let index = indicesNeedingResampling[i] + resultBufferPointer[index] = range.upperBound + return + } + } + let unitRandom = Element(rands[i]) * (Element.ulpOfOne / 2) + let randFloat = delta * unitRandom + range.lowerBound + let index = indicesNeedingResampling[i] + resultBufferPointer[index] = randFloat + } + initializedCount = count + } + } +} + +extension Array where Element: FixedWidthInteger { + // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/Integers.swift#L2663 + static func random( + count: Int, + in range: Range, + using generator: inout RNG + ) -> Self { + precondition(!range.isEmpty, "Can't get random value with an empty range") + // Compute delta, the distance between the lower and upper bounds. This + // value may not representable by the type Bound if Bound is signed, but + // is always representable as Bound.Magnitude. + let delta = Element.Magnitude(truncatingIfNeeded: range.upperBound &- range.lowerBound) + let rands = generator.next(count: count, upperBound: UInt64(delta)) + return Self(unsafeUninitializedCapacity: count) { bufferPointer, initializedCount in + DispatchQueue.concurrentPerform(iterations: count) { + // The mathematical result we want is lowerBound plus a random value in + // 0 ..< delta. We need to be slightly careful about how we do this + // arithmetic; the Bound type cannot generally represent the random value, + // so we use a wrapping addition on Bound.Magnitude. This will often + // overflow, but produces the correct bit pattern for the result when + // converted back to Bound. + bufferPointer[$0] = Element(truncatingIfNeeded: + Element.Magnitude(truncatingIfNeeded: range.lowerBound) &+ + Element.Magnitude(rands[$0])) + } + initializedCount = count + } + } + + // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/Integers.swift#L2732 + static func random( + count: Int, + in range: ClosedRange, + using generator: inout RNG + ) -> Self { + precondition(!range.isEmpty, "Can't get random value with an empty range") + // Compute delta, the distance between the lower and upper bounds. This + // value may not representable by the type Bound if Bound is signed, but + // is always representable as Bound.Magnitude. + var delta = Element.Magnitude(truncatingIfNeeded: range.upperBound &- range.lowerBound) + // Subtle edge case: if the range is the whole set of representable values, + // then adding one to delta to account for a closed range will overflow. + // If we used &+ instead, the result would be zero, which isn't helpful, + // so we actually need to handle this case separately. + if delta == Element.Magnitude.max { + return Self(unsafeUninitializedCapacity: count) { + bufferPointer, initializedCount in + DispatchQueue.concurrentPerform(iterations: count) { + bufferPointer[$0] = + Element(truncatingIfNeeded: generator.next() as Element.Magnitude) + } + initializedCount = count + } + } + // Need to widen delta to account for the right-endpoint of a closed range. + delta += 1 + let rands = generator.next(count: count, upperBound: UInt64(delta)) + return Self(unsafeUninitializedCapacity: count) { + bufferPointer, initializedCount in + DispatchQueue.concurrentPerform(iterations: count) { + // The mathematical result we want is lowerBound plus a random value in + // 0 ..< delta. We need to be slightly careful about how we do this + // arithmetic; the Bound type cannot generally represent the random value, + // so we use a wrapping addition on Bound.Magnitude. This will often + // overflow, but produces the correct bit pattern for the result when + // converted back to Bound. + bufferPointer[$0] = Element(truncatingIfNeeded: + Element.Magnitude(truncatingIfNeeded: range.lowerBound) &+ + Element.Magnitude(rands[$0])) + } + initializedCount = count + } + } +} + +extension RandomNumberGenerator { + /// Returns a random value within the specified range. + mutating func next(as type: T.Type, in bounds: Range) -> T + where T: BinaryFloatingPoint, T.RawSignificand: FixedWidthInteger { + return T.random(in: bounds, using: &self) + } + + /// Returns a random value within the specified range. + mutating func next(as type: T.Type, in bounds: ClosedRange) -> T + where T: BinaryFloatingPoint, T.RawSignificand: FixedWidthInteger { + return T.random(in: bounds, using: &self) + } +} diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 40d870f5..724cb315 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -12,69 +12,155 @@ import Generation import Hub import Tokenizers +@available(macOS 15.0, iOS 18.0, *) public class LanguageModel { public let model: MLModel public let minContextLength: Int public let maxContextLength: Int - let input_ids = "input_ids" - let attention_mask = "attention_mask" - - struct Configurations { - var modelConfig: Config - var tokenizerConfig: Config? - var tokenizerData: Config - } - private var configuration: LanguageModelConfigurationFromHub? private var _tokenizer: Tokenizer? public required init(model: MLModel) { self.model = model + (minContextLength, maxContextLength) = Self.contextRange(from: model) + configuration = LanguageModelConfigurationFromHub(modelName: modelName) + } - // We assume inputs named "input_ids" with shape (1, seq_length) - // Perhaps we should convert to vectors of shape (seq_length) and use sequenceConstraint instead of shapeConstraint - let inputDescription = model.modelDescription.inputDescriptionsByName["input_ids"] + public func resetState() async { } + + public func predictNextTokenScores( + _ tokens: MLTensor, + config: GenerationConfig + ) async -> MLTensor { + assert(tokens.rank == 2) // [batch, current sequence length] + let tokenCount = tokens.shape[1] + let padLength = maxContextLength - tokenCount + let padding = MLTensor(repeating: Int32(config.padTokenId ?? 0), shape: [1, padLength]) + let inputIDs = MLTensor(concatenating: [tokens, padding], alongAxis: -1) + var inputDictionary = [inputIdsName: inputIDs] + if isRequiringAttentionMask { + let mask = [Int32](repeating: 1, count: tokenCount) + [Int32](repeating: 0, count: padLength) + let attentionMask = MLTensor(shape: inputIDs.shape, scalars: mask) + inputDictionary[Keys.attentionMask] = attentionMask + } + let outputs = try! await model.prediction(from: inputDictionary) + + assert(outputs.keys.contains(Keys.logits)) + + let scores = outputs[Keys.logits]! + assert(scores.rank == 3) + let tokenIndex = tokenCount - 1 + let nextTokenScores = scores[nil, tokenIndex, nil].expandingShape(at: 0) + assert(nextTokenScores.rank == 3) + assert(nextTokenScores.shape[0] == 1 && nextTokenScores.shape[1] == 1) + return nextTokenScores + } +} + +@available(macOS 15.0, iOS 18.0, *) +private extension LanguageModel { + static func contextRange(from model: MLModel) -> (min: Int, max: Int) { + contextRange(from: model, inputKey: Keys.inputIds) + } + + static func contextRange(from model: MLModel, inputKey: String) -> (min: Int, max: Int) { + let inputDescription = model.modelDescription.inputDescriptionsByName[inputKey] guard let shapeConstraint = inputDescription?.multiArrayConstraint?.shapeConstraint else { fatalError("Cannot obtain shape information") } + var minContextLength = 128 + var maxContextLength = 128 + switch shapeConstraint.type { case .enumerated: - // TODO: support a set of fixed shapes (keeping the first one here) minContextLength = shapeConstraint.enumeratedShapes[0][1].intValue maxContextLength = minContextLength case .range: - let range = inputDescription?.multiArrayConstraint?.shapeConstraint.sizeRangeForDimension[1] as? NSRange - minContextLength = range?.location ?? 1 - maxContextLength = range?.length ?? 128 + if let sizeRangeForDimension = inputDescription?.multiArrayConstraint?.shapeConstraint.sizeRangeForDimension { + let lastAxis = sizeRangeForDimension.count - 1 + let range = sizeRangeForDimension[lastAxis] as? NSRange + minContextLength = range?.location ?? 1 + maxContextLength = range?.length ?? 128 + } case .unspecified: - minContextLength = 128 - maxContextLength = 128 + break @unknown default: - minContextLength = 128 - maxContextLength = 128 + break } - configuration = LanguageModelConfigurationFromHub(modelName: modelName) + return (minContextLength, maxContextLength) + } +} + +@available(macOS 15.0, iOS 18.0, *) +extension LanguageModel { + struct Configurations { + var modelConfig: Config + var tokenizerConfig: Config? + var tokenizerData: Config + } +} + +@available(macOS 15.0, iOS 18.0, *) +extension LanguageModel { + enum Keys { + // Input keys + static let inputIds = "inputIds" + static let attentionMask = "attentionMask" + static let causalMask = "causalMask" + static let keyCache = "keyCache" + static let valueCache = "valueCache" + // Output keys + static let logits = "logits" + static let presentKeys = "presentKeys" + static let presentValues = "presentValues" } } +@available(macOS 15.0, iOS 18.0, *) public extension LanguageModel { - static func loadCompiled(url: URL, computeUnits: MLComputeUnits = .cpuAndGPU) throws -> LanguageModel { + static func loadCompiled( + url: URL, + computeUnits: MLComputeUnits = .cpuAndGPU + ) throws -> LanguageModel { let config = MLModelConfiguration() config.computeUnits = computeUnits let model = try MLModel(contentsOf: url, configuration: config) - return LanguageModel(model: model) + return switch kvCacheAvailability(for: model) { + case .statefulKVCache: LanguageModelWithStatefulKVCache(model: model) + default: LanguageModel(model: model) + } } } +extension LanguageModel { + enum KVCacheAvailability { + /// Language models that support KV cache via state. Implementation details for handling state + /// encapsulated within the Core ML framework. + /// + /// Input: State + /// Output: N/A + case statefulKVCache + } +} + +@available(macOS 15.0, iOS 18.0, *) public extension LanguageModel { + var metadata: [MLModelMetadataKey: Any] { + model.modelDescription.metadata + } + + var modelDescription: MLModelDescription { + model.modelDescription + } + var description: String { - if let description = model.modelDescription.metadata[MLModelMetadataKey.description] as? String, - !description.isEmpty + if let description = metadata[MLModelMetadataKey.description] as? String, + !description.isEmpty { return description } @@ -83,18 +169,20 @@ public extension LanguageModel { /// `name_or_path` in the Python world var modelName: String { - if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String], - let name = userFields["co.huggingface.exporters.name"] + if let userFields = metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String], + let name = userFields["co.huggingface.exporters.name"] { return name } // This is usually the basename of the file, that's our best bet if no metadata exists - guard let modelName = model.configuration.modelDisplayName else { fatalError("Models must have a name that identifies them") } + guard let modelName = model.configuration.modelDisplayName else { + fatalError("Models must have a name that identifies them") + } return modelName } var inputIdsDescription: MLFeatureDescription { - model.modelDescription.inputDescriptionsByName[input_ids]! + modelDescription.inputDescriptionsByName[Keys.inputIds]! } var inputIdsName: String { @@ -103,43 +191,71 @@ public extension LanguageModel { /// The expected shape of the models latent sample input var inputIdsShape: [Int] { - inputIdsDescription.multiArrayConstraint!.shape.map { $0.intValue } + inputIdsDescription.multiArrayConstraint!.shape.map(\.intValue) } - var requiresAttention: Bool { - model.modelDescription.inputDescriptionsByName[attention_mask] != nil + var isRequiringAttentionMask: Bool { + modelDescription.inputDescriptionsByName[Keys.attentionMask] != nil } - /// MLShapedArrayProtocol is either a MLShapedArray or a MLShapedArraySlice - func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { - // TODO: exceptions - - // Maybe pad or truncate - let maxTokens = min(tokens.count, maxContextLength) - let padLength = maxTokens >= minContextLength ? 0 : minContextLength - maxTokens - let inputTokens = Array(tokens[0..(scalars: inputTokens.map { Int32($0) }, shape: inputIdsShape) - var inputDictionary = [inputIdsName: MLFeatureValue(shapedArray: inputIds)] - if requiresAttention { - let mask = Array(repeating: 1, count: maxTokens) + Array(repeating: 0, count: padLength) - let attentionMask = MLShapedArray(scalars: mask.map { Int32($0) }, shape: inputIdsShape) - inputDictionary[attention_mask] = MLFeatureValue(shapedArray: attentionMask) + fileprivate static func kvCacheAvailability(for model: MLModel) -> KVCacheAvailability? { + func isStatefulKVCacheAvailable(for model: MLModel) -> Bool { + let kCacheState = model.modelDescription.stateDescriptionsByName[Keys.keyCache] != nil + let vCacheState = model.modelDescription.stateDescriptionsByName[Keys.valueCache] != nil + guard Set([kCacheState, vCacheState]).count == 1 else { + fatalError("Invalid model configuration, expecting KV cache for states") + } + return kCacheState && kCacheState } - let input = try! MLDictionaryFeatureProvider(dictionary: inputDictionary) - let output = try! model.prediction(from: input) + func isDynamicallyShaped(_ description: MLFeatureDescription) -> Bool { + guard let multiArrayConstraint = description.multiArrayConstraint else { + return false + } + return switch multiArrayConstraint.shapeConstraint.type { + case .unspecified: true + case .enumerated: multiArrayConstraint.shapeConstraint.enumeratedShapes.count > 1 + case .range: true + default: false + } + } - // TODO: maybe try to support models with "token_scores" too (after the softmax) - assert(output.featureNames.first! == "logits") + if isStatefulKVCacheAvailable(for: model) { + return .statefulKVCache + } + let kCacheInput = model.modelDescription.inputDescriptionsByName[Keys.keyCache] != nil + let vCacheInput = model.modelDescription.inputDescriptionsByName[Keys.valueCache] != nil + let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentKeys] != nil + let vCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentValues] != nil - let scores = output.featureValue(for: output.featureNames.first!)!.shapedArrayValue(of: Float.self)! - let nextTokenScores = scores[0, maxTokens - 1] - return nextTokenScores + guard Set([kCacheInput, vCacheInput, kCacheOutput, vCacheOutput]).count == 1 else { + fatalError("Invalid model configuration, expecting KV cache for inputs and outputs") + } + guard kCacheInput else { + return nil + } + // Check if cache is dynamic or not. + let kCacheConstraint = model.modelDescription.inputDescriptionsByName[Keys.keyCache]! + if isDynamicallyShaped(kCacheConstraint) { + fatalError(""" + KV Cache using IO is currently not supported, please file a feature request on \ + https://github.com/huggingface/swift-transformers + """) + } else { + fatalError(""" + KV Cache using IO is currently not supported, please file a feature request on \ + https://github.com/huggingface/swift-transformers + """) + } } } /// async properties downloaded from the configuration +@available(macOS 15.0, iOS 18.0, *) public extension LanguageModel { var modelConfig: Config { get async throws { @@ -193,21 +309,26 @@ public extension LanguageModel { var tokenizer: Tokenizer { get async throws { - guard _tokenizer == nil else { return _tokenizer! } + if let _tokenizer { + return _tokenizer + } guard let tokenizerConfig = try await tokenizerConfig else { - throw TokenizerError.tokenizerConfigNotFound + throw "Cannot retrieve Tokenizer configuration" } let tokenizerData = try await tokenizerData - _tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) + _tokenizer = try AutoTokenizer.from( + tokenizerConfig: tokenizerConfig, + tokenizerData: tokenizerData + ) return _tokenizer! } } } +@available(macOS 15.0, iOS 18.0, *) extension LanguageModel: TextGenerationModel { - // TODO: retrieve from the json: https://huggingface.co/nlpcloud/instruct-gpt-j-fp16/blob/main/config.json#L26 public var defaultGenerationConfig: GenerationConfig { - var config = GenerationConfig(maxNewTokens: 30) + var config = GenerationConfig(maxNewTokens: 2048) switch modelName.lowercased() { case let x where x.contains("gpt"): config.doSample = true @@ -218,15 +339,74 @@ extension LanguageModel: TextGenerationModel { } } -public enum TokenizerError: LocalizedError { - case tokenizerConfigNotFound +@available(macOS 15.0, iOS 18.0, *) +public class LanguageModelWithStatefulKVCache: LanguageModel { + private enum Mode { + case prefilling + case extending + } + private var mode: Mode = .prefilling + + var state: MLState? + + public required init(model: MLModel) { + super.init(model: model) + // To support pre-filling and extend, the input must support + // flexible shapes. + guard maxContextLength - minContextLength > 1 else { + fatalError("Expecting ranged query sequence length.") + } + } + + public override func resetState() async { + state = model.makeState() + mode = .prefilling + } - public var errorDescription: String? { - switch self { - case .tokenizerConfigNotFound: - String(localized: "Tokenizer configuration could not be found. The model may be missing required tokenizer files.", comment: "Error when tokenizer configuration is missing") + public override func predictNextTokenScores( + _ tokens: MLTensor, + config _: GenerationConfig + ) async -> MLTensor { + assert(tokens.rank == 2) // [batch, current sequence length] + let tokenCount = tokens.shape[1] + guard let state else { + fatalError(""" + Encountered uninitialized `state`. Ensure `resetState` is called prior to calling \ + `predictNextTokenScores`. + """) + } + let inputIds = switch mode { + case .prefilling: tokens // Pass in all takens if pre-filling prompt + case .extending: tokens[nil, -1].expandingShape(at: 0) // otherwise just the last token } + mode = .extending + + var inputDictionary = [ + Keys.inputIds: inputIds, + ] + if isRequiringAttentionMask { + // TODO: Infer scalar type from cache or model I/O descriptors + let attentionMask = MLTensor(zeros: [1, 1, 1, tokenCount + 1], scalarType: Float16.self) + inputDictionary[Keys.attentionMask] = attentionMask + } + if isRequiringCausalMask { + // TODO: Infer scalar type from cache or model I/O descriptors + let causalMask = MLTensor(zeros: [1, 1, 1, tokenCount + 1], scalarType: Float16.self) + inputDictionary[Keys.causalMask] = causalMask + } + let outputs = try! await model.prediction(from: inputDictionary, using: state) + + assert(outputs.keys.contains(Keys.logits)) + let scores = outputs[Keys.logits]! + assert(scores.rank == 3) + let tokenIndex = inputIds.shape[1] - 1 + let nextTokenScores = scores[nil, tokenIndex, nil].expandingShape(at: 0) + assert(nextTokenScores.rank == 3) + assert(nextTokenScores.shape[0] == 1 && nextTokenScores.shape[1] == 1) + return nextTokenScores } } +extension String: @retroactive Error {} + #endif // canImport(CoreML) diff --git a/Sources/Models/LanguageModelTypes.swift b/Sources/Models/LanguageModelTypes.swift index 5c9cc68b..0a639184 100644 --- a/Sources/Models/LanguageModelTypes.swift +++ b/Sources/Models/LanguageModelTypes.swift @@ -11,6 +11,9 @@ import CoreML import Generation import Tokenizers + +/// A causal language model. +@available(macOS 15.0, iOS 18.0, *) public protocol LanguageModelProtocol { /// `name_or_path` in the Python world var modelName: String { get } @@ -18,28 +21,54 @@ public protocol LanguageModelProtocol { var tokenizer: Tokenizer { get async throws } var model: MLModel { get } + /// Resets the state of the language model. + /// + /// Call `resetState()` for each new sequence generated. + func resetState() async + init(model: MLModel) - /// Make prediction callable (this works like __call__ in Python) - func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol - func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol + /// Returns the next token conditioned on the given input. + /// - Parameters: + /// - input: The input sequence to condition the language model. + /// - config: The generation configuration. + /// - Returns: The raw scores of the next token. + func predictNextTokenScores(_ input: MLTensor, config: GenerationConfig) async -> MLTensor } +@available(macOS 15.0, iOS 18.0, *) public extension LanguageModelProtocol { - func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { - predictNextTokenScores(tokens, config: config) + func callAsFunction(_ input: MLTensor, config: GenerationConfig) async -> MLTensor { + await predictNextTokenScores(input, config: config) } } +@available(macOS 15.0, iOS 18.0, *) public protocol TextGenerationModel: Generation, LanguageModelProtocol { var defaultGenerationConfig: GenerationConfig { get } - func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback?) async throws -> String + + func generate( + config: GenerationConfig, + prompt: String, + callback: PredictionStringCallback? + ) async throws -> String } +@available(macOS 15.0, iOS 18.0, *) public extension TextGenerationModel { @discardableResult func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async throws -> String { - try await generate(config: config, prompt: prompt, model: callAsFunction, tokenizer: tokenizer, callback: callback) + // Prepare the language model for a new sequence. + await resetState() + + // Run inference. + return try await generate( + config: config, + prompt: prompt, + model: callAsFunction, + tokenizer: tokenizer, + callback: callback + ) } } #endif // canImport(CoreML) diff --git a/Sources/TransformersCLI/Transformers.swift b/Sources/TransformersCLI/Transformers.swift new file mode 100644 index 00000000..75c243aa --- /dev/null +++ b/Sources/TransformersCLI/Transformers.swift @@ -0,0 +1,140 @@ +import ArgumentParser +import CoreML +import Foundation + +import Models +import Generation + +@available(macOS 15.0, iOS 18.0, *) +@main +struct TransformersCLI: AsyncParsableCommand { + static let configuration = CommandConfiguration( + abstract: "Run text generation on a Core ML language model", + version: "0.0.1" + ) + + @Argument(help: "Input text") + var prompt: String + + @Argument(help: "Path to Core ML mlpackage model") + var modelPath: String = "./model.mlpackage" + + @Option(help: "Maximum amount of tokens the model should generate") + var maxLength: Int = 100 + + @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") + var computeUnits: ComputeUnits = .cpuAndGPU + + @Option(help: """ + When enabled, two generation passes are ran, one to 'warm up' and another to collect \ + benchmark metrics. + """) + var warmup: Bool = false + + func generate( + model: LanguageModel, + config: GenerationConfig, + prompt: String, + printOutput: Bool = true + ) async throws { + var tokensReceived = 0 + var previousIndex: String.Index? = nil + var startTime = Date() + var promptProcessingTime: Double = 0 // seconds + try await model.generate(config: config, prompt: prompt) { inProgressGeneration in + if previousIndex == nil { // Prompt pre-filling + promptProcessingTime = Date().timeIntervalSince(startTime) + // Reset start time to more accurately compute the average tps. + startTime = Date() + } else { // Extend + // Only start counting tokens once the prompt has been processed. + tokensReceived += 1 + } + let response = formatResponse(inProgressGeneration) + if printOutput { + print(response[(previousIndex ?? response.startIndex)...], terminator: "") + fflush(stdout) + } + previousIndex = response.endIndex + } + // Current time - start time + elapsed time to process the prompt + let endTime = Date() + let completionTime = endTime.timeIntervalSince(startTime) + promptProcessingTime + let tps = Double(tokensReceived) / endTime.timeIntervalSince(startTime) + if printOutput { + print("") + print(""" + \(tps.formatted("%.2f")) tokens/s, \ + prompt pre-filling time: \(promptProcessingTime.formatted("%.2f"))s, \ + total time: \(completionTime.formatted("%.2f"))s + """) + } + } + + func compile(at url: URL) throws -> URL { + #if os(watchOS) + fatalError("Model compilation is not supported on watchOS") + #else + if url.pathExtension == "mlmodelc" { return url } + print("Compiling model \(url)") + return try MLModel.compileModel(at: url) + #endif + } + + func run() async throws { + let url = URL(filePath: modelPath) + let compiledURL = try compile(at: url) + print("Loading model \(compiledURL)") + let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits) + + // Using greedy generation for now + var config = model.defaultGenerationConfig + config.doSample = false + config.maxNewTokens = maxLength + + // Given the size of the out-of-model computation, dispatch all + // tensor operations to the CPU. + + if warmup { + print("Warming up...") + try await withMLTensorComputePolicy(.cpuOnly) { + try await generate(model: model, config: config, prompt: prompt, printOutput: false) + } + } + + print("Generating") + try await withMLTensorComputePolicy(.cpuOnly) { + try await generate(model: model, config: config, prompt: prompt) + } + } +} + +@available(macOS 15.0, iOS 18.0, *) +enum ComputeUnits: String, ExpressibleByArgument, CaseIterable { + case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine + var asMLComputeUnits: MLComputeUnits { + switch self { + case .all: return .all + case .cpuAndGPU: return .cpuAndGPU + case .cpuOnly: return .cpuOnly + case .cpuAndNeuralEngine: return .cpuAndNeuralEngine + } + } +} + +/// Returns a cleaned and formatted version of the response. +/// +/// - Parameter respone: The response to clean and format. +/// - Returns: A 'user friendly' representation of the generated response. +fileprivate func formatResponse(_ response: String) -> String { + response + .replacingOccurrences(of: "\\n", with: "\n") + .replacingOccurrences(of: "", with: "") + .replacingOccurrences(of: "", with: "") +} + +extension Double { + func formatted(_ format: String) -> String { + return String(format: "\(format)", self) + } +} From 4b3f7f6062a2c2bc37d68c5ef84a230f684e7695 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 8 Aug 2024 18:09:16 +0200 Subject: [PATCH 02/16] Remove Random (#115) --- Sources/Generation/Random.swift | 1000 ------------------------------- 1 file changed, 1000 deletions(-) delete mode 100644 Sources/Generation/Random.swift diff --git a/Sources/Generation/Random.swift b/Sources/Generation/Random.swift deleted file mode 100644 index 4f4ae693..00000000 --- a/Sources/Generation/Random.swift +++ /dev/null @@ -1,1000 +0,0 @@ -#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) -import Darwin -#else -import Glibc -#endif - -/// Type-erased random number generator. -internal class AnyRandomNumberGenerator: RandomNumberGenerator { - private var rng: RandomNumberGenerator - - /// Creates a type-erased random number generator. - /// - /// - Parameters: - /// - rng: A random number generator. - init(_ rng: RandomNumberGenerator) { - self.rng = rng - } - - func next() -> UInt64 { - rng.next() - } -} - -extension AnyRandomNumberGenerator: ParallelRandomNumberGenerator { - func next(count: Int) -> [UInt64] { - if let rng = rng as? ParallelRandomNumberGenerator { - return rng.next(count: count) - } - return (0..(count: Int, upperBound: T) -> [T] { - if let rng = rng as? ParallelRandomNumberGenerator { - return rng.next(count: count, upperBound: upperBound) - } - return (0..(seed: T) -} - -extension SeedableRandomNumberGenerator { - init(seed: T) { - var newSeed: [UInt8] = [] - for i in 0..> (UInt8.bitWidth * i))) - } - self.init(seed: newSeed) - } -} - -extension RandomNumberGenerator { - mutating func next(count: Int) -> [UInt64] { - if let generator = self as? ParallelRandomNumberGenerator { - return generator.next(count: count) - } else { - return (0..( - count: Int, - upperBound: T - ) -> [T] { - if let generator = self as? ParallelRandomNumberGenerator { - return generator.next(count: count, upperBound: upperBound) - } else { - return (0.. 0, "Length of seed must be positive") - precondition(seed.count <= 256, "Length of seed must be at most 256") - var j: UInt8 = 0 - for i: UInt8 in 0...255 { - j &+= S(i) &+ seed[Int(i) % seed.count] - swapAt(i, j) - } - } - - // Produce the next random UInt64 from the stream, and advance the internal - // state. - public mutating func next() -> UInt64 { - var result: UInt64 = 0 - for _ in 0.. UInt8 { - return state[Int(index)] - } - - // Helper to swap elements of the state. - private mutating func swapAt(_ i: UInt8, _ j: UInt8) { - state.swapAt(Int(i), Int(j)) - } - - // Generates the next byte in the keystream. - private mutating func nextByte() -> UInt8 { - iPos &+= 1 - jPos &+= S(iPos) - swapAt(iPos, jPos) - return S(S(iPos) &+ S(jPos)) - } -} - -/// An implementation of `SeedableRandomNumberGenerator` using Threefry. -/// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -/// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -/// -/// This struct implements a 20-round Threefry2x32 PRNG. It must be seeded with -/// a 64-bit value. -/// -/// An individual generator is not thread-safe, but distinct generators do not -/// share state. The random data generated is of high-quality, but is not -/// suitable for cryptographic applications. -struct ThreefryRandomNumberGenerator: SeedableRandomNumberGenerator { - private let rot: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32) - = (13, 15, 26, 6, 17, 29, 16, 24) - - private func rotl32(value: UInt32, n: UInt32) -> UInt32 { - return (value << (n & 31)) | (value >> ((32 - n) & 31)) - } - - private var ctr: UInt64 = 0 - private let key: SIMD2 - - private func random(forCtr ctr: SIMD2, key: SIMD2) -> SIMD2 { - let skeinKsParity32: UInt32 = 0x1BD11BDA - - let ks0 = key.x - let ks1 = key.y - let ks2 = skeinKsParity32 ^ key.x ^ key.y - var X0 = ctr.x - var X1 = ctr.y - - // 20 rounds - // Key injection (r = 0) - X0 &+= ks0 - X1 &+= ks1 - // R1 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.0) - X1 ^= X0 - // R2 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.1) - X1 ^= X0 - // R3 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.2) - X1 ^= X0 - // R4 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.3) - X1 ^= X0 - // Key injection (r = 1) - X0 &+= ks1 - X1 &+= (ks2 + 1) - // R5 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.4) - X1 ^= X0 - // R6 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.5) - X1 ^= X0 - // R7 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.6) - X1 ^= X0 - // R8 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.7) - X1 ^= X0 - // Key injection (r = 2) - X0 &+= ks2 - X1 &+= (ks0 + 2) - // R9 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.0) - X1 ^= X0 - // R10 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.1) - X1 ^= X0 - // R11 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.2) - X1 ^= X0 - // R12 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.3) - X1 ^= X0 - // Key injection (r = 3) - X0 &+= ks0 - X1 &+= (ks1 + 3) - // R13 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.4) - X1 ^= X0 - // R14 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.5) - X1 ^= X0 - // R15 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.6) - X1 ^= X0 - // R16 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.7) - X1 ^= X0 - // Key injection (r = 4) - X0 &+= ks1 - X1 &+= (ks2 + 4) - // R17 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.0) - X1 ^= X0 - // R18 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.1) - X1 ^= X0 - // R19 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.2) - X1 ^= X0 - // R20 - X0 &+= X1 - X1 = rotl32(value: X1, n: rot.3) - X1 ^= X0 - // Key injection (r = 5) - X0 &+= ks2 - X1 &+= (ks0 + 5) - - return [X0, X1] - } - - internal init(uint64Seed seed: UInt64) { - key = seed.vector2 - } - - public init(seed: [UInt8]) { - precondition(seed.count > 0, "Length of seed must be positive") - precondition(seed.count <= 8, "Length of seed must be at most 8") - var combinedSeed: UInt64 = 0 - for (i, byte) in seed.enumerated() { - combinedSeed += UInt64(byte) << UInt64(8 * i) - } - self.init(uint64Seed: combinedSeed) - } - - public mutating func next() -> UInt64 { - defer { ctr += 1 } - return UInt64(highAndLow: random(forCtr: ctr.vector2, key: key)) - } -} - -/// An implementation of `SeedableRandomNumberGenerator` using Philox. -/// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -/// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -/// -/// This struct implements a 10-round Philox4x32 PRNG. It must be seeded with -/// a 64-bit value. -/// -/// An individual generator is not thread-safe, but distinct generators do not -/// share state. The random data generated is of high-quality, but is not -/// suitable for cryptographic applications. -struct PhiloxRandomNumberGenerator: SeedableRandomNumberGenerator { - @usableFromInline - var counter: UInt64 = 0 - @usableFromInline - let key: SIMD2 - - // Since we generate two 64-bit values at a time, we only need to run the - // generator every other invocation. - @usableFromInline - var useNextValue = false - @usableFromInline - var nextValue: UInt64 = 0 - - @inlinable - func bump(key: SIMD2) -> SIMD2 { - SIMD2(0x9E3779B9, 0xBB67AE85) &+ key - } - - @inlinable - func round(counter: SIMD4, key: SIMD2) -> SIMD4 { - let roundConstants = SIMD2(0xD2511F53, 0xCD9E8D57) - let products = roundConstants &* SIMD2(UInt64(counter[0]), UInt64(counter[2])) - - let hi = SIMD2(truncatingIfNeeded: products &>> 32) - let lo = SIMD2(truncatingIfNeeded: products & 0x0000_0000_FFFF_FFFF) - return [ - hi[1] ^ counter[1] ^ key[0], - lo[1], - hi[0] ^ counter[3] ^ key[1], - lo[0] - ] - } - - @inlinable - func random( - forCounter initialCounter: SIMD4, - key initialKey: SIMD2 - ) -> SIMD4 { - var counter = initialCounter - var key = initialKey - // 10 rounds - // R1 - counter = round(counter: counter, key: key) - // R2 - key = bump(key: key) - counter = round(counter: counter, key: key) - // R3 - key = bump(key: key) - counter = round(counter: counter, key: key) - // R4 - key = bump(key: key) - counter = round(counter: counter, key: key) - // R5 - key = bump(key: key) - counter = round(counter: counter, key: key) - // R6 - key = bump(key: key) - counter = round(counter: counter, key: key) - // R7 - key = bump(key: key) - counter = round(counter: counter, key: key) - // R8 - key = bump(key: key) - counter = round(counter: counter, key: key) - // R9 - key = bump(key: key) - counter = round(counter: counter, key: key) - // R10 - key = bump(key: key) - counter = round(counter: counter, key: key) - - return counter - } - - @inlinable - public init(uint64Seed seed: UInt64) { - key = seed.vector2 - } - - @inlinable - public init(seed: [UInt8]) { - precondition(seed.count > 0, "Length of seed must be positive") - precondition(seed.count <= 8, "Length of seed must be at most 8") - var combinedSeed: UInt64 = 0 - for (i, byte) in seed.enumerated() { - combinedSeed += UInt64(byte) << UInt64(8 * i) - } - self.init(uint64Seed: combinedSeed) - } - - @inlinable - public mutating func next() -> UInt64 { - if useNextValue { - useNextValue = false - return nextValue - } - let pair = random(forCounter: counter.vector4, key: key).reinterpretedUInt64Vector - useNextValue = true - nextValue = pair.y - counter += 1 - return pair.x - } -} - -/// Private helpers. -extension UInt64 { - @inlinable - var vector2: SIMD2 { - let msb = UInt32(truncatingIfNeeded: self >> 32) - let lsb = UInt32(truncatingIfNeeded: self & 0x0000_0000_FFFF_FFFF) - return [msb, lsb] - } - - @inlinable - var vector4: SIMD4 { - let msb = UInt32(truncatingIfNeeded: self >> 32) - let lsb = UInt32(truncatingIfNeeded: self) - return [0, 0, msb, lsb] - } - - @inlinable - init(highAndLow: SIMD2) { - self = (UInt64(highAndLow.x) << 32) + UInt64(highAndLow.y) - } -} - -extension SIMD4 where Scalar == UInt32 { - @inlinable - var reinterpretedUInt64Vector: SIMD2 { - let a = (UInt64(x) << 32) + UInt64(y) - let b = (UInt64(z) << 32) + UInt64(w) - return [a, b] - } -} - -// MARK: - Random distributions - -import Dispatch - -protocol RandomDistribution { - associatedtype Sample - func next(using generator: inout G) -> Sample - func next(_ count: Int, using generator: inout G) -> [Sample] -} - -extension RandomDistribution { - @_specialize( - where Self == UniformFloatingPointDistribution, - G == DefaultRandomNumberGeneratorForTensor) - public func next(_ count: Int, using generator: inout G) -> [Sample] { - return Array( - unsafeUninitializedCapacity: count - ) { buffer, initializedCount in - for i in 0..(using generator: inout G) -> Bool { - Bool.random(using: &generator) - } - - public func next(_ count: Int, using generator: inout G) -> [Bool] { - Array.random(count: count, using: &generator) - } -} - -struct UniformIntegerDistribution: RandomDistribution { - public let bounds: ClosedRange - - public init(bounds: ClosedRange = T.min...T.max) { - self.bounds = bounds - } - - public func next(using generator: inout G) -> T { - return T.random(in: bounds, using: &generator) - } - - public func next(_ count: Int, using generator: inout G) -> [T] { - Array.random(count: count, in: bounds, using: &generator) - } -} - -struct UniformFloatingPointDistribution: RandomDistribution - where T.RawSignificand: FixedWidthInteger -{ - public let bounds: ClosedRange - - public init(bounds: ClosedRange = 0...1) { - self.bounds = bounds - } - - @_specialize(where T == Float, G == DefaultRandomNumberGeneratorForTensor) - public func next(using generator: inout G) -> T { - return T.random(in: bounds, using: &generator) - } - - @_specialize(where T == Float, G == DefaultRandomNumberGeneratorForTensor) - public func next(_ count: Int, using generator: inout G) -> [T] { - Array.random(count: count, in: bounds, using: &generator) - } -} - -struct NormalDistribution: RandomDistribution - where T.RawSignificand: FixedWidthInteger -{ - public let mean: T - public let standardDeviation: T - @usableFromInline - let uniformDistribution = UniformFloatingPointDistribution() - - public init(mean: T = 0, standardDeviation: T = 1) { - self.mean = mean - self.standardDeviation = standardDeviation - } - - @_specialize(where T == Float) - @inlinable - func normalized(_ u1: T, _ u2: T) -> T { - let r = (-2 * T(log(Float(u1)))).squareRoot() - let theta = 2 * T.pi * u2 - let normal01 = r * T(cos(Float(theta))) - return mean + standardDeviation * normal01 - } - - @_specialize(where T == Float, G == DefaultRandomNumberGeneratorForTensor) - @inlinable - public func next(using generator: inout G) -> T { - // FIXME: Box-Muller can generate two values for only a little more than the - // cost of one. - normalized( - uniformDistribution.next(using: &generator), - uniformDistribution.next(using: &generator)) - } - - @_specialize(where T == Float, G == DefaultRandomNumberGeneratorForTensor) - @inlinable - public func next(_ count: Int, using generator: inout G) -> [T] { - let uniformNumbers = uniformDistribution.next(count * 2, using: &generator) - return Array(unsafeUninitializedCapacity: count) { buffer, initializedCount in - DispatchQueue.concurrentPerform(iterations: count) { i in - let offset = i * 2 - buffer[i] = normalized(uniformNumbers[offset], uniformNumbers[offset + 1]) - } - initializedCount = count - } - } -} - -struct TruncatedNormalDistribution: RandomDistribution - where T.RawSignificand: FixedWidthInteger -{ - public let mean: T - public let standardDeviation: T - private let normalDistribution = NormalDistribution(mean: 0, standardDeviation: 1) - - public init(mean: T = 0, standardDeviation: T = 1) { - self.mean = mean - self.standardDeviation = standardDeviation - } - - public func next(using generator: inout G) -> T { - // FIXME: Implement this. See - // https://github.com/tensorflow/tensorflow/blob/b1a6b315a63bb29b4593bfb98095da4397d8cd5a/tensorflow/compiler/tf2xla/lib/random.cc#L42. - fatalError("Unimplemented") - } - - public func next(_ count: Int, using generator: inout G) -> [T] { - // FIXME: Implement this. See - // https://github.com/tensorflow/tensorflow/blob/b1a6b315a63bb29b4593bfb98095da4397d8cd5a/tensorflow/compiler/tf2xla/lib/random.cc#L42. - fatalError("Unimplemented") - } -} - -struct BetaDistribution: RandomDistribution { - public let alpha: Float - public let beta: Float - private let uniformDistribution = UniformFloatingPointDistribution() - - public init(alpha: Float = 0, beta: Float = 1) { - self.alpha = alpha - self.beta = beta - } - - public func next(using generator: inout G) -> Float { - // Generate a sample using Cheng's sampling algorithm from: - // R. C. H. Cheng, "Generating beta variates with nonintegral shape - // parameters.". Communications of the ACM, 21, 317-322, 1978. - let a = min(alpha, beta) - let b = max(alpha, beta) - if a > 1 { - return BetaDistribution.chengsAlgorithmBB(alpha, a, b, using: &generator) - } else { - return BetaDistribution.chengsAlgorithmBC(alpha, b, a, using: &generator) - } - } - - /// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BB - /// algorithm, when both alpha and beta are greater than 1. - /// - /// - Parameters: - /// - alpha: First Beta distribution shape parameter. - /// - a: `min(alpha, beta)`. - /// - b: `max(alpha, beta)`. - /// - generator: Random number generator. - /// - /// - Returns: Sample obtained using Cheng's BB algorithm. - private static func chengsAlgorithmBB( - _ alpha0: Float, - _ a: Float, - _ b: Float, - using generator: inout G - ) -> Float { - let alpha = a + b - let beta = sqrt((alpha - 2) / (2 * a * b - alpha)) - let gamma = a + 1 / beta - - var r: Float = 0.0 - var w: Float = 0.0 - var t: Float = 0.0 - - repeat { - let u1 = Float.random(in: 0.0...1.0, using: &generator) - let u2 = Float.random(in: 0.0...1.0, using: &generator) - let v = beta * (log(u1) - log1p(-u1)) - r = gamma * v - 1.3862944 - let z = u1 * u1 * u2 - w = a * exp(v) - - let s = a + r - w - if s + 2.609438 >= 5 * z { - break - } - - t = log(z) - if s >= t { - break - } - } while r + alpha * (log(alpha) - log(b + w)) < t - - w = min(w, Float.greatestFiniteMagnitude) - return a == alpha0 ? w / (b + w) : b / (b + w) - } - - /// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BC - /// algorithm, when at least one of alpha and beta is less than 1. - /// - /// - Parameters: - /// - alpha: First Beta distribution shape parameter. - /// - a: `max(alpha, beta)`. - /// - b: `min(alpha, beta)`. - /// - generator: Random number generator. - /// - /// - Returns: Sample obtained using Cheng's BB algorithm. - private static func chengsAlgorithmBC( - _ alpha0: Float, - _ a: Float, - _ b: Float, - using generator: inout G - ) -> Float { - let alpha = a + b - let beta = 1 / b - let delta = 1 + a - b - let k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778) - let k2 = 0.25 + (0.5 + 0.25 / delta) * b - - var w: Float = 0.0 - - while true { - let u1 = Float.random(in: 0.0...1.0, using: &generator) - let u2 = Float.random(in: 0.0...1.0, using: &generator) - let y = u1 * u2 - let z = u1 * y - - if u1 < 0.5 { - if 0.25 * u2 + z - y >= k1 { - continue - } - } else { - if z <= 0.25 { - let v = beta * (log(u1) - log1p(-u1)) - w = a * exp(v) - break - } - if z >= k2 { - continue - } - } - - let v = beta * (log(u1) - log1p(-u1)) - w = a * exp(v) - if alpha * (log(alpha) - log(b + 1) + v) - 1.3862944 >= log(z) { - break - } - } - - w = min(w, Float.greatestFiniteMagnitude) - return a == alpha0 ? w / (b + w): b / (b + w) - } -} - -// MARK: Parallel Random Number Generators - -protocol ParallelRandomNumberGenerator: RandomNumberGenerator { - func next(count: Int) -> [UInt64] - func next(count: Int, upperBound: T) -> [T] -} - -#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) -public struct SystemArc4RandomNumberGenerator: ParallelRandomNumberGenerator { - public mutating func next() -> UInt64 { - var result: UInt64 = 0 - arc4random_buf(&result, MemoryLayout.size) - return result - } - - public func next(count: Int) -> [UInt64] { - return Array(unsafeUninitializedCapacity: count) { buffer, size in - size = count - arc4random_buf( - UnsafeMutableRawPointer(buffer.baseAddress), - count * MemoryLayout.stride) - } - } - - public func next(count: Int, upperBound: T) -> [T] { - let rands = next(count: count) - return Array(unsafeUninitializedCapacity: count) { - bufferPointer, initializedCount in - DispatchQueue.concurrentPerform(iterations: count) { - bufferPointer[$0] = self.upperBound( - rands[$0], - to: upperBound) - } - initializedCount = count - } - } - - // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/Random.swift#L93 - private func upperBound( - _ val: UInt64, - to upperBound: T - ) -> T { - precondition(upperBound != 0, "upperBound cannot be zero.") - #if arch(i386) || arch(arm) || arch(arm64_32) // TODO(FIXME) SR-10912 - let tmp = (T.max % upperBound) + 1 - let range = tmp == upperBound ? 0 : tmp - var random = T(truncatingIfNeeded: val) - - while random < range { - withUnsafeMutablePointer(to: &random) { - arc4random_buf($0, MemoryLayout.size) - } - } - - return random % upperBound - #else - var random = T(truncatingIfNeeded: val) - var m = random.multipliedFullWidth(by: upperBound) - if m.low < upperBound { - let t = (0 &- upperBound) % upperBound - while m.low < t { - withUnsafeMutablePointer(to: &random) { - arc4random_buf($0, MemoryLayout.size) - } - m = random.multipliedFullWidth(by: upperBound) - } - } - return m.high - #endif - } -} -#endif - -// MARK: Random Array Generators - -extension Array where Element == Bool { - static func random( - count: Int, - using generator: inout RNG - ) -> Self { - let rands = generator.next(count: count) - return Self(unsafeUninitializedCapacity: count) { bufferPointer, initializedCount in - DispatchQueue.concurrentPerform(iterations: count) { - bufferPointer[$0] = (rands[$0] >> 17) & 1 == 0 - } - initializedCount = count - } - } -} - -extension Array where Element: BinaryFloatingPoint, Element.RawSignificand: FixedWidthInteger { - // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/FloatingPoint.swift#L2052 - static func random( - count: Int, - in range: Range, - using generator: inout RNG - ) -> Self { - let delta = range.upperBound - range.lowerBound - precondition(delta.isFinite, "There is no uniform distribution on an infinite range") - func randArray(_ count: Int) -> [UInt64] { - if Element.RawSignificand.bitWidth == Element.significandBitCount + 1 { - return generator.next(count: count) - } else { - let significandCount = Element.significandBitCount + 1 - let maxSignificand: Element.RawSignificand = 1 << significandCount - return generator.next(count: count).map { $0 & UInt64(maxSignificand - 1) } - } - } - return Self(unsafeUninitializedCapacity: count) { - resultBufferPointer, initializedCount in - var indicesNeedingResampling = [Int](resultBufferPointer.indices) - while !indicesNeedingResampling.isEmpty { - let rands = randArray(indicesNeedingResampling.count) - indicesNeedingResampling.withUnsafeMutableBufferPointer { - indicesNeedingResamplingBufferPtr in - DispatchQueue.concurrentPerform( - iterations: indicesNeedingResamplingBufferPtr.count - ) { i in - let rand = rands[i] - let unitRandom = Element(rand) * (Element.ulpOfOne / 2) - let randFloat = delta * unitRandom + range.lowerBound - if randFloat != range.upperBound { - let index = indicesNeedingResamplingBufferPtr[i] - resultBufferPointer[index] = randFloat - indicesNeedingResamplingBufferPtr[i] = -1 - } - } - } - indicesNeedingResampling.removeAll(where: { $0 == -1 }) - } - initializedCount = count - } - } - - // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/FloatingPoint.swift#L2152 - static func random( - count: Int, - in range: ClosedRange, - using generator: inout RNG - ) -> Self { - let delta = range.upperBound - range.lowerBound - precondition(delta.isFinite, "There is no uniform distribution on an infinite range") - func randArrays(_ count: Int) -> (rand: [UInt64], tmp: [UInt64]?) { - let rands: [UInt64] - var tmp: [UInt64]? = nil - if Element.RawSignificand.bitWidth == Element.significandBitCount + 1 { - rands = generator.next(count: count) - tmp = generator.next(count: count) - } else { - let significandCount = Element.significandBitCount + 1 - let maxSignificand: Element.RawSignificand = 1 << significandCount - rands = generator.next(count: count, upperBound: UInt64(maxSignificand + 1)) - } - return (rands, tmp) - } - - return Self(unsafeUninitializedCapacity: count) { - resultBufferPointer, initializedCount in - let indicesNeedingResampling = [Int](resultBufferPointer.indices) - let (rands, tmp) = randArrays(indicesNeedingResampling.count) - DispatchQueue.concurrentPerform( - iterations: indicesNeedingResampling.count - ) { i in - if Element.RawSignificand.bitWidth == Element.significandBitCount + 1 { - guard let tmp = tmp else { - fatalError("Expected the 'tmp' array to be initialized.") - } - if rands[i] == Element.RawSignificand.max && (tmp[i] & 1) == 1 { - let index = indicesNeedingResampling[i] - resultBufferPointer[index] = range.upperBound - return - } - } else { - let significandCount = Element.significandBitCount + 1 - let maxSignificand: Element.RawSignificand = 1 << significandCount - if rands[i] == maxSignificand { - let index = indicesNeedingResampling[i] - resultBufferPointer[index] = range.upperBound - return - } - } - let unitRandom = Element(rands[i]) * (Element.ulpOfOne / 2) - let randFloat = delta * unitRandom + range.lowerBound - let index = indicesNeedingResampling[i] - resultBufferPointer[index] = randFloat - } - initializedCount = count - } - } -} - -extension Array where Element: FixedWidthInteger { - // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/Integers.swift#L2663 - static func random( - count: Int, - in range: Range, - using generator: inout RNG - ) -> Self { - precondition(!range.isEmpty, "Can't get random value with an empty range") - // Compute delta, the distance between the lower and upper bounds. This - // value may not representable by the type Bound if Bound is signed, but - // is always representable as Bound.Magnitude. - let delta = Element.Magnitude(truncatingIfNeeded: range.upperBound &- range.lowerBound) - let rands = generator.next(count: count, upperBound: UInt64(delta)) - return Self(unsafeUninitializedCapacity: count) { bufferPointer, initializedCount in - DispatchQueue.concurrentPerform(iterations: count) { - // The mathematical result we want is lowerBound plus a random value in - // 0 ..< delta. We need to be slightly careful about how we do this - // arithmetic; the Bound type cannot generally represent the random value, - // so we use a wrapping addition on Bound.Magnitude. This will often - // overflow, but produces the correct bit pattern for the result when - // converted back to Bound. - bufferPointer[$0] = Element(truncatingIfNeeded: - Element.Magnitude(truncatingIfNeeded: range.lowerBound) &+ - Element.Magnitude(rands[$0])) - } - initializedCount = count - } - } - - // Implementation based on https://github.com/apple/swift/blob/master/stdlib/public/core/Integers.swift#L2732 - static func random( - count: Int, - in range: ClosedRange, - using generator: inout RNG - ) -> Self { - precondition(!range.isEmpty, "Can't get random value with an empty range") - // Compute delta, the distance between the lower and upper bounds. This - // value may not representable by the type Bound if Bound is signed, but - // is always representable as Bound.Magnitude. - var delta = Element.Magnitude(truncatingIfNeeded: range.upperBound &- range.lowerBound) - // Subtle edge case: if the range is the whole set of representable values, - // then adding one to delta to account for a closed range will overflow. - // If we used &+ instead, the result would be zero, which isn't helpful, - // so we actually need to handle this case separately. - if delta == Element.Magnitude.max { - return Self(unsafeUninitializedCapacity: count) { - bufferPointer, initializedCount in - DispatchQueue.concurrentPerform(iterations: count) { - bufferPointer[$0] = - Element(truncatingIfNeeded: generator.next() as Element.Magnitude) - } - initializedCount = count - } - } - // Need to widen delta to account for the right-endpoint of a closed range. - delta += 1 - let rands = generator.next(count: count, upperBound: UInt64(delta)) - return Self(unsafeUninitializedCapacity: count) { - bufferPointer, initializedCount in - DispatchQueue.concurrentPerform(iterations: count) { - // The mathematical result we want is lowerBound plus a random value in - // 0 ..< delta. We need to be slightly careful about how we do this - // arithmetic; the Bound type cannot generally represent the random value, - // so we use a wrapping addition on Bound.Magnitude. This will often - // overflow, but produces the correct bit pattern for the result when - // converted back to Bound. - bufferPointer[$0] = Element(truncatingIfNeeded: - Element.Magnitude(truncatingIfNeeded: range.lowerBound) &+ - Element.Magnitude(rands[$0])) - } - initializedCount = count - } - } -} - -extension RandomNumberGenerator { - /// Returns a random value within the specified range. - mutating func next(as type: T.Type, in bounds: Range) -> T - where T: BinaryFloatingPoint, T.RawSignificand: FixedWidthInteger { - return T.random(in: bounds, using: &self) - } - - /// Returns a random value within the specified range. - mutating func next(as type: T.Type, in bounds: ClosedRange) -> T - where T: BinaryFloatingPoint, T.RawSignificand: FixedWidthInteger { - return T.random(in: bounds, using: &self) - } -} From 9ba8173a2d1239279cc2d7f57d3b95692b90c463 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 24 Sep 2025 08:56:47 +0200 Subject: [PATCH 03/16] Throwing error when the configs fail JSON serialization (#114) * Added error for JSON serialization errors * Fix merge commit --------- Co-authored-by: Pedro Cuenca --- Sources/Hub/Hub.swift | 3 +++ Sources/Hub/HubApi.swift | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 4f03e822..330b6bab 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -14,6 +14,7 @@ public extension Hub { case authorizationRequired case httpStatusCode(Int) case parse + case jsonSerialization(fileURL: URL, message: String) case unexpectedError case downloadError(String) case fileNotFound(String) @@ -31,6 +32,8 @@ public extension Hub { String(localized: "HTTP error with status code: \(code)") case .parse: String(localized: "Failed to parse server response.") + case .jsonSerialization(_, let message): + return message case .unexpectedError: String(localized: "An unexpected error occurred.") case let .downloadError(message): diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 7def999e..e347bed5 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -268,7 +268,9 @@ public extension HubApi { /// `fileURL` is a complete local file path for the given model func configuration(fileURL: URL) throws -> Config { let data = try Data(contentsOf: fileURL) - let parsed = try JSONSerialization.bomPreservingJsonObject(with: data) + guard let parsed = try? JSONSerialization.bomPreservingJsonObject(with: data) else { + throw Hub.HubClientError.jsonSerialization(fileURL: fileURL, message: "JSON Serialization failed for \(fileURL). Please verify that you have set the HF_TOKEN environment variable.") + } guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse } return Config(dictionary) } From e6b6dd448519f28e287d51669190b6bd5b0bc2f3 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 9 Sep 2024 17:19:39 +0200 Subject: [PATCH 04/16] Allow archiving for Mac (#121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SPM dependencies are always compiled for the standard architectures, but Float16 is not available for `x86_64`. Thanks @joshnewnham for the workaround 🙌 --- Sources/Models/LanguageModel.swift | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 724cb315..65de45d0 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -385,14 +385,22 @@ public class LanguageModelWithStatefulKVCache: LanguageModel { Keys.inputIds: inputIds, ] if isRequiringAttentionMask { + #if !((os(macOS) || (macCatalyst)) && arch(x86_64)) // TODO: Infer scalar type from cache or model I/O descriptors let attentionMask = MLTensor(zeros: [1, 1, 1, tokenCount + 1], scalarType: Float16.self) inputDictionary[Keys.attentionMask] = attentionMask + #else + fatalError() + #endif } if isRequiringCausalMask { + #if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) // TODO: Infer scalar type from cache or model I/O descriptors let causalMask = MLTensor(zeros: [1, 1, 1, tokenCount + 1], scalarType: Float16.self) inputDictionary[Keys.causalMask] = causalMask + #else + fatalError() + #endif } let outputs = try! await model.prediction(from: inputDictionary, using: state) From d2bc390a6066629b9239b5e6faa04fcd531a8813 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Tue, 17 Dec 2024 12:43:08 +0000 Subject: [PATCH 05/16] chore: strategic deletes avoid OOM --- Examples/Mistral7B/export.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Examples/Mistral7B/export.py b/Examples/Mistral7B/export.py index fdebfd13..e1bc44a0 100644 --- a/Examples/Mistral7B/export.py +++ b/Examples/Mistral7B/export.py @@ -164,6 +164,8 @@ def export() -> None: input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32) causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32) traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask]) + kv_cache_shape = torch_model.kv_cache_shape + del torch_model # Convert traced TorchScript to Core ML format query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) @@ -179,11 +181,11 @@ def export() -> None: outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")] states: List[ct.StateType] = [ ct.StateType( - wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16), + wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16), name="keyCache", ), ct.StateType( - wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16), + wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16), name="valueCache", ), ] @@ -197,6 +199,7 @@ def export() -> None: minimum_deployment_target=ct.target.iOS18, skip_model_load=True, ) + del traced_model # Block-wise quantize model weights to int4 op_config = ct.optimize.coreml.OpLinearQuantizerConfig( @@ -208,6 +211,7 @@ def export() -> None: config = ct.optimize.coreml.OptimizationConfig(global_config=op_config) mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config) mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID}) + del mlmodel_fp16 mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage") From 7d7870b94579a73b27287126646e2b6c44dea917 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 24 Sep 2025 09:02:24 +0200 Subject: [PATCH 06/16] Remove RepetitionPenaltyWarper, fix build --- .../RepetitionPenaltyWarper.swift | 25 ------------------- Sources/Hub/Hub.swift | 2 +- 2 files changed, 1 insertion(+), 26 deletions(-) delete mode 100644 Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift deleted file mode 100644 index cbc5c707..00000000 --- a/Sources/Generation/LogitsWarper/RepetitionPenaltyWarper.swift +++ /dev/null @@ -1,25 +0,0 @@ -import Foundation - -/// `RepetitionPenaltyWarper` prevents the repetition of previous tokens through a penalty. -/// This penalty is applied at most once per token. -/// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L294 -public struct RepetitionPenaltyWarper: LogitsWarper { - public var penalty: Float - - public init(penalty: Double) { - self.penalty = Float(penalty) - } - - public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { - var logits = logits - for index in indices { - if logits[index] < 0 { - logits[index] *= penalty - } else { - logits[index] /= penalty - } - } - - return (indices, logits) - } -} diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 330b6bab..50e5d5fa 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -33,7 +33,7 @@ public extension Hub { case .parse: String(localized: "Failed to parse server response.") case .jsonSerialization(_, let message): - return message + message case .unexpectedError: String(localized: "An unexpected error occurred.") case let .downloadError(message): From 551ae06655da9049a657f506872365a53f9c22ed Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 24 Sep 2025 09:31:35 +0200 Subject: [PATCH 07/16] Remove GenerationTests --- Package.swift | 1 - Tests/GenerationTests/LogitsWarperTests.swift | 160 ------------------ Tests/GenerationTests/MathTests.swift | 59 ------- Tests/GenerationTests/TestUtils.swift | 7 - 4 files changed, 227 deletions(-) delete mode 100644 Tests/GenerationTests/LogitsWarperTests.swift delete mode 100644 Tests/GenerationTests/MathTests.swift delete mode 100644 Tests/GenerationTests/TestUtils.swift diff --git a/Package.swift b/Package.swift index 5f26e1f1..5d4f32d4 100644 --- a/Package.swift +++ b/Package.swift @@ -22,7 +22,6 @@ let package = Package( .target(name: "Hub", resources: [.process("Resources")], swiftSettings: swiftSettings), .target(name: "Models", dependencies: ["Tokenizers", "Generation"]), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]), - .testTarget(name: "GenerationTests", dependencies: ["Generation"]), .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")], swiftSettings: swiftSettings), .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]), diff --git a/Tests/GenerationTests/LogitsWarperTests.swift b/Tests/GenerationTests/LogitsWarperTests.swift deleted file mode 100644 index 5d97652a..00000000 --- a/Tests/GenerationTests/LogitsWarperTests.swift +++ /dev/null @@ -1,160 +0,0 @@ -// -// LogitsWarperTests.swift -// -// Created by Jan Krukowski on 09/12/2023. -// - -#if canImport(CoreML) -import CoreML -@testable import Generation -import Testing - -@Suite("Logits Warper Tests") -struct LogitsWarperTests { - private let accuracy: Float = 0.00001 - - @Test("Temperature logits warper functionality") - func temperatureLogitsWarper() { - let result1 = TemperatureLogitsWarper(temperature: 0.0)([], []) - #expect(result1.indices.isEmpty) - #expect(result1.logits.isEmpty) - - let result2 = TemperatureLogitsWarper(temperature: 1.0)([], []) - #expect(result2.indices.isEmpty) - #expect(result2.logits.isEmpty) - - let result3 = TemperatureLogitsWarper(temperature: 1.0)([0, 1], [2.0, 1.0]) - #expect(result3.indices == [0, 1]) - #expect(isClose(result3.logits, [2.0, 1.0], accuracy: accuracy)) - - let result4 = TemperatureLogitsWarper(temperature: 2.0)([0, 1], [2.0, 1.0]) - #expect(result4.indices == [0, 1]) - #expect(isClose(result4.logits, [1.0, 0.5], accuracy: accuracy)) - - let result5 = TemperatureLogitsWarper(temperature: 0.5)([0, 1], [2.0, 1.0]) - #expect(result5.indices == [0, 1]) - #expect(isClose(result5.logits, [4.0, 2.0], accuracy: accuracy)) - - let result6 = TemperatureLogitsWarper(temperature: 0.5)([200, 100], [2.0, 1.0]) - #expect(result6.indices == [200, 100]) - #expect(isClose(result6.logits, [4.0, 2.0], accuracy: accuracy)) - } - - @Test("Top-K logits warper functionality") - func topKLogitsWarper() { - let result1 = TopKLogitsWarper(k: 0)([], []) - #expect(result1.indices.isEmpty) - #expect(result1.logits.isEmpty) - - let result2 = TopKLogitsWarper(k: 3)([], []) - #expect(result2.indices.isEmpty) - #expect(result2.logits.isEmpty) - - let result3 = TopKLogitsWarper(k: 3)([0, 1], [2.0, 1.0]) - #expect(result3.indices == [0, 1]) - #expect(isClose(result3.logits, [2.0, 1.0], accuracy: accuracy)) - - let result4 = TopKLogitsWarper(k: 3)([0, 1, 2], [2.0, 1.0, 3.0]) - #expect(result4.indices == [2, 0, 1]) - #expect(isClose(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) - - let result5 = TopKLogitsWarper(k: 4)([0, 1, 2, 3, 4, 5], [2.0, 1.0, 3.0, -1.0, 123.0, 0.0]) - #expect(result5.indices == [4, 2, 0, 1]) - #expect(isClose(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy)) - - let result6 = TopKLogitsWarper(k: 3)([10, 1, 52], [2.0, 1.0, 3.0]) - #expect(result6.indices == [52, 10, 1]) - #expect(isClose(result6.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) - } - - @Test("Top-P logits warper functionality") - func topPLogitsWarper() { - let result1 = TopPLogitsWarper(p: 0.99)([], []) - #expect(result1.indices.isEmpty) - #expect(result1.logits.isEmpty) - - let logits = (0..<10).map { Float($0) } - let indices = Array(logits.indices) - let result2 = TopPLogitsWarper(p: 0.99)(indices, logits) - #expect(result2.indices == [9, 8, 7, 6, 5]) - #expect(isClose(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy)) - - let result3 = TopPLogitsWarper(p: 0.95)(indices, logits) - #expect(result3.indices == [9, 8, 7]) - #expect(isClose(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy)) - - let result4 = TopPLogitsWarper(p: 0.6321493)(indices, logits) - #expect(result4.indices == [9, 8]) - #expect(isClose(result4.logits, [9.0, 8.0], accuracy: accuracy)) - - let result5 = TopPLogitsWarper(p: 0.95)([3, 1, 8], [0, 1, 2]) - #expect(result5.indices == [8, 1, 3]) - #expect(isClose(result5.logits, [2, 1, 0], accuracy: accuracy)) - } - - @Test("Repetition penalty warper functionality") - func repetitionPenaltyWarper() { - let indices = Array(0..<10) - let logits = indices.map { Float($0) } - - let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits) - #expect(result1.indices == indices) - #expect(isClose(result1.logits, logits, accuracy: accuracy)) - - let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits) - #expect(result2.indices == indices) - let logits2 = indices.map { Float($0) / 3.75 } - #expect(isClose(result2.logits, logits2, accuracy: accuracy)) - - let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119]) - #expect(result3.indices == [0, 1, 2]) - #expect(isClose(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4)) - - let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158]) - #expect(result4.indices == [2, 3, 4]) - #expect(isClose(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4)) - - let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966]) - #expect(result5.indices == [0, 1, 2]) - #expect(isClose(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4)) - - let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755]) - #expect(result6.indices == [3, 1, 2]) - #expect(isClose(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4)) - } - - @Test("Logits processor functionality") - func logitsProcessor() { - let processor1 = LogitsProcessor(logitsWarpers: []) - let result1 = processor1([]) - #expect(result1.indices.isEmpty) - #expect(result1.logits.isEmpty) - - let processor2 = LogitsProcessor(logitsWarpers: []) - let result2 = processor2([2.0, 1.0]) - #expect(result2.indices == [0, 1]) - #expect(isClose(result2.logits, [2.0, 1.0], accuracy: accuracy)) - - let processor3 = LogitsProcessor( - logitsWarpers: [TopKLogitsWarper(k: 3)] - ) - let result3 = processor3([2.0, 1.0, 3.0, -5.0]) - #expect(result3.indices == [2, 0, 1]) - #expect(isClose(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) - - let processor4 = LogitsProcessor( - logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)] - ) - let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5]) - #expect(result4.indices == [5]) - #expect(isClose(result4.logits, [12.5], accuracy: accuracy)) - - let processor5 = LogitsProcessor( - logitsWarpers: [TopKLogitsWarper(k: 4), TopPLogitsWarper(p: 0.99)] - ) - let result5 = processor5([2.0, 1.0, 3.0, -5.0, -3.0, 4.5]) - #expect(result5.indices == [5, 2, 0, 1]) - #expect(isClose(result5.logits, [4.5, 3.0, 2.0, 1.0], accuracy: accuracy)) - } -} -#endif // canImport(CoreML) diff --git a/Tests/GenerationTests/MathTests.swift b/Tests/GenerationTests/MathTests.swift deleted file mode 100644 index 0e255ecd..00000000 --- a/Tests/GenerationTests/MathTests.swift +++ /dev/null @@ -1,59 +0,0 @@ -// -// MathTests.swift -// -// Created by Jan Krukowski on 25/11/2023. -// - -#if canImport(CoreML) -import CoreML -@testable import Generation -import Testing - -@Suite("Math Tests") -struct MathTests { - private let accuracy: Float = 0.00001 - - @Test("Cumulative sum functionality") - func cumsum() { - #expect(Math.cumsum([]).isEmpty) - #expect(Math.cumsum([1]) == [1]) - #expect(Math.cumsum([1, 2, 3, 4]) == [1, 3, 6, 10]) - } - - @Test("Argmax functionality") - func argmax() throws { - let result1 = Math.argmax([3.0, 4.0, 1.0, 2.0] as [Float], count: 4) - #expect(result1.0 == 1) - #expect(result1.1 == 4.0) - - let result2 = Math.argmax32([3.0, 4.0, 1.0, 2.0], count: 4) - #expect(result2.0 == 1) - #expect(result2.1 == 4.0) - - let result3 = Math.argmax([3.0, 4.0, 1.0, 2.0] as [Double], count: 4) - #expect(result3.0 == 1) - #expect(result3.1 == 4.0) - - let result4 = try Math.argmax32(MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Float])) - #expect(result4.0 == 1) - #expect(result4.1 == 4.0) - - let result5 = try Math.argmax(MLMultiArray([3.0, 4.0, 1.0, 2.0] as [Double])) - #expect(result5.0 == 1) - #expect(result5.1 == 4.0) - - let result6 = Math.argmax(MLShapedArray(scalars: [3.0, 4.0, 1.0, 2.0] as [Float], shape: [4])) - #expect(result6.0 == 1) - #expect(result6.1 == 4.0) - } - - @Test("Softmax functionality") - func softmax() { - #expect(Math.softmax([]) == []) - - let result1 = Math.softmax([3.0, 4.0, 1.0, 2.0]) - #expect(isClose(result1, [0.23688284, 0.6439143, 0.032058604, 0.08714432], accuracy: accuracy)) - #expect(abs(result1.reduce(0, +) - 1.0) < accuracy) - } -} -#endif // canImport(CoreML) diff --git a/Tests/GenerationTests/TestUtils.swift b/Tests/GenerationTests/TestUtils.swift deleted file mode 100644 index 4f9f2a59..00000000 --- a/Tests/GenerationTests/TestUtils.swift +++ /dev/null @@ -1,7 +0,0 @@ -import Foundation - -/// Check if two floating-point arrays are equal within a given accuracy -func isClose(_ lhs: [T], _ rhs: [T], accuracy: T) -> Bool { - guard lhs.count == rhs.count else { return false } - return zip(lhs, rhs).allSatisfy { abs($0.0 - $0.1) <= accuracy } -} From 67f1b08ba6fcfaab3668924df33208674dac4b58 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 24 Sep 2025 09:35:44 +0200 Subject: [PATCH 08/16] Restore TokenizerError --- Sources/Models/LanguageModel.swift | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 65de45d0..f29a660b 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -309,17 +309,12 @@ public extension LanguageModel { var tokenizer: Tokenizer { get async throws { - if let _tokenizer { - return _tokenizer - } + guard _tokenizer == nil else { return _tokenizer! } guard let tokenizerConfig = try await tokenizerConfig else { - throw "Cannot retrieve Tokenizer configuration" + throw TokenizerError.tokenizerConfigNotFound } let tokenizerData = try await tokenizerData - _tokenizer = try AutoTokenizer.from( - tokenizerConfig: tokenizerConfig, - tokenizerData: tokenizerData - ) + _tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) return _tokenizer! } } @@ -415,6 +410,15 @@ public class LanguageModelWithStatefulKVCache: LanguageModel { } } -extension String: @retroactive Error {} +public enum TokenizerError: LocalizedError { + case tokenizerConfigNotFound + + public var errorDescription: String? { + switch self { + case .tokenizerConfigNotFound: + String(localized: "Tokenizer configuration could not be found. The model may be missing required tokenizer files.", comment: "Error when tokenizer configuration is missing") + } + } +} #endif // canImport(CoreML) From b9c8f0c70da0eea5dfc9325926990cbdf0838003 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 24 Sep 2025 11:18:57 +0200 Subject: [PATCH 09/16] Fix deprecation warnings in tests --- Tests/HubTests/DownloaderTests.swift | 4 ++-- Tests/HubTests/HubApiTests.swift | 18 +++++++++--------- Tests/TokenizersTests/BertTokenizerTests.swift | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Tests/HubTests/DownloaderTests.swift b/Tests/HubTests/DownloaderTests.swift index aecfa778..552fbf3e 100644 --- a/Tests/HubTests/DownloaderTests.swift +++ b/Tests/HubTests/DownloaderTests.swift @@ -181,11 +181,11 @@ final class DownloaderTests: XCTestCase { ) { error in XCTAssertEqual((error as NSError).code, 516) } - XCTAssertEqual(try String(contentsOf: destination), "existing") + XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), "existing") XCTAssertNoThrow( try FileManager.default.moveDownloadedFile(from: source2, to: destination, percentEncoded: false) ) - XCTAssertEqual(try String(contentsOf: destination), "v2") + XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), "v2") } } diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index aee71a6d..10781d47 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -570,7 +570,7 @@ class SnapshotDownloadTests: XCTestCase { // If that's the case we just update the metadata and keep the local file. XCTAssertEqual(originalTimestamp, thirdDownloadTimestamp) - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) // Updated metadata file needs to have the correct commit hash, etag and timestamp. // This is being updated because the local etag (SHA256 checksum) matches the remote etag @@ -601,7 +601,7 @@ class SnapshotDownloadTests: XCTestCase { ) let metadataFile = metadataDestination.appendingPathComponent("llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata") - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nfc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107" XCTAssertTrue(metadataString.contains(expected)) @@ -631,7 +631,7 @@ class SnapshotDownloadTests: XCTestCase { ) let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" XCTAssertTrue(metadataString.contains(expected)) @@ -661,7 +661,7 @@ class SnapshotDownloadTests: XCTestCase { ) let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let metadataArr = metadataString.components(separatedBy: .newlines) let commitHash = metadataArr[0] @@ -719,7 +719,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" XCTAssertTrue(metadataString.contains(expected)) @@ -770,7 +770,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" XCTAssertTrue(metadataString.contains(expected)) @@ -821,7 +821,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataString = try String(contentsOfFile: metadataFile.path, encoding: .utf8) let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nd6ceb92ce9e3c83ab146dc8e92a93517ac1cc66f" XCTAssertTrue(metadataString.contains(expected)) @@ -1008,7 +1008,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) + let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path, encoding: .utf8) let expected = """ { @@ -1047,7 +1047,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path) + let fileContents = try String(contentsOfFile: downloadedTo.appendingPathComponent("config.json").path, encoding: .utf8) let expected = """ X diff --git a/Tests/TokenizersTests/BertTokenizerTests.swift b/Tests/TokenizersTests/BertTokenizerTests.swift index 7cad3e34..f37f13e0 100644 --- a/Tests/TokenizersTests/BertTokenizerTests.swift +++ b/Tests/TokenizersTests/BertTokenizerTests.swift @@ -86,7 +86,7 @@ private enum Squad { private let bertTokenizer: BertTokenizer = { let vocab = { let url = Bundle.module.url(forResource: "bert-vocab", withExtension: "txt")! - let vocabTxt = try! String(contentsOf: url) + let vocabTxt = try! String(contentsOf: url, encoding: .utf8) let tokens = vocabTxt.split(separator: "\n").map { String($0) } var vocab: [String: Int] = [:] for (i, token) in tokens.enumerated() { From c48ccb1eb29bc9f777173255e6ddb6be5f1eb94b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 25 Sep 2025 14:30:55 +0200 Subject: [PATCH 10/16] Move transformers-cli to an example --- Examples/Mistral7B/README.md | 2 +- Examples/transformers-cli/Package.swift | 22 +++++++++++++++++++ .../transformers-cli}/Transformers.swift | 19 ++++++++-------- 3 files changed, 33 insertions(+), 10 deletions(-) create mode 100644 Examples/transformers-cli/Package.swift rename {Sources/TransformersCLI => Examples/transformers-cli/Sources/transformers-cli}/Transformers.swift (94%) diff --git a/Examples/Mistral7B/README.md b/Examples/Mistral7B/README.md index 61123251..8e851304 100644 --- a/Examples/Mistral7B/README.md +++ b/Examples/Mistral7B/README.md @@ -15,7 +15,7 @@ Running compression: 100%|██████████████████ ### Generate Text ```shell -✗ swift run transformers "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage +✗ swift run transformers-cli "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage Best recommendations for a place to visit in Paris in August 2024: diff --git a/Examples/transformers-cli/Package.swift b/Examples/transformers-cli/Package.swift new file mode 100644 index 00000000..50ae0777 --- /dev/null +++ b/Examples/transformers-cli/Package.swift @@ -0,0 +1,22 @@ +// swift-tools-version: 6.2 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "transformers-cli", + platforms: [.iOS(.v18), .macOS(.v15)], + dependencies: [ + .package(url: "https://github.com/huggingface/swift-transformers", branch: "main"), + .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"), + ], + targets: [ + .executableTarget( + name: "transformers-cli", + dependencies: [ + .product(name: "Transformers", package: "swift-transformers"), + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ) + ] +) diff --git a/Sources/TransformersCLI/Transformers.swift b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift similarity index 94% rename from Sources/TransformersCLI/Transformers.swift rename to Examples/transformers-cli/Sources/transformers-cli/Transformers.swift index 75c243aa..cb24e3ca 100644 --- a/Sources/TransformersCLI/Transformers.swift +++ b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift @@ -1,9 +1,8 @@ import ArgumentParser import CoreML import Foundation - -import Models import Generation +import Models @available(macOS 15.0, iOS 18.0, *) @main @@ -25,10 +24,11 @@ struct TransformersCLI: AsyncParsableCommand { @Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") var computeUnits: ComputeUnits = .cpuAndGPU - @Option(help: """ - When enabled, two generation passes are ran, one to 'warm up' and another to collect \ - benchmark metrics. - """) + @Option( + help: """ + When enabled, two generation passes are ran, one to 'warm up' and another to collect \ + benchmark metrics. + """) var warmup: Bool = false func generate( @@ -63,7 +63,8 @@ struct TransformersCLI: AsyncParsableCommand { let tps = Double(tokensReceived) / endTime.timeIntervalSince(startTime) if printOutput { print("") - print(""" + print( + """ \(tps.formatted("%.2f")) tokens/s, \ prompt pre-filling time: \(promptProcessingTime.formatted("%.2f"))s, \ total time: \(completionTime.formatted("%.2f"))s @@ -86,7 +87,7 @@ struct TransformersCLI: AsyncParsableCommand { let compiledURL = try compile(at: url) print("Loading model \(compiledURL)") let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits) - + // Using greedy generation for now var config = model.defaultGenerationConfig config.doSample = false @@ -126,7 +127,7 @@ enum ComputeUnits: String, ExpressibleByArgument, CaseIterable { /// /// - Parameter respone: The response to clean and format. /// - Returns: A 'user friendly' representation of the generated response. -fileprivate func formatResponse(_ response: String) -> String { +private func formatResponse(_ response: String) -> String { response .replacingOccurrences(of: "\\n", with: "\n") .replacingOccurrences(of: "", with: "") From a7e812ae9338a1d130a623c7b772c03a9d6ea696 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 25 Sep 2025 14:34:50 +0200 Subject: [PATCH 11/16] Format --- Sources/Generation/Decoders.swift | 2 +- Sources/Generation/Generation.swift | 25 ++++++++++++------------ Sources/Models/LanguageModel.swift | 26 ++++++++++++++----------- Sources/Models/LanguageModelTypes.swift | 1 - 4 files changed, 29 insertions(+), 25 deletions(-) diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift index 173ff97e..b7fccb95 100644 --- a/Sources/Generation/Decoders.swift +++ b/Sources/Generation/Decoders.swift @@ -12,7 +12,7 @@ func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, let temperatureAdjustedScores = scores / temperature let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK) let topKProbs = topKScores.softmax(alongAxis: -1) - let rnd = topKProbs.sum() * Float.random(in: 0 ..< 1) + let rnd = topKProbs.sum() * Float.random(in: 0..<1) var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1) accumTopKProbs += (accumTopKProbs .< rnd) * 100.0 let topKIndex = accumTopKProbs.argsort()[..., 0] diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index f08fcb5e..4a6e2445 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -47,18 +47,19 @@ extension Generation { var outputTokens = MLTensor(tokens).expandingShape(at: 0) while outputTokens.shape[1] < config.maxLength { let nextTokenScores = await model(outputTokens, config) - let nextToken = switch config.generationMode { - case .greedy: - selectNextTokenUsingGreedyDecoding(from: nextTokenScores) - case .sample: - selectNextTokenUsingTopKSampling( - from: nextTokenScores, - temperature: config.temperature, - topK: config.topK - ) - default: - fatalError("Generation mode \(config.generationMode) not implemented yet") - } + let nextToken = + switch config.generationMode { + case .greedy: + selectNextTokenUsingGreedyDecoding(from: nextTokenScores) + case .sample: + selectNextTokenUsingTopKSampling( + from: nextTokenScores, + temperature: config.temperature, + topK: config.topK + ) + default: + fatalError("Generation mode \(config.generationMode) not implemented yet") + } if let nextTokenId = await tensorToGenerationOutput(nextToken).first, nextTokenId == config.eosTokenId { break diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index f29a660b..7fc61d8b 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -28,7 +28,7 @@ public class LanguageModel { configuration = LanguageModelConfigurationFromHub(modelName: modelName) } - public func resetState() async { } + public func resetState() async {} public func predictNextTokenScores( _ tokens: MLTensor, @@ -160,7 +160,7 @@ public extension LanguageModel { var description: String { if let description = metadata[MLModelMetadataKey.description] as? String, - !description.isEmpty + !description.isEmpty { return description } @@ -170,7 +170,7 @@ public extension LanguageModel { /// `name_or_path` in the Python world var modelName: String { if let userFields = metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String], - let name = userFields["co.huggingface.exporters.name"] + let name = userFields["co.huggingface.exporters.name"] { return name } @@ -241,12 +241,14 @@ public extension LanguageModel { // Check if cache is dynamic or not. let kCacheConstraint = model.modelDescription.inputDescriptionsByName[Keys.keyCache]! if isDynamicallyShaped(kCacheConstraint) { - fatalError(""" + fatalError( + """ KV Cache using IO is currently not supported, please file a feature request on \ https://github.com/huggingface/swift-transformers """) } else { - fatalError(""" + fatalError( + """ KV Cache using IO is currently not supported, please file a feature request on \ https://github.com/huggingface/swift-transformers """) @@ -365,19 +367,21 @@ public class LanguageModelWithStatefulKVCache: LanguageModel { assert(tokens.rank == 2) // [batch, current sequence length] let tokenCount = tokens.shape[1] guard let state else { - fatalError(""" + fatalError( + """ Encountered uninitialized `state`. Ensure `resetState` is called prior to calling \ `predictNextTokenScores`. """) } - let inputIds = switch mode { - case .prefilling: tokens // Pass in all takens if pre-filling prompt - case .extending: tokens[nil, -1].expandingShape(at: 0) // otherwise just the last token - } + let inputIds = + switch mode { + case .prefilling: tokens // Pass in all takens if pre-filling prompt + case .extending: tokens[nil, -1].expandingShape(at: 0) // otherwise just the last token + } mode = .extending var inputDictionary = [ - Keys.inputIds: inputIds, + Keys.inputIds: inputIds ] if isRequiringAttentionMask { #if !((os(macOS) || (macCatalyst)) && arch(x86_64)) diff --git a/Sources/Models/LanguageModelTypes.swift b/Sources/Models/LanguageModelTypes.swift index 0a639184..892ba545 100644 --- a/Sources/Models/LanguageModelTypes.swift +++ b/Sources/Models/LanguageModelTypes.swift @@ -11,7 +11,6 @@ import CoreML import Generation import Tokenizers - /// A causal language model. @available(macOS 15.0, iOS 18.0, *) public protocol LanguageModelProtocol { From d785220fa97d7756b62dea18d6c01160c9851ff6 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 25 Sep 2025 15:43:24 +0200 Subject: [PATCH 12/16] Relax requirements for main package But keep iOS 18 / macOS 15 for Core ML --- Package.swift | 2 +- Sources/Generation/Decoders.swift | 4 ++++ Sources/Models/LanguageModel.swift | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index 5d4f32d4..2617b751 100644 --- a/Package.swift +++ b/Package.swift @@ -10,7 +10,7 @@ let swiftSettings: [SwiftSetting] = [ let package = Package( name: "swift-transformers", - platforms: [.iOS("18.0"), .macOS("15.0")], + platforms: [.iOS(.v17), .macOS(.v14)], products: [ .library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]) ], diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift index b7fccb95..fafbd1b4 100644 --- a/Sources/Generation/Decoders.swift +++ b/Sources/Generation/Decoders.swift @@ -1,13 +1,16 @@ +#if canImport(CoreML) import CoreML // MARK: Greedy Decoding +@available(macOS 15.0, iOS 18.0, *) func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor { scores.argmax(alongAxis: -1).reshaped(to: [1, 1]) } // MARK: Top-K Sampling +@available(macOS 15.0, iOS 18.0, *) func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor { let temperatureAdjustedScores = scores / temperature let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK) @@ -22,3 +25,4 @@ func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, ) return nextTokenTensor.reshaped(to: [1, 1]) } +#endif // canImport(CoreML) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 7fc61d8b..36a8086a 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -137,6 +137,7 @@ public extension LanguageModel { } } +@available(macOS 15.0, iOS 18.0, *) extension LanguageModel { enum KVCacheAvailability { /// Language models that support KV cache via state. Implementation details for handling state From 0296d28619284e06cf16c797db81746ae1992ec6 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 25 Sep 2025 16:28:21 +0200 Subject: [PATCH 13/16] Revert platform requirements --- Package.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Package.swift b/Package.swift index 2617b751..af09c380 100644 --- a/Package.swift +++ b/Package.swift @@ -10,7 +10,7 @@ let swiftSettings: [SwiftSetting] = [ let package = Package( name: "swift-transformers", - platforms: [.iOS(.v17), .macOS(.v14)], + platforms: [.iOS(.v16), .macOS(.v13)], products: [ .library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]) ], From b0dd129de788e61612c4abc775905cf6568f1f73 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 25 Sep 2025 16:33:10 +0200 Subject: [PATCH 14/16] Relative package location plus comment --- Examples/transformers-cli/Package.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Examples/transformers-cli/Package.swift b/Examples/transformers-cli/Package.swift index 50ae0777..f360afb1 100644 --- a/Examples/transformers-cli/Package.swift +++ b/Examples/transformers-cli/Package.swift @@ -7,7 +7,9 @@ let package = Package( name: "transformers-cli", platforms: [.iOS(.v18), .macOS(.v15)], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers", branch: "main"), + .package(path: "../.."), + // If you copy this manifest as a template, use the following line instead + //.package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"), .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"), ], targets: [ From 6225c7ff1fe05388a003d3681806fae81a4499c3 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 25 Sep 2025 16:41:29 +0200 Subject: [PATCH 15/16] Mistral example: uv-ify and unpin --- Examples/Mistral7B/README.md | 2 +- Examples/Mistral7B/export.py | 11 +++++++++++ Examples/Mistral7B/generate.py | 11 +++++++++++ Examples/Mistral7B/requirements.txt | 12 ++++++------ 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/Examples/Mistral7B/README.md b/Examples/Mistral7B/README.md index 8e851304..30d64b14 100644 --- a/Examples/Mistral7B/README.md +++ b/Examples/Mistral7B/README.md @@ -1,7 +1,7 @@ ### Export Mistral 7B Instruct v0.3 ```shell -✗ python export.py +✗ uv run export.py Loading checkpoint shards: 100%|███████████████████████████| 3/3 [00:12<00:00, 4.11s/it] Converting PyTorch Frontend ==> MIL Ops: 100%|███| 5575/5575 [00:02<00:00, 2440.66 ops/s] diff --git a/Examples/Mistral7B/export.py b/Examples/Mistral7B/export.py index e1bc44a0..4509a937 100644 --- a/Examples/Mistral7B/export.py +++ b/Examples/Mistral7B/export.py @@ -1,3 +1,14 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "coremltools", +# "numpy", +# "sentencepiece", +# "torch", +# "tqdm", +# "transformers", +# ] +# /// import logging import os import warnings diff --git a/Examples/Mistral7B/generate.py b/Examples/Mistral7B/generate.py index 9e373592..0c90acd0 100644 --- a/Examples/Mistral7B/generate.py +++ b/Examples/Mistral7B/generate.py @@ -1,3 +1,14 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "coremltools", +# "numpy", +# "sentencepiece", +# "torch", +# "tqdm", +# "transformers", +# ] +# /// import argparse from typing import Dict, Generator, List, Tuple diff --git a/Examples/Mistral7B/requirements.txt b/Examples/Mistral7B/requirements.txt index f0f1fa68..3c954fb2 100644 --- a/Examples/Mistral7B/requirements.txt +++ b/Examples/Mistral7B/requirements.txt @@ -1,6 +1,6 @@ -coremltools==8.0b1 -numpy==1.26.4 -torch==2.3.1 -tqdm==4.66.4 -transformers==4.42.3 -sentencepiece==0.2.0 +coremltools +numpy +torch +tqdm +transformers +sentencepiece From 29afbf2edebb131cd94100584d96816a1706e6ae Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 25 Sep 2025 16:46:12 +0200 Subject: [PATCH 16/16] Remove obsolete GenerationTests again --- Package.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/Package.swift b/Package.swift index fb69f9f3..dc6c3f5a 100644 --- a/Package.swift +++ b/Package.swift @@ -24,7 +24,6 @@ let package = Package( .target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings), .target(name: "Models", dependencies: ["Tokenizers", "Generation"]), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]), - .testTarget(name: "GenerationTests", dependencies: ["Generation"]), .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings), .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]),