Skip to content

Commit aaab1f9

Browse files
committed
feat: Adding support for native int64
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4fe5feb commit aaab1f9

File tree

9 files changed

+256
-15
lines changed

9 files changed

+256
-15
lines changed

core/runtime/register_jit_hooks.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ TORCH_LIBRARY(tensorrt, m) {
122122
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
123123
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
124124
});
125+
m.def("set_logging_level", [](int64_t level) -> void {
126+
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
127+
});
128+
m.def(
129+
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
125130
}
126131

127132
} // namespace

core/util/trt_util.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use
164164
// Acceptable range for pos is [-d.nbDims - 1, d.nbDims]
165165
TORCHTRT_ASSERT(
166166
pos >= (-d.nbDims - 1) && pos <= d.nbDims,
167-
"ERROR: Index to unsqueeze is out of bounds. "
168-
<< "Expected value in range [" << (-d.nbDims - 1) << ", " << d.nbDims << "], but got " << pos);
167+
"ERROR: Index to unsqueeze is out of bounds. " << "Expected value in range [" << (-d.nbDims - 1) << ", "
168+
<< d.nbDims << "], but got " << pos);
169169

170170
// Unsqueeze with negative dimensions creates a new dimension at that index
171171
pos = (pos < 0) ? (pos + d.nbDims + 1) : pos;
@@ -292,7 +292,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
292292
{at::kFloat, nvinfer1::DataType::kFLOAT},
293293
{at::kHalf, nvinfer1::DataType::kHALF},
294294
{at::kInt, nvinfer1::DataType::kINT32},
295-
{at::kLong, nvinfer1::DataType::kINT32},
295+
{at::kLong, nvinfer1::DataType::kINT64},
296296
{at::kChar, nvinfer1::DataType::kINT8},
297297
{at::kByte, nvinfer1::DataType::kINT8},
298298
{at::kBool, nvinfer1::DataType::kBOOL}};
@@ -304,6 +304,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
304304
{nvinfer1::DataType::kFLOAT, at::kFloat},
305305
{nvinfer1::DataType::kHALF, at::kHalf},
306306
{nvinfer1::DataType::kINT32, at::kInt},
307+
{nvinfer1::DataType::kINT64, at::kLong},
307308
{nvinfer1::DataType::kINT8, at::kChar},
308309
{nvinfer1::DataType::kBOOL, at::kBool},
309310
};

core/util/trt_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
5353
return stream << "Int8";
5454
case nvinfer1::DataType::kINT32:
5555
return stream << "Int32";
56+
case nvinfer1::DataType::kINT64:
57+
return stream << "Int64";
5658
case nvinfer1::DataType::kBOOL:
5759
return stream << "Bool";
5860
default:

py/torch_tensorrt/_enums.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from typing import Any, Optional, Type, Union
66

77
import numpy as np
8-
import tensorrt as trt
98
import torch
109
from torch_tensorrt._features import ENABLED_FEATURES
1110

11+
import tensorrt as trt
12+
1213

