Skip to content

Commit 9d200cf

Browse files
Add gguf support for bloom (#33473)
* add bloom arch support for gguf * apply format * small refactoring, bug fix in GGUF_TENSOR_MAPPING naming * optimize bloom GGUF_TENSOR_MAPPING * implement reverse reshaping for bloom gguf * add qkv weights test * add q_8 test for bloom
1 parent 3e039d3 commit 9d200cf

File tree

6 files changed

+140
-8
lines changed

6 files changed

+140
-8
lines changed

docs/source/en/gguf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ For now the supported model architectures are the architectures that have been v
8080
- Qwen2
8181
- Qwen2Moe
8282
- Phi3
83+
- Bloom
8384

8485
## Example usage
8586

src/transformers/convert_slow_tokenizer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,11 @@ def converted(self) -> Tokenizer:
328328

329329

330330
class GPT2Converter(Converter):
331-
def converted(self) -> Tokenizer:
332-
vocab = self.original_tokenizer.encoder
333-
merges = list(self.original_tokenizer.bpe_ranks.keys())
331+
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
332+
if not vocab:
333+
vocab = self.original_tokenizer.encoder
334+
if not merges:
335+
merges = list(self.original_tokenizer.bpe_ranks)
334336

335337
tokenizer = Tokenizer(
336338
BPE(
@@ -343,9 +345,11 @@ def converted(self) -> Tokenizer:
343345
)
344346
)
345347

346-
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
348+
add_prefix_space = False
349+
add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
350+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
347351
tokenizer.decoder = decoders.ByteLevel()
348-
if self.original_tokenizer.add_bos_token:
352+
if getattr(self.original_tokenizer, "add_bos_token", False):
349353
bos = self.original_tokenizer.bos_token
350354
bos_token_id = self.original_tokenizer.bos_token_id
351355
tokenizer.post_processor = processors.TemplateProcessing(

src/transformers/integrations/ggml.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from tokenizers.models import BPE
2626

2727
from .. import AddedToken
28-
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter
28+
from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter
2929
from ..utils import logging
3030
from ..utils.logging import tqdm
3131

@@ -107,6 +107,19 @@
107107
"output.weight": "lm_head.weight",
108108
"output_norm": "model.norm",
109109
},
110+
"bloom": {
111+
"token_embd.weight": "transformer.word_embeddings.weight",
112+
"token_embd_norm": "transformer.word_embeddings_layernorm",
113+
"blk": "transformer.h",
114+
"ffn_up": "mlp.dense_h_to_4h",
115+
"ffn_down": "mlp.dense_4h_to_h",
116+
"ffn_norm": "post_attention_layernorm",
117+
"attn_norm": "input_layernorm",
118+
"attn_qkv": "self_attention.query_key_value",
119+
"attn_output": "self_attention.dense",
120+
"output.weight": "lm_head.weight",
121+
"output_norm": "transformer.ln_f",
122+
},
110123
}
111124

112125

@@ -183,6 +196,13 @@
183196
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
184197
"vocab_size": "vocab_size",
185198
},
199+
"bloom": {
200+
"block_count": "n_layer",
201+
"embedding_length": "hidden_size",
202+
"attention.head_count": "n_head",
203+
"vocab_size": "vocab_size",
204+
"attention.layer_norm_epsilon": "layer_norm_epsilon",
205+
},
186206
}
187207

