@@ -70,13 +70,13 @@ def aten_ops_div(
7070 kwargs_new ["input" ].dtype == trt .int8 or kwargs_new ["input" ].dtype == trt .int32
7171 ):
7272 kwargs_new ["input" ] = cast_trt_tensor (
73- network , kwargs_new ["input" ], trt .float32 , name
73+ network , kwargs_new ["input" ], trt .float32 , name , target
7474 )
7575 elif isinstance (args [1 ], TRTTensor ) and (
7676 kwargs_new ["other" ].dtype == trt .int8 or kwargs_new ["other" ].dtype == trt .int32
7777 ):
7878 kwargs_new ["other" ] = cast_trt_tensor (
79- network , kwargs_new ["other" ], trt .float32 , name
79+ network , kwargs_new ["other" ], trt .float32 , name , target
8080 )
8181 rounding_mode = kwargs .get ("rounding_mode" )
8282 if rounding_mode is None :
@@ -377,3 +377,77 @@ def aten_ops_permute(
377377 args [0 ],
378378 args [1 ],
379379 )
380+
381+
382+ def to_copy_dtype_validator (to_copy_node : Node ):
383+ allowed_casts = {torch .float , torch .int32 , torch .bool , torch .int8 , torch .float16 }
384+
385+ # Validate input node has convertible kwargs
386+ if "dtype" in to_copy_node .kwargs :
387+ if to_copy_node .kwargs ["dtype" ] in allowed_casts :
388+ return True
389+ else :
390+ _LOGGER .debug (
391+ f"_to_copy converter rejected node { to_copy_node } with dtype { to_copy_node .kwargs ['dtype' ]} "
392+ )
393+ return False
394+ else :
395+ _LOGGER .debug (
396+ f"_to_copy converter rejected node { to_copy_node } with kwargs { to_copy_node .kwargs } "
397+ )
398+ return False
399+
400+
401+ @dynamo_tensorrt_converter (
402+ torch .ops .aten ._to_copy .default , capability_validator = to_copy_dtype_validator
403+ )
404+ def aten_ops_to_copy_dtype (
405+ network : TRTNetwork ,
406+ target : Target ,
407+ args : Tuple [Argument , ...],
408+ kwargs : Dict [str , Argument ],
409+ name : str ,
410+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
411+ return impl .cast .to_copy (
412+ network ,
413+ target ,
414+ SourceIR .ATEN ,
415+ name ,
416+ args [0 ],
417+ kwargs ["dtype" ],
418+ )
419+
420+
421+ @dynamo_tensorrt_converter (operator .getitem )
422+ def operator_getitem (
423+ network : TRTNetwork ,
424+ target : Target ,
425+ args : Tuple [Argument , ...],
426+ kwargs : Dict [str , Argument ],
427+ name : str ,
428+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
429+ return impl .evaluators .getitem (
430+ network ,
431+ target ,
432+ SourceIR .ATEN ,
433+ name ,
434+ args [0 ],
435+ args [1 ],
436+ )
437+
438+
439+ @dynamo_tensorrt_converter (torch .ops .aten .clone .default )
440+ def aten_ops_clone (
441+ network : TRTNetwork ,
442+ target : Target ,
443+ args : Tuple [Argument , ...],
444+ kwargs : Dict [str , Argument ],
445+ name : str ,
446+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
447+ return impl .evaluators .clone (
448+ network ,
449+ target ,
450+ SourceIR .ATEN ,
451+ name ,
452+ args [0 ],
453+ )
0 commit comments