@@ -763,15 +763,17 @@ inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
763763 *
764764 * \param a The source array.
765765 * \param indices The indices of the values to extract.
766+ * \param batch_dims The number of batch dimensions.
766767 * \param mode The mode of the operation.
767768 * \param name The name of the operation.
768769 * \param mode The mode of to handle out of bound indices.
769770 * \param tag The tag to mark the operation.
770771 *
771772 * \return A Tensor whose op member is the take operation
772773 */
773- inline Tensor take (const Tensor& a, const Tensor& indices, std::string mode = " clip" ,
774- std::string name = " T_take" , std::string tag = kInjective ) {
774+ inline Tensor take (const Tensor& a, const Tensor& indices, int batch_dims,
775+ std::string mode = " clip" , std::string name = " T_take" ,
776+ std::string tag = kInjective ) {
775777 Array<PrimExpr> a_shape = a->shape ;
776778 Array<PrimExpr> out_shape = indices->shape ;
777779 PrimExpr a_size = 1 ;
@@ -846,6 +848,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
846848 *
847849 * \param a The source array.
848850 * \param indices The indices of the values to extract.
851+ * \param batch_dims The number of batch dimensions. By default is 0.
849852 * \param axis The axis over which to select values. By default,
850853 * the flattened input array is used.
851854 * \param mode The mode for handling out of bound indices.
@@ -854,46 +857,99 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
854857 *
855858 * \return A Tensor whose op member is the take operation
856859 */
857- inline Tensor take (const Tensor& a, const Tensor& indices, int axis, std::string mode = " clip" ,
858- std::string name = " T_take" , std::string tag = kInjective ) {
860+ inline Tensor take (const Tensor& a, const Tensor& indices, int batch_dims, int axis,
861+ std::string mode = " clip" , std::string name = " T_take" ,
862+ std::string tag = kInjective ) {
859863 if (axis < 0 ) {
860864 axis += static_cast <int >(a->shape .size ());
861865 }
862866 ICHECK_GE (axis, 0 ) << " axis out of bounds" ;
863867 ICHECK_LT (axis, a->shape .size ()) << " axis out of bounds" ;
864868 auto axis_dim = a->shape [axis];
865-
866869 int indices_len = static_cast <int >(indices->shape .size ());
867- Array<PrimExpr> out_shape;
868- for ( size_t i = 0 ; i < a-> shape . size (); ++i) {
869- if (axis == static_cast < int >(i) ) {
870- for ( size_t j = 0 ; j < indices->shape .size (); ++j) {
871- out_shape. push_back ( indices->shape [j]) ;
872- }
873- } else {
874- out_shape. push_back (a ->shape [i]) ;
870+
871+ int batch_dims_ = batch_dims;
872+ if (batch_dims_ != 0 ) {
873+ ICHECK_GE (batch_dims_, - static_cast < int >( indices->shape .size ())) << " batch_dims out of bounds " ;
874+ ICHECK_LE (batch_dims_, indices->shape . size ()) << " batch_dims out of bounds " ;
875+
876+ if (batch_dims_ < 0 ) {
877+ batch_dims_ = indices ->shape . size () + batch_dims_ ;
875878 }
879+
880+ ICHECK_LT (batch_dims_, a->shape .size ()) << " batch_dims out of bounds" ;
881+ ICHECK_LE (batch_dims_, axis) << " batch_dims must be less than or equal to axis" ;
882+ for (int i = 0 ; i < batch_dims_; ++i) {
883+ auto addr1 = a->shape [i];
884+ auto addr2 = indices->shape [i];
885+ auto v1 = static_cast <IntImm*>(&addr1)->get ()->value ;
886+ auto v2 = static_cast <IntImm*>(&addr2)->get ()->value ;
887+ ICHECK_EQ (v1, v2) << " a.shape[" << i << " ] should be equal to indices.shape[" << i << " ]" ;
888+ }
889+ }
890+
891+ // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
892+ // a.shape[axis + 1:].
893+
894+ Array<PrimExpr> out_shape;
895+ for (int i = 0 ; i < batch_dims_; ++i) {
896+ out_shape.push_back (a->shape [i]);
897+ }
898+ for (int i = batch_dims_; i < axis; ++i) {
899+ out_shape.push_back (a->shape [i]);
900+ }
901+ for (size_t i = static_cast <size_t >(batch_dims_); i < indices->shape .size (); ++i) {
902+ out_shape.push_back (indices->shape [i]);
903+ }
904+ for (size_t i = axis + 1 ; i < a->shape .size (); ++i) {
905+ out_shape.push_back (a->shape [i]);
876906 }
907+
877908 if (mode == " clip" ) {
878- return compute (
879- out_shape,
880- [&](const Array<Var>& out_index) {
881- Array<PrimExpr> indices_position;
882- for (size_t j = axis; j < static_cast <size_t >(axis + indices_len); ++j) {
883- indices_position.push_back (out_index[j]);
884- }
885- Array<PrimExpr> real_indices;
886- for (size_t j = 0 ; j < static_cast <size_t >(axis); ++j) {
887- real_indices.push_back (out_index[j]);
888- }
889- auto idx = tvm::min (tvm::max (0 , indices (indices_position)), axis_dim - 1 );
890- real_indices.push_back (idx);
891- for (size_t j = axis + indices_len; j < out_index.size (); ++j) {
892- real_indices.push_back (out_index[j]);
893- }
894- return a (real_indices);
895- },
896- name, tag);
909+ if (batch_dims_ == 0 ) {
910+ return compute (
911+ out_shape,
912+ [&](const Array<Var>& out_index) {
913+ Array<PrimExpr> indices_position;
914+ for (size_t j = axis; j < static_cast <size_t >(axis + indices_len); ++j) {
915+ indices_position.push_back (out_index[j]);
916+ }
917+ Array<PrimExpr> real_indices;
918+ for (size_t j = 0 ; j < static_cast <size_t >(axis); ++j) {
919+ real_indices.push_back (out_index[j]);
920+ }
921+ auto idx = tvm::min (tvm::max (0 , indices (indices_position)), axis_dim - 1 );
922+ real_indices.push_back (idx);
923+ for (size_t j = axis + indices_len; j < out_index.size (); ++j) {
924+ real_indices.push_back (out_index[j]);
925+ }
926+ return a (real_indices);
927+ },
928+ name, tag);
929+ } else {
930+ return compute (
931+ out_shape,
932+ [&](const Array<Var>& out_index) {
933+ Array<PrimExpr> indices_position;
934+ for (size_t j = 0 ; j < static_cast <size_t >(batch_dims_); ++j) {
935+ indices_position.push_back (out_index[j]);
936+ }
937+ for (size_t j = axis; j < static_cast <size_t >(axis + indices_len - batch_dims_); ++j) {
938+ indices_position.push_back (out_index[j]);
939+ }
940+ Array<PrimExpr> real_indices;
941+ for (size_t j = 0 ; j < static_cast <size_t >(axis); ++j) {
942+ real_indices.push_back (out_index[j]);
943+ }
944+ auto idx = tvm::min (tvm::max (0 , indices (indices_position)), axis_dim - 1 );
945+ real_indices.push_back (idx);
946+ for (size_t j = axis + indices_len - batch_dims_; j < out_index.size (); ++j) {
947+ real_indices.push_back (out_index[j]);
948+ }
949+ return a (real_indices);
950+ },
951+ name, tag);
952+ }
897953 } else if (mode == " fast" ) {
898954 LOG (WARNING) << " Fast mode segfaults when there are out-of-bounds indices. "
899955 " Make sure input indices are in bound" ;
0 commit comments