From 34c632a79337f0a38142f86bc1b21e2c66c1fb79 Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Fri, 20 Dec 2024 13:21:49 +0700 Subject: [PATCH 1/5] Update modeling_qwen2.py --- src/transformers/models/qwen2/modeling_qwen2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 36fb1ddf1390ac..6ad8cf33ed0911 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -331,6 +331,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" From 38fc180047a5f203360864d23f7fd5ed5b3382e0 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sat, 21 Dec 2024 13:39:35 +0700 Subject: [PATCH 2/5] fix in modeling_llama --- src/transformers/models/llama/modeling_llama.py | 3 +++ src/transformers/models/qwen2/modeling_qwen2.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5be33c26414cd7..361b62da1f9a76 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -128,6 +128,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 6ad8cf33ed0911..996aa3cdd69726 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -333,7 +333,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) - + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" From a81fc34eba6f4261166a4ffa834a13c13783d706 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sat, 21 Dec 2024 13:44:51 +0700 Subject: [PATCH 3/5] quality fix --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 361b62da1f9a76..2dce32cc97268e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -130,7 +130,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) - + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" From 1949712799cb4d947335b93942a2c2774f875eff Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sat, 21 Dec 2024 13:48:50 +0700 Subject: [PATCH 4/5] fix consistency --- src/transformers/models/aria/modeling_aria.py | 3 +++ src/transformers/models/bamba/modeling_bamba.py | 3 +++ src/transformers/models/falcon/modeling_falcon.py | 3 +++ src/transformers/models/gemma/modeling_gemma.py | 3 +++ src/transformers/models/gemma2/modeling_gemma2.py | 3 +++ src/transformers/models/glm/modeling_glm.py | 3 +++ src/transformers/models/gpt_neox/modeling_gpt_neox.py | 3 +++ src/transformers/models/granite/modeling_granite.py | 3 +++ src/transformers/models/mistral/modeling_mistral.py | 3 +++ src/transformers/models/mixtral/modeling_mixtral.py | 3 +++ src/transformers/models/nemotron/modeling_nemotron.py | 3 +++ src/transformers/models/olmo/modeling_olmo.py | 3 +++ src/transformers/models/olmo2/modeling_olmo2.py | 3 +++ src/transformers/models/olmoe/modeling_olmoe.py | 3 +++ src/transformers/models/persimmon/modeling_persimmon.py | 3 +++ src/transformers/models/phi/modeling_phi.py | 3 +++ src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 3 +++ src/transformers/models/stablelm/modeling_stablelm.py | 3 +++ src/transformers/models/starcoder2/modeling_starcoder2.py | 3 +++ 19 files changed, 57 insertions(+) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 6481d6f3c434c7..7a99d88bdce2cf 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -771,6 +771,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index c89d8d7853008d..3c67531a6311e1 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -168,6 +168,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 8d5a224f4f6654..81c8947cb443d5 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -159,6 +159,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e2ea12b03fe434..5ea4a56d852d01 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -140,6 +140,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 67fc6c86a3bac6..c2707e09127f17 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -372,6 +372,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 95ad0d9719951d..557e4e7145eddf 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -304,6 +304,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7152d72f5b7fc8..680ca5fcac5e56 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -536,6 +536,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 2e045e149d95de..9b0514044305f7 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -357,6 +357,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 90c38895b4280b..657b76c51f6c1b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -318,6 +318,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 84ed327d9be920..f9fa2551aca3bf 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -440,6 +440,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index a0a10bdc6f3550..65d7cf92949a2e 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -131,6 +131,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 11d3d99f4f72c9..11506dd255757e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -322,6 +322,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 49ae798e7f1101..efe59b6ca62925 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -323,6 +323,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index fa3c2f3cd4d11b..065a3aba072e88 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -206,6 +206,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8d3c20b9ace717..8ca5802c488d94 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -105,6 +105,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 477896decd5318..da0be698d4237d 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -318,6 +318,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 1ce41509a5c0d1..1b1e2bef376ac2 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -215,6 +215,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 88dc437cdcb91d..3122ebd8d6681b 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -111,6 +111,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3b4fdbcb81ccc4..907dbea0feb073 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -322,6 +322,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" From 39bab933f678b7776b0860b7118ffe32629ef209 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sat, 21 Dec 2024 13:52:09 +0700 Subject: [PATCH 5/5] fix consistency --- .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 3 +++ src/transformers/models/granitemoe/modeling_granitemoe.py | 3 +++ src/transformers/models/jetmoe/modeling_jetmoe.py | 3 +++ src/transformers/models/mimi/modeling_mimi.py | 3 +++ src/transformers/models/moshi/modeling_moshi.py | 3 +++ 5 files changed, 15 insertions(+) diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 71602f01e7d6f8..00605ff8109b86 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -273,6 +273,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 1c4c06bbc8d71e..4014a65d13c14b 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -206,6 +206,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 7b7fd5a90d69ed..1c6d019d2b4f9f 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -434,6 +434,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 1440ce1e075c95..86c623b9a59c4b 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -413,6 +413,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index f0281f57cf1c75..3f68b3c2a30da5 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -356,6 +356,9 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() + + inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"