1818from torch .fx .immutable_collections import immutable_list
1919from torch .fx .node import Argument , Target
2020
21- from ..utils import get_dynamic_dims , torch_dtype_from_trt , torch_dtype_to_trt
21+ from ..utils import get_dynamic_dims , unified_dtype_converter , Frameworks
2222
2323from .converter_utils import * # noqa: F403
2424from torch_tensorrt .fx .passes .lower_basic_pass import (
@@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer(
400400 )
401401
402402 # cast value to TRTensor
403- dt = torch_dtype_from_trt (input_val .dtype )
403+ dt = unified_dtype_converter (input_val .dtype , Frameworks . TORCH )
404404 value = 0 if value == None else value
405405 value_const = get_trt_tensor (
406406 network , torch .tensor ([value ], dtype = dt ), f"{ name } _value"
@@ -1550,7 +1550,7 @@ def acc_ops_to_dtype(
15501550 input_t = get_trt_tensor (network , input_val , f"{ name } _input_t" )
15511551 if input_dtype :
15521552 if isinstance (input_dtype , torch .dtype ):
1553- input_dtype = torch_dtype_to_trt (input_dtype )
1553+ input_dtype = unified_dtype_converter (input_dtype , Frameworks . TRT )
15541554 input_t = type_cast (network , target , f"{ name } _input" , input_t , input_dtype )
15551555 return input_t
15561556
@@ -1811,7 +1811,7 @@ def acc_ops_logical_xor(
18111811# f"isinf received input {input_t} that is not part "
18121812# "of the TensorRT region!"
18131813# )
1814- # tdtype = torch_dtype_from_trt (input_t.dtype)
1814+ # tdtype = unified_dtype_converter (input_t.dtype, Frameworks.TORCH )
18151815
18161816# inf_t = torch.ones(tuple(input_t.shape))
18171817# inf_t = inf_t * float("inf")
@@ -1849,7 +1849,7 @@ def acc_ops_any(
18491849
18501850 if input_t .dtype in (trt .float32 , trt .float16 , trt .int32 ):
18511851 comp_t = torch .zeros (tuple ([* input_t .shape ])).to (
1852- torch_dtype_from_trt (input_t .dtype )
1852+ unified_dtype_converter (input_t .dtype , Frameworks . TORCH )
18531853 )
18541854 comp_t = get_trt_tensor (network , comp_t , f"{ name } _comp_t" )
18551855 kwargs_new = {"input" : input_t , "other" : comp_t }
@@ -2738,7 +2738,7 @@ def acc_ops_masked_fill_tensor(
27382738 if type (value_t ) is torch .Tensor :
27392739 value_t = value_t .cpu ().numpy ()
27402740 # cast to input type
2741- input_dtype = torch_dtype_from_trt (input_t .dtype )
2741+ input_dtype = unified_dtype_converter (input_t .dtype , Frameworks . TORCH )
27422742 value_t = (torch .ones (shape ) * value_t ).to (input_dtype )
27432743 input_val = get_trt_tensor (network , input_t , f"{ name } _input" )
27442744 value_val = get_trt_tensor (network , value_t , f"{ name } _input" )
@@ -2872,7 +2872,11 @@ def add_clamp(network, input, val, op, name):
28722872 # clamping scalar
28732873 acc_ops_clamp_trt = get_trt_tensor (
28742874 network ,
2875- squeeze_left (torch .tensor ([val ], dtype = torch_dtype_from_trt (input .dtype ))),
2875+ squeeze_left (
2876+ torch .tensor (
2877+ [val ], dtype = unified_dtype_converter (input .dtype , Frameworks .TORCH )
2878+ )
2879+ ),
28762880 f"{ name } _clamp_{ val } " ,
28772881 )
28782882 else :
@@ -2881,7 +2885,8 @@ def add_clamp(network, input, val, op, name):
28812885 (
28822886 val
28832887 * torch .ones (
2884- acc_ops_clamp_shape , dtype = torch_dtype_from_trt (input .dtype )
2888+ acc_ops_clamp_shape ,
2889+ dtype = unified_dtype_converter (input .dtype , Frameworks .TORCH ),
28852890 )
28862891 )
28872892 .cpu ()
@@ -3527,7 +3532,9 @@ def acc_ops_cumsum(
35273532 iterator = loop .add_iterator (input_val , dim , False )
35283533 data = iterator .get_output (0 )
35293534 new_dims = tuple (data .shape )
3530- zero_tensor = torch .zeros (new_dims , dtype = trt_dtype_to_torch_dtype (input_val .dtype ))
3535+ zero_tensor = torch .zeros (
3536+ new_dims , dtype = unified_dtype_converter (input_val .dtype , Frameworks .TORCH )
3537+ )
35313538 zero_tensor = network .add_constant (
35323539 zero_tensor .shape , to_numpy (zero_tensor )
35333540 ).get_output (0 )
@@ -3670,7 +3677,7 @@ def acc_ops_new_ones(
36703677 dtype_val = kwargs .get ("dtype" )
36713678 if dtype_val is None :
36723679 dtype_val = input_val .dtype
3673- dtype_val = torch_dtype_from_trt (dtype_val )
3680+ dtype_val = unified_dtype_converter (dtype_val , Frameworks . TORCH )
36743681
36753682 device_val = kwargs .get ("device" )
36763683 assert (
@@ -3694,7 +3701,7 @@ def acc_ops_new_empty(
36943701 dtype_val = kwargs .get ("dtype" )
36953702 if dtype_val is None :
36963703 dtype_val = input_val .dtype
3697- dtype_val = torch_dtype_from_trt (dtype_val )
3704+ dtype_val = unified_dtype_converter (dtype_val , Frameworks . TORCH )
36983705
36993706 device_val = kwargs .get ("device" )
37003707 assert (
0 commit comments