Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add falcon gguf #33437

Merged
merged 9 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ For now the supported model architectures are the architectures that have been v
- Qwen2Moe
- Phi3
- Bloom
- Falcon

## Example usage

Expand Down
1 change: 0 additions & 1 deletion src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]]
)
)

add_prefix_space = False
add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
Expand Down
36 changes: 36 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,29 @@
"output.weight": "lm_head.weight",
"output_norm": "transformer.ln_f",
},
"falcon7b": {
"token_embd": "word_embeddings",
"blk": "h",
"ffn_up": "mlp.dense_h_to_4h",
"ffn_down": "mlp.dense_4h_to_h",
"attn_norm": "input_layernorm",
"attn_qkv": "self_attention.query_key_value",
"attn_output": "self_attention.dense",
".output.": ".lm_head.",
"output_norm": "ln_f",
},
"falcon40b": {
"token_embd": "word_embeddings",
"blk": "h",
"ffn_up": "mlp.dense_h_to_4h",
"ffn_down": "mlp.dense_4h_to_h",
".attn_norm.": ".ln_mlp.",
"attn_norm_2": "ln_attn",
"attn_qkv": "self_attention.query_key_value",
"attn_output": "self_attention.dense",
".output.": ".lm_head.",
"output_norm": "ln_f",
},
}


Expand Down Expand Up @@ -178,6 +201,18 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"falcon": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
Expand Down Expand Up @@ -530,6 +565,7 @@ def converted(self) -> Tokenizer:
"qwen2_moe": GGUFQwen2Converter,
"phi3": GGUFPhi3Converter,
"bloom": GGUFBloomConverter,
"falcon": GGUFBloomConverter,
}


Expand Down
27 changes: 16 additions & 11 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -99,8 +100,20 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
if "qwen2moe" in architecture:
updated_architecture = "qwen2_moe"

if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture} not supported")
model_size = ""
# extract the number of params from file name as architectures can differ ;
# eg. for falcon : `...falcon-7b-...`
if "falcon" in architecture:
gguf_file_name = gguf_checkpoint_path.split("/")[-1].lower()
m = re.search(r"-\d+b-", gguf_file_name) # regex to catch `-7b-`
if m is None:
raise ValueError(
f"From file name, cannot determine the number of parameters for {architecture} architecture"
)
model_size = m.group().strip("-") # only keeps `7b`

if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture + model_size} not supported")

# List all key-value pairs in a columnized format
for gguf_key, field in reader.fields.items():
Expand Down Expand Up @@ -146,17 +159,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
)

if return_tensors:
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture]
tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size]

for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
renamed_tensor_name = tensor.name

for tensor_name_mapping in GGUF_TO_TRANSFORMERS_MAPPING["tensors"]:
if tensor_name_mapping in renamed_tensor_name:
renamed_tensor_name = renamed_tensor_name.replace(
tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping]
)

SunMarc marked this conversation as resolved.
Show resolved Hide resolved
name = tensor.name

weights = dequantize(tensor.data, tensor.tensor_type)
Expand Down
58 changes: 58 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class GgufIntegrationTests(unittest.TestCase):
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
bloom_model_id = "afrideva/bloom-560m-GGUF"
original_bloom_model_id = "bigscience/bloom-560m"
falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf"
falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf"
original_flacon7b_model_id = "tiiuae/falcon-7b"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -74,6 +77,9 @@ class GgufIntegrationTests(unittest.TestCase):
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf"
fp16_falcon7b_model_id = "falcon-7b-fp16.gguf"
q2_k_falcon40b_model_id = "tiiuae-falcon-40b-Q2_K.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -445,6 +451,58 @@ def test_bloom_weights_conversion_fp16(self):
self.assertTrue(quantized_param.shape == original_param.shape)
torch.testing.assert_close(quantized_param, original_param)

@unittest.skip(reason="Heavy memory")
def test_falcon40b_q2_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.falcon40b_model_id, gguf_file=self.q2_k_falcon40b_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.falcon40b_model_id,
gguf_file=self.q2_k_falcon40b_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello All,\nI am new to this forum."
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_falcon7b_q2_k(self):
tokenizer = AutoTokenizer.from_pretrained(self.falcon7b_model_id, gguf_file=self.q2_k_falcon7b_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.falcon7b_model_id,
gguf_file=self.q2_k_falcon7b_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello All,\nI am new to this forum."
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

def test_falcon7b_weights_conversion_fp16(self):
quantized_model = AutoModelForCausalLM.from_pretrained(
self.falcon7b_model_id,
gguf_file=self.fp16_falcon7b_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
original_model = AutoModelForCausalLM.from_pretrained(
self.original_flacon7b_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

quantized_state_dict = quantized_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in quantized_state_dict:
self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, quantized_state_dict[layer_name])

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down
Loading