Skip to content

Conversation

@margaretqian
Copy link
Contributor

Add a relay pass to collect fake quantized ops and frequencies from within fake quantized regions.

@AndrewZhaoLuo
Copy link
Contributor

I will take a look tomorrow

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! A few comments

mod = tvm.IRModule.from_expr(op)
fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)

assert len(fake_quantized_op_freqs) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just do direct equality throughout this file

dict(fake_quantized_op_freqs) == {"nn.conv2d": 1}


using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;

class FakeQuantizedRegionExtractor : public ExprVisitor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just reuse SubgraphExtractor in src/relay/transforms/fake_quantization_to_integer.cc?

Copy link
Contributor

@anwang2009 anwang2009 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests look great! very comprehensive.

Margaret Qian added 2 commits January 31, 2022 18:31
Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some nits

* \file src/relay/transforms/fake_quantization_to_integer.h
* \brief Extract subgraph of a fake quantized region.
*
* https://llvm.org/doxygen/CallGraph_8h_source.html
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is probably copypasta?

if (op != dequantize_op_) {
if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) {
fake_quantized_op_freqs_.Set(op_name,
int64_t(fake_quantized_op_freqs_.at(op_name)) + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to cast here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was getting compile-time errors that fake_quantized_op_freqs_.at(op_name) + 1 is a PrimExpr instead of a tvm::Integer and it seemed like casting worked around the issue -- lmk if there's a better way around this?

@AndrewZhaoLuo AndrewZhaoLuo merged commit efe662f into apache:main Feb 2, 2022
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
…10089)

* add relay pass to collect fake quantized ops

* add more tests

* more tests

* lint

* lint

* remove unused imports

* update comment

* lint

* reuse SubgraphExtractor and update test assertions

* remove print

* lint

* remove unneeded comment

Co-authored-by: Margaret Qian <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants