Skip to content

Commit ae31a33

Browse files
author
Xingyu Zhou
authored
[Frontend][Tensorflow]add batch_dim support for gatherV2 (#7951)
* add batch_dim support * fix lint * add check for num of arguments for topi.take * fix gpu test cases * add check for batch_dims in take_grad
1 parent 3e35130 commit ae31a33

File tree

12 files changed

+182
-73
lines changed

12 files changed

+182
-73
lines changed

include/tvm/relay/attrs/transform.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
145145
};
146146

147147
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
148+
Integer batch_dims;
148149
Integer axis;
149150
std::string mode;
150151

151152
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
153+
TVM_ATTR_FIELD(batch_dims)
154+
.set_default(0)
155+
.describe("The batch_dims over which to select values.");
152156
TVM_ATTR_FIELD(axis)
153157
.set_default(NullValue<Integer>())
154158
.describe("The axis over which to select values.");

include/tvm/topi/transform.h

Lines changed: 88 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -763,15 +763,17 @@ inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
763763
*
764764
* \param a The source array.
765765
* \param indices The indices of the values to extract.
766+
* \param batch_dims The number of batch dimensions.
766767
* \param mode The mode of the operation.
767768
* \param name The name of the operation.
768769
* \param mode The mode of to handle out of bound indices.
769770
* \param tag The tag to mark the operation.
770771
*
771772
* \return A Tensor whose op member is the take operation
772773
*/
773-
inline Tensor take(const Tensor& a, const Tensor& indices, std::string mode = "clip",
774-
std::string name = "T_take", std::string tag = kInjective) {
774+
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
775+
std::string mode = "clip", std::string name = "T_take",
776+
std::string tag = kInjective) {
775777
Array<PrimExpr> a_shape = a->shape;
776778
Array<PrimExpr> out_shape = indices->shape;
777779
PrimExpr a_size = 1;
@@ -846,6 +848,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
846848
*
847849
* \param a The source array.
848850
* \param indices The indices of the values to extract.
851+
* \param batch_dims The number of batch dimensions. By default is 0.
849852
* \param axis The axis over which to select values. By default,
850853
* the flattened input array is used.
851854
* \param mode The mode for handling out of bound indices.
@@ -854,46 +857,99 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
854857
*
855858
* \return A Tensor whose op member is the take operation
856859
*/
857-
inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string mode = "clip",
858-
std::string name = "T_take", std::string tag = kInjective) {
860+
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis,
861+
std::string mode = "clip", std::string name = "T_take",
862+
std::string tag = kInjective) {
859863
if (axis < 0) {
860864
axis += static_cast<int>(a->shape.size());
861865
}
862866
ICHECK_GE(axis, 0) << "axis out of bounds";
863867
ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
864868
auto axis_dim = a->shape[axis];
865-
866869
int indices_len = static_cast<int>(indices->shape.size());
867-
Array<PrimExpr> out_shape;
868-
for (size_t i = 0; i < a->shape.size(); ++i) {
869-
if (axis == static_cast<int>(i)) {
870-
for (size_t j = 0; j < indices->shape.size(); ++j) {
871-
out_shape.push_back(indices->shape[j]);
872-
}
873-
} else {
874-
out_shape.push_back(a->shape[i]);
870+
871+
int batch_dims_ = batch_dims;
872+
if (batch_dims_ != 0) {
873+
ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds";
874+
ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds";
875+
876+
if (batch_dims_ < 0) {
877+
batch_dims_ = indices->shape.size() + batch_dims_;
875878
}
879+
880+
ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
881+
ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
882+
for (int i = 0; i < batch_dims_; ++i) {
883+
auto addr1 = a->shape[i];
884+
auto addr2 = indices->shape[i];
885+
auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
886+
auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
887+
ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
888+
}
889+
}
890+
891+
// The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
892+
// a.shape[axis + 1:].
893+
894+
Array<PrimExpr> out_shape;
895+
for (int i = 0; i < batch_dims_; ++i) {
896+
out_shape.push_back(a->shape[i]);
897+
}
898+
for (int i = batch_dims_; i < axis; ++i) {
899+
out_shape.push_back(a->shape[i]);
900+
}
901+
for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
902+
out_shape.push_back(indices->shape[i]);
903+
}
904+
for (size_t i = axis + 1; i < a->shape.size(); ++i) {
905+
out_shape.push_back(a->shape[i]);
876906
}
907+
877908
if (mode == "clip") {
878-
return compute(
879-
out_shape,
880-
[&](const Array<Var>& out_index) {
881-
Array<PrimExpr> indices_position;
882-
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
883-
indices_position.push_back(out_index[j]);
884-
}
885-
Array<PrimExpr> real_indices;
886-
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
887-
real_indices.push_back(out_index[j]);
888-
}
889-
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
890-
real_indices.push_back(idx);
891-
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
892-
real_indices.push_back(out_index[j]);
893-
}
894-
return a(real_indices);
895-
},
896-
name, tag);
909+
if (batch_dims_ == 0) {
910+
return compute(
911+
out_shape,
912+
[&](const Array<Var>& out_index) {
913+
Array<PrimExpr> indices_position;
914+
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
915+
indices_position.push_back(out_index[j]);
916+
}
917+
Array<PrimExpr> real_indices;
918+
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
919+
real_indices.push_back(out_index[j]);
920+
}
921+
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
922+
real_indices.push_back(idx);
923+
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
924+
real_indices.push_back(out_index[j]);
925+
}
926+
return a(real_indices);
927+
},
928+
name, tag);
929+
} else {
930+
return compute(
931+
out_shape,
932+
[&](const Array<Var>& out_index) {
933+
Array<PrimExpr> indices_position;
934+
for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
935+
indices_position.push_back(out_index[j]);
936+
}
937+
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
938+
indices_position.push_back(out_index[j]);
939+
}
940+
Array<PrimExpr> real_indices;
941+
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
942+
real_indices.push_back(out_index[j]);
943+
}
944+
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
945+
real_indices.push_back(idx);
946+
for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
947+
real_indices.push_back(out_index[j]);
948+
}
949+
return a(real_indices);
950+
},
951+
name, tag);
952+
}
897953
} else if (mode == "fast") {
898954
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
899955
"Make sure input indices are in bound";

python/tvm/relay/frontend/mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ def _mx_take(inputs, attrs):
911911
if mode == "raise":
912912
raise tvm.error.OpAttributeUnimplemented("take with raise mode is not supported yet")
913913
axis = attrs.get_int("axis", 0)
914-
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
914+
return _op.take(inputs[0], inputs[1].astype("int32"), axis=axis, mode=mode)
915915

916916

917917
def _mx_gather_nd(inputs, attrs):

python/tvm/relay/frontend/tensorflow.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,14 +2002,19 @@ def _impl(inputs, attr, params, mod):
20022002
axis = _get_num_param(params, inputs.pop(2))
20032003
else:
20042004
axis = 0
2005+
batch_dims = 0
20052006
if int(attr.get("batch_dims", 0)) != 0:
2006-
raise tvm.error.OpAttributeUnImplemented("Attribute batch_dims is not supported")
2007+
batch_dims = int(attr.get("batch_dims", 0))
20072008
new_input = inputs[0:2]
2008-
return AttrCvt(
2009+
op_ = AttrCvt(
20092010
op_name="take",
2010-
extras={"axis": tvm.tir.const(axis, "int32")},
2011-
ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class", "batch_dims"],
2011+
extras={
2012+
"axis": tvm.tir.const(axis, "int32"),
2013+
"batch_dims": tvm.tir.const(batch_dims, "int32"),
2014+
},
2015+
ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class"],
20122016
)(new_input, attr)
2017+
return op_
20132018

20142019
return _impl
20152020

python/tvm/relay/op/_tensor_grad.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def make_scalar_tensor(v):
711711
# TODO(@altanh): we currently assume indices are in range
712712
data, indices = orig.args
713713
axis = orig.attrs.axis
714+
batch_dims = orig.attrs.batch_dims
714715
zero, one = map(make_scalar_tensor, [0, 1])
715716
data_grad = zeros_like(data)
716717
try:
@@ -726,6 +727,12 @@ def make_scalar_tensor(v):
726727
data_shape = (data_shape,)
727728
else:
728729
axis = int(axis)
730+
if batch_dims is None:
731+
batch_dims = 0
732+
else:
733+
batch_dims = int(batch_dims)
734+
if batch_dims != 0:
735+
raise OpError("take_grad only supports batch_dims equales to 0")
729736
strides = [1] * len(data_shape)
730737

731738
if len(indices.checked_type.shape) == 0:

python/tvm/relay/op/_transform.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def _take_no_axis_shape_func(indices_shape, out_ndim):
390390

391391

392392
@script
393-
def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
393+
def _take_with_axis_shape_func(data_shape, indices_shape, axis, batch_dims, out_ndim):
394394
out = output_tensor((out_ndim,), "int64")
395395
for i in const_range(axis):
396396
out[i] = data_shape[i]
@@ -399,10 +399,10 @@ def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
399399
for i in const_range(axis + 1, len(data_shape)):
400400
out[i - 1] = data_shape[i]
401401
else:
402-
for i in const_range(len(indices_shape)):
403-
out[axis + i] = indices_shape[i]
402+
for i in const_range(len(indices_shape) - batch_dims):
403+
out[axis + i] = indices_shape[i + batch_dims]
404404
for i in const_range(axis + 1, len(data_shape)):
405-
out[len(indices_shape) + i - 1] = data_shape[i]
405+
out[len(indices_shape) + i - 1 - batch_dims] = data_shape[i]
406406
return out
407407

408408

@@ -414,11 +414,16 @@ def take_shape_func(attrs, inputs, out_ndims):
414414
if attrs.axis is None:
415415
return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
416416
axis = get_const_int(attrs.axis)
417+
batch_dims = get_const_int(attrs.batch_dims)
417418
data_ndim = int(inputs[0].shape[0])
419+
if inputs[1].shape:
420+
indicies_ndim = int(inputs[1].shape[0])
418421
if axis < 0:
419422
axis += data_ndim
420423
assert 0 <= axis < data_ndim
421-
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
424+
if batch_dims < 0:
425+
batch_dims += indicies_ndim
426+
return [_take_with_axis_shape_func(*inputs, convert(axis), convert(batch_dims), out_ndims[0])]
422427

423428

424429
@_reg.register_legalize("take")

python/tvm/relay/op/transform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_e
388388
return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end)
389389

390390

391-
def take(data, indices, axis=None, mode="clip"):
391+
def take(data, indices, axis=None, batch_dims=0, mode="clip"):
392392
"""Take elements from an array along an axis.
393393
394394
Parameters
@@ -403,6 +403,9 @@ def take(data, indices, axis=None, mode="clip"):
403403
The axis over which to select values. By default,
404404
the flattened input array is used.
405405
406+
batch_dims : int
407+
The number of batch dimensions. By default is 0.
408+
406409
mode : str, optional
407410
Specifies how out-of-bound indices will behave [clip, wrap, fast].
408411
clip: clip to the range (default).
@@ -414,7 +417,7 @@ def take(data, indices, axis=None, mode="clip"):
414417
ret : relay.Expr
415418
The computed result.
416419
"""
417-
return _make.take(data, indices, axis, mode)
420+
return _make.take(data, indices, batch_dims, axis, mode)
418421

419422

420423
def full(fill_value, shape=(), dtype=""):

python/tvm/topi/transform.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def split(ary, indices_or_sections, axis=0):
396396
return cpp.split(ary, indices_or_sections, axis)
397397

398398

399-
def take(a, indices, axis=None, mode="clip"):
399+
def take(a, indices, axis=None, batch_dims=0, mode="clip"):
400400
"""Take elements from an array along an axis.
401401
402402
Parameters
@@ -411,6 +411,9 @@ def take(a, indices, axis=None, mode="clip"):
411411
The axis over which to select values. By default,
412412
the flattened input array is used.
413413
414+
batch_dims : int
415+
The number of batch dimensions. By default is 0.
416+
414417
mode : str, optional
415418
Specifies how out-of-bound indices will behave.
416419
clip - clip to the range (default)
@@ -422,8 +425,8 @@ def take(a, indices, axis=None, mode="clip"):
422425
ret : tvm.te.Tensor
423426
"""
424427
if axis is None:
425-
return cpp.take(a, indices, mode)
426-
return cpp.take(a, indices, int(axis), mode)
428+
return cpp.take(a, indices, int(batch_dims), mode)
429+
return cpp.take(a, indices, int(batch_dims), int(axis), mode)
427430

428431

429432
@tvm.target.generic_func

src/relay/op/make_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);
107107