188208
GGUF_TOKENIZER_MAPPING = {
@@ -492,11 +512,24 @@ def converted(self) -> Tokenizer:
492512
return tokenizer
493513

494514

515+
class GGUFBloomConverter(GPT2Converter):
516+
def __init__(self, tokenizer_dict):
517+
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
518+
self.additional_kwargs = {}
519+
520+
def converted(self) -> Tokenizer:
521+
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
522+
merges = self.original_tokenizer.merges
523+
tokenizer = super().converted(vocab, merges)
524+
return tokenizer
525+
526+
495527
GGUF_TO_FAST_CONVERTERS = {
496528
"llama": GGUFLlamaConverter,
497529
"qwen2": GGUFQwen2Converter,
498530
"qwen2_moe": GGUFQwen2Converter,
499531
"phi3": GGUFPhi3Converter,
532+
"bloom": GGUFBloomConverter,
500533
}
501534

502535

src/transformers/modeling_gguf_pytorch_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
169169
elif ".attn_k." in name:
170170
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
171171

172+
if architecture == "bloom" and "attn_qkv" in name:
173+
num_heads = parsed_parameters["config"]["n_head"]
174+
n_embed = parsed_parameters["config"]["hidden_size"]
175+
if "weight" in name:
176+
weights = reverse_reshape_weights(weights, num_heads, n_embed)
177+
else:
178+
weights = reverse_reshape_bias(weights, num_heads, n_embed)
179+
172180
for tensor_name in tensor_key_mapping:
173181
if tensor_name in name:
174182
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
@@ -191,3 +199,29 @@ def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Opti
191199
dim = weights.shape[0] // n_head // 2
192200
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
193201
return w.swapaxes(2, 1).reshape(weights.shape)
202+
203+
204+
def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
205+
# Original reshape implementation
206+
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
207+
q, k, v = np.array_split(weights, 3, axis=0)
208+
209+
q = q.reshape(n_head, n_embed // n_head, n_embed)
210+
k = k.reshape(n_head, n_embed // n_head, n_embed)
211+
v = v.reshape(n_head, n_embed // n_head, n_embed)
212+
qkv_weights = np.stack([q, k, v], axis=1)
213+
214+
return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
215+
216+
217+
def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
218+
# Original reshape implementation
219+
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
220+
q_bias, k_bias, v_bias = np.array_split(weights, 3)
221+
222+
q_bias = q_bias.reshape(n_head, n_embed // n_head)
223+
k_bias = k_bias.reshape(n_head, n_embed // n_head)
224+
v_bias = v_bias.reshape(n_head, n_embed // n_head)
225+
226+
qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
227+
return qkv_bias

src/transformers/models/bloom/tokenization_bloom_fast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def __init__(
9999
**kwargs,
100100
):
101101
super().__init__(
102-
vocab_file,
103-
merges_file,
102+
vocab_file=vocab_file,
103+
merges_file=merges_file,
104104
tokenizer_file=tokenizer_file,
105105
unk_token=unk_token,
106106
bos_token=bos_token,

tests/quantization/ggml/test_ggml.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class GgufIntegrationTests(unittest.TestCase):
4242
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
4343
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
4444
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
45+
bloom_model_id = "afrideva/bloom-560m-GGUF"
46+
original_bloom_model_id = "bigscience/bloom-560m"
4547

4648
# standard quants
4749
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
@@ -69,6 +71,8 @@ class GgufIntegrationTests(unittest.TestCase):
6971
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
7072
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
7173
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
74+
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
75+
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
7276
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
7377

7478
example_text = "Hello"
@@ -385,6 +389,62 @@ def test_llama3_q4_0(self):
385389
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
386390
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
387391

392+
def test_bloom_fp16(self):
393+
tokenizer = AutoTokenizer.from_pretrained(self.bloom_model_id, gguf_file=self.fp16_bloom_model_id)
394+
model = AutoModelForCausalLM.from_pretrained(
395+
self.bloom_model_id,
396+
gguf_file=self.fp16_bloom_model_id,
397+
device_map="auto",
398+
torch_dtype=torch.float16,
399+
)
400+
401+
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
402+
out = model.generate(**text, max_new_tokens=10)
403+
404+
EXPECTED_TEXT = "Hello, I just want to say that I am very"
405+
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
406+
407+
def test_bloom_q8_0(self):
408+
tokenizer = AutoTokenizer.from_pretrained(self.bloom_model_id, gguf_file=self.q8_bloom_model_id)
409+
model = AutoModelForCausalLM.from_pretrained(
410+
self.bloom_model_id,
411+
gguf_file=self.q8_bloom_model_id,
412+
device_map="auto",
413+
torch_dtype=torch.float16,
414+
)
415+
416+
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
417+
out = model.generate(**text, max_new_tokens=10)
418+
419+
EXPECTED_TEXT = "Hello, I just want to say that I am very"
420+
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
421+
422+
def test_bloom_weights_conversion_fp16(self):
423+
quantized_model = AutoModelForCausalLM.from_pretrained(
424+
self.bloom_model_id,
425+
gguf_file=self.fp16_bloom_model_id,
426+
device_map="auto",
427+
torch_dtype=torch.float16,
428+
)
429+
original_model = AutoModelForCausalLM.from_pretrained(
430+
self.original_bloom_model_id,
431+
device_map="auto",
432+
torch_dtype=torch.float16,
433+
)
434+
435+
quantized_state_dict = quantized_model.state_dict()
436+
original_state_dict = original_model.state_dict()
437+
438+
for (quantized_name, quantized_param), (original_name, original_param) in zip(
439+
quantized_state_dict.items(), original_state_dict.items()
440+
):
441+
if (
442+
"self_attention.query_key_value" in quantized_name
443+
and "self_attention.query_key_value" in original_name
444+
):
445+
self.assertTrue(quantized_param.shape == original_param.shape)
446+
torch.testing.assert_close(quantized_param, original_param)
447+
388448
def test_tokenization_xnli(self):
389449
import tqdm
390450
from datasets import load_dataset

0 commit comments

Comments
 (0)