@@ -194,60 +194,6 @@ bool add_expand_dynamic(
194194 return true ;
195195}
196196
197- bool add_repeat (ConversionCtx* ctx, const torch::jit::Node* n, args& args, const std::string& layer) {
198- auto in = args[0 ].ITensorOrFreeze (ctx);
199- auto input_dims = in->getDimensions ();
200- auto repeats = args[1 ].unwrapToIntList ().vec ();
201- int repeats_rank = repeats.size ();
202- TORCHTRT_CHECK (
203- repeats_rank >= input_dims.nbDims ,
204- " Number of repeat dimensions cannot be smaller than number of input dimensions" );
205-
206- auto num_expand_dims = repeats_rank - input_dims.nbDims ;
207-
208- if (ctx->input_is_dynamic ) {
209- int input_rank = input_dims.nbDims ;
210- int output_rank = repeats_rank;
211- auto new_input_shape_tensor = concat (output_rank, input_rank, ctx, in);
212-
213- auto shuffle = ctx->net ->addShuffle (*in);
214- shuffle->setInput (1 , *new_input_shape_tensor);
215- in = shuffle->getOutput (0 );
216- } else {
217- if (num_expand_dims > 0 ) {
218- nvinfer1::Dims reshape_dims;
219- reshape_dims.nbDims = repeats.size ();
220- for (int i = 0 ; i < num_expand_dims; i++) {
221- reshape_dims.d [i] = 1 ;
222- }
223- for (int i = 0 ; i < input_dims.nbDims ; i++) {
224- reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
225- }
226- // Add a reshape layer to expand dims
227- auto reshape_layer = ctx->net ->addShuffle (*in);
228- reshape_layer->setReshapeDimensions (reshape_dims);
229- in = reshape_layer->getOutput (0 );
230- LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
231- }
232- LOG_DEBUG (" Repeats: " << repeats);
233- }
234-
235- // Concat across all repeat axes.
236- for (int i = repeats.size () - 1 ; i >= 0 ; --i) {
237- std::vector<nvinfer1::ITensor*> tensors_vec;
238- for (int j = 0 ; j < repeats[i]; j++) {
239- tensors_vec.push_back (in);
240- }
241- auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
242- concat_layer->setAxis (i);
243- in = concat_layer->getOutput (0 );
244- }
245-
246- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
247- LOG_DEBUG (layer << " layer output tensor shape: " << out->getDimensions ());
248- return true ;
249- }
250-
251197auto expand_registrations TORCHTRT_UNUSED =
252198 RegisterNodeConversionPatterns ()
253199 .pattern(
@@ -284,7 +230,59 @@ auto expand_registrations TORCHTRT_UNUSED =
284230 .pattern(
285231 {" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
286232 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
287- return add_repeat (ctx, n, args, " Repeat" );
233+ auto in = args[0 ].ITensorOrFreeze (ctx);
234+ auto input_dims = in->getDimensions ();
235+ auto repeats = args[1 ].unwrapToIntList ().vec ();
236+ int repeats_rank = repeats.size ();
237+ TORCHTRT_CHECK (
238+ repeats_rank >= input_dims.nbDims ,
239+ " Number of repeat dimensions cannot be smaller than number of input dimensions" );
240+ auto num_expand_dims = repeats_rank - input_dims.nbDims ;
241+
242+ if (ctx->input_is_dynamic ) {
243+ int input_rank = input_dims.nbDims ;
244+ int output_rank = repeats_rank;
245+ auto new_input_shape_tensor = concat (output_rank, input_rank, ctx, in);
246+
247+ // Add a reshape layer to expand dims
248+ auto shuffle = ctx->net ->addShuffle (*in);
249+ shuffle->setInput (1 , *new_input_shape_tensor);
250+ in = shuffle->getOutput (0 );
251+ } else {
252+ if (num_expand_dims > 0 ) {
253+ nvinfer1::Dims reshape_dims;
254+ reshape_dims.nbDims = repeats.size ();
255+ for (int i = 0 ; i < num_expand_dims; i++) {
256+ reshape_dims.d [i] = 1 ;
257+ }
258+ for (int i = 0 ; i < input_dims.nbDims ; i++) {
259+ reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
260+ }
261+ // Add a reshape layer to expand dims
262+ auto reshape_layer = ctx->net ->addShuffle (*in);
263+ reshape_layer->setReshapeDimensions (reshape_dims);
264+ in = reshape_layer->getOutput (0 );
265+ LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
266+ }
267+ LOG_DEBUG (" Repeats: " << repeats);
268+ }
269+
270+ // Concat across all repeat axes.
271+ // TODO: Implementation might not be performant. Explore other strategies to improve performance.
272+ for (int i = repeats.size () - 1 ; i >= 0 ; --i) {
273+ std::vector<nvinfer1::ITensor*> tensors_vec;
274+ for (int j = 0 ; j < repeats[i]; j++) {
275+ tensors_vec.push_back (in);
276+ }
277+ auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
278+ concat_layer->setAxis (i);
279+ in = concat_layer->getOutput (0 );
280+ }
281+
282+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
283+
284+ LOG_DEBUG (" Repeat layer output tensor shape: " << out->getDimensions ());
285+ return true ;
288286 }})
289287 .pattern(
290288 {" aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)" ,
@@ -397,11 +395,6 @@ auto expand_registrations TORCHTRT_UNUSED =
397395
398396 return true ;
399397 }})
400- .pattern(
401- {" aten::tile(Tensor self, int[] dims) -> (Tensor)" ,
402- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
403- return add_repeat (ctx, n, args, " Tile" );
404- }})
405398 .pattern(
406399 {" aten::meshgrid(Tensor[] tensors) -> (Tensor[])" ,
407400 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -491,4 +484,4 @@ auto expand_registrations TORCHTRT_UNUSED =
491484} // namespace converters
492485} // namespace conversion
493486} // namespace core
494- } // namespace torch_tensorrt
487+ } // namespace torch_tensorrt
0 commit comments