| 
2 | 2 | 
 
  | 
3 | 3 | import logging  | 
4 | 4 | from dataclasses import fields, replace  | 
 | 5 | +from enum import Enum  | 
5 | 6 | from typing import Any, Callable, Dict, Optional, Sequence, Union  | 
6 | 7 | 
 
  | 
 | 8 | +import numpy as np  | 
 | 9 | +import tensorrt as trt  | 
7 | 10 | import torch  | 
8 | 11 | from torch_tensorrt._Device import Device  | 
9 | 12 | from torch_tensorrt._enums import dtype  | 
 | 
13 | 16 | 
 
  | 
14 | 17 | from packaging import version  | 
15 | 18 | 
 
  | 
 | 19 | +from .types import TRTDataType  | 
 | 20 | + | 
16 | 21 | logger = logging.getLogger(__name__)  | 
17 | 22 | 
 
  | 
18 | 23 | COSINE_THRESHOLD = 0.99  | 
19 | 24 | DYNAMIC_DIM = -1  | 
20 | 25 | 
 
  | 
21 | 26 | 
 
  | 
 | 27 | +class Frameworks(Enum):  | 
 | 28 | +    NUMPY = "numpy"  | 
 | 29 | +    TORCH = "torch"  | 
 | 30 | +    TRT = "trt"  | 
 | 31 | + | 
 | 32 | + | 
 | 33 | +DataTypeEquivalence: Dict[  | 
 | 34 | +    TRTDataType, Dict[Frameworks, Union[TRTDataType, np.dtype, torch.dtype]]  | 
 | 35 | +] = {  | 
 | 36 | +    trt.int8: {  | 
 | 37 | +        Frameworks.NUMPY: np.int8,  | 
 | 38 | +        Frameworks.TORCH: torch.int8,  | 
 | 39 | +        Frameworks.TRT: trt.int8,  | 
 | 40 | +    },  | 
 | 41 | +    trt.int32: {  | 
 | 42 | +        Frameworks.NUMPY: np.int32,  | 
 | 43 | +        Frameworks.TORCH: torch.int32,  | 
 | 44 | +        Frameworks.TRT: trt.int32,  | 
 | 45 | +    },  | 
 | 46 | +    trt.int64: {  | 
 | 47 | +        Frameworks.NUMPY: np.int64,  | 
 | 48 | +        Frameworks.TORCH: torch.int64,  | 
 | 49 | +        Frameworks.TRT: trt.int64,  | 
 | 50 | +    },  | 
 | 51 | +    trt.float16: {  | 
 | 52 | +        Frameworks.NUMPY: np.float16,  | 
 | 53 | +        Frameworks.TORCH: torch.float16,  | 
 | 54 | +        Frameworks.TRT: trt.float16,  | 
 | 55 | +    },  | 
 | 56 | +    trt.float32: {  | 
 | 57 | +        Frameworks.NUMPY: np.float32,  | 
 | 58 | +        Frameworks.TORCH: torch.float32,  | 
 | 59 | +        Frameworks.TRT: trt.float32,  | 
 | 60 | +    },  | 
 | 61 | +    trt.bool: {  | 
 | 62 | +        Frameworks.NUMPY: bool,  | 
 | 63 | +        Frameworks.TORCH: torch.bool,  | 
 | 64 | +        Frameworks.TRT: trt.bool,  | 
 | 65 | +    },  | 
 | 66 | +}  | 
 | 67 | + | 
 | 68 | +if trt.__version__ >= "7.0":  | 
 | 69 | +    DataTypeEquivalence[trt.bool] = {  | 
 | 70 | +        Frameworks.NUMPY: np.bool_,  | 
 | 71 | +        Frameworks.TORCH: torch.bool,  | 
 | 72 | +        Frameworks.TRT: trt.bool,  | 
 | 73 | +    }  | 
 | 74 | + | 
 | 75 | + | 
22 | 76 | def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:  | 
23 | 77 |     """Parses a user-provided input argument regarding Python runtime  | 
24 | 78 | 
  | 
@@ -317,3 +371,34 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any:  | 
317 | 371 |         return function_wrapper  | 
318 | 372 | 
 
  | 
319 | 373 |     return nested_decorator  | 
 | 374 | + | 
 | 375 | + | 
 | 376 | +def unified_dtype_converter(  | 
 | 377 | +    dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks  | 
 | 378 | +) -> Union[np.dtype, torch.dtype, TRTDataType]:  | 
 | 379 | +    """  | 
 | 380 | +    Convert TensorRT, Numpy, or Torch data types to any other of those data types.  | 
 | 381 | +
  | 
 | 382 | +    Args:  | 
 | 383 | +        dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.  | 
 | 384 | +        to (Frameworks): The framework to convert the data type to.  | 
 | 385 | +
  | 
 | 386 | +    Returns:  | 
 | 387 | +        The equivalent data type in the requested framework.  | 
 | 388 | +    """  | 
 | 389 | +    assert to in Frameworks, f"Expected valid Framework for translation, got {to}"  | 
 | 390 | +    trt_major_version = int(trt.__version__.split(".")[0])  | 
 | 391 | +    if dtype in (np.int8, torch.int8, trt.int8):  | 
 | 392 | +        return DataTypeEquivalence[trt.int8][to]  | 
 | 393 | +    elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool):  | 
 | 394 | +        return DataTypeEquivalence[trt.bool][to]  | 
 | 395 | +    elif dtype in (np.int32, torch.int32, trt.int32):  | 
 | 396 | +        return DataTypeEquivalence[trt.int32][to]  | 
 | 397 | +    elif dtype in (np.int64, torch.int64, trt.int64):  | 
 | 398 | +        return DataTypeEquivalence[trt.int64][to]  | 
 | 399 | +    elif dtype in (np.float16, torch.float16, trt.float16):  | 
 | 400 | +        return DataTypeEquivalence[trt.float16][to]  | 
 | 401 | +    elif dtype in (np.float32, torch.float32, trt.float32):  | 
 | 402 | +        return DataTypeEquivalence[trt.float32][to]  | 
 | 403 | +    else:  | 
 | 404 | +        raise TypeError("%s is not a supported dtype" % dtype)  | 
0 commit comments