Skip to content

Commit 6bcadc5

Browse files
committed
chore: remove dependency on fx util, move to dynamo
1 parent 7774738 commit 6bcadc5

File tree

2 files changed

+86
-4
lines changed

2 files changed

+86
-4
lines changed

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111
cast_trt_tensor,
1212
flatten_dims,
1313
get_trt_tensor,
14-
)
15-
from torch_tensorrt.fx.converters.converter_utils import (
16-
Frameworks,
1714
set_layer_name,
18-
unified_dtype_converter,
1915
)
16+
from torch_tensorrt.dynamo.utils import Frameworks, unified_dtype_converter
2017
from torch_tensorrt.fx.types import TRTTensor
2118

2219

py/torch_tensorrt/dynamo/utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import logging
44
from dataclasses import fields, replace
5+
from enum import Enum
56
from typing import Any, Callable, Dict, Optional, Sequence, Union
67

8+
import numpy as np
9+
import tensorrt as trt
710
import torch
811
from torch_tensorrt._Device import Device
912
from torch_tensorrt._enums import dtype
@@ -13,12 +16,63 @@
1316

1417
from packaging import version
1518

19+
from .types import TRTDataType
20+
1621
logger = logging.getLogger(__name__)
1722

1823
COSINE_THRESHOLD = 0.99
1924
DYNAMIC_DIM = -1
2025

2126

27+
class Frameworks(Enum):
28+
NUMPY = "numpy"
29+
TORCH = "torch"
30+
TRT = "trt"
31+
32+
33+
DataTypeEquivalence: Dict[
34+
TRTDataType, Dict[Frameworks, Union[TRTDataType, np.dtype, torch.dtype]]
35+
] = {
36+
trt.int8: {
37+
Frameworks.NUMPY: np.int8,
38+
Frameworks.TORCH: torch.int8,
39+
Frameworks.TRT: trt.int8,
40+
},
41+
trt.int32: {
42+
Frameworks.NUMPY: np.int32,
43+
Frameworks.TORCH: torch.int32,
44+
Frameworks.TRT: trt.int32,
45+
},
46+
trt.int64: {
47+
Frameworks.NUMPY: np.int64,
48+
Frameworks.TORCH: torch.int64,
49+
Frameworks.TRT: trt.int64,
50+
},
51+
trt.float16: {
52+
Frameworks.NUMPY: np.float16,
53+
Frameworks.TORCH: torch.float16,
54+
Frameworks.TRT: trt.float16,
55+
},
56+
trt.float32: {
57+
Frameworks.NUMPY: np.float32,
58+
Frameworks.TORCH: torch.float32,
59+
Frameworks.TRT: trt.float32,
60+
},
61+
trt.bool: {
62+
Frameworks.NUMPY: bool,
63+
Frameworks.TORCH: torch.bool,
64+
Frameworks.TRT: trt.bool,
65+
},
66+
}
67+
68+
if trt.__version__ >= "7.0":
69+
DataTypeEquivalence[trt.bool] = {
70+
Frameworks.NUMPY: np.bool_,
71+
Frameworks.TORCH: torch.bool,
72+
Frameworks.TRT: trt.bool,
73+
}
74+
75+
2276
def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:
2377
"""Parses a user-provided input argument regarding Python runtime
2478
@@ -317,3 +371,34 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any:
317371
return function_wrapper
318372

319373
return nested_decorator
374+
375+
376+
def unified_dtype_converter(
377+
dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks
378+
) -> Union[np.dtype, torch.dtype, TRTDataType]:
379+
"""
380+
Convert TensorRT, Numpy, or Torch data types to any other of those data types.
381+
382+
Args:
383+
dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
384+
to (Frameworks): The framework to convert the data type to.
385+
386+
Returns:
387+
The equivalent data type in the requested framework.
388+
"""
389+
assert to in Frameworks, f"Expected valid Framework for translation, got {to}"
390+
trt_major_version = int(trt.__version__.split(".")[0])
391+
if dtype in (np.int8, torch.int8, trt.int8):
392+
return DataTypeEquivalence[trt.int8][to]
393+
elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool):
394+
return DataTypeEquivalence[trt.bool][to]
395+
elif dtype in (np.int32, torch.int32, trt.int32):
396+
return DataTypeEquivalence[trt.int32][to]
397+
elif dtype in (np.int64, torch.int64, trt.int64):
398+
return DataTypeEquivalence[trt.int64][to]
399+
elif dtype in (np.float16, torch.float16, trt.float16):
400+
return DataTypeEquivalence[trt.float16][to]
401+
elif dtype in (np.float32, torch.float32, trt.float32):
402+
return DataTypeEquivalence[trt.float32][to]
403+
else:
404+
raise TypeError("%s is not a supported dtype" % dtype)

0 commit comments

Comments
 (0)