108108
Expr MakeShapeOf(Expr data, DataType dtype);
109109

110-
Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);
110+
Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode);
111111

112112
} // namespace relay
113113
} // namespace tvm

src/relay/op/tensor/transform.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,15 +1204,24 @@ bool TakeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
12041204
const auto ndim_data = static_cast<int>(data->shape.size());
12051205
const auto ndim_indices = static_cast<int>(indices->shape.size());
12061206
int axis = static_cast<int>(param->axis->value);
1207+
int batch_dims = static_cast<int>(param->batch_dims->value);
12071208
if (axis < 0) axis += ndim_data;
1209+
if (batch_dims < 0) axis += ndim_indices;
12081210
ICHECK_LE(axis, ndim_data) << "axis should be with in data shape"
12091211
<< ", but got = " << axis;
1212+
ICHECK_LE(batch_dims, ndim_indices) << "batch_dims should be with in indices shape"
1213+
<< ", but got = " << batch_dims;
1214+
ICHECK_LE(batch_dims, axis) << "batch_dims should be less than or equal to axis"
1215+
<< ", but got = " << batch_dims;
12101216

1211-
oshape.reserve(ndim_data - 1 + ndim_indices);
1212-
for (int i = 0; i < axis; ++i) {
1217+
oshape.reserve(ndim_data - 1 + ndim_indices - batch_dims);
1218+
for (int i = 0; i < batch_dims; ++i) {
1219+
oshape.emplace_back(data->shape[i]);
1220+
}
1221+
for (int i = batch_dims; i < axis; ++i) {
12131222
oshape.emplace_back(data->shape[i]);
12141223
}
1215-
for (int i = 0; i < ndim_indices; ++i) {
1224+
for (int i = batch_dims; i < ndim_indices; ++i) {
12161225
oshape.emplace_back(indices->shape[i]);
12171226
}
12181227
for (int i = axis + 1; i < ndim_data; ++i) {
@@ -1228,14 +1237,16 @@ Array<te::Tensor> TakeCompute(const Attrs& attrs, const Array<te::Tensor>& input
12281237
const auto* param = attrs.as<TakeAttrs>();
12291238
ICHECK(param != nullptr);
12301239
if (!param->axis.defined()) {
1231-
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->mode)};
1240+
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->batch_dims, param->mode)};
12321241
} else {
1233-
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->axis, param->mode)};
1242+
return Array<te::Tensor>{
1243+
topi::take(inputs[0], inputs[1], param->batch_dims, param->axis, param->mode)};
12341244
}
12351245
}
12361246

1237-
Expr MakeTake(Expr data, Expr indices, Integer axis, String mode) {
1247+
Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode) {
12381248
auto attrs = make_object<TakeAttrs>();
1249+
attrs->batch_dims = std::move(batch_dims);
12391250
attrs->axis = std::move(axis);
12401251
attrs->mode = std::move(mode);
12411252
static const Op& op = Op::Get("take");

0 commit comments

Comments
 (0)