Skip to content

Commit 8e2ca35

Browse files
authored
Unify get_block_size (#3039)
* Unify get_block_size * Remove granularity defines in the pt2e path * Fix format
1 parent eadead5 commit 8e2ca35

File tree

11 files changed

+50
-188
lines changed

11 files changed

+50
-188
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2948,10 +2948,11 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
29482948
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
29492949
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
29502950
def test_channel_group_quantization(self):
2951+
from torchao.quantization import PerGroup, PerToken
29512952
from torchao.quantization.pt2e._affine_quantization import (
29522953
AffineQuantizedMinMaxObserver,
29532954
)
2954-
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
2955+
from torchao.quantization.pt2e.observer import MappingType
29552956

29562957
class BackendAQuantizer(Quantizer):
29572958
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
@@ -3031,13 +3032,13 @@ def forward(self, x):
30313032
def test_dynamic_affine_act_per_channel_weights(self):
30323033
import operator
30333034

3035+
from torchao.quantization import PerToken
30343036
from torchao.quantization.pt2e._affine_quantization import (
30353037
AffineQuantizedMovingAverageMinMaxObserver,
30363038
)
30373039
from torchao.quantization.pt2e.observer import (
30383040
MappingType,
30393041
PerChannelMinMaxObserver,
3040-
PerToken,
30413042
)
30423043

30433044
class BackendAQuantizer(Quantizer):
@@ -3122,12 +3123,14 @@ def forward(self, x):
31223123
def test_dynamic_per_tok_act_per_group_weights(self):
31233124
import operator
31243125

3126+
from torchao.quantization import PerGroup, PerToken
3127+
31253128
# TODO: merge into torchao observer
31263129
from torchao.quantization.pt2e._affine_quantization import (
31273130
AffineQuantizedMinMaxObserver,
31283131
AffineQuantizedPlaceholderObserver,
31293132
)
3130-
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
3133+
from torchao.quantization.pt2e.observer import MappingType
31313134

31323135
class BackendAQuantizer(Quantizer):
31333136
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
MultiTensorInputRecorder,
2020
)
2121
from .granularity import (
22+
Granularity,
2223
PerAxis,
2324
PerGroup,
2425
PerRow,
@@ -197,6 +198,7 @@
197198
"MappingType",
198199
"ZeroPointDomain",
199200
"TorchAODType",
201+
"Granularity",
200202
"PerTensor",
201203
"PerAxis",
202204
"PerGroup",

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def _same_metadata(
133133

134134
@implements([torch.nn.functional.linear, aten.linear.default])
135135
def _(func, types, args, kwargs):
136-
137136
input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None)
138137
weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None)
139138
bias = kwargs.get("bias", args[2] if len(args) > 2 else None)

torchao/quantization/observer.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from .granularity import (
1616
Granularity,
17-
PerAxis,
1817
PerRow,
1918
PerTensor,
2019
)
@@ -24,6 +23,7 @@
2423
_get_reduction_params,
2524
choose_qparams_affine_with_min_max,
2625
)
26+
from .utils import get_block_size
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -63,26 +63,6 @@ def _with_args(cls_or_self, *args, **kwargs):
6363
return r
6464

6565

66-
def get_block_size(
67-
input_shape: Tuple[int, ...], granularity: Granularity
68-
) -> Tuple[int, ...]:
69-
"""Get the block size based on the input shape and granularity type.
70-
71-
Args:
72-
input_shape: The input tensor shape possibly more than 2 dimensions
73-
granularity: The granularity type of the quantization
74-
"""
75-
if isinstance(granularity, PerTensor):
76-
return input_shape
77-
elif isinstance(granularity, PerAxis):
78-
block_size = list(input_shape)
79-
block_size[granularity.axis] = 1
80-
return tuple(block_size)
81-
elif isinstance(granularity, PerRow):
82-
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
83-
raise ValueError(f"Unsupported Granularity: {granularity}")
84-
85-
8666
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
8767

8868

torchao/quantization/pt2e/__init__.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from .observer import (
4949
AffineQuantizedObserverBase,
5050
FixedQParamsObserver,
51-
Granularity,
5251
HistogramObserver,
5352
MappingType,
5453
MinMaxObserver,
@@ -57,20 +56,13 @@
5756
NoopObserver,
5857
ObserverBase,
5958
PartialWrapper,
60-
PerAxis,
61-
PerBlock,
6259
PerChannelMinMaxObserver,
63-
PerGroup,
64-
PerRow,
65-
PerTensor,
66-
PerToken,
6760
PlaceholderObserver,
6861
RecordingObserver,
6962
ReuseInputObserver,
7063
TorchAODType,
7164
UniformQuantizationObserverBase,
7265
ZeroPointDomain,
73-
get_block_size,
7466
)
7567

