diff --git a/ai_edge_torch/generative/examples/test_models/toy_model.py b/ai_edge_torch/generative/examples/test_models/toy_model.py index c0487fb0..a226ed54 100644 --- a/ai_edge_torch/generative/examples/test_models/toy_model.py +++ b/ai_edge_torch/generative/examples/test_models/toy_model.py @@ -71,6 +71,56 @@ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: return self.lm_head(x) +class ToySingleLayerModelWeightSharing(torch.nn.Module): + + def __init__(self, config: cfg.ModelConfig) -> None: + super().__init__() + self.lm_head = nn.Linear( + config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias + ) + self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim) + self.lm_head = nn.Linear( + config.embedding_dim, + config.vocab_size, + bias=config.lm_head_use_bias, + ) + self.lm_head.weight.data = self.tok_embedding.weight.data + self.transformer_block = TransformerBlock(config) + self.final_norm = builder.build_norm( + config.embedding_dim, + config.final_norm_config, + ) + self.rope_cache = attn_utils.build_rope_cache( + size=config.max_seq_len, + dim=int( + config.attn_config.rotary_percentage * config.attn_config.head_dim + ), + base=10_000, + condense_ratio=1, + dtype=torch.float32, + device=torch.device('cpu'), + ) + self.mask_cache = attn_utils.build_causal_mask_cache( + size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu') + ) + self.config = config + + @torch.inference_mode + def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + x = self.tok_embedding(idx) + cos, sin = self.rope_cache + + cos = cos.index_select(0, input_pos) + sin = sin.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, : self.config.max_seq_len] + + x = self.transformer_block(x, (cos, sin), mask, input_pos) + x = self.final_norm(x) + res = self.lm_head(x) + return res + + def get_model_config() -> cfg.ModelConfig: attn_config = cfg.AttentionConfig( num_heads=32, diff --git a/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py b/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py index 8a08bdfd..e703d619 100644 --- a/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +++ b/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py @@ -17,7 +17,8 @@ from ai_edge_torch.generative.quantize import quant_attrs from ai_edge_torch.generative.quantize import quant_recipe -_OpExecutionMode = quantizer.qtyping.OpExecutionMode +_ComputePrecision = quantizer.qtyping.ComputePrecision +_QuantGranularity = quantizer.qtyping.QuantGranularity _OpName = quantizer.qtyping.TFLOperationName _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig @@ -50,21 +51,31 @@ def _get_dtype_from_dtype( return quantizer.qtyping.TensorDataType.INT -def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode: +def _get_compute_precision_from_mode( + mode: quant_attrs.Mode, +) -> _ComputePrecision: if mode == quant_attrs.Mode.DYNAMIC_RANGE: - return _OpExecutionMode.DRQ + return _ComputePrecision.INTEGER elif mode == quant_attrs.Mode.WEIGHT_ONLY: - return _OpExecutionMode.WEIGHT_ONLY + return _ComputePrecision.FLOAT raise ValueError('Unimplemented execution mode') -def _get_channelwise_from_granularity( +def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool: + if mode == quant_attrs.Mode.DYNAMIC_RANGE: + return False + elif mode == quant_attrs.Mode.WEIGHT_ONLY: + return True + raise ValueError('Unimplemented execution mode') + + +def _get_granularity( granularity: quant_attrs.Granularity, ) -> bool: if granularity == quant_attrs.Granularity.CHANNELWISE: - return True - elif granularity == quant_attrs.Granularity.NONE: - return False + return _QuantGranularity.CHANNELWISE + if granularity == quant_attrs.Granularity.NONE: + return _QuantGranularity.TENSORWISE raise ValueError('Unimplemented granularity') @@ -88,12 +99,13 @@ def _set_quant_config( weight_tensor_config=_TensorQuantConfig( num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype), symmetric=True, - channel_wise=_get_channelwise_from_granularity( - layer_recipe.granularity - ), + granularity=_get_granularity(layer_recipe.granularity), dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype), ), - execution_mode=_get_execution_mode_from_mode(layer_recipe.mode), + compute_precision=_get_compute_precision_from_mode(layer_recipe.mode), + explicit_dequantize=_get_explicit_dequant_from_mode( + layer_recipe.mode + ), ), algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm), ) diff --git a/ai_edge_torch/generative/test/test_quantize.py b/ai_edge_torch/generative/test/test_quantize.py index 73c250ca..55768fb1 100644 --- a/ai_edge_torch/generative/test/test_quantize.py +++ b/ai_edge_torch/generative/test/test_quantize.py @@ -14,6 +14,7 @@ # ============================================================================== import ai_edge_torch +from absl.testing import parameterized from ai_edge_torch import config from ai_edge_torch.generative.examples.test_models import toy_model # NOQA from ai_edge_torch.generative.quantize import quant_recipe @@ -25,16 +26,15 @@ from ai_edge_torch.generative.quantize.quant_attrs import Mode from ai_edge_torch.quantize import quant_config from ai_edge_torch.testing import model_coverage -from parameterized import parameterized import torch from absl.testing import absltest as googletest -class TestVerifyRecipes(googletest.TestCase): +class TestVerifyRecipes(parameterized.TestCase): """Unit tests that check for model quantization recipes.""" - @parameterized.expand([ + @parameterized.parameters([ (Dtype.FP32, Dtype.FP32), (Dtype.INT8, Dtype.INT8), (Dtype.INT8, Dtype.FP16), @@ -52,7 +52,7 @@ def test_verify_invalid_recipes( with self.assertRaises(ValueError): quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify() - @parameterized.expand([ + @parameterized.parameters([ ( Dtype.FP32, Dtype.INT8, @@ -88,7 +88,7 @@ def test_verify_valid_recipes( ).verify() -class TestQuantizeConvert(googletest.TestCase): +class TestQuantizeConvert(parameterized.TestCase): """Test conversion with quantization.""" def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig: @@ -105,17 +105,13 @@ def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig: ) ) - @parameterized.expand([ + @parameterized.parameters([ (quant_recipes.full_fp16_recipe()), (quant_recipes.full_int8_dynamic_recipe()), (quant_recipes.full_int8_weight_only_recipe()), (_attention_int8_dynamic_recipe()), (_feedforward_int8_dynamic_recipe()), ]) - @googletest.skipIf( - not config.Config.use_torch_xla, - reason="Not working with odml_torch at the moment.", - ) def test_quantize_convert_toy_sizes(self, quant_config): config = toy_model.get_model_config() pytorch_model = toy_model.ToySingleLayerModel(config) @@ -132,6 +128,23 @@ def test_quantize_convert_toy_sizes(self, quant_config): "Quantized model isn't smaller than F32 model.", ) + def test_quantize_convert_toy_weight_sharing(self): + config = toy_model.get_model_config() + pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config) + idx = torch.unsqueeze(torch.arange(0, 100), 0) + input_pos = torch.arange(0, 100) + + quant_config = quant_recipes.full_int8_dynamic_recipe() + quantized_model = ai_edge_torch.convert( + pytorch_model, (idx, input_pos), quant_config=quant_config + ) + float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos)) + self.assertLess( + len(quantized_model._tflite_model), + len(float_model._tflite_model), + "Quantized model isn't smaller than F32 model.", + ) + def test_quantize_convert_compare_toy(self): self.skipTest("b/338288901") config = toy_model_with_kv_cache.get_model_config() diff --git a/ai_edge_torch/lowertools/odml_torch_utils.py b/ai_edge_torch/lowertools/odml_torch_utils.py index c634d022..d031f70d 100644 --- a/ai_edge_torch/lowertools/odml_torch_utils.py +++ b/ai_edge_torch/lowertools/odml_torch_utils.py @@ -29,6 +29,7 @@ from tensorflow.compiler.tf2xla.python import xla as tfxla from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb +from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA MlirBundle = odml_torch.export.MlirLowered @@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model( converter._experimental_enable_composite_direct_lowering = True converter.model_origin_framework = "PYTORCH" + conversion_utils.set_tfl_converter_quant_flags(converter, quant_config) + if ( + quant_config is not None + and quant_config._quantizer_mode + == quant_config._QuantizerMode.AI_EDGE_QUANTIZER + ): + translated_recipe = translate_recipe.translate_to_ai_edge_recipe( + quant_config.generative_recipe + ) + conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags) tflite_model = converter.convert() + if ( + quant_config is not None + and quant_config._quantizer_mode + == quant_config._QuantizerMode.AI_EDGE_QUANTIZER + ): + tflite_model = translate_recipe.quantize_model( + tflite_model, translated_recipe + ) + return tflite_model diff --git a/odmltorch-requirements.txt b/odmltorch-requirements.txt index 93ca4f65..003b3445 100644 --- a/odmltorch-requirements.txt +++ b/odmltorch-requirements.txt @@ -7,7 +7,7 @@ torchaudio==2.4.0+cpu --pre tf-nightly>=2.18.0.dev20240722 torch_xla2[odml]>=0.0.1.dev20240801 -ai-edge-quantizer-nightly==0.0.1.dev20240718 +ai-edge-quantizer-nightly jax[cpu] scipy numpy diff --git a/requirements.txt b/requirements.txt index 021e6913..7478afc7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ torchaudio==2.4.0+cpu torch_xla==2.4.0 --pre tf-nightly>=2.18.0.dev20240722 -ai-edge-quantizer-nightly==0.0.1.dev20240718 +ai-edge-quantizer-nightly scipy numpy tabulate diff --git a/setup.py b/setup.py index 1d65ca99..4f2d380a 100644 --- a/setup.py +++ b/setup.py @@ -88,6 +88,6 @@ "torch>=2.4.0", "torch_xla>=2.4.0", "tf-nightly>=2.18.0.dev20240722", - "ai-edge-quantizer-nightly==0.0.1.dev20240718", + "ai-edge-quantizer-nightly", ], )