@@ -353,4 +353,82 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
353353
354354 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toList ().vec ()[0 ].toTensor (), trt_out.toList ().vec ()[0 ].toTensor (), 1e-5 ));
355355 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toList ().vec ()[1 ].toTensor (), trt_out.toList ().vec ()[1 ].toTensor (), 1e-5 ));
356+ }
357+
358+
359+ TEST (CppAPITests, TestCollectionComplexModel) {
360+
361+ std::string path =
362+ " /root/Torch-TensorRT/complex_model.ts" ;
363+ torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
364+ std::vector<at::Tensor> inputs;
365+ inputs.push_back (in0);
366+
367+ torch::jit::Module mod;
368+ try {
369+ // Deserialize the ScriptModule from a file using torch::jit::load().
370+ mod = torch::jit::load (path);
371+ } catch (const c10::Error& e) {
372+ std::cerr << " error loading the model\n " ;
373+ }
374+ mod.eval ();
375+ mod.to (torch::kCUDA );
376+
377+
378+ std::vector<torch::jit::IValue> inputs_;
379+
380+ for (auto in : inputs) {
381+ inputs_.push_back (torch::jit::IValue (in.clone ()));
382+ }
383+
384+ std::vector<torch::jit::IValue> complex_inputs;
385+ auto input_list = c10::impl::GenericList (c10::TensorType::get ());
386+ input_list.push_back (inputs_[0 ]);
387+ input_list.push_back (inputs_[0 ]);
388+
389+ torch::jit::IValue input_list_ivalue = torch::jit::IValue (input_list);
390+
391+ complex_inputs.push_back (input_list_ivalue);
392+
393+
394+ auto out = mod.forward (complex_inputs);
395+ LOG_DEBUG (" Finish torchscirpt forward" );
396+
397+
398+ // auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
399+ auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kHalf );
400+
401+ auto input_shape_ivalue = torch::jit::IValue (std::move (c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
402+
403+
404+ c10::TypePtr elementType = input_shape_ivalue.type ();
405+ auto list = c10::impl::GenericList (elementType);
406+ list.push_back (input_shape_ivalue);
407+ list.push_back (input_shape_ivalue);
408+
409+
410+ torch::jit::IValue complex_input_shape (list);
411+ std::tuple<torch::jit::IValue> input_tuple2 (complex_input_shape);
412+ torch::jit::IValue complex_input_shape2 (input_tuple2);
413+
414+ auto compile_settings = torch_tensorrt::ts::CompileSpec (complex_input_shape2);
415+ compile_settings.require_full_compilation = false ;
416+ compile_settings.min_block_size = 1 ;
417+
418+ // Need to skip the conversion of __getitem__ and ListConstruct
419+ compile_settings.torch_executed_ops .push_back (" aten::__getitem__" );
420+ compile_settings.torch_executed_ops .push_back (" prim::ListConstruct" );
421+
422+ // // FP16 execution
423+ compile_settings.enabled_precisions = {torch::kHalf };
424+ // // Compile module
425+ auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
426+ LOG_DEBUG (" Finish compile" );
427+ auto trt_out = trt_mod.forward (complex_inputs);
428+ // auto trt_out = trt_mod.forward(complex_inputs_list);
429+
430+ // std::cout << out.toTuple()->elements()[0].toTensor() << std::endl;
431+
432+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor (), 1e-5 ));
433+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor (), 1e-5 ));
356434}
0 commit comments