Skip to content

Commit 8695761

Browse files
[TorchFX] Use torchao for quantize_pt2e API when possible
1 parent 73b39a3 commit 8695761

32 files changed

+599
-15391
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from dataclasses import dataclass
13+
14+
import torch
15+
16+
17+
@dataclass
18+
class TorchQDQParameters:
19+
"""
20+
Stores the quantization parameters required for
21+
creation of a PyTorch quantize-dequantize pair.
22+
23+
:param quant_min: Minimum quant value.
24+
:type quant_min: int
25+
:param quant_max: Maximum quant value.
26+
:type quant_max: int
27+
:param scale: Defines the scale factor used for quantization.
28+
:type scale: torch.Tensor
29+
:param zero_point: Specifies the quantized value to which 0 in floating point maps to.
30+
:type zero_point: torch.Tensor
31+
:param is_per_channel: Whether quantization is applied per channel.
32+
:type is_per_channel: bool
33+
:param ch_axis: Channel axis used for per-channel quantization.
34+
:type ch_axis: int
35+
"""
36+
37+
quant_min: int
38+
quant_max: int
39+
scale: torch.Tensor
40+
zero_point: torch.Tensor
41+
is_per_channel: bool
42+
ch_axis: int

src/nncf/experimental/torch/fx/quantization/quantize_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
1818
from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat
1919
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
20+
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
2021
from torch.fx import GraphModule
2122
from torch.fx.passes.infra.pass_manager import PassManager
2223

@@ -27,7 +28,6 @@
2728
from nncf.data import Dataset
2829
from nncf.experimental.torch.fx.quantization.backend_parameters import is_weight_compression_needed
2930
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
30-
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
3131
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
3232
from nncf.experimental.torch.fx.transformations import fq_weights_transformation
3333
from nncf.parameters import BackupMode
@@ -87,8 +87,9 @@ def quantize_impl(
8787
advanced_parameters=advanced_parameters,
8888
)
8989

90-
# To make it easier for bias correction algorithms.
91-
apply_quantization_transformations(copied_model)
90+
# Fuse batch norms to convolutions bias
91+
# the same way it done in torchao
92+
_fuse_conv_bn_(copied_model)
9293

9394
nncf_graph = NNCFGraphFactory.create(copied_model)
9495
quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)

src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414

1515
import torch
1616
import torch.fx
17-
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
18-
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
19-
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
20-
from torch.ao.quantization.quantizer import Quantizer
2117
from torch.fx import GraphModule
2218
from torch.fx.passes.infra.pass_manager import PassManager
2319

@@ -38,6 +34,17 @@
3834
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
3935
from nncf.quantization.range_estimator import RangeEstimatorParameters
4036

37+
try:
38+
from torchao.quantization.pt2e.quantizer import Quantizer
39+
from torchao.quantization.pt2e.quantizer.port_metadata_pass import PortNodeMetaForQDQ
40+
from torchao.quantization.pt2e.utils import _disallow_eval_train
41+
from torchao.quantization.pt2e.utils import _fuse_conv_bn_
42+
except ImportError:
43+
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
44+
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
45+
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
46+
from torch.ao.quantization.quantizer import Quantizer
47+
4148

