Skip to content

Commit 1fac10b

Browse files
authored
Fix GatherND attribute registration (#8269)
1 parent ec6a817 commit 1fac10b

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

include/tvm/relay/attrs/transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
148148
Integer batch_dims;
149149
Optional<Integer> index_rank;
150150

151-
TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
151+
TVM_DECLARE_ATTRS(GatherNDAttrs, "relay.attrs.GatherNDAttrs") {
152152
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
153153
TVM_ATTR_FIELD(index_rank)
154154
.set_default(NullValue<Integer>())

src/relay/op/tensor/transform.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3290,6 +3290,8 @@ which must just be not null. Output will have same shape as ``indices``.
32903290
.set_attr<FTVMCompute>("FTVMCompute", GatherCompute)
32913291
.set_attr<TOpPattern>("TOpPattern", kInjective);
32923292

3293+
TVM_REGISTER_NODE_TYPE(GatherNDAttrs);
3294+
32933295
// gather_nd operator
32943296
bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
32953297
const TypeReporter& reporter) {
@@ -3367,6 +3369,7 @@ When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M
33673369
In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}).
33683370
)code" TVM_ADD_FILELINE)
33693371
.set_num_inputs(2)
3372+
.set_attrs_type<GatherNDAttrs>()
33703373
.add_argument("data", "Tensor", "The input tensor.")
33713374
.add_argument("indices", "Tensor", "The indices of values to gather.")
33723375
.set_support_level(3)

0 commit comments

Comments
 (0)