From fd6e099a62ed3e5fb9f9861dfb49483fa7d7c255 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Wed, 18 Sep 2024 14:38:06 +0200 Subject: [PATCH 1/8] fix tests with main revision and read token --- tests/models/mamba2/test_modeling_mamba2.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 276ecf2fd6b0fb..eaa1526862b521 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -20,7 +20,7 @@ from parameterized import parameterized from transformers import AutoTokenizer, Mamba2Config, is_torch_available -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -96,7 +96,7 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings def get_large_model_config(self): - return Mamba2Config.from_pretrained("revision='refs/pr/9'") + return Mamba2Config.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1") def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False @@ -292,12 +292,11 @@ def test_inputs_embeds_matches_input_ids_with_generate(self): @require_torch @slow +@require_read_token class Mamba2IntegrationTest(unittest.TestCase): def setUp(self): self.model_id = "mistralai/Mamba-Codestral-7B-v0.1" - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_id, revision="refs/pr/9", from_slow=True, legacy=False - ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False) self.prompt = ("[INST]Write a hello world program in C++.",) @parameterized.expand( @@ -317,7 +316,7 @@ def test_simple_generate(self, device): tokenizer = self.tokenizer tokenizer.pad_token_id = tokenizer.eos_token_id - model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16) + model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16) model.to(device) input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to( device @@ -343,9 +342,7 @@ def test_batched_equivalence_with_cache(self): "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", ] - model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to( - torch_device - ) + model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) tokenizer.pad_token_id = tokenizer.eos_token_id # batched generation tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device) @@ -375,9 +372,7 @@ def test_batched_equivalence_without_cache(self): "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", ] - model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to( - torch_device - ) + model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) tokenizer.pad_token_id = tokenizer.eos_token_id # batched generation tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device) From 72deed0b06c4efb7bb6d12081201e0dee66e041f Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Wed, 18 Sep 2024 14:39:43 +0200 Subject: [PATCH 2/8] [run-slow]mamba2 From 39fc615a4f06294e78619e33a40de35e10f7541c Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Wed, 18 Sep 2024 15:03:22 +0200 Subject: [PATCH 3/8] test previously skipped tests --- tests/models/mamba2/test_modeling_mamba2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index eaa1526862b521..441671d0c47330 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -199,6 +199,7 @@ def test_initialization(self): def test_tied_weights_keys(self): pass + """ @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") def test_beam_search_generate_dict_outputs_use_cache(self): pass @@ -226,6 +227,7 @@ def test_multi_gpu_data_parallel_forward(self): @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") def test_generate_from_inputs_embeds_decoder_only(self): pass + """ def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 78298a0e5dda1f47aa6888c95a4a48888d813afb Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Wed, 18 Sep 2024 15:03:39 +0200 Subject: [PATCH 4/8] [run-slow]mamba2 From f8cca0b416c2e392069b9001203e70063dbf6d12 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Wed, 18 Sep 2024 15:11:58 +0200 Subject: [PATCH 5/8] skip some tests --- tests/models/mamba2/test_modeling_mamba2.py | 24 +++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 441671d0c47330..128caa2bdc955d 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -199,23 +199,27 @@ def test_initialization(self): def test_tied_weights_keys(self): pass - """ - @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") - def test_beam_search_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") - def test_beam_sample_generate(self): + @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") + def test_generate_without_input_ids(self): pass @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - def test_generate_without_input_ids(self): + def test_generate_from_inputs_embeds_decoder_only(self): pass @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") def test_greedy_generate_dict_outputs_use_cache(self): pass + @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + """ + @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") + def test_beam_sample_generate(self): + pass + @unittest.skip(reason="Initialization of mamba2 fails this") def test_save_load_fast_init_from_base(self): pass @@ -224,9 +228,7 @@ def test_save_load_fast_init_from_base(self): def test_multi_gpu_data_parallel_forward(self): pass - @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - def test_generate_from_inputs_embeds_decoder_only(self): - pass + """ def test_model_outputs_equivalence(self): From 2ec7b14d7b761c444f17460680e5ffa966ed6902 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Wed, 18 Sep 2024 15:12:44 +0200 Subject: [PATCH 6/8] [run-slow]mamba2 From 149ec87a90e8fd1f9cacb4cefc29133f8bf6c526 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Wed, 18 Sep 2024 15:18:19 +0200 Subject: [PATCH 7/8] finalize tests --- tests/models/mamba2/test_modeling_mamba2.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 128caa2bdc955d..a1e2138d4d6d78 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -215,22 +215,10 @@ def test_greedy_generate_dict_outputs_use_cache(self): def test_beam_search_generate_dict_outputs_use_cache(self): pass - """ - @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") - def test_beam_sample_generate(self): - pass - - @unittest.skip(reason="Initialization of mamba2 fails this") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") def test_multi_gpu_data_parallel_forward(self): pass - - """ - def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 949da602e62b38450dc4ac8fe2914e63e77d8565 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Wed, 18 Sep 2024 15:18:31 +0200 Subject: [PATCH 8/8] [run-slow]mamba2