Skip to content

Commit fe3f359

Browse files
authored
pytorch/ao/tutorials/developer_api_guide
Differential Revision: D67388106 Pull Request resolved: #1440
1 parent 9f0dcdc commit fe3f359

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

tutorials/developer_api_guide/export_to_executorch.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,29 @@
77
88
This can also support exporting the model to other platforms like ONNX as well.
99
"""
10+
11+
from typing import List, Optional
12+
1013
import torch
1114
import torchao
12-
from my_dtype_tensor_subclass import (
13-
MyDTypeTensor,
14-
)
15-
from torchao.utils import _register_custom_op
15+
from my_dtype_tensor_subclass import MyDTypeTensor
1616
from torchao.quantization.quant_primitives import dequantize_affine
17-
from typing import Optional, List
17+
from torchao.utils import _register_custom_op
1818

1919
quant_lib = torch.library.Library("quant", "FRAGMENT")
2020
register_custom_op = _register_custom_op(quant_lib)
2121

22+
2223
class MyDTypeTensorExtended(MyDTypeTensor):
2324
pass
2425

26+
2527
implements = MyDTypeTensorExtended.implements
2628
to_my_dtype_extended = MyDTypeTensorExtended.from_float
2729

2830
aten = torch.ops.aten
2931

32+
3033
# NOTE: the op must start with `_`
3134
# NOTE: typing must be compatible with infer_schema (https://github.com/pytorch/pytorch/blob/main/torch/_library/infer_schema.py)
3235
# This will register a torch.ops.quant.embedding
@@ -59,27 +62,29 @@ def _(func, types, args, kwargs):
5962

6063
def main():
6164
group_size = 64
62-
m = torch.nn.Sequential(
63-
torch.nn.Embedding(4096, 128)
64-
)
65+
m = torch.nn.Sequential(torch.nn.Embedding(4096, 128))
6566
input = torch.randint(0, 4096, (1, 6))
6667

67-
m[0].weight = torch.nn.Parameter(to_my_dtype_extended(m[0].weight), requires_grad=False)
68+
m[0].weight = torch.nn.Parameter(
69+
to_my_dtype_extended(m[0].weight), requires_grad=False
70+
)
6871
y_ref = m[0].weight.dequantize()[input]
6972
y_q = m(input)
7073
from torchao.quantization.utils import compute_error
74+
7175
sqnr = compute_error(y_ref, y_q)
7276
assert sqnr > 45.0
7377

7478
# export
7579
m_unwrapped = torchao.utils.unwrap_tensor_subclass(m)
76-
m_exported = torch.export.export(m_unwrapped, (input,)).module()
80+
m_exported = torch.export.export(m_unwrapped, (input,), strict=True).module()
7781
y_q_exported = m_exported(input)
7882

7983
assert torch.equal(y_ref, y_q_exported)
8084
ops = [n.target for n in m_exported.graph.nodes]
8185
print(m_exported)
8286
assert torch.ops.quant.embedding_byte.default in ops
8387

88+
8489
if __name__ == "__main__":
8590
main()

0 commit comments

Comments
 (0)