diff --git a/mediapipe/tasks/python/genai/converter/safetensors_converter.py b/mediapipe/tasks/python/genai/converter/safetensors_converter.py index 9f5b4f0503..6304dc26a2 100644 --- a/mediapipe/tasks/python/genai/converter/safetensors_converter.py +++ b/mediapipe/tasks/python/genai/converter/safetensors_converter.py @@ -392,6 +392,7 @@ def __init__( backend: str, reader: _SafetensorsReader, is_v2: bool, + is_nested: bool = False, ): super().__init__( is_symmetric=is_symmetric, @@ -402,6 +403,7 @@ def __init__( ) self._reader = reader self._is_v2 = is_v2 + self._is_gemma3n = is_nested def map_to_actions( self, layer_name: str @@ -458,7 +460,8 @@ def update_target_name(self, target_name: str) -> str: """Updates the target name to match the tensor name convention.""" # For removing multimodality stack from Gemma3-4B - target_name = target_name.replace("language_model.", "") + if self._is_nested: + target_name = target_name.replace("language_model.", "") target_name = target_name.replace("base_model.model.", "") target_name = target_name.replace( @@ -609,7 +612,16 @@ def __init__( "GEMMA3_12B", "GEMMA3_27B", "GEMMA3_300M", + "GEMMA3N_2B", + "GEMMA3N_4B", + "GEMMA3N_8B", + "GEMMA_3N_E2B_IT", + "GEMMA_3N_E4B_IT", ]: + # Identify all models that have the nested 'language_model.' prefix + nested_models = ["GEMMA3-4B"] + [m for m in special_model if "3N" in m.upper()] + is_nested_model = special_model in nested_models + self.mapper = GemmaMapper( is_symmetric, attention_quant_bits, @@ -617,7 +629,8 @@ def __init__( embedding_quant_bits, backend, self._reader, - False if special_model in ["GEMMA_2B", "GEMMA_7B"] else True, + is_v2=(special_model not in ["GEMMA_2B", "GEMMA_7B"]), + is_nested=is_nested_model # <-- Pass the corrected flag ) else: raise ValueError(f"Unknown special model: {special_model}") diff --git a/mediapipe/tasks/python/genai/converter/safetensors_converter_test.py b/mediapipe/tasks/python/genai/converter/safetensors_converter_test.py index 0ea5f11c60..3c2fe0c992 100644 --- a/mediapipe/tasks/python/genai/converter/safetensors_converter_test.py +++ b/mediapipe/tasks/python/genai/converter/safetensors_converter_test.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. + """Unit tests for safetensors_converter.""" import os +from unittest import mock from absl.testing import absltest from absl.testing import parameterized +import numpy as np from mediapipe.tasks.python.genai.converter import safetensors_converter from mediapipe.tasks.python.test import test_utils @@ -78,6 +81,60 @@ def test_load_to_actions(self, quant_bits): actions = loader.load_to_actions() self.assertLen(list(actions), 15) + @parameterized.named_parameters( + ('gemma_3n_nested', 'GEMMA3N_4B'), + ('gemma_3_4b_nested', 'GEMMA3-4B'), + ) + @mock.patch.object(safetensors_converter, '_SafetensorsReader') + def testNestedGemmaConversion(self, model_name, MockReader): + """Tests that nested Gemma models have their prefixes stripped.""" + mock_reader_instance = MockReader.return_value + gemma_nested_variable_names = [ + # Standard language model layers with the 'language_model.' prefix + 'language_model.model.embed_tokens.weight', + 'language_model.model.layers.0.input_layernorm.weight', + 'language_model.model.layers.0.mlp.down_proj.weight', + 'language_model.model.layers.0.self_attn.o_proj.weight', + 'language_model.model.norm.weight', + # Vision tower layers that should be skipped + 'vision_tower.vision_tower.encoder.layers.0.blocks.0.attn.qkv.weight', + 'multi_modal_projector.linear_1.weight', + ] + mock_reader_instance.get_tensor_names.return_value = gemma_nested_variable_names + mock_reader_instance.read_tensor_as_numpy.return_value = np.zeros( + (1, 1), dtype=np.float32 + ) + + loader = safetensors_converter.SafetensorsCkptLoader( + ckpt_path='/fake/path', + is_symmetric=True, + attention_quant_bits=8, + feedforward_quant_bits=8, + embedding_quant_bits=8, + special_model=model_name, # Use the parameterized model name + backend='gpu', + ) + actions_list = list(loader.load_to_actions()) + + # Check that the vision layers were skipped, and only 5 actions were created + self.assertLen(actions_list, 5) + + # Check that the 'language_model.' prefix was correctly removed + target_names = [actions[0].target_name for actions in actions_list] + self.assertIn( + 'params.lm.softmax.logits_ffn.w', target_names + ) + self.assertIn( + 'params.lm.transformer.x_layers_0.pre_layer_norm.scale', target_names + ) + self.assertIn( + 'params.lm.transformer.x_layers_0.ff_layer.ffn_layer2.w', target_names + ) + self.assertIn( + 'params.lm.transformer.x_layers_0.self_attention.post.w', target_names + ) + self.assertIn('params.lm.final_ln.scale', target_names) + if __name__ == '__main__': absltest.main()