@@ -267,7 +267,7 @@ def __init__(self):
267267 self .custom_quant_annotations : Sequence [Callable ] = []
268268 self .discard_nodes : Set [str ] = set ()
269269
270- self .enable_per_channel_conv_quant : bool = True
270+ self .use_per_channel_weight_quant_ops : Set [ OpOverload ] = set ()
271271 # the weight quantized for activation 8 bits and 16 bits
272272 self .per_channel_weight_dtype : Dict = {
273273 "8bit_act" : torch .int8 ,
@@ -290,16 +290,13 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
290290 def _get_quant_config (self , op : str | OpOverload ) -> Optional [QuantizationConfig ]:
291291 """
292292 Priority:
293- 1. per channel config when enable_per_channel_conv_quant is True
293+ 1. is one of use_per_channel_weight_quant_ops
294294 2. int8 / int16 config
295295 """
296296 if type (op ) == str :
297297 return
298298
299- if self .enable_per_channel_conv_quant and op in [
300- torch .ops .aten .conv1d .default ,
301- torch .ops .aten .conv2d .default ,
302- ]:
299+ if op in self .use_per_channel_weight_quant_ops :
303300 if op in self .bit16_quant_ops :
304301 return get_ptq_per_channel_weight_config (
305302 torch .uint16 , self .per_channel_weight_dtype ["16bit_act" ]
@@ -316,6 +313,12 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig
316313
317314 print (f"No quant config is implemented for op, { op } " )
318315
316+ def _update_per_channel_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
317+ if enable :
318+ self .use_per_channel_weight_quant_ops .update (ops )
319+ else :
320+ self .use_per_channel_weight_quant_ops .difference (ops )
321+
319322 def add_16bit_quant_ops (self , ops : Set [OpOverload ]) -> None :
320323 for op in ops :
321324 assert (
@@ -368,8 +371,15 @@ def set_per_channel_weight_dtype(
368371 if weight_dtype_for_16bit_act :
369372 self .per_channel_weight_dtype ["16bit_act" ] = weight_dtype_for_16bit_act
370373
371- def set_per_channel_quant (self , enable : bool ) -> None :
372- self .enable_per_channel_conv_quant = enable
374+ def set_per_channel_conv_quant (self , enable : bool ) -> None :
375+ conv_ops = {torch .ops .aten .conv1d .default , torch .ops .aten .conv2d .default }
376+ self ._update_per_channel_weight_quant_ops (conv_ops , enable )
377+
378+ def set_per_channel_linear_quant (self , enable : bool ) -> None :
379+ linear_ops = {
380+ torch .ops .aten .linear .default ,
381+ }
382+ self ._update_per_channel_weight_quant_ops (linear_ops , enable )
373383
374384 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
375385 model = RemoveClone ()(model ).graph_module
0 commit comments