11from typing import Dict , Optional , Sequence , Union
22
3+ import numpy as np
34import torch
45from torch .fx .node import Target
56from torch_tensorrt .dynamo ._SourceIR import SourceIR
6- from torch_tensorrt .fx .converters .converter_utils import set_layer_name
77from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
8- from torch_tensorrt .dynamo .conversion .converter_utils import get_trt_tensor
8+ from torch_tensorrt .dynamo .conversion .converter_utils import (
9+ SourceIR ,
10+ get_positive_dim ,
11+ get_trt_tensor ,
12+ )
13+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
914from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
1015
1116
@@ -14,16 +19,16 @@ def cat(
1419 target : Target ,
1520 source_ir : Optional [SourceIR ],
1621 name : str ,
17- input : Union [TRTTensor , Sequence [ TRTTensor ]],
22+ input : Sequence [ Union [TRTTensor , torch . Tensor , np . ndarray ]],
1823 dim : int ,
1924) -> Union [TRTTensor , Sequence [TRTTensor ]]:
25+ trt_inputs = []
2026 for each_input in input :
21- if (not isinstance (each_input , TRTTensor )):
22- each_input = get_trt_tensor (each_input )
23- concat_layer = ctx .net .add_concatenation (input )
24- if dim < 0 :
25- dim = len (input [0 ].shape ) + dim
26-
27+ if not isinstance (each_input , TRTTensor ):
28+ each_input = get_trt_tensor (ctx , each_input , name + "_tensor_{i}" )
29+ trt_inputs .append (each_input )
30+ concat_layer = ctx .net .add_concatenation (trt_inputs )
31+ dim = get_positive_dim (dim , len (input [0 ].shape ))
2732 concat_layer .axis = dim
2833 set_layer_name (concat_layer , target , name + "_gather" , source_ir )
2934 return concat_layer .get_output (0 )
0 commit comments