7668
for _f in [
@@ -139,17 +131,9 @@
139131
"compare_results",
140132
# should be merged with torchao/quantization/observer.py in the future
141133
"AffineQuantizedObserverBase",
142-
"Granularity",
143134
"MappingType",
144-
"PerAxis",
145-
"PerBlock",
146-
"PerGroup",
147-
"PerRow",
148-
"PerTensor",
149-
"PerToken",
150135
"TorchAODType",
151136
"ZeroPointDomain",
152-
"get_block_size",
153137
"default_fake_quant",
154138
"default_dynamic_fake_quant",
155139
]

torchao/quantization/pt2e/_affine_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
MappingType,
2020
TorchAODType,
2121
ZeroPointDomain,
22-
get_block_size,
2322
)
23+
from torchao.quantization.utils import get_block_size
2424

2525
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
2626

torchao/quantization/pt2e/observer.py

Lines changed: 1 addition & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch.fx import Node
2828

2929
import torchao
30+
from torchao.quantization import Granularity
3031
from torchao.quantization.pt2e.utils import (
3132
calculate_qmin_qmax,
3233
check_min_max_valid,
@@ -67,17 +68,9 @@
6768
"ReuseInputObserver",
6869
"UniformQuantizationObserverBase",
6970
"AffineQuantizedObserverBase",
70-
"Granularity",
7171
"MappingType",
72-
"PerAxis",
73-
"PerBlock",
74-
"PerGroup",
75-
"PerRow",
76-
"PerTensor",
77-
"PerToken",
7872
"TorchAODType",
7973
"ZeroPointDomain",
80-
"get_block_size",
8174
]
8275

8376

