Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add diffllama #34083

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open

Conversation

weak-kajuma
Copy link

What does this PR do?

This PR adds the codes for the DiffLlama, which is Llama model with Differential Transformer. Please refer to Differential Transformer. @ArthurZucker

@weak-kajuma
Copy link
Author

I am coding now, but it's first time I contribute transformers and other OSS. I may ask you some help.

@weak-kajuma
Copy link
Author

I still have a error located in modeling_diffllama.py@377: apply_rotary_pos_emb. Var "query_states" must be torch.Size([2, 32, 10, 128]) but the var is torch.Size([2, 64, 10, 64]). I need to change "query_states" or "cos"&"sin".

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I think this would be an awesome fit to use modular transfomresr!
A bit of doc here: https://huggingface.co/docs/transformers/en/modular_transformers

this would help isolating the changes!

@weak-kajuma
Copy link
Author

I've finished making normal/eager Attention, and I can run with AutoModelforForCausalLM.generate().
But I'll adapt it for FlashAttention2 and Sdpa Attention.

@weak-kajuma
Copy link
Author

And also I fixed to fit modular transfomres.

weak-kajuma and others added 8 commits October 20, 2024 11:52
You don't need to divide by 2 if we use same number of attention heads as llama. instead you can just split in forward.

Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place

Co-authored-by: Minho Ryu <[email protected]>
new codes are more meaningful than before

Co-authored-by: Minho Ryu <[email protected]>
new codes are more meaningful than before

Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place

Co-authored-by: Minho Ryu <[email protected]>
fix 2times divide by sqrt(self.head_dim)

Co-authored-by: Minho Ryu <[email protected]>
fix 2times divide by sqrt(self.head_dim)

Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place.
and more visible

Co-authored-by: Minho Ryu <[email protected]>
Copy link
Contributor

@bzantium bzantium left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented flash and sdpa attention as well.

src/transformers/models/diffllama/modeling_diffllama.py Outdated Show resolved Hide resolved
@weak-kajuma
Copy link
Author

@bzantium
I found Attention missed implemented from paper still on e072544.
So I'll revert to e072544 and re-implement with your suggested code style.

@bzantium
Copy link
Contributor

Could you review this PR?
to: @Cyrilvallez

@bzantium
Copy link
Contributor

could you make all test passed?
to: @weak-kajuma

@bzantium
Copy link
Contributor

I found that you need to place diffllama alphabetically on the src/transformers/__init__.py to pass check_code_quality.
to: @weak-kajuma

@ArthurZucker
Copy link
Collaborator

I think runing make fixup should help you with this!

@ArthurZucker
Copy link
Collaborator

Rebasing / merging from main will fix the other unrelated tests!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Here is a first round of reviews, could you please rebase upon main as some modifications to Llama happened as well? 🤗

src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/diffllama/__init__.py Outdated Show resolved Hide resolved
Comment on lines 488 to 523
def test_rope_class_retrocompatibility(self):
# Delete me when we remove compatibility for the old API :)
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
config.rope_scaling = {"type": "linear", "factor": 10}

# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)

# Old API -- under the hood, "type": "linear" is set and `DiffLlamaRotaryEmbedding` is called
old_api_rope = DiffLlamaLinearScalingRotaryEmbedding(
config.hidden_size // config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
old_cos_short, old_sin_short = old_api_rope(x, position_ids_short)
old_cos_long, old_sin_long = old_api_rope(x, position_ids_long)

# New API
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
new_api_rope = DiffLlamaRotaryEmbedding(config=config).to(torch_device)
new_cos_short, new_sin_short = new_api_rope(x, position_ids_short)
new_cos_long, new_sin_long = new_api_rope(x, position_ids_long)

# The results should match
torch.testing.assert_close(old_cos_short, new_cos_short)
torch.testing.assert_close(old_sin_short, new_sin_short)
torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

Comment on lines 579 to 617
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@require_read_token
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = DiffLlamaForCausalLM.from_pretrained(
"meta-diffllama/DiffLlama-2-7b-hf",
load_in_4bit=True,
device_map={"": 0},
)

tokenizer = DiffLlamaTokenizer.from_pretrained("meta-diffllama/DiffLlama-2-7b-hf")

texts = ["hi", "Hello this is a very long sentence"]

tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)

output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)

model = DiffLlamaForCausalLM.from_pretrained(
"meta-diffllama/DiffLlama-2-7b-hf",
load_in_4bit=True,
device_map={"": 0},
attn_implementation="flash_attention_2",
)

output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)

self.assertListEqual(output_native, output_fa_2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have weights on the hub use them here, otherwise you can remove the test

Comment on lines 647 to 707
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30

tokenizer = DiffLlamaTokenizer.from_pretrained("saibo/diffllama-1B")

model_sdpa = DiffLlamaForCausalLM.from_pretrained(
"saibo/diffllama-1B",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)

self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")

model_eager = DiffLlamaForCausalLM.from_pretrained(
"saibo/diffllama-1B",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)

self.assertTrue(model_eager.config._attn_implementation == "eager")

for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")

has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")

texts = [
"hi here's a longer context, getting longer and",
"Hello this is a very long sentence my friend, very long for real",
"Today I am in Paris and",
]

for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)

res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

with self.subTest(f"{padding_side}"):
torch.testing.assert_close(
res_eager,
res_sdpa,
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Comment on lines 721 to 902
# fmt: off
# Expected mean on dim = -1
EXPECTED_MEAN = {
7: torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]),
8: torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]])
}

self.assertTrue(torch.allclose(EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))

# slicing logits[0, 0, 0:15]
EXPECTED_SLICE = {
7: torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]),
8: torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328])
}
# fmt: on

