15
15
16
16
import torch
17
17
import torch .fx
18
- from torch .ao .quantization .fx .utils import create_getattr_from_value
19
- from torch .ao .quantization .pt2e .utils import _fuse_conv_bn_
20
18
from torch .fx .node import map_arg
21
19
from torch .fx .passes .infra .pass_base import PassBase
22
20
from torch .fx .passes .infra .pass_base import PassResult
23
- from torch .quantization .fake_quantize import FakeQuantize
24
21
25
22
import nncf
26
23
import nncf .torch
29
26
from nncf .experimental .torch .fx .constant_folding import constant_fold
30
27
from nncf .experimental .torch .fx .node_utils import get_graph_node_by_name
31
28
from nncf .experimental .torch .fx .node_utils import get_tensor_constant_from_node
29
+ from nncf .experimental .torch .fx .quantization .qdq_parameters import TorchQDQParameters
32
30
from nncf .torch .graph .transformations .commands import PTTargetPoint
33
31
34
32
TransformationFNType = Callable [[torch .fx .GraphModule ], None ]
@@ -223,16 +221,16 @@ def constant_update_fn(
223
221
224
222
225
223
def qdq_insertion_transformation_builder (
226
- quantizer : FakeQuantize , target_points : list [PTTargetPoint ]
224
+ parameters : TorchQDQParameters , target_points : list [PTTargetPoint ]
227
225
) -> TransformationFNType :
228
226
"""
229
- Returns transformation which inserts quantize-dequantize operations with parameters
230
- inherited from the given quantizer to each given target point.
227
+ Returns transformation which inserts quantize-dequantize operations with
228
+ the given parameters to each given target point.
231
229
232
- :param quantizer: Quantizer module to inherit quantization parameters from .
230
+ :param quantizer: Quantization parameters.
233
231
:param target_points: List of target point used to insert quantize-dequantize pairs.
234
- :return: Transformation which inserts quantize-dequantize operations with parameters
235
- inherited from the given quantizer to each given target point.
232
+ :return: Transformation which inserts quantize-dequantize operations with
233
+ the given parameters to each given target point.
236
234
"""
237
235
238
236
def qdq_insertion_transformation (model : torch .fx .GraphModule ):
@@ -243,7 +241,7 @@ def qdq_insertion_transformation(model: torch.fx.GraphModule):
243
241
)
244
242
raise nncf .InternalError (msg )
245
243
for target_point in target_points :
246
- insert_one_qdq (model , target_point , quantizer )
244
+ insert_one_qdq (model , target_point , parameters )
247
245
248
246
return qdq_insertion_transformation
249
247
@@ -311,38 +309,38 @@ def output_insertion_transformation(model: torch.fx.GraphModule):
311
309
return output_insertion_transformation
312
310
313
311
314
- def insert_one_qdq (model : torch .fx .GraphModule , target_point : PTTargetPoint , quantizer : FakeQuantize ):
312
+ def insert_one_qdq (model : torch .fx .GraphModule , target_point : PTTargetPoint , parameters : TorchQDQParameters ):
315
313
"""
316
314
Inserts quantize-dequantize after the target node to the target model.
317
315
318
316
:param model: Target model.
319
317
:param target_node: Target node, quantizer-dequantizer pair is inserted just after the
320
318
target node.
321
- :param quantizer: Quantizer module to inherit quantization parameters from .
319
+ :param parameters: Quantization parameters.
322
320
"""
323
- # Copied from torch.ao .quantization.quantize_pt2e.convert_pt2e
321
+ # Copied from torchao .quantization.quantize_pt2e.convert_pt2e
324
322
# 1. extract information for inserting q/dq node from activation_post_process
325
323
node_type = "call_function"
326
324
quantize_op : Optional [Callable ] = None
327
325
328
- dtype = torch .int8 if quantizer .quant_min < 0 else torch .uint8
329
- if quantizer .is_per_channel :
326
+ dtype = torch .int8 if parameters .quant_min < 0 else torch .uint8
327
+ if parameters .is_per_channel :
330
328
qparams = {
331
- "_scale_" : quantizer .scale ,
332
- "_zero_point_" : quantizer .zero_point ,
333
- "_axis_" : quantizer .ch_axis ,
334
- "_quant_min_" : quantizer .quant_min ,
335
- "_quant_max_" : quantizer .quant_max ,
329
+ "_scale_" : parameters .scale ,
330
+ "_zero_point_" : parameters .zero_point ,
331
+ "_axis_" : parameters .ch_axis ,
332
+ "_quant_min_" : parameters .quant_min ,
333
+ "_quant_max_" : parameters .quant_max ,
336
334
"_dtype_" : dtype ,
337
335
}
338
336
quantize_op = torch .ops .quantized_decomposed .quantize_per_channel .default
339
337
dequantize_op = torch .ops .quantized_decomposed .dequantize_per_channel .default
340
338
else :
341
339
qparams = {
342
- "_scale_" : float (quantizer .scale ),
343
- "_zero_point_" : int (quantizer .zero_point ),
344
- "_quant_min_" : quantizer .quant_min ,
345
- "_quant_max_" : quantizer .quant_max ,
340
+ "_scale_" : float (parameters .scale ),
341
+ "_zero_point_" : int (parameters .zero_point ),
342
+ "_quant_min_" : parameters .quant_min ,
343
+ "_quant_max_" : parameters .quant_max ,
346
344
"_dtype_" : dtype ,
347
345
}
348
346
quantize_op = torch .ops .quantized_decomposed .quantize_per_tensor .default
@@ -721,19 +719,6 @@ def match_filters(match, original_graph, graph):
721
719
_set_meta_for_matches (model , matches )
722
720
723
721
724
- def apply_quantization_transformations (model : torch .fx .GraphModule ) -> None :
725
- """
726
- Applies quantization transformations to the model.
727
-
728
- :param model: Model to apply transformations to.
729
- """
730
- # BatchNorm operations have 3 output ports,
731
- # to make it easier for algorithms to work
732
- # with the target graph BatchNorm operations
733
- # are being fused
734
- _fuse_conv_bn_ (model )
735
-
736
-
737
722
def fold_constant_except_qdq (model : torch .fx .GraphModule ):
738
723
"""
739
724
Performs constant folding avoiding quantize-dequantize pattern.
@@ -826,3 +811,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
826
811
graph_module .graph .eliminate_dead_code ()
827
812
graph_module .recompile ()
828
813
return PassResult (graph_module , True )
814
+
815
+
816
+ def get_device (module : torch .nn .Module ) -> torch .device :
817
+ """
818
+ Retrieves device of the first parameter of the given module.
819
+ If there are no parameters - returns CPU device.
820
+
821
+ :param module: A torch.nn.Module instance.
822
+ :return: A device of the first parameter of the given module.
823
+ If there are no parameters - returns CPU device.
824
+ """
825
+ try :
826
+ named_param = next (module .parameters ())
827
+ except StopIteration :
828
+ named_param = None
829
+ if named_param is None :
830
+ return torch .device ("cpu" )
831
+ return named_param .device
832
+
833
+
834
+ def create_getattr_from_value (module : torch .nn .Module , graph : torch .fx .Graph , prefix : str , value : Any ) -> torch .fx .Node :
835
+ """
836
+ Given a value of any type, creates a getattr node corresponding to the value and
837
+ registers the value as a buffer to the module.
838
+
839
+ :param module: A torch.nn.Module instance.
840
+ :param graph: A torch.fx.Graph instance.
841
+ :param prefix: A string to use as a name prefix for the new getattr node.
842
+ :param value: A value
843
+ :return: A getattr node corresponding to the given value.
844
+ """
845
+
846
+ def get_new_attr_name (module : torch .nn .Module , prefix : str ):
847
+ def get_attr_name (i : int ):
848
+ return prefix + str (i )
849
+
850
+ i = 0
851
+ attr_name = get_attr_name (i )
852
+ while hasattr (module , attr_name ):
853
+ i += 1
854
+ attr_name = get_attr_name (i )
855
+ return attr_name
856
+
857
+ attr_name = get_new_attr_name (module , prefix .replace ("." , "_" ))
858
+ device = get_device (module )
859
+ new_value = value .detach ().clone () if isinstance (value , torch .Tensor ) else torch .tensor (value , device = device )
860
+ module .register_buffer (attr_name , new_value )
861
+ attr_node = graph .create_node ("get_attr" , attr_name )
862
+ return attr_node
0 commit comments