@@ -1622,7 +1615,6 @@ def calculate_qparams(self):
16221615
We plan to merge the following with torchao repo after we move pt2e flow to torchao
16231616
copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
16241617
"""
1625-
from dataclasses import dataclass
16261618
from enum import Enum, auto
16271619

16281620

@@ -1679,139 +1671,6 @@ class TorchAODType(Enum):
16791671
INT7 = auto()
16801672

16811673

1682-
@dataclass(frozen=True)
1683-
class Granularity:
1684-
"""
1685-
Base class for representing the granularity of quantization.
1686-
1687-
This class serves as a parent for specific granularity types used in
1688-
quantization operations, such as per-tensor or per-axis quantization.
1689-
"""
1690-
1691-
1692-
@dataclass(frozen=True)
1693-
class PerBlock(Granularity):
1694-
"""
1695-
Represents per-block granularity in quantization. See
1696-
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
1697-
`block_size`
1698-
1699-
Attributes:
1700-
block_size (Tuple[int, ...]): The size of each quantization group
1701-
"""
1702-
1703-
block_size: tuple[int, ...]
1704-
1705-
1706-
@dataclass(frozen=True)
1707-
class PerTensor(Granularity):
1708-
"""
1709-
Represents per-tensor granularity in quantization.
1710-
1711-
This granularity type calculates the quantization parameters
1712-
based off the entire tensor.
1713-
1714-
"""
1715-
1716-
1717-
@dataclass(frozen=True)
1718-
class PerAxis(Granularity):
1719-
"""
1720-
Represents per-axis granularity in quantization.
1721-
1722-
This granularity type calculates different quantization parameters
1723-
along a specified axis of the tensor.
1724-
1725-
For example if the input tensor is shape [8, 16] and axis=0, then
1726-
the quantization parameters are calculated for each row of the tensor.
1727-
Giving a total of 8 quantization parameters.
1728-
1729-
Attributes:
1730-
axis (int): The axis along which reduction is performed.
1731-
"""
1732-
1733-
axis: int
1734-
1735-
1736-
@dataclass(frozen=True)
1737-
class PerGroup(Granularity):
1738-
"""
1739-
Represents per-channel group granularity in quantization.
1740-
1741-
This granularity type calculates different quantization parameters
1742-
for each group of <group_size> elements.
1743-
1744-
For example if the input tensor is shape [8, 16], and the group size is 4, then
1745-
the input tensor is reshaped to [64, 4]
1746-
quantization parameters are calculated for each group of 4 elements,
1747-
giving a total of 64 quantization parameters.
1748-
1749-
Attributes:
1750-
group_size (int): The size of each quantization group
1751-
1752-
"""
1753-
1754-
group_size: int
1755-
1756-
1757-
class PerRow(Granularity):
1758-
"""
1759-
Represents row-wise granularity in quantization.
1760-
1761-
This is a special case of per-axis quantization and is unique to Float8 matmuls
1762-
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
1763-
is quantized with a block_size of (1, weight.shape[1]).
1764-
"""
1765-
1766-
1767-
class PerToken(Granularity):
1768-
"""
1769-
Represents per-token granularity in quantization.
1770-
1771-
This granularity type calculates a different set of quantization parameters
1772-
for each token, which is represented as the last dimension of the tensor.
1773-
1774-
For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
1775-
with 4 elements each, and we will calculate 6 sets of quantization parameters,
1776-
one for each token.
1777-
1778-
If the input tensor has only two dimensions, e.g. [8, 16], then this is
1779-
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
1780-
"""
1781-
1782-
1783-
def get_block_size(
1784-
input_shape: tuple[int, ...], granularity: Granularity
1785-
) -> tuple[int, ...]:
1786-
"""Get the block size based on the input shape and granularity type.
1787-
1788-
Args:
1789-
input_shape: The input tensor shape possibly more than 2 dimensions
1790-
granularity: The granularity type of the quantization
1791-
"""
1792-
assert isinstance(granularity, Granularity), (
1793-
"Please provide an instance of Granularity, not subclass of it"
1794-
)
1795-
if isinstance(granularity, PerTensor):
1796-
return input_shape
1797-
elif isinstance(granularity, PerAxis):
1798-
block_size = list(input_shape)
1799-
block_size[granularity.axis] = 1
1800-
return tuple(block_size)
1801-
elif isinstance(granularity, PerRow):
1802-
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
1803-
elif isinstance(granularity, PerGroup):
1804-
assert len(input_shape) == 2, (
1805-
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
1806-
)
1807-
return (1, granularity.group_size)
1808-
elif isinstance(granularity, PerToken):
1809-
block_size = [1] * len(input_shape)
1810-
block_size[-1] = input_shape[-1]
1811-
return tuple(block_size)
1812-
raise ValueError(f"Unsupported Granularity: {granularity}")
1813-
1814-
18151674
class AffineQuantizedObserverBase(ABC, torch.nn.Module):
18161675
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
18171676

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
PerRow,
1515
PerToken,
1616
)
17-
from torchao.quantization.observer import get_block_size
1817
from torchao.quantization.quant_primitives import (
1918
_DTYPE_TO_BIT_WIDTH,
2019
_DTYPE_TO_QVALUE_BOUNDS,
@@ -28,6 +27,7 @@
2827
)
2928
from torchao.quantization.utils import (
3029
_get_per_token_block_size,
30+
get_block_size,
3131
get_group_qparams_symmetric,
3232
get_groupwise_affine_qparams,
3333
)

torchao/quantization/quant_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from torchao.quantization.linear_activation_weight_observed_tensor import (
6565
LinearActivationWeightObservedTensor,
6666
)
67-
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
67+
from torchao.quantization.observer import AffineQuantizedObserverBase
6868
from torchao.quantization.quantize_.common import (
6969
KernelPreference,
7070
)
@@ -87,6 +87,7 @@
8787
_QUANTIZE_CONFIG_HANDLER,
8888
register_quantize_module_handler,
8989
)
90+
from torchao.quantization.utils import get_block_size
9091
from torchao.quantization.weight_tensor_linear_activation_quantization import (
9192
to_weight_tensor_with_linear_activation_quantization_metadata,
9293
)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
preprocess_scale,
2424
)
2525
from torchao.quantization.granularity import PerRow, PerTensor
26-
from torchao.quantization.observer import get_block_size
2726
from torchao.quantization.quant_primitives import (
2827
_choose_scale_float8,
2928
_dequantize_affine_float8,
@@ -34,6 +33,7 @@
3433
QuantizeTensorKwargs,
3534
_choose_quant_func_and_quantize_tensor,
3635
)
36+
from torchao.quantization.utils import get_block_size
3737
from torchao.utils import (
3838
TorchAOBaseTensor,
3939
_is_fbgemm_genai_gpu_available,

0 commit comments

Comments
 (0)