@@ -99,18 +99,24 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
9999 return nullptr ;
100100}
101101
102- torch::jit::Node* createCastNode (SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
102+ torch::jit::Node* createCastNode (
103+ SegmentedBlock& seg_block,
104+ size_t index,
105+ bool is_input,
106+ at::ScalarType dtype,
107+ std::string device,
108+ bool force_create_node = false ) {
103109 auto cast_raw_value = is_input ? seg_block.raw_inputs ()[index] : seg_block.raw_outputs ()[index];
104110 auto cast_subgraph_value = is_input ? seg_block.inputs ()[index] : seg_block.outputs ()[index];
105111 torch::jit::Node* cast_node = getUpstreamCastNode (cast_raw_value);
106112 auto g = seg_block.g ();
107113 // if we can find upstream aten::to node, we use it's parameters for creating new cast node
108- if (cast_node) {
114+ if (cast_node && !force_create_node ) {
109115 std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
110116 value_map.insert ({cast_node->inputs ()[0 ], cast_subgraph_value});
111117 if (!is_input) {
112118 // if this value is output, we need to cast it to int32
113- auto const_val = g->insertConstant (3 );
119+ auto const_val = g->insertConstant (dtype );
114120 if (cast_node->inputs ()[1 ]->node ()->output ()->type ()->kind () == torch::jit::TypeKind::DeviceObjType) {
115121 value_map.insert ({cast_node->inputs ()[2 ], const_val});
116122 } else {
@@ -122,7 +128,7 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
122128 // auto cast_node = g->prependNode(g->createClone(cast_node, env));
123129 } else {
124130 // if there is no explicit cast aten::to operation, we need to create a node
125- auto const_type = is_input ? g->insertConstant (4 ) : g-> insertConstant ( 3 );
131+ auto const_type = g->insertConstant (dtype );
126132 auto const_zero = g->insertConstant (0 );
127133 const_zero->setType (torch::jit::BoolType::get ());
128134 auto cuda = g->insertConstant (device);
@@ -222,27 +228,56 @@ void getSegmentsOutputByRunning(
222228
223229 auto target_device = partitioning_info.getGPUDeviceString ();
224230
225- // auto int64 <=> int32 conversion
226- if (seg_block.target () == SegmentedBlock::kTorch && partitioning_info. truncate_long_and_double ) {
231+ // auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
232+ if (seg_block.target () == SegmentedBlock::kTorch ) {
227233 // First, check if there is Int64 input
228234 for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
229235 if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
230236 auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
231237 at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
232- if (t == at::kLong ) {
238+ if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
239+ LOG_DEBUG (
240+ " Detected graph Long tensor input type during shape analysis, "
241+ << " inserting aten::to cast to Long to ensure this Torch block receives "
242+ << " a Long-type tensor input." );
233243 // we add a cast operation to cast the type to Int64
234- auto cast_node = createCastNode (seg_block, i, true , target_device);
244+ auto cast_node = createCastNode (seg_block, i, true , at::kLong , target_device);
245+ seg_block.g ()->prependNode (cast_node);
246+ seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
247+ } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
248+ LOG_DEBUG (
249+ " Detected graph Byte tensor input type during shape analysis, "
250+ << " inserting aten::to cast to Byte to ensure this Torch block receives "
251+ << " a Byte-type tensor input." );
252+ // If the input has type Byte, ensure it is casted to the correct type
253+ auto cast_node = createCastNode (seg_block, i, true , at::kByte , target_device, /* force_create_node=*/ true );
235254 seg_block.g ()->prependNode (cast_node);
236255 seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
237256 }
238257 }
239258 }
259+
240260 for (size_t i = 0 ; i < seg_block.outputs ().size (); ++i) {
241261 if (ivalues_maps[seg_block.raw_outputs ()[i]].isTensor ()) {
242262 auto cur_ivalue = ivalues_maps[seg_block.raw_outputs ()[i]];
243263 at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
244- if (t == at::kLong ) {
245- auto cast_node = createCastNode (seg_block, i, false , target_device);
264+
265+ // If the output has type Long and truncation was requested, insert truncate
266+ if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
267+ LOG_DEBUG (
268+ " Detected graph Long tensor output type during shape analysis, "
269+ << " inserting aten::to cast to Int to ensure the subsequent TensorRT block "
270+ << " receives an Int-type tensor input." );
271+ auto cast_node = createCastNode (seg_block, i, false , at::kInt , target_device);
272+ seg_block.g ()->appendNode (cast_node);
273+ seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
274+ } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
275+ LOG_DEBUG (
276+ " Detected graph Byte tensor output type during shape analysis, "
277+ << " inserting aten::to cast to Int to ensure the subsequent TensorRT block "
278+ << " receives an Int-type tensor input." );
279+ // If the output has type Byte and casting was requested, insert Integer cast
280+ auto cast_node = createCastNode (seg_block, i, false , at::kInt , target_device, /* force_create_node=*/ true );
246281 seg_block.g ()->appendNode (cast_node);
247282 seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
248283 }
@@ -254,11 +289,13 @@ void getSegmentsOutputByRunning(
254289 std::vector<std::vector<int64_t >> input_shapes;
255290 std::vector<at::ScalarType> input_types;
256291 for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
257- if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
292+ auto current_input = seg_block.raw_inputs ()[i];
293+
294+ if (ivalues_maps[current_input].isTensor ()) {
258295 // set the input_shape and data_type
259296 // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
260297 // shape inference
261- auto cur_ivalue = ivalues_maps[seg_block. raw_inputs ()[i] ];
298+ auto cur_ivalue = ivalues_maps[current_input ];
262299 at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
263300
264301 if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble )) {
@@ -271,10 +308,16 @@ void getSegmentsOutputByRunning(
271308 cur_ivalue = cur_ivalue.toTensor ().to (at::kFloat );
272309 LOG_WARNING (" Truncating graph input type from at::kDouble to at::kFloat" );
273310 }
311+
274312 c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType (cur_ivalue.toTensor ().dtype ());
275313 if (dtype == c10::nullopt ) {
276314 TORCHTRT_THROW_ERROR (" Unsupported input data type " << cur_ivalue.toTensor ().dtype ());
315+ } else if (dtype && dtype.value () == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs ) {
316+ // Special case to ensure input IValues to TensorRT engine are not Int8 type if the
317+ // model itself is not quantized
318+ cur_ivalue = cur_ivalue.toTensor ().to (at::kInt );
277319 }
320+
278321 if (cur_ivalue.toTensor ().sizes ().size () == 0 ) {
279322 // handle Scalar types, which has sizes of []
280323 input_shapes.push_back (util::toVec (util::toDims (c10::List<int64_t >({1 }))));
@@ -297,6 +340,7 @@ void runShapeAnalysis(
297340 const ir::ShapeMode& shape_mode) {
298341 // register every segment's input shape, and it's running output IValues
299342 for (auto & seg_block : ctx->partitioned_blocks [block]) {
343+ LOG_GRAPH (" Running shape analysis on block " << seg_block);
300344 torch::jit::ConstantPooling (seg_block.g ());
301345 getSegmentsOutputByRunning (seg_block, example_tensor_map, ctx->settings , shape_mode);
302346 }
0 commit comments