-
Notifications
You must be signed in to change notification settings - Fork 796
Migrate CoreMLQuantizer to ET #16473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # Copyright (c) 2024, Apple Inc. All rights reserved. | ||
| # | ||
| # Use of this source code is governed by a BSD-3-clause license that can be | ||
| # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause | ||
|
|
||
| from typing import Optional as _Optional | ||
|
|
||
| import torch as _torch | ||
|
|
||
| from attr import define as _define | ||
|
|
||
| from coremltools.optimize.torch.quantization.quantization_config import ( | ||
| ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig, | ||
| QuantizationScheme as _QuantizationScheme, | ||
| ) | ||
|
|
||
| from torchao.quantization.pt2e.fake_quantize import FakeQuantize as _FakeQuantize | ||
|
|
||
| from torchao.quantization.pt2e.observer import ( | ||
| MinMaxObserver as _MinMaxObserver, | ||
| MovingAverageMinMaxObserver as _MovingAverageMinMaxObserver, | ||
| MovingAveragePerChannelMinMaxObserver as _MovingAveragePerChannelMinMaxObserver, | ||
| PerChannelMinMaxObserver as _PerChannelMinMaxObserver, | ||
| ) | ||
| from torchao.quantization.pt2e.quantizer import ( | ||
| QuantizationSpec as _TorchQuantizationSpec, | ||
| ) | ||
|
|
||
|
|
||
| def _get_observer(observer_type, is_per_channel: bool): | ||
| _str_to_observer_map = { | ||
| "moving_average_min_max": _MovingAverageMinMaxObserver, | ||
| "min_max": _MinMaxObserver, | ||
| "moving_average_min_max_per_channel": _MovingAveragePerChannelMinMaxObserver, | ||
| "min_max_per_channel": _PerChannelMinMaxObserver, | ||
| } | ||
| observer_name = observer_type.value | ||
| if is_per_channel: | ||
| observer_name = f"{observer_name}_per_channel" | ||
| if observer_name not in _str_to_observer_map: | ||
| raise ValueError(f"Unsupported observer type: {observer_name}") | ||
| return _str_to_observer_map[observer_name] | ||
|
|
||
|
|
||
| @_define | ||
| class AnnotationConfig: | ||
| """ | ||
| Module/Operator level configuration class for :py:class:`CoreMLQuantizer`. | ||
|
|
||
| For each module/operator, defines the dtype, quantization scheme and observer type | ||
| for input(s), output and weights (if any). | ||
| """ | ||
|
|
||
| input_activation: _Optional[_TorchQuantizationSpec] = None | ||
| output_activation: _Optional[_TorchQuantizationSpec] = None | ||
| weight: _Optional[_TorchQuantizationSpec] = None | ||
|
|
||
| @staticmethod | ||
| def _normalize_dtype(dtype: _torch.dtype) -> _torch.dtype: | ||
| """ | ||
| PyTorch export quantizer only supports uint8 and int8 data types, | ||
| so we map the quantized dtypes to the corresponding supported dtype. | ||
| """ | ||
| dtype_map = { | ||
| _torch.quint8: _torch.uint8, | ||
| _torch.qint8: _torch.int8, | ||
| } | ||
| return dtype_map.get(dtype, dtype) | ||
|
|
||
| @classmethod | ||
| def from_quantization_config( | ||
| cls, | ||
| quantization_config: _Optional[_ModuleLinearQuantizerConfig], | ||
| ) -> _Optional["AnnotationConfig"]: | ||
| """ | ||
| Creates a :py:class:`AnnotationConfig` from ``ModuleLinearQuantizerConfig`` | ||
| """ | ||
| if ( | ||
| quantization_config is None | ||
| or quantization_config.weight_dtype == _torch.float32 | ||
| ): | ||
| return None | ||
|
|
||
| # Activation QSpec | ||
| if quantization_config.activation_dtype == _torch.float32: | ||
| output_activation_qspec = None | ||
| else: | ||
| activation_qscheme = _QuantizationScheme.get_qscheme( | ||
| quantization_config.quantization_scheme, | ||
| is_per_channel=False, | ||
| ) | ||
| activation_dtype = cls._normalize_dtype( | ||
| quantization_config.activation_dtype | ||
| ) | ||
| output_activation_qspec = _TorchQuantizationSpec( | ||
| observer_or_fake_quant_ctr=_FakeQuantize.with_args( | ||
| observer=_get_observer( | ||
| quantization_config.activation_observer, | ||
| is_per_channel=False, | ||
| ), | ||
| dtype=activation_dtype, | ||
| qscheme=activation_qscheme, | ||
| ), | ||
| dtype=activation_dtype, | ||
| qscheme=activation_qscheme, | ||
| ) | ||
|
|
||
| # Weight QSpec | ||
| weight_qscheme = _QuantizationScheme.get_qscheme( | ||
| quantization_config.quantization_scheme, | ||
| is_per_channel=quantization_config.weight_per_channel, | ||
| ) | ||
| weight_dtype = cls._normalize_dtype(quantization_config.weight_dtype) | ||
| weight_qspec = _TorchQuantizationSpec( | ||
| observer_or_fake_quant_ctr=_FakeQuantize.with_args( | ||
| observer=_get_observer( | ||
| quantization_config.weight_observer, | ||
| is_per_channel=quantization_config.weight_per_channel, | ||
| ), | ||
| dtype=weight_dtype, | ||
| qscheme=weight_qscheme, | ||
| ), | ||
| dtype=weight_dtype, | ||
| qscheme=weight_qscheme, | ||
| ) | ||
| return AnnotationConfig( | ||
| input_activation=output_activation_qspec, | ||
| output_activation=output_activation_qspec, | ||
| weight=weight_qspec, | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better for these to continue to use torch.ao since we are planning to deprecate these in torchao/pt2e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say more on the deprecation plan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait, I remember there might be some incompatibilities of the observer in torchao/pt2e v.s. torch.ao
does the previous torch.ao import work for CoreMLQuantizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't try it, I assumed it wouldn't work. In all other quantizers we migrated to use observers in torchao.ao
Were the observers removed from torch.ao?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think they don't work.
observers are not removed from torch.ao, it's just we'd like to deprecate them together with all the fx / eager flows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just updated my PR: apple/coremltools#2634, it seems that coreml is currently using the same observer for both fx and pt2e flow, and since torchao pt2e uses a different set of observer/fake_quant, we can't make all tests pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any other concerns with this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @jerryzh168