@@ -187,27 +187,20 @@ def aten_ops_fmod(
187
187
return acc_ops_converters .acc_ops_fmod (network , target , None , kwargs_new , name )
188
188
189
189
190
- @tensorrt_converter (torch .ops .aten .mm .default )
191
- @tensorrt_converter (torch .ops .aten .addmm .default )
190
+ @tensorrt_converter (torch .ops .aten .linear )
192
191
def aten_ops_linear (
193
192
network : TRTNetwork ,
194
193
target : Target ,
195
194
args : Tuple [Argument , ...],
196
195
kwargs : Dict [str , Argument ],
197
196
name : str ,
198
197
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
199
- if target == torch .ops .aten .addmm .default :
200
- kwargs_new = {
201
- "bias" : args [0 ],
202
- "input" : args [1 ],
203
- "weight" : args [2 ],
204
- }
205
- elif target == torch .ops .aten .mm .default :
206
- kwargs_new = {
207
- "bias" : None ,
208
- "input" : args [0 ],
209
- "weight" : args [1 ],
210
- }
198
+ kwargs_new = {
199
+ "input" : args [0 ],
200
+ "weight" : args [1 ],
201
+ "bias" : args [2 ],
202
+ }
203
+
211
204
return acc_ops_converters .acc_ops_linear (network , target , None , kwargs_new , name )
212
205
213
206
@@ -320,3 +313,35 @@ def aten_ops_reshape(
320
313
"acc_out_ty" : acc_utils .build_raw_tensor_meta (shape = args [1 ]),
321
314
}
322
315
return acc_ops_converters .acc_ops_reshape (network , target , None , kwargs_new , name )
316
+
317
+
318
+ @tensorrt_converter (torch .ops .aten .cat .default )
319
+ def aten_ops_cat (
320
+ network : TRTNetwork ,
321
+ target : Target ,
322
+ args : Tuple [Argument , ...],
323
+ kwargs : Dict [str , Argument ],
324
+ name : str ,
325
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
326
+ kwargs_new = {
327
+ "tensors" : args [0 ],
328
+ "dim" : args [1 ],
329
+ }
330
+ return acc_ops_converters .acc_ops_cat (network , target , None , kwargs_new , name )
331
+
332
+
333
+ @tensorrt_converter (torch .ops .aten .expand .default )
334
+ def aten_ops_expand (
335
+ network : TRTNetwork ,
336
+ target : Target ,
337
+ args : Tuple [Argument , ...],
338
+ kwargs : Dict [str , Argument ],
339
+ name : str ,
340
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
341
+ kwargs_new = {
342
+ "input" : args [0 ],
343
+ "sizes" : args [1 ],
344
+ }
345
+ return acc_ops_converters .acc_ops_expand_tensor (
346
+ network , target , None , kwargs_new , name
347
+ )
0 commit comments