@@ -937,6 +937,175 @@ RELAY_REGISTER_OP("arange")
937937.set_attr<FTVMCompute>(" FTVMCompute" , ArangeCompute)
938938.set_attr<TOpPattern>(" TOpPattern" , kInjective );
939939
940+ // repeat operator
941+ TVM_REGISTER_NODE_TYPE (RepeatAttrs);
942+
943+ bool RepeatRel (const Array<Type>& types,
944+ int num_inputs,
945+ const Attrs& attrs,
946+ const TypeReporter& reporter) {
947+ // `types` contains: [data, result]
948+ CHECK_EQ (types.size (), 2 );
949+ const auto * data = types[0 ].as <TensorTypeNode>();
950+ if (data == nullptr ) {
951+ CHECK (types[0 ].as <IncompleteTypeNode>())
952+ << " repeat: expect input type to be TensorType but get "
953+ << types[0 ];
954+ return false ;
955+ }
956+ const auto * param = attrs.as <RepeatAttrs>();
957+ const int ndim = static_cast <int >(data->shape .size ());
958+ const int repeats = param->repeats ;
959+ const int axis = param->axis ;
960+ CHECK (repeats >= 1 )
961+ << " repeat only accepts `repeats >= 1`"
962+ << " , but got repeats = " << repeats;
963+ CHECK (-ndim - 1 <= axis && axis <= ndim)
964+ << " repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
965+ << " , but got axis = " << axis
966+ << " , and data.ndim = " << ndim;
967+ const int pivot = axis < 0 ? ndim + axis : axis;
968+ std::vector<IndexExpr> oshape;
969+ oshape.reserve (ndim + repeats);
970+ for (int i = 0 ; i < pivot; ++i) {
971+ oshape.emplace_back (data->shape [i]);
972+ }
973+ oshape.emplace_back (data->shape [pivot] * repeats);
974+ for (int i = pivot + 1 ; i < ndim; ++i) {
975+ oshape.emplace_back (data->shape [i]);
976+ }
977+ reporter->Assign (types[1 ], TensorTypeNode::make (oshape, data->dtype ));
978+ return true ;
979+ }
980+
981+ Array<Tensor> RepeatCompute (const Attrs& attrs,
982+ const Array<Tensor>& inputs,
983+ const Type& out_type,
984+ const Target& target) {
985+ const RepeatAttrs *param = attrs.as <RepeatAttrs>();
986+ CHECK (param != nullptr );
987+ return { topi::repeat (inputs[0 ], param->repeats , param->axis ) };
988+ }
989+
990+ Expr MakeRepeat (Expr data,
991+ int repeats,
992+ int axis) {
993+ auto attrs = make_node<RepeatAttrs>();
994+ attrs->repeats = repeats;
995+ attrs->axis = axis;
996+ static const Op& op = Op::Get (" repeat" );
997+ return CallNode::make (op, {data}, Attrs (attrs), {});
998+ }
999+
1000+ TVM_REGISTER_API (" relay.op._make.repeat" )
1001+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
1002+ runtime::detail::unpack_call<Expr, 3 >(MakeRepeat, args, rv);
1003+ });
1004+
1005+ RELAY_REGISTER_OP (" repeat" )
1006+ .describe(R"code( Repeat elements of an array `repeats` times along axis `axis`
1007+
1008+ - **data**: The input data to the operator.
1009+
1010+ )code" TVM_ADD_FILELINE)
1011+ .set_num_inputs(1 )
1012+ .set_attrs_type_key(" relay.attrs.Repeat" )
1013+ .add_argument(" data" , " Tensor" , " The input tensor." )
1014+ .set_support_level(1 )
1015+ .add_type_rel(" Repeat" , RepeatRel)
1016+ .set_attr<FTVMCompute>(" FTVMCompute" , RepeatCompute)
1017+ .set_attr<TOpPattern>(" TOpPattern" , kBroadcast );
1018+
1019+ // tile operator
1020+ TVM_REGISTER_NODE_TYPE (TileAttrs);
1021+
1022+ bool TileRel (const Array<Type>& types,
1023+ int num_inputs,
1024+ const Attrs& attrs,
1025+ const TypeReporter& reporter) {
1026+ // `types` contains: [data, result]
1027+ CHECK_EQ (types.size (), 2 );
1028+ const auto * data = types[0 ].as <TensorTypeNode>();
1029+ if (data == nullptr ) {
1030+ CHECK (types[0 ].as <IncompleteTypeNode>())
1031+ << " tile: expect input type to be TensorType but get "
1032+ << types[0 ];
1033+ return false ;
1034+ }
1035+ const auto * param = attrs.as <TileAttrs>();
1036+ const int ndim = static_cast <int >(data->shape .size ());
1037+ const Array<Integer>& reps = param->reps ;
1038+ // check dimension match
1039+ CHECK (!reps.defined ())
1040+ << " repetition array is not defined. data.ndim = " << ndim;
1041+ const int rndim = static_cast <int >(reps.size ());
1042+ int tndim = (ndim > rndim) ? ndim : rndim;
1043+ // re-construct data shape or reps shape
1044+ std::vector<IndexExpr> data_shape;
1045+ std::vector<IndexExpr> reps_shape;
1046+ data_shape.reserve (tndim);
1047+ reps_shape.reserve (tndim);
1048+ if (ndim == rndim) {
1049+ for (int i = 0 ; i < tndim; ++i) {
1050+ data_shape.emplace_back (data->shape [i]);
1051+ reps_shape.emplace_back (reps[i]);
1052+ }
1053+ } else if (ndim > rndim) {
1054+ for (int i = 0 ; i < ndim; ++i)
1055+ data_shape.emplace_back (data->shape [i]);
1056+ for (int i = 0 ; i < (ndim - rndim); ++i)
1057+ reps_shape.emplace_back (1 );
1058+ for (int i = 0 ; i < rndim; ++i)
1059+ reps_shape.emplace_back (reps[i]);
1060+ } else {
1061+ for (int i = 0 ; i < rndim; ++i)
1062+ reps_shape.emplace_back (reps[i]);
1063+ }
1064+ std::vector<IndexExpr> oshape;
1065+ oshape.reserve (tndim);
1066+ for (int i = 0 ; i < tndim; ++i) {
1067+ oshape.emplace_back (data_shape[i] * reps_shape[i]);
1068+ }
1069+ reporter->Assign (types[1 ], TensorTypeNode::make (oshape, data->dtype ));
1070+ return true ;
1071+ }
1072+
1073+ Array<Tensor> TileCompute (const Attrs& attrs,
1074+ const Array<Tensor>& inputs,
1075+ const Type& out_type,
1076+ const Target& target) {
1077+ const TileAttrs *param = attrs.as <TileAttrs>();
1078+ CHECK (param != nullptr );
1079+ return { topi::tile (inputs[0 ], param->reps ) };
1080+ }
1081+
1082+ Expr MakeTile (Expr data,
1083+ Array<Integer> reps) {
1084+ auto attrs = make_node<TileAttrs>();
1085+ attrs->reps = reps;
1086+ static const Op& op = Op::Get (" tile" );
1087+ return CallNode::make (op, {data}, Attrs (attrs), {});
1088+ }
1089+
1090+ TVM_REGISTER_API (" relay.op._make.tile" )
1091+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
1092+ runtime::detail::unpack_call<Expr, 2 >(MakeTile, args, rv);
1093+ });
1094+
1095+ RELAY_REGISTER_OP (" tile" )
1096+ .describe(R"code( Repeat the whole array multiple times.
1097+
1098+ - **data**: The input data to the operator.
1099+
1100+ )code" TVM_ADD_FILELINE)
1101+ .set_num_inputs(1 )
1102+ .set_attrs_type_key(" relay.attrs.Tile" )
1103+ .add_argument(" data" , " Tensor" , " The input tensor." )
1104+ .set_support_level(1 )
1105+ .add_type_rel(" Tile" , TileRel)
1106+ .set_attr<FTVMCompute>(" FTVMCompute" , TileCompute)
1107+ .set_attr<TOpPattern>(" TOpPattern" , kBroadcast );
1108+
9401109// where operator
9411110bool WhereRel (const Array<Type>& types,
9421111 int num_inputs,
0 commit comments