- 
                Notifications
    
You must be signed in to change notification settings  - Fork 3.7k
 
Closed
Description
@tqchen I met this problem when implementing argmax and argmin in topi. The following code will raise an error
import tvm
def argmax_comp(x, y):
    idx = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
    val = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
    return idx, val
def argmax_init(idx_typ, val_typ):
    return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
argmax = tvm.comm_reducer(argmax_comp, argmax_init, name='argmax')
m = tvm.var('m')
n = tvm.var('n')
val = tvm.placeholder((m, n), name='val', dtype='float32')
val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2')
k = tvm.reduce_axis((0, n), 'k')
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T')
s = tvm.create_schedule(T_idx.op)
s[val2].compute_inline()
tvm.lower(s, [val, T_idx, T_val], simple_mode=True)TVMError: [14:13:24] src/op/compute_op.cc:122: Check failed: ReduceEqual(reduce_, reduce) The Reduce inputs of ComputeOp should have the same attribute except value_index
I've printed the inner info of this line https://github.com/dmlc/tvm/blob/master/src/op/compute_op.cc#L122 and I find the only difference is the reduce_.source and reduce.source.
reduce_->source=[k, val2(i, k)]
reduce->source=[k, exp(val(i, k))]
This problem does not happen if we use tvm.sum.
Metadata
Metadata
Assignees
Labels
No labels