diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 5ac7279292bb15..5277e9f966dc50 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -81,6 +81,7 @@ For now the supported model architectures are the architectures that have been v - Qwen2Moe - Phi3 - Bloom +- Falcon ## Example usage diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index a4fd90f2bfe473..92371415918150 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -345,7 +345,6 @@ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] ) ) - add_prefix_space = False add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False) tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space) tokenizer.decoder = decoders.ByteLevel() diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index d7a548791e8a70..ca39b5ef5f917a 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -120,6 +120,29 @@ "output.weight": "lm_head.weight", "output_norm": "transformer.ln_f", }, + "falcon7b": { + "token_embd": "word_embeddings", + "blk": "h", + "ffn_up": "mlp.dense_h_to_4h", + "ffn_down": "mlp.dense_4h_to_h", + "attn_norm": "input_layernorm", + "attn_qkv": "self_attention.query_key_value", + "attn_output": "self_attention.dense", + ".output.": ".lm_head.", + "output_norm": "ln_f", + }, + "falcon40b": { + "token_embd": "word_embeddings", + "blk": "h", + "ffn_up": "mlp.dense_h_to_4h", + "ffn_down": "mlp.dense_4h_to_h", + ".attn_norm.": ".ln_mlp.", + "attn_norm_2": "ln_attn", + "attn_qkv": "self_attention.query_key_value", + "attn_output": "self_attention.dense", + ".output.": ".lm_head.", + "output_norm": "ln_f", + }, } @@ -178,6 +201,18 @@ "attention.layer_norm_rms_epsilon": "rms_norm_eps", "vocab_size": "vocab_size", }, + "falcon": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "vocab_size": "vocab_size", + }, "tokenizer": { "ggml.bos_token_id": "bos_token_id", "ggml.eos_token_id": "eos_token_id", @@ -530,6 +565,7 @@ def converted(self) -> Tokenizer: "qwen2_moe": GGUFQwen2Converter, "phi3": GGUFPhi3Converter, "bloom": GGUFBloomConverter, + "falcon": GGUFBloomConverter, } diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index c2e06624e15714..3bca05b1251f3f 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Optional import numpy as np @@ -99,8 +100,20 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): if "qwen2moe" in architecture: updated_architecture = "qwen2_moe" - if architecture not in GGUF_SUPPORTED_ARCHITECTURES: - raise ValueError(f"Architecture {architecture} not supported") + model_size = "" + # extract the number of params from file name as architectures can differ ; + # eg. for falcon : `...falcon-7b-...` + if "falcon" in architecture: + gguf_file_name = gguf_checkpoint_path.split("/")[-1].lower() + m = re.search(r"-\d+b-", gguf_file_name) # regex to catch `-7b-` + if m is None: + raise ValueError( + f"From file name, cannot determine the number of parameters for {architecture} architecture" + ) + model_size = m.group().strip("-") # only keeps `7b` + + if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES: + raise ValueError(f"Architecture {architecture + model_size} not supported") # List all key-value pairs in a columnized format for gguf_key, field in reader.fields.items(): @@ -146,17 +159,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): ) if return_tensors: - tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture] + tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size] for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."): - renamed_tensor_name = tensor.name - - for tensor_name_mapping in GGUF_TO_TRANSFORMERS_MAPPING["tensors"]: - if tensor_name_mapping in renamed_tensor_name: - renamed_tensor_name = renamed_tensor_name.replace( - tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping] - ) - name = tensor.name weights = dequantize(tensor.data, tensor.tensor_type) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 13e64677be5c42..ddc6288f36dd31 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -44,6 +44,9 @@ class GgufIntegrationTests(unittest.TestCase): phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf" bloom_model_id = "afrideva/bloom-560m-GGUF" original_bloom_model_id = "bigscience/bloom-560m" + falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf" + falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf" + original_flacon7b_model_id = "tiiuae/falcon-7b" # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" @@ -74,6 +77,9 @@ class GgufIntegrationTests(unittest.TestCase): fp16_bloom_model_id = "bloom-560m.fp16.gguf" q8_bloom_model_id = "bloom-560m.q8_0.gguf" f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf" + q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf" + fp16_falcon7b_model_id = "falcon-7b-fp16.gguf" + q2_k_falcon40b_model_id = "tiiuae-falcon-40b-Q2_K.gguf" example_text = "Hello" @@ -445,6 +451,58 @@ def test_bloom_weights_conversion_fp16(self): self.assertTrue(quantized_param.shape == original_param.shape) torch.testing.assert_close(quantized_param, original_param) + @unittest.skip(reason="Heavy memory") + def test_falcon40b_q2_k(self): + tokenizer = AutoTokenizer.from_pretrained(self.falcon40b_model_id, gguf_file=self.q2_k_falcon40b_model_id) + model = AutoModelForCausalLM.from_pretrained( + self.falcon40b_model_id, + gguf_file=self.q2_k_falcon40b_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello All,\nI am new to this forum." + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_falcon7b_q2_k(self): + tokenizer = AutoTokenizer.from_pretrained(self.falcon7b_model_id, gguf_file=self.q2_k_falcon7b_model_id) + model = AutoModelForCausalLM.from_pretrained( + self.falcon7b_model_id, + gguf_file=self.q2_k_falcon7b_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello All,\nI am new to this forum." + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_falcon7b_weights_conversion_fp16(self): + quantized_model = AutoModelForCausalLM.from_pretrained( + self.falcon7b_model_id, + gguf_file=self.fp16_falcon7b_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + original_model = AutoModelForCausalLM.from_pretrained( + self.original_flacon7b_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + quantized_state_dict = quantized_model.state_dict() + original_state_dict = original_model.state_dict() + + for layer_name, original_params in original_state_dict.items(): + if layer_name in quantized_state_dict: + self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape) + torch.testing.assert_close(original_params, quantized_state_dict[layer_name]) + def test_tokenization_xnli(self): import tqdm from datasets import load_dataset