|
9 | 9 | import torch |
10 | 10 |
|
11 | 11 | from torch._ops import OpOverload |
| 12 | +from torch._subclasses import FakeTensor |
12 | 13 |
|
13 | 14 | from torch.ao.quantization.quantizer import ( |
14 | 15 | QuantizationAnnotation, |
@@ -41,6 +42,18 @@ def decorator(annotator: Callable): |
41 | 42 |
|
42 | 43 | return decorator |
43 | 44 |
|
| 45 | +def _is_input_float_tensor(node: Node): |
| 46 | + """Check if the input is not a float tensor, so that we can skip quantization for the node |
| 47 | + since observers only works with float Tensors |
| 48 | + """ |
| 49 | + if ( |
| 50 | + not isinstance(node, Node) |
| 51 | + or "val" not in node.meta |
| 52 | + or not isinstance(node.meta["val"], FakeTensor) |
| 53 | + ): |
| 54 | + return False |
| 55 | + return node.meta["val"].dtype == torch.float32 |
| 56 | + |
44 | 57 |
|
45 | 58 | def _is_annotated(nodes: List[Node]): |
46 | 59 | """ |
@@ -123,11 +136,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None |
123 | 136 |
|
124 | 137 | input_qspec_map = {} |
125 | 138 | input_act0 = node.args[0] |
126 | | - if isinstance(input_act0, Node): |
| 139 | + if _is_input_float_tensor(input_act0): |
127 | 140 | input_qspec_map[input_act0] = input_act_qspec |
128 | 141 |
|
129 | 142 | input_act1 = node.args[1] |
130 | | - if isinstance(input_act1, Node): |
| 143 | + if _is_input_float_tensor(input_act1): |
131 | 144 | input_qspec_map[input_act1] = input_act_qspec |
132 | 145 |
|
133 | 146 | node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
|
0 commit comments