Skip to content

Commit 5f0406e

Browse files
committed
address comments
1 parent 577387d commit 5f0406e

File tree

4 files changed

+133
-61
lines changed

4 files changed

+133
-61
lines changed

python/tvm/relay/quantize/quantize.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ def annotate_context():
196196

197197

198198
def collect_stats(graph):
199+
"""Given an annotated graph, create a profile graph to collect profile data from the
200+
calibration dataset. This pass finds simulated_quantize op and collects its input into a tuple.
201+
The tuple is the output of the profile graph.
202+
203+
Parameters
204+
----------
205+
graph: Function
206+
The simulation graph after annotation.
207+
208+
Returns
209+
-------
210+
ret: Function
211+
The profile graph which outputs a tuple of profile data.
212+
"""
199213
return _quantize.CollectStats(graph)
200214

201215

@@ -215,6 +229,16 @@ def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
215229
ctx: tvm.relay.PassContext
216230
The pass context used for calibration.
217231
232+
weight_scales: 'power2' or 'max'.
233+
The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
234+
power2: Find the maximum of the absolute value of the tensor, and then round up to power
235+
of two.
236+
max: Find the maximum of the absolute value of the tensor.
237+
238+
scales: List[float]
239+
Pre-calculated scales for input and activations. Length and the order of elements of the
240+
scales list should match the output tuple of the profile graph created by collect_stats.
241+
218242
Returns
219243
-------
220244
ret: Function
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
*
23+
* \file calibration.cc
24+
*
25+
* \brief Create profile graph and calibrate on dataset
26+
*/
27+
#include <tvm/relay/analysis.h>
28+
#include <tvm/relay/expr_functor.h>
29+
#include "./quantize.h"
30+
31+
32+
namespace tvm {
33+
namespace relay {
34+
namespace quantize {
35+
36+
class StatsCollector : private ExprMutator {
37+
public:
38+
Expr Collect(const Expr& expr) {
39+
auto new_e = this->Mutate(expr);
40+
const FunctionNode* func = new_e.as<FunctionNode>();
41+
CHECK(func) << "Input shoule be Function";
42+
Expr new_body = TupleNode::make(std::move(profile_data_));
43+
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
44+
func->attrs);
45+
}
46+
47+
private:
48+
Array<Expr> profile_data_;
49+
50+
Expr VisitExpr_(const CallNode* call) {
51+
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
52+
Expr new_e = ExprMutator::VisitExpr_(call);
53+
const CallNode* new_call = new_e.as<CallNode>();
54+
CHECK(new_call);
55+
if (new_call->op.same_as(simulated_quantize)) {
56+
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
57+
const Expr& quantize_input = new_call->args[0]; // expression being quantized
58+
if (attrs->kind != QAnnotateKind::kQWeight) {
59+
CHECK(!quantize_input.as<ConstantNode>());
60+
profile_data_.push_back(quantize_input);
61+
}
62+
return quantize_input;
63+
} else {
64+
return new_e;
65+
}
66+
}
67+
};
68+
69+
/*
70+
* \brief Given an annotated graph, create a profile graph to collect profile data from the
71+
*
72+
* calibration dataset.
73+
*
74+
* This pass finds simulated_quantize op and collects its input into a tuple. The tuple is the
75+
* output of the profile graph. Both input and output of this pass
76+
* are relay::Function.
77+
*
78+
* \param expr Expression after Annotate pass.
79+
* \return The profile graph.
80+
*/
81+
Expr CollectStats(const Expr& expr) {
82+
return StatsCollector().Collect(expr);
83+
}
84+
85+
TVM_REGISTER_API("relay._quantize.CollectStats")
86+
.set_body_typed(CollectStats);
87+
88+
} // namespace quantize
89+
} // namespace relay
90+
} // namespace tvm

src/relay/pass/quantize.cc renamed to src/relay/pass/quantize/quantize.cc

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
#include <vector>
3737
#include <stack>
3838
#include <utility>
39-
#include "pattern_util.h"
40-
#include "quantize.h"
39+
#include "../pattern_util.h"
40+
#include "./quantize.h"
4141

4242

4343
namespace tvm {
@@ -46,22 +46,6 @@ namespace quantize {
4646

4747
using namespace relay::transform;
4848

49-
/*! \brief Attribute for simulated quantize operator */
50-
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
51-
int kind;
52-
bool sign;
53-
std::string rounding;
54-
55-
TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
56-
TVM_ATTR_FIELD(kind)
57-
.describe("kind of field, hint for nbit/dtype configuration.");
58-
TVM_ATTR_FIELD(sign).set_default(true)
59-
.describe("whether to use signed data type.");
60-
TVM_ATTR_FIELD(rounding).set_default("round")
61-
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
62-
}
63-
};
64-
6549
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
6650

6751
bool SimulatedQuantizeRel(const Array<Type>& types,
@@ -739,48 +723,6 @@ TVM_REGISTER_API("relay._quantize.temp_expr_realize")
739723
return n->Realize();
740724
});
741725

742-
// =============
743-
// calibration
744-
745-
class StatsCollector : private ExprMutator {
746-
public:
747-
Expr Collect(const Expr& expr) {
748-
auto new_e = this->Mutate(expr);
749-
const FunctionNode* func = new_e.as<FunctionNode>();
750-
CHECK(func);
751-
Expr new_body = TupleNode::make(std::move(profile_data_));
752-
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
753-
func->attrs);
754-
}
755-
756-
private:
757-
Array<Expr> profile_data_;
758-
759-
Expr VisitExpr_(const CallNode* call) {
760-
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
761-
Expr new_e = ExprMutator::VisitExpr_(call);
762-
const CallNode* new_call = new_e.as<CallNode>();
763-
CHECK(new_call);
764-
if (new_call->op.same_as(simulated_quantize)) {
765-
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
766-
if (attrs->kind != QAnnotateKind::kQWeight) {
767-
CHECK(!new_call->args[0].as<ConstantNode>());
768-
const Expr& quantize_input = new_call->args[0]; // expression being quantized
769-
profile_data_.push_back(quantize_input);
770-
}
771-
return new_call->args[0];
772-
} else {
773-
return new_e;
774-
}
775-
}
776-
};
777-
778-
Expr CollectStats(const Expr& expr) {
779-
return StatsCollector().Collect(expr);
780-
}
781-
782-
TVM_REGISTER_API("relay._quantize.CollectStats")
783-
.set_body_typed(CollectStats);
784726

785727
} // namespace quantize
786728
} // namespace relay

src/relay/pass/quantize.h renamed to src/relay/pass/quantize/quantize.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <tvm/relay/op.h>
3030
#include <tvm/relay/expr.h>
3131
#include <string>
32-
#include "pattern_util.h"
32+
#include "../pattern_util.h"
3333

3434
namespace tvm {
3535
namespace relay {
@@ -42,6 +42,22 @@ enum QAnnotateKind : int {
4242
kQActivation = 3,
4343
};
4444

45+
/*! \brief Attribute for simulated quantize operator */
46+
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
47+
int kind;
48+
bool sign;
49+
std::string rounding;
50+
51+
TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
52+
TVM_ATTR_FIELD(kind)
53+
.describe("kind of field, hint for nbit/dtype configuration.");
54+
TVM_ATTR_FIELD(sign).set_default(true)
55+
.describe("whether to use signed data type.");
56+
TVM_ATTR_FIELD(rounding).set_default("round")
57+
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
58+
}
59+
};
60+
4561
/*!
4662
* \brief TempExpr used during annotate forward rewrite.
4763
*/

0 commit comments

Comments
 (0)