Skip to content

Commit f775105

Browse files
committed
Address comment
1 parent 61f3fef commit f775105

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

python/tvm/autotvm/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def get_const_tuple(in_tuple):
180180
if isinstance(elem, expr.Var):
181181
ret.append(elem)
182182
elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
183-
elem = tvm.ir_pass.Simplify(elem)
183+
elem = ir_pass.Simplify(elem)
184184
if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
185185
ret.append(elem)
186186
else:

python/tvm/relay/op/_tensor.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,11 @@ def broadcast_shape_func(attrs, inputs, out_ndims):
154154
"""
155155
return [_broadcast_shape_func(*inputs, out_ndims[0])]
156156

157-
@script
158-
def _elemwise_shape_func(data_shape):
159-
out = output_tensor((data_shape.shape[0],), "int64")
160-
for i in const_range(data_shape.shape[0]):
161-
out[i] = data_shape[i]
162-
163-
return out
164-
165157
def elemwise_shape_func(attrs, inputs, _):
166158
"""
167159
Shape function for elemwise op.
168160
"""
169-
return [_elemwise_shape_func(inputs[0])]
161+
return [topi.math.identity(inputs[0])]
170162

171163
register_shape_func("cast", False, cast_shape_func)
172164

src/lang/data_layout.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,12 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
288288
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
289289
// e.g., (C * 16 + c) / 32
290290
std::unordered_map<const Variable*, Expr> bind_map;
291-
std::unordered_set<std::string> symbolic_var_set;
291+
std::unordered_set<size_t> symbolic_var_set;
292292
for (size_t i = 0; i < src_shape.size(); ++i) {
293293
Expr orig_shape = src_shape[i];
294294
IterVar orig_axis = src_axis[i];
295295
if (orig_shape.as<ir::Any>()) {
296-
symbolic_var_set.insert(orig_axis->var->name_hint);
296+
symbolic_var_set.insert(i);
297297
}
298298
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
299299
if (orig_shape.defined()) {
@@ -321,7 +321,7 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
321321
if (!LayoutAxis::Get(axis).IsPrimal()) {
322322
result.push_back(axis->dom->extent);
323323
} else {
324-
if (symbolic_var_set.count(axis->var->name_hint)) {
324+
if (symbolic_var_set.count(i)) {
325325
result.push_back(ir::Any::make());
326326
} else {
327327
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));

0 commit comments

Comments
 (0)