|
9 | 9 | import torch |
10 | 10 |
|
11 | 11 | from torch._ops import OpOverload |
| 12 | +from torch._subclasses import FakeTensor |
12 | 13 |
|
13 | 14 | from torch.ao.quantization.quantizer import ( |
14 | 15 | QuantizationAnnotation, |
@@ -42,6 +43,19 @@ def decorator(annotator: Callable): |
42 | 43 | return decorator |
43 | 44 |
|
44 | 45 |
|
| 46 | +def _is_input_float_tensor(node: Node): |
| 47 | + """Check if the input is not a float tensor, so that we can skip quantization for the node |
| 48 | + since observers only works with float Tensors |
| 49 | + """ |
| 50 | + if ( |
| 51 | + not isinstance(node, Node) |
| 52 | + or "val" not in node.meta |
| 53 | + or not isinstance(node.meta["val"], FakeTensor) |
| 54 | + ): |
| 55 | + return False |
| 56 | + return node.meta["val"].dtype == torch.float32 |
| 57 | + |
| 58 | + |
45 | 59 | def _is_annotated(nodes: List[Node]): |
46 | 60 | """ |
47 | 61 | Given a list of nodes (that represents an operator pattern), |
@@ -123,11 +137,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None |
123 | 137 |
|
124 | 138 | input_qspec_map = {} |
125 | 139 | input_act0 = node.args[0] |
126 | | - if isinstance(input_act0, Node): |
| 140 | + if _is_input_float_tensor(input_act0): |
127 | 141 | input_qspec_map[input_act0] = input_act_qspec |
128 | 142 |
|
129 | 143 | input_act1 = node.args[1] |
130 | | - if isinstance(input_act1, Node): |
| 144 | + if _is_input_float_tensor(input_act1): |
131 | 145 | input_qspec_map[input_act1] = input_act_qspec |
132 | 146 |
|
133 | 147 | node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
|
0 commit comments