@@ -150,6 +150,97 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) {
150150 ASSERT_TRUE (trt_block_cnt == 1 && torch_block_cnt == 1 );
151151}
152152
153+ TEST (Partitioning, ResolveMultipleNonTensorInputsCorrectly) {
154+ const auto graph = R"IR(
155+ graph(%x.1 : Tensor):
156+ # TensorRT-intended Block
157+ %16 : int = prim::Constant[value=8]()
158+ %15 : int = prim::Constant[value=64]()
159+ %13 : int = prim::Constant[value=0]()
160+ %10 : int = prim::Constant[value=1]()
161+ %self.linear.bias : Float(4096, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
162+ %self.linear.weight : Float(4096, 64, strides=[64, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
163+ %3 : int = prim::Constant[value=-1]()
164+ %2 : int = prim::Constant[value=1]()
165+ %x.5 : Tensor = aten::flatten(%x.1, %2, %3)
166+ %4 : Tensor = aten::t(%self.linear.weight)
167+ %6 : Tensor = aten::matmul(%x.5, %4)
168+ %7 : Tensor = trt::const(%self.linear.bias)
169+ %9 : Tensor = aten::add(%7, %6, %10)
170+ %11 : int[] = aten::size(%9) # <string>:13:9
171+ %12 : int = aten::__getitem__(%11, %13)
172+ %shape.3 : int[] = prim::ListConstruct(%12, %15, %16, %16)
173+ %x.13 : Tensor = aten::reshape(%9, %shape.3)
174+
175+ # Torch-intended Block
176+ %num_spatial_dims.2 : int = prim::Constant[value=2]()
177+ %11 : int[] = prim::Constant[value=[0, 0]]()
178+ %10 : bool = prim::Constant[value=0]()
179+ %conv1_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
180+ %conv1_weight : Float(32, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
181+ %6 : int = prim::Constant[value=1]()
182+ %5 : int[] = prim::Constant[value=[1, 1]]()
183+ %4 : int[] = prim::Constant[value=[2, 2]]()
184+ %conv_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
185+ %conv_weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
186+ %input.16 : Tensor = aten::conv_transpose2d(%x.13, %conv_weight, %conv_bias, %4, %5, %5, %6, %5)
187+ %7 : Tensor = aten::_convolution(%input.16, %conv1_weight, %conv1_bias, %5, %5, %5, %10, %11, %6, %10, %10, %10, %10)
188+ %12 : int[] = aten::size(%7)
189+ %96 : int = aten::len(%12)
190+ %14 : int = aten::__range_length(%num_spatial_dims.2, %96, %6)
191+
192+ # TensorRT-intended Block
193+ %15 : float = prim::Constant[value=1e-05]()
194+ %14 : float = prim::Constant[value=0.1]()
195+ %13 : NoneType = prim::Constant()
196+ %num_spatial_dims.2 : int = prim::Constant[value=2]()
197+ %300 : int = prim::Constant[value=3]()
198+ %345 : int = aten::sub(%300, %96)
199+ %3 : int = aten::add(%345, %6)
200+ %2 : bool = prim::Constant[value=1]()
201+ %size_prods.2 : int = prim::Loop(%3, %2, %6)
202+ block0(%loop : int, %size_prods.13 : int):
203+ %i.3 : int = aten::__derive_index(%loop, %num_spatial_dims.2, %3)
204+ %8 : int = aten::__getitem__(%12, %i.3)
205+ %size_prods.15 : int = aten::mul(%size_prods.13, %8)
206+ -> (%2, %size_prods.15)
207+ %11 : Tensor = aten::instance_norm(%7, %13, %13, %13, %13, %2, %14, %15, %2)
208+ return (%11))IR" ;
209+
210+ auto g = std::make_shared<torch::jit::Graph>();
211+ torch::jit::parseIR (graph, g.get (), true );
212+
213+ torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
214+ partitioning_info.enabled = true ;
215+ std::vector<torch_tensorrt::core::ir::Input> inputs;
216+ inputs.push_back (torch_tensorrt::core::ir::Input ({1 , 64 }));
217+
218+ torch_tensorrt::core::ir::CollectionInputSpecMap inputs_map;
219+ std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
220+ for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
221+ inputs_map.insert ({g->inputs ()[i], {inputs[i]}});
222+ input_types.insert ({g->inputs ()[i], {{at::kFloat }}});
223+ }
224+
225+ partitioning_info.collection_input_spec_map = inputs_map;
226+ partitioning_info.forced_fallback_operators = {" aten::_convolution" };
227+ torch_tensorrt::core::partitioning::PartitioningCtx ctx (g->block (), partitioning_info);
228+ ctx.input_types_map = input_types;
229+
230+ torch_tensorrt::core::partitioning::populateInputIValues (&ctx);
231+ torch_tensorrt::core::partitioning::partition (&ctx);
232+ std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
233+ ctx.partitioned_blocks .begin ()->second ;
234+
235+ // For each TensorRT segmented block, verify that all inputs are of Tensor type
236+ for (auto block : segmented_blocks) {
237+ if (block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::SegmentedBlockTarget::kTensorRT ) {
238+ for (auto input : block.raw_inputs ())
239+ ASSERT_TRUE (input->type ()->isSubtypeOf (c10::TensorType::get ()));
240+ }
241+ }
242+ }
243+
153244TEST (Partitioning, ResolveTensorListInputsInTrtCorrectly) {
154245 const auto graph = R"IR(
155246 graph(%0 : Float(1, 3, 16, 16, strides=[768, 256, 16, 1]),
0 commit comments