|
7 | 7 |
|
8 | 8 | This can also support exporting the model to other platforms like ONNX as well. |
9 | 9 | """ |
| 10 | + |
| 11 | +from typing import List, Optional |
| 12 | + |
10 | 13 | import torch |
11 | 14 | import torchao |
12 | | -from my_dtype_tensor_subclass import ( |
13 | | - MyDTypeTensor, |
14 | | -) |
15 | | -from torchao.utils import _register_custom_op |
| 15 | +from my_dtype_tensor_subclass import MyDTypeTensor |
16 | 16 | from torchao.quantization.quant_primitives import dequantize_affine |
17 | | -from typing import Optional, List |
| 17 | +from torchao.utils import _register_custom_op |
18 | 18 |
|
19 | 19 | quant_lib = torch.library.Library("quant", "FRAGMENT") |
20 | 20 | register_custom_op = _register_custom_op(quant_lib) |
21 | 21 |
|
| 22 | + |
22 | 23 | class MyDTypeTensorExtended(MyDTypeTensor): |
23 | 24 | pass |
24 | 25 |
|
| 26 | + |
25 | 27 | implements = MyDTypeTensorExtended.implements |
26 | 28 | to_my_dtype_extended = MyDTypeTensorExtended.from_float |
27 | 29 |
|
28 | 30 | aten = torch.ops.aten |
29 | 31 |
|
| 32 | + |
30 | 33 | # NOTE: the op must start with `_` |
31 | 34 | # NOTE: typing must be compatible with infer_schema (https://github.com/pytorch/pytorch/blob/main/torch/_library/infer_schema.py) |
32 | 35 | # This will register a torch.ops.quant.embedding |
@@ -59,27 +62,29 @@ def _(func, types, args, kwargs): |
59 | 62 |
|
60 | 63 | def main(): |
61 | 64 | group_size = 64 |
62 | | - m = torch.nn.Sequential( |
63 | | - torch.nn.Embedding(4096, 128) |
64 | | - ) |
| 65 | + m = torch.nn.Sequential(torch.nn.Embedding(4096, 128)) |
65 | 66 | input = torch.randint(0, 4096, (1, 6)) |
66 | 67 |
|
67 | | - m[0].weight = torch.nn.Parameter(to_my_dtype_extended(m[0].weight), requires_grad=False) |
| 68 | + m[0].weight = torch.nn.Parameter( |
| 69 | + to_my_dtype_extended(m[0].weight), requires_grad=False |
| 70 | + ) |
68 | 71 | y_ref = m[0].weight.dequantize()[input] |
69 | 72 | y_q = m(input) |
70 | 73 | from torchao.quantization.utils import compute_error |
| 74 | + |
71 | 75 | sqnr = compute_error(y_ref, y_q) |
72 | 76 | assert sqnr > 45.0 |
73 | 77 |
|
74 | 78 | # export |
75 | 79 | m_unwrapped = torchao.utils.unwrap_tensor_subclass(m) |
76 | | - m_exported = torch.export.export(m_unwrapped, (input,)).module() |
| 80 | + m_exported = torch.export.export(m_unwrapped, (input,), strict=True).module() |
77 | 81 | y_q_exported = m_exported(input) |
78 | 82 |
|
79 | 83 | assert torch.equal(y_ref, y_q_exported) |
80 | 84 | ops = [n.target for n in m_exported.graph.nodes] |
81 | 85 | print(m_exported) |
82 | 86 | assert torch.ops.quant.embedding_byte.default in ops |
83 | 87 |
|
| 88 | + |
84 | 89 | if __name__ == "__main__": |
85 | 90 | main() |
0 commit comments