self.assertTrue(
torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
out.logits[0, 0, :15],
atol=1e-2,
rtol=1e-2,
)
)

@slow
def test_model_7b_dola_generation(self):
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
EXPECTED_TEXT_COMPLETION = (
"Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of "
"physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of "
"relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our "
"understanding of space and time."
)
prompt = "Simply put, the theory of relativity states that "
tokenizer = DiffLlamaTokenizer.from_pretrained("meta-diffllama/DiffLlama-2-7b-chat-hf")
model = DiffLlamaForCausalLM.from_pretrained(
"meta-diffllama/DiffLlama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
)
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# greedy generation outputs
generated_ids = model.generate(
**model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low"
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)

@slow
@require_torch_gpu
@require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")

NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]

prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = DiffLlamaTokenizer.from_pretrained(
"meta-diffllama/DiffLlama-2-7b-hf", pad_token="</s>", padding_side="right"
)
model = DiffLlamaForCausalLM.from_pretrained(
"meta-diffllama/DiffLlama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)

# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)

# Static Cache + compile
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here!

Comment on lines 904 to 916

@slow
@require_torch_accelerator
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)

def setUp(self):
model_name = "TinyDiffLlama/TinyDiffLlama-1.1B-Chat-v1.0"
self.model_dtype = torch.float32
self.tokenizer = DiffLlamaTokenizer.from_pretrained(model_name)
self.model = DiffLlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, if you do not have weights to load, remove the class

@weak-kajuma
Copy link
Author

To pass the test of test_initialization and test_mismatched_shapes_have_properly_initialized_weights, I want to change/add to the code of tests/test_modeking_common.py. But this is common code. Could I change/add to the code like below?

tests/test_modeking_common.py:705

def test_initialization(self):
    config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

    configs_no_init = _config_zero_init(config)
+   configs_no_init.zero_init = True
    for model_class in self.all_model_classes:
        model = model_class(config=configs_no_init)
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.assertIn(
                    ((param.data.mean() * 1e9).round() / 1e9).item(),
                    [0.0, 1.0],
                    msg=f"Parameter {name} of model {model_class} seems not properly initialized",
                )

similarly for test_mismatched_shapes_have_properly_initialized_weights at 3437

@weak-kajuma
Copy link
Author

All tests passed other than tests/utils/test_modeling_utils.py::ModelUtilsTest::test_generation_config_is_loaded_with_model, unrelated to adding this model.

Please review this PR again?
And could you tell me how to fix the error?
to: @Cyrilvallez

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Still some small issues to address but we're very close! 🤗

@@ -0,0 +1,27 @@
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the company here, it is very likely wrong!

Comment on lines 2 to 7
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this is likely wrong

Comment on lines 73 to 77
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to remove everything related to pretraining_tp

Comment on lines 2 to 7
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here check again

Comment on lines 292 to 308
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove everything related to pretraining_tp

self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also

Comment on lines 1568 to 1574
"DiffLlamaRMSNorm",
"DiffLlamaRotaryEmbedding",
"DiffLlamaMLP",
"DiffLlamaAttention",
"DiffLlamaFlashAttention2",
"DiffLlamaSdpaAttention",
"DiffLlamaDecoderLayer",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To remove, we don't want those objects

Comment on lines 320 to 321
# used in `test_torch_compile`
_torch_compile_test_ckpt = "meta-diffllama/DiffLlama-2-7b-hf"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this

@@ -706,6 +706,7 @@ def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

configs_no_init = _config_zero_init(config)
configs_no_init.zero_init = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to remove that, see other comments 😉

Comment on lines 267 to 271
std_dev = 1e-10 if getattr(config, "zero_init", False) else 0.1
self.lambda_q1 = nn.Parameter(torch.normal(0, std_dev, size=(self.head_dim,)))
self.lambda_k1 = nn.Parameter(torch.normal(0, std_dev, size=(self.head_dim,)))
self.lambda_q2 = nn.Parameter(torch.normal(0, std_dev, size=(self.head_dim,)))
self.lambda_k2 = nn.Parameter(torch.normal(0, std_dev, size=(self.head_dim,)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, instead of "hacking" into the config (the attribute is not supposed to exist), let's add a lambda_std_dev attribute in the config, with default value 0.1.
This is cleaner than relying on default value of getattr, and will take care of passing the test_initialization, without modifying test_modeling_common.py.

@Cyrilvallez
Copy link
Member

All tests passed other than tests/utils/test_modeling_utils.py::ModelUtilsTest::test_generation_config_is_loaded_with_model, unrelated to adding this model.

Please review this PR again? And could you tell me how to fix the error? to: @Cyrilvallez

This failing test seems to only be due to a CI internal error (this happens sometimes unfortunately). When it happens, you can push an empty commit to re-trigger the CIs.

@Cyrilvallez
Copy link
Member

To pass the test of test_initialization and test_mismatched_shapes_have_properly_initialized_weights, I want to change/add to the code of tests/test_modeking_common.py. But this is common code. Could I change/add to the code like below?

tests/test_modeking_common.py:705

def test_initialization(self):
    config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

    configs_no_init = _config_zero_init(config)
+   configs_no_init.zero_init = True
    for model_class in self.all_model_classes:
        model = model_class(config=configs_no_init)
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.assertIn(
                    ((param.data.mean() * 1e9).round() / 1e9).item(),
                    [0.0, 1.0],
                    msg=f"Parameter {name} of model {model_class} seems not properly initialized",
                )

similarly for test_mismatched_shapes_have_properly_initialized_weights at 3437

We want to avoid it, see my review for more infos 🤗

@weak-kajuma
Copy link
Author

All of your review implemented. And I tried the test many times, but it didn't pass. What should I do?
To: @Cyrilvallez

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants