-
Notifications
You must be signed in to change notification settings - Fork 178
Stateful cache, MLTensor #257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 12 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
e0b3637
feat: preview
pcuenca 4b3f7f6
Remove Random (#115)
pcuenca 9ba8173
Throwing error when the configs fail JSON serialization (#114)
pcuenca e6b6dd4
Allow archiving for Mac (#121)
pcuenca d2bc390
chore: strategic deletes avoid OOM
FL33TW00D 7d7870b
Remove RepetitionPenaltyWarper, fix build
pcuenca 551ae06
Remove GenerationTests
pcuenca 67f1b08
Restore TokenizerError
pcuenca b9c8f0c
Fix deprecation warnings in tests
pcuenca 785c0ce
Merge remote-tracking branch 'origin/main' into preview-2025
pcuenca c48ccb1
Move transformers-cli to an example
pcuenca a7e812a
Format
pcuenca d785220
Relax requirements for main package
pcuenca 697dc34
Merge remote-tracking branch 'origin/main' into preview-2025
pcuenca 0296d28
Revert platform requirements
pcuenca b0dd129
Relative package location plus comment
pcuenca d1cacbe
Merge branch 'main' of github.com:huggingface/swift-transformers into…
pcuenca 6225c7f
Mistral example: uv-ify and unpin
pcuenca 29afbf2
Remove obsolete GenerationTests again
pcuenca File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| ### Export Mistral 7B Instruct v0.3 | ||
|
|
||
| ```shell | ||
| ✗ python export.py | ||
|
|
||
| Loading checkpoint shards: 100%|███████████████████████████| 3/3 [00:12<00:00, 4.11s/it] | ||
| Converting PyTorch Frontend ==> MIL Ops: 100%|███| 5575/5575 [00:02<00:00, 2440.66 ops/s] | ||
| Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 7.12 passes/s] | ||
| Running MIL default pipeline: 100%|█████████████████| 79/79 [02:36<00:00, 1.98s/ passes] | ||
| Running MIL backend_mlprogram pipeline: 100%|███████| 12/12 [00:00<00:00, 22.90 passes/s] | ||
| Running compression: 100%|███████████████████████████| 296/296 [03:04<00:00, 1.60 ops/s] | ||
| ... | ||
| ``` | ||
|
|
||
| ### Generate Text | ||
|
|
||
| ```shell | ||
| ✗ swift run transformers-cli "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage | ||
|
|
||
| Best recommendations for a place to visit in Paris in August 2024: | ||
|
|
||
| 1. Palace of Versailles: This iconic palace is a must-visit. It's a short train ride from Paris and offers a glimpse into the opulence of the French monarchy. | ||
|
|
||
| 2. Eiffel Tower: No trip to Paris is complete without a visit to the Eiffel Tower. You can take an elevator ride to the top for a stunning view of the city. | ||
|
|
||
| 3. Louvre Museum: Home to thousands of works of art, including the Mona Lisa and the Winged Victory of Samothrace, the Louvre is a cultural treasure. | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| import logging | ||
| import os | ||
| import warnings | ||
| from typing import List, Optional, Tuple | ||
|
|
||
| import coremltools as ct | ||
| import numpy as np | ||
| import torch | ||
| from transformers.cache_utils import Cache | ||
| from transformers.models.mistral.modeling_mistral import ( | ||
| MISTRAL_ATTENTION_CLASSES, | ||
| MistralAttention, | ||
| MistralConfig, | ||
| MistralForCausalLM, | ||
| apply_rotary_pos_emb, | ||
| repeat_kv, | ||
| ) | ||
|
|
||
| warnings.filterwarnings("ignore") | ||
| logging.getLogger("coremltools").setLevel(logging.ERROR) | ||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
|
||
| # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 | ||
| MODEL_ID: str = "mistralai/Mistral-7B-Instruct-v0.3" | ||
| METADATA_TOKENIZER: str = "co.huggingface.exporters.name" | ||
|
|
||
|
|
||
| class SliceUpdateKeyValueCache(Cache): | ||
| def __init__( | ||
| self, | ||
| shape: Tuple[int, ...], | ||
| device="cpu", | ||
| dtype=torch.float32, | ||
| ) -> None: | ||
| """KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim).""" | ||
| super().__init__() | ||
| self.past_seen_tokens: int = 0 | ||
| self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device) | ||
| self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device) | ||
|
|
||
| def update( | ||
| self, | ||
| k_state: torch.Tensor, | ||
| v_state: torch.Tensor, | ||
| layer_idx: int, | ||
| slice_indices: torch.LongTensor, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]). | ||
| Return slice of key/value cache tensors from [0, slice_indices[1]). | ||
| """ | ||
| if len(slice_indices) != 2: | ||
| raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.") | ||
| begin, end = slice_indices | ||
| self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state | ||
| self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state | ||
| k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :] | ||
| v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :] | ||
| return k_cache, v_cache | ||
|
|
||
| def get_seq_length(self, _: int | None = 0) -> int: | ||
| """Get the sequence length of the cache.""" | ||
| return self.past_seen_tokens | ||
|
|
||
|
|
||
| class SliceUpdateMistralAttention(MistralAttention): | ||
| def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): | ||
| super().__init__(config=config, layer_idx=layer_idx) | ||
|
|
||
| @torch.no_grad() | ||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: torch.Tensor, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_value: Optional[Cache] = None, | ||
| **kwargs, | ||
| ) -> Tuple[torch.Tensor | None, ...]: | ||
| bsz, q_len, _ = hidden_states.size() | ||
|
|
||
| query_states = self.q_proj(hidden_states) | ||
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( | ||
| 1, 2 | ||
| ) | ||
| value_states = value_states.view( | ||
| bsz, q_len, self.num_key_value_heads, self.head_dim | ||
| ).transpose(1, 2) | ||
|
|
||
| cos, sin = self.rotary_emb(value_states, position_ids) | ||
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
|
|
||
| # Slice update key/value cache | ||
| end_step = attention_mask.shape[-1] | ||
| key_states, value_states = past_key_value.update( | ||
| key_states, | ||
| value_states, | ||
| self.layer_idx, | ||
| slice_indices=(end_step - q_len, end_step), | ||
| ) | ||
|
|
||
| key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
| value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
|
||
| attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| attn_mask=attention_mask, | ||
| ) | ||
|
|
||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||
| attn_output = attn_output.view(bsz, q_len, self.hidden_size) | ||
| attn_output = self.o_proj(attn_output) | ||
| return attn_output, None, None | ||
|
|
||
|
|
||
| class StatefulMistralForCausalLM(torch.nn.Module): | ||
| def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None: | ||
| super().__init__() | ||
|
|
||
| # Custom attention implementation for stateful slice update key/value cache, override | ||
| # "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation | ||
| MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention | ||
| self.model = MistralForCausalLM.from_pretrained(model_path) | ||
|
|
||
| # Register KV cache buffers to be recognized as Core ML states | ||
| config: MistralConfig = self.model.config | ||
| self.kv_cache_shape: Tuple[int, ...] = ( | ||
| config.num_hidden_layers, | ||
| batch_size, | ||
| config.num_key_value_heads, | ||
| max_context_size, | ||
| config.hidden_size // config.num_attention_heads, | ||
| ) | ||
| self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape) | ||
| self.register_buffer("keyCache", self.kv_cache.k_cache) | ||
| self.register_buffer("valueCache", self.kv_cache.v_cache) | ||
|
|
||
| @torch.no_grad() | ||
| def forward( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
| causal_mask: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # Compute past seen tokens used for updating key/value cache slices | ||
| self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1] | ||
| return self.model( | ||
| input_ids, | ||
| attention_mask=causal_mask, | ||
| past_key_values=self.kv_cache, | ||
| use_cache=True, | ||
| ).logits | ||
|
|
||
|
|
||
| def export() -> None: | ||
| # Construct model from transformers and trace to TorchScript | ||
| max_context_size: int = 2048 | ||
| torch_model = StatefulMistralForCausalLM(MODEL_ID, max_context_size=max_context_size) | ||
| torch_model.eval() | ||
| input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32) | ||
| causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32) | ||
| traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask]) | ||
| kv_cache_shape = torch_model.kv_cache_shape | ||
| del torch_model | ||
|
|
||
| # Convert traced TorchScript to Core ML format | ||
| query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) | ||
| end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1) | ||
| inputs: List[ct.TensorType] = [ | ||
| ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"), | ||
| ct.TensorType( | ||
| shape=(1, 1, query_length, end_step_dim), | ||
| dtype=np.float16, | ||
| name="causalMask", | ||
| ), | ||
| ] | ||
| outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")] | ||
| states: List[ct.StateType] = [ | ||
| ct.StateType( | ||
| wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16), | ||
| name="keyCache", | ||
| ), | ||
| ct.StateType( | ||
| wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16), | ||
| name="valueCache", | ||
| ), | ||
| ] | ||
|
|
||
| # Convert model with FP16 precision | ||
| mlmodel_fp16: ct.MLModel = ct.convert( | ||
| traced_model, | ||
| inputs=inputs, | ||
| outputs=outputs, | ||
| states=states, | ||
| minimum_deployment_target=ct.target.iOS18, | ||
| skip_model_load=True, | ||
| ) | ||
| del traced_model | ||
|
|
||
| # Block-wise quantize model weights to int4 | ||
| op_config = ct.optimize.coreml.OpLinearQuantizerConfig( | ||
| mode="linear_symmetric", | ||
| dtype="int4", | ||
| granularity="per_block", | ||
| block_size=32, | ||
| ) | ||
| config = ct.optimize.coreml.OptimizationConfig(global_config=op_config) | ||
| mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config) | ||
| mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID}) | ||
| del mlmodel_fp16 | ||
| mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| export() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| import argparse | ||
| from typing import Dict, Generator, List, Tuple | ||
|
|
||
| import numpy as np | ||
| from coremltools.models import MLModel | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from export import METADATA_TOKENIZER | ||
|
|
||
|
|
||
| def load(model_path: str) -> Tuple[MLModel, AutoTokenizer]: | ||
| """Load a Core ML model and corresponding tokenizer.""" | ||
| model: MLModel = MLModel(model_path) | ||
| description = model.get_spec().description | ||
| if METADATA_TOKENIZER not in description.metadata.userDefined: | ||
| raise ValueError("Model metadata does not contain tokenizer path.") | ||
| tokenizer_path: str = description.metadata.userDefined[METADATA_TOKENIZER] | ||
| tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | ||
| return model, tokenizer | ||
|
|
||
|
|
||
| def get_next_token(model: MLModel, prompt_tokens: np.ndarray) -> Generator[int, None, None]: | ||
| """Generate a sequence of tokens with naive greedy decoding.""" | ||
|
|
||
| def sample(logits: np.ndarray) -> int: | ||
| """Perform greedy decoding on the logits array to get the next token.""" | ||
| return int(np.argmax(logits[0][-1], axis=-1)) | ||
|
|
||
| def inference(model: MLModel, input_ids: np.ndarray, num_past_tokens: int) -> np.ndarray: | ||
| """Perform inference with the given model and input data.""" | ||
| causal_mask: np.ndarray = np.triu( | ||
| np.full( | ||
| (1, 1, input_ids.shape[-1], num_past_tokens + input_ids.shape[-1]), | ||
| fill_value=-np.inf if num_past_tokens == 0 else 0, | ||
| ), | ||
| k=1, | ||
| ).astype(np.float16) | ||
| outputs: Dict[str, np.ndarray] = model.predict( | ||
| data={"inputIds": input_ids, "causalMask": causal_mask}, | ||
| state=kv_cache_state, | ||
| ) | ||
| return outputs["logits"] | ||
|
|
||
| kv_cache_state = model.make_state() | ||
| logits: np.ndarray = inference(model, input_ids=prompt_tokens, num_past_tokens=0) | ||
| token: int = sample(logits=logits) | ||
| num_past_tokens: int = prompt_tokens.shape[-1] | ||
|
|
||
| while True: | ||
| yield token | ||
| logits: np.ndarray = inference( | ||
| model, | ||
| input_ids=np.array([[token]], dtype=np.int32), | ||
| num_past_tokens=num_past_tokens, | ||
| ) | ||
| token: int = sample(logits=logits) | ||
| num_past_tokens += 1 | ||
|
|
||
|
|
||
| def generate( | ||
| model: MLModel, | ||
| prompt: str, | ||
| tokenizer: AutoTokenizer, | ||
| max_new_tokens: int, | ||
| ) -> str: | ||
| prompt_tokens: np.ndarray = tokenizer(prompt, return_tensors="np").input_ids | ||
| extend_tokens: List[int] = [] | ||
| for i, token in enumerate(get_next_token(model, prompt_tokens=prompt_tokens.astype(np.int32))): | ||
| if token == tokenizer.eos_token_id or i == max_new_tokens: | ||
| break | ||
| extend_tokens.append(token) | ||
| return tokenizer.decode(prompt_tokens[0].tolist() + extend_tokens) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("model_path", type=str) | ||
| parser.add_argument("--prompt", type=str, default="Hello") | ||
| parser.add_argument("--max_new_tokens", type=int, default=128) | ||
| args = parser.parse_args() | ||
| model, tokenizer = load(args.model_path) | ||
| extend_text: str = generate( | ||
| model, | ||
| prompt=args.prompt, | ||
| tokenizer=tokenizer, | ||
| max_new_tokens=args.max_new_tokens, | ||
| ) | ||
| print(extend_text) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| coremltools==8.0b1 | ||
| numpy==1.26.4 | ||
| torch==2.3.1 | ||
| tqdm==4.66.4 | ||
| transformers==4.42.3 | ||
| sentencepiece==0.2.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| // swift-tools-version: 6.2 | ||
| // The swift-tools-version declares the minimum version of Swift required to build this package. | ||
|
|
||
| import PackageDescription | ||
|
|
||
| let package = Package( | ||
| name: "transformers-cli", | ||
| platforms: [.iOS(.v18), .macOS(.v15)], | ||
| dependencies: [ | ||
| .package(url: "https://github.com/huggingface/swift-transformers", branch: "main"), | ||
| .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"), | ||
| ], | ||
| targets: [ | ||
| .executableTarget( | ||
| name: "transformers-cli", | ||
| dependencies: [ | ||
| .product(name: "Transformers", package: "swift-transformers"), | ||
| .product(name: "ArgumentParser", package: "swift-argument-parser"), | ||
| ] | ||
| ) | ||
| ] | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use a relative import here