1314
class dtype(Enum):
1415
"""Enum to set supported dtypes in the compiler"""
@@ -103,6 +104,8 @@ def _from(
103104
return dtype.i8
104105
elif t == trt.int32:
105106
return dtype.i32
107+
elif t == trt.int64:
108+
return dtype.i64
106109
elif t == trt.float16:
107110
return dtype.f16
108111
elif t == trt.float32:
@@ -227,6 +230,8 @@ def to(
227230
return trt.DataType.INT8
228231
elif self == dtype.i32:
229232
return trt.DataType.INT32
233+
elif self == dtype.i64:
234+
return trt.DataType.INT64
230235
elif self == dtype.f16:
231236
return trt.DataType.HALF
232237
elif self == dtype.f32:

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,20 @@
2121
from torch.fx.node import Argument, Node, Target, _get_qualified_name
2222
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
2323
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
24-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
24+
25+
import tensorrt as trt
2526

2627
logger = logging.getLogger(__name__)
2728

2829
LegacyConverterImplSignature = Callable[
2930
[
30-
TRTNetwork,
31+
trt.INetworkDefinition,
3132
Target,
3233
Tuple[Argument, ...],
3334
Dict[str, Argument],
3435
str,
3536
],
36-
Union[TRTTensor, Sequence[TRTTensor]],
37+
Union[trt.ITensor, Sequence[trt.ITensor]],
3738
]
3839

3940
DynamoConverterImplSignature = Callable[
@@ -44,7 +45,7 @@
4445
Dict[str, Argument],
4546
str,
4647
],
47-
Union[TRTTensor, Sequence[TRTTensor]],
48+
Union[trt.ITensor, Sequence[trt.ITensor]],
4849
]
4950

5051
ConverterImplSignature = Union[

py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,12 @@ def _repair_64bit_input(
5959
dtype: Data type of tensor at position in submodule (double/long)
6060
"""
6161
assert dtype in (
62-
torch.int64,
6362
torch.float64,
64-
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"
63+
), f"dtype argument must be torch.float64, got {dtype}"
6564

6665
# Determine target data type in 32 and 64 bit forms
6766
dtype_64bit = dtype
68-
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32
67+
dtype_32bit = torch.float32
6968

7069
# Find the node representing the submodule in the graph
7170
module_node = None
@@ -143,7 +142,7 @@ def _repair_64bit_input(
143142
cast_node_64bit = gm.graph.call_function(
144143
torch.ops.aten._to_copy.default,
145144
args=(get_node,),
146-
kwargs={"dtype": torch.int64},
145+
kwargs={"dtype": torch.float64},
147146
)
148147

149148
get_node.replace_all_uses_with(
@@ -189,7 +188,7 @@ def repair_long_or_double_inputs(
189188

190189
# If the data type of the input is long/double, insert necessary
191190
# casts to replace the operation
192-
if param.dtype in (torch.int64, torch.float64):
191+
if param.dtype == torch.float64:
193192
# Ensure outputs are only repaired once per submodule to avoid
194193
# unnecessary ops showing up in the graph
195194
if not repaired_outputs_once:
@@ -206,7 +205,7 @@ def repair_long_or_double_inputs(
206205
repaired_outputs_once = True
207206

208207
# Repair submodule inputs in accordance with inserted casts
209-
dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32
208+
dtype_32bit = torch.float32
210209
submodule_torch_inputs = (
211210
list(submodule_torch_inputs[:position])
212211
+ [

py/torch_tensorrt/logging.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Any
33

4+
import torch
45
from torch_tensorrt._features import ENABLED_FEATURES
56

67
import tensorrt as trt
@@ -51,6 +52,12 @@ def __enter__(self) -> None:
5152
self.ts_level = ts_logging.get_reportable_log_level()
5253
ts_logging.set_reportable_log_level(ts_logging.Level.InternalError)
5354

55+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
56+
self.rt_level = torch.ops.tensorrt.get_logging_level()
57+
torch.ops.tensorrt.set_logging_level(
58+
int(trt.ILogger.Severity.INTERNAL_ERROR)
59+
)
60+
5461
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
5562
_LOGGER.setLevel(self.external_lvl)
5663

@@ -59,6 +66,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
5966

6067
ts_logging.set_reportable_log_level(self.ts_level)
6168

69+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
70+
torch.ops.tensorrt.set_logging_level(self.rt_level)
71+
6272

6373
class errors:
6474
"""Context-manager to limit displayed log messages to just errors and above
@@ -79,6 +89,10 @@ def __enter__(self) -> None:
7989
self.ts_level = ts_logging.get_reportable_log_level()
8090
ts_logging.set_reportable_log_level(ts_logging.Level.Error)
8191

92+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
93+
self.rt_level = torch.ops.tensorrt.get_logging_level()
94+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR))
95+
8296
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
8397
_LOGGER.setLevel(self.external_lvl)
8498

@@ -87,6 +101,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
87101

88102
ts_logging.set_reportable_log_level(self.ts_level)
89103

104+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
105+
torch.ops.tensorrt.set_logging_level(self.rt_level)
106+
90107

91108
class warnings:
92109
"""Context-manager to limit displayed log messages to just warnings and above
@@ -107,6 +124,10 @@ def __enter__(self) -> None:
107124
self.ts_level = ts_logging.get_reportable_log_level()
108125
ts_logging.set_reportable_log_level(ts_logging.Level.Warning)
109126

127+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
128+
self.rt_level = torch.ops.tensorrt.get_logging_level()
129+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING))
130+
110131
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
111132
_LOGGER.setLevel(self.external_lvl)
112133

@@ -115,6 +136,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
115136

116137
ts_logging.set_reportable_log_level(self.ts_level)
117138

139+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
140+
torch.ops.tensorrt.set_logging_level(self.rt_level)
141+
118142

119143
class info:
120144
"""Context-manager to display all info and greater severity messages
@@ -135,6 +159,10 @@ def __enter__(self) -> None:
135159
self.ts_level = ts_logging.get_reportable_log_level()
136160
ts_logging.set_reportable_log_level(ts_logging.Level.Info)
137161

162+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
163+
self.rt_level = torch.ops.tensorrt.get_logging_level()
164+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO))
165+
138166
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
139167
_LOGGER.setLevel(self.external_lvl)
140168

@@ -143,6 +171,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
143171

144172
ts_logging.set_reportable_log_level(self.ts_level)
145173

174+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
175+
torch.ops.tensorrt.set_logging_level(self.rt_level)
176+
146177

147178
class debug:
148179
"""Context-manager to display full debug information through the logger
@@ -163,6 +194,10 @@ def __enter__(self) -> None:
163194
self.ts_level = ts_logging.get_reportable_log_level()
164195
ts_logging.set_reportable_log_level(ts_logging.Level.Debug)
165196

197+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
198+
self.rt_level = torch.ops.tensorrt.get_logging_level()
199+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE))
200+
166201
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
167202
_LOGGER.setLevel(self.external_lvl)
168203

@@ -171,6 +206,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
171206

172207
ts_logging.set_reportable_log_level(self.ts_level)
173208

209+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
210+
torch.ops.tensorrt.set_logging_level(self.rt_level)
211+
174212

175213
class graphs:
176214
"""Context-manager to display the results of intermediate lowering passes
@@ -192,10 +230,17 @@ def __enter__(self) -> None:
192230
self.ts_level = ts_logging.get_reportable_log_level()
193231
ts_logging.set_reportable_log_level(ts_logging.Level.Graph)
194232

233+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
234+
self.rt_level = torch.ops.tensorrt.get_logging_level()
235+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE) + 1)
236+
195237
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
196238
_LOGGER.setLevel(self.external_lvl)
197239

198240
if ENABLED_FEATURES.torchscript_frontend:
199241
from torch_tensorrt.ts import logging as ts_logging
200242

201243
ts_logging.set_reportable_log_level(self.ts_level)
244+
245+
elif ENABLED_FEATURES.torch_tensorrt_runtime:
246+
torch.ops.tensorrt.set_logging_level(self.rt_level)

tests/py/dynamo/models/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ def pytest_addoption(parser):
77
metavar="Internal Representation",
88
nargs=1,
99
type=str,
10-
required=True,
10+
required=False,
1111
help="IR to compile with",
1212
choices=["dynamo", "torch_compile"],
13+
default="dynamo",
1314
)
1415

1516

0 commit comments

Comments
 (0)