-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Relay][Pass] Add a relay pass to extract fake quantized ops #10089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I will take a look tomorrow |
AndrewZhaoLuo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! A few comments
| mod = tvm.IRModule.from_expr(op) | ||
| fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) | ||
|
|
||
| assert len(fake_quantized_op_freqs) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just do direct equality throughout this file
dict(fake_quantized_op_freqs) == {"nn.conv2d": 1}
|
|
||
| using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>; | ||
|
|
||
| class FakeQuantizedRegionExtractor : public ExprVisitor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just reuse SubgraphExtractor in src/relay/transforms/fake_quantization_to_integer.cc?
anwang2009
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests look great! very comprehensive.
AndrewZhaoLuo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some nits
| * \file src/relay/transforms/fake_quantization_to_integer.h | ||
| * \brief Extract subgraph of a fake quantized region. | ||
| * | ||
| * https://llvm.org/doxygen/CallGraph_8h_source.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is probably copypasta?
| if (op != dequantize_op_) { | ||
| if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) { | ||
| fake_quantized_op_freqs_.Set(op_name, | ||
| int64_t(fake_quantized_op_freqs_.at(op_name)) + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to cast here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i was getting compile-time errors that fake_quantized_op_freqs_.at(op_name) + 1 is a PrimExpr instead of a tvm::Integer and it seemed like casting worked around the issue -- lmk if there's a better way around this?
…10089) * add relay pass to collect fake quantized ops * add more tests * more tests * lint * lint * remove unused imports * update comment * lint * reuse SubgraphExtractor and update test assertions * remove print * lint * remove unneeded comment Co-authored-by: Margaret Qian <[email protected]>
Add a relay pass to collect fake quantized ops and frequencies from within fake quantized regions.