@@ -20,40 +20,34 @@ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
2020 .pattern({
2121 " aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))" ,
2222 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
23- std::cout << " select.int converter recognized" << std::endl;
24-
2523 auto in = args[0 ].ITensor ();
2624 auto axis = args[1 ].unwrapToInt ();
2725 auto ind = (int32_t ) args[2 ].unwrapToInt ();
2826
29- // tried: vector for input
30- // std::vector<int32_t> indices_input = {ind};
31-
32- auto options = torch::TensorOptions ().device (torch::kCUDA , 1 ).dtype (torch::kInt32 );
33- at::Tensor indices = torch::tensor (torch::detail::TensorDataContainer (ind), options);
34-
27+ // index to access needs to be an at::Tensor
28+ at::Tensor indices = torch::tensor ({ind}).to (torch::kI32 );
3529 auto weights = Weights (ctx, indices);
36- // manually setting weights
37- // weights.data.type = nvinfer1::DataType::kINT32;
3830
31+ // IConstantLayer to convert indices from Weights to ITensor
3932 auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
40- const_layer->setName (util::node_info (n).c_str ());
41- // manually setting output type
42- // const_layer->setOutputType(0, nvinfer1::DataType::kINT32);
43-
44- auto const_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], const_layer->getOutput (0 ));
33+ TRTORCH_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
34+ auto const_out = const_layer->getOutput (0 );
4535
36+ // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
4637 auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
47- gather_layer->setName (util::node_info (n).c_str ());
48- // manually setting output type
49- // gather_layer->setOutputType(0, nvinfer1::DataType::kINT32);
38+ TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
39+ auto gather_out = gather_layer->getOutput (0 );
5040
51- auto gather_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gather_layer->getOutput (0 ));
41+ // IShuffleLayer removes redundant dimensions
42+ auto shuffle_layer = ctx->net ->addShuffle (*gather_out);
43+ TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
44+ shuffle_layer->setReshapeDimensions (util::unpadDims (gather_out->getDimensions ()));
45+ shuffle_layer->setName (util::node_info (n).c_str ());
46+ auto shuffle_out = shuffle_layer->getOutput (0 );
5247
53- LOG_DEBUG (" Output tensor shape: " << gather_output->getDimensions ());
54-
55- // for debugging
56- // std::raise(SIGTRAP);
48+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_out);
49+
50+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
5751
5852 return true ;
5953 }
0 commit comments