4249
@api(canonical_alias="nncf.experimental.torch.fx.quantize_pt2e")
4350
def quantize_pt2e(

src/nncf/experimental/torch/fx/quantization/quantizer/openvino_quantizer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@
1313
from typing import Optional, Union
1414

1515
import torch.fx
16-
from torch.ao.quantization.observer import HistogramObserver
17-
from torch.ao.quantization.observer import PerChannelMinMaxObserver
18-
from torch.ao.quantization.quantizer.quantizer import EdgeOrNode
19-
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation
20-
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec
21-
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase
22-
from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer
23-
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec
2416

2517
import nncf
2618
from nncf import IgnoredScope
@@ -43,6 +35,25 @@
4335
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
4436
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
4537

38+
try:
39+
from torchao.quantization.pt2e.observer import HistogramObserver
40+
from torchao.quantization.pt2e.observer import PerChannelMinMaxObserver
41+
from torchao.quantization.pt2e.quantizer.quantizer import EdgeOrNode
42+
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation
43+
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec
44+
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase
45+
from torchao.quantization.pt2e.quantizer.quantizer import Quantizer as TorchAOQuantizer
46+
from torchao.quantization.pt2e.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec
47+
except ImportError:
48+
from torch.ao.quantization.observer import HistogramObserver
49+
from torch.ao.quantization.observer import PerChannelMinMaxObserver
50+
from torch.ao.quantization.quantizer.quantizer import EdgeOrNode
51+
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation
52+
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec
53+
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase
54+
from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer
55+
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec
56+
4657
QUANT_ANNOTATION_KEY = "quantization_annotation"
4758

4859

src/nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515

1616
import torch
1717
import torch.fx
18-
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
19-
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
20-
from torch.ao.quantization.quantizer import Quantizer as TorchAOQuantizer
21-
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
22-
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec
2318

2419
import nncf
2520
from nncf.common.graph.graph import NNCFGraph
@@ -34,12 +29,26 @@
3429
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
3530
from nncf.tensor.definitions import TensorDataType
3631

32+
try:
33+
from torchao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
34+
from torchao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
35+
from torchao.quantization.pt2e.quantizer import Quantizer as TorchAOQuantizer
36+
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpec
37+
from torchao.quantization.pt2e.quantizer.quantizer import SharedQuantizationSpec
38+
except ImportError:
39+
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
40+
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
41+
from torch.ao.quantization.quantizer import Quantizer as TorchAOQuantizer
42+
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
43+
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec
44+
45+
3746
EdgeOrNode = Union[tuple[torch.fx.Node, torch.fx.Node]]
3847

3948

4049
class TorchAOQuantizerAdapter(Quantizer):
4150
"""
42-
Implementation of the NNCF Quantizer interface for any given torch.ao quantizer.
51+
Implementation of the NNCF Quantizer interface for any given torchao quantizer.
4352
"""
4453

4554
def __init__(self, quantizer: TorchAOQuantizer):
@@ -120,7 +129,7 @@ def _get_node_args(node: torch.fx.Node) -> tuple[Any, ...]:
120129
def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
121130
"""
122131
Process a torch.fx.GraphModule annotated with quantization specifications
123-
(e.g., via torch.ao observers) and generates a corresponding NNCF quantization setup object,
132+
(e.g., via torchao observers) and generates a corresponding NNCF quantization setup object,
124133
which maps quantization configurations to graph edges.
125134
126135
:param annotated: A torch.fx.GraphModule that has been annotated with Torch quantization observers.
@@ -149,7 +158,7 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
149158
if qspec is None:
150159
continue
151160
if not isinstance(qspec, QuantizationSpec):
152-
msg = f"Unknown torch.ao quantization spec: {qspec}"
161+
msg = f"Unknown torchao quantization spec: {qspec}"
153162
raise nncf.InternalError(msg)
154163

155164
if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:

src/nncf/experimental/torch/fx/transformations.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@
1515

1616
import torch
1717
import torch.fx
18-
from torch.ao.quantization.fx.utils import create_getattr_from_value
19-
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
2018
from torch.fx.node import map_arg
2119
from torch.fx.passes.infra.pass_base import PassBase
2220
from torch.fx.passes.infra.pass_base import PassResult
23-
from torch.quantization.fake_quantize import FakeQuantize
2421

2522
import nncf
2623
import nncf.torch
@@ -29,6 +26,7 @@
2926
from nncf.experimental.torch.fx.constant_folding import constant_fold
3027
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
3128
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
29+
from nncf.experimental.torch.fx.quantization.qdq_parameters import TorchQDQParameters
3230
from nncf.torch.graph.transformations.commands import PTTargetPoint
3331

3432
TransformationFNType = Callable[[torch.fx.GraphModule], None]
@@ -223,16 +221,16 @@ def constant_update_fn(
223221

224222

225223
def qdq_insertion_transformation_builder(
226-
quantizer: FakeQuantize, target_points: list[PTTargetPoint]
224+
parameters: TorchQDQParameters, target_points: list[PTTargetPoint]
227225
) -> TransformationFNType:
228226
"""
229-
Returns transformation which inserts quantize-dequantize operations with parameters
230-
inherited from the given quantizer to each given target point.
227+
Returns transformation which inserts quantize-dequantize operations with
228+
the given parameters to each given target point.
231229
232-
:param quantizer: Quantizer module to inherit quantization parameters from.
230+
:param quantizer: Quantization parameters.
233231
:param target_points: List of target point used to insert quantize-dequantize pairs.
234-
:return: Transformation which inserts quantize-dequantize operations with parameters
235-
inherited from the given quantizer to each given target point.
232+
:return: Transformation which inserts quantize-dequantize operations with
233+
the given parameters to each given target point.
236234
"""
237235

238236
def qdq_insertion_transformation(model: torch.fx.GraphModule):
@@ -243,7 +241,7 @@ def qdq_insertion_transformation(model: torch.fx.GraphModule):
243241
)
244242
raise nncf.InternalError(msg)
245243
for target_point in target_points:
246-
insert_one_qdq(model, target_point, quantizer)
244+
insert_one_qdq(model, target_point, parameters)
247245

248246
return qdq_insertion_transformation
249247

@@ -311,38 +309,38 @@ def output_insertion_transformation(model: torch.fx.GraphModule):
311309
return output_insertion_transformation
312310

313311

314-
def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, quantizer: FakeQuantize):
312+
def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, parameters: TorchQDQParameters):
315313
"""
316314
Inserts quantize-dequantize after the target node to the target model.
317315
318316
:param model: Target model.
319317
:param target_node: Target node, quantizer-dequantizer pair is inserted just after the
320318
target node.
321-
:param quantizer: Quantizer module to inherit quantization parameters from.
319+
:param parameters: Quantization parameters.
322320
"""
323-
# Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e
321+
# Copied from torchao.quantization.quantize_pt2e.convert_pt2e
324322
# 1. extract information for inserting q/dq node from activation_post_process
325323
node_type = "call_function"
326324
quantize_op: Optional[Callable] = None
327325

328-
dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8
329-
if quantizer.is_per_channel:
326+
dtype = torch.int8 if parameters.quant_min < 0 else torch.uint8
327+
if parameters.is_per_channel:
330328
qparams = {
331-
"_scale_": quantizer.scale,
332-
"_zero_point_": quantizer.zero_point,
333-
"_axis_": quantizer.ch_axis,
334-
"_quant_min_": quantizer.quant_min,
335-
"_quant_max_": quantizer.quant_max,
329+
"_scale_": parameters.scale,
330+
"_zero_point_": parameters.zero_point,
331+
"_axis_": parameters.ch_axis,
332+
"_quant_min_": parameters.quant_min,
333+
"_quant_max_": parameters.quant_max,
336334
"_dtype_": dtype,
337335
}
338336
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
339337
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
340338
else:
341339
qparams = {
342-
"_scale_": float(quantizer.scale),
343-
"_zero_point_": int(quantizer.zero_point),
344-
"_quant_min_": quantizer.quant_min,
345-
"_quant_max_": quantizer.quant_max,
340+
"_scale_": float(parameters.scale),
341+
"_zero_point_": int(parameters.zero_point),
342+
"_quant_min_": parameters.quant_min,
343+
"_quant_max_": parameters.quant_max,
346344
"_dtype_": dtype,
347345
}
348346
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
@@ -721,19 +719,6 @@ def match_filters(match, original_graph, graph):
721719
_set_meta_for_matches(model, matches)
722720

723721

724-
def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
725-
"""
726-
Applies quantization transformations to the model.
727-
728-
:param model: Model to apply transformations to.
729-
"""
730-
# BatchNorm operations have 3 output ports,
731-
# to make it easier for algorithms to work
732-
# with the target graph BatchNorm operations
733-
# are being fused
734-
_fuse_conv_bn_(model)
735-
736-
737722
def fold_constant_except_qdq(model: torch.fx.GraphModule):
738723
"""
739724
Performs constant folding avoiding quantize-dequantize pattern.
@@ -826,3 +811,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
826811
graph_module.graph.eliminate_dead_code()
827812
graph_module.recompile()
828813
return PassResult(graph_module, True)
814+
815+
816+
def get_device(module: torch.nn.Module) -> torch.device:
817+
"""
818+
Retrieves device of the first parameter of the given module.
819+
If there are no parameters - returns CPU device.
820+
821+
:param module: A torch.nn.Module instance.
822+
:return: A device of the first parameter of the given module.
823+
If there are no parameters - returns CPU device.
824+
"""
825+
try:
826+
named_param = next(module.parameters())
827+
except StopIteration:
828+
named_param = None
829+
if named_param is None:
830+
return torch.device("cpu")
831+
return named_param.device
832+
833+
834+
def create_getattr_from_value(module: torch.nn.Module, graph: torch.fx.Graph, prefix: str, value: Any) -> torch.fx.Node:
835+
"""
836+
Given a value of any type, creates a getattr node corresponding to the value and
837+
registers the value as a buffer to the module.
838+
839+
:param module: A torch.nn.Module instance.
840+
:param graph: A torch.fx.Graph instance.
841+
:param prefix: A string to use as a name prefix for the new getattr node.
842+
:param value: A value
843+
:return: A getattr node corresponding to the given value.
844+
"""
845+
846+
def get_new_attr_name(module: torch.nn.Module, prefix: str):
847+
def get_attr_name(i: int):
848+
return prefix + str(i)
849+
850+
i = 0
851+
attr_name = get_attr_name(i)
852+
while hasattr(module, attr_name):
853+
i += 1
854+
attr_name = get_attr_name(i)
855+
return attr_name
856+
857+
attr_name = get_new_attr_name(module, prefix.replace(".", "_"))
858+
device = get_device(module)
859+
new_value = value.detach().clone() if isinstance(value, torch.Tensor) else torch.tensor(value, device=device)
860+
module.register_buffer(attr_name, new_value)
861+
attr_node = graph.create_node("get_attr", attr_name)
862+
return attr_node

0 commit comments

Comments
 (0)