@@ -73,17 +73,23 @@ auto select_registrations TRTORCH_UNUSED =
73
73
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
74
74
auto in = args[0 ].ITensorOrFreeze (ctx);
75
75
auto maxDim = static_cast <int64_t >(in->getDimensions ().nbDims );
76
- auto axis = args[1 ].unwrapToInt ();
77
- axis = axis < 0 ? axis + maxDim : axis;
76
+ auto dim = args[1 ].unwrapToInt ();
77
+ // Handle negative axis by refering to nbDims of input Tensor
78
+ dim = dim < 0 ? dim + maxDim : dim;
78
79
auto ind = (int32_t )args[2 ].unwrapToInt ();
80
+ // Along the specified dimension, handle negative index by subtracting along length of dimension.
81
+ ind = ind < 0 ? ind + in->getDimensions ().d [dim] : ind;
82
+ LOG_DEBUG (" Gather input dimensions: " << in->getDimensions ());
83
+ LOG_DEBUG (" Dimension to select: " << dim);
84
+ LOG_DEBUG (" Index: " << ind);
79
85
80
86
// index to access needs to be an at::Tensor
81
87
at::Tensor indices = torch::tensor ({ind}).to (torch::kI32 );
82
88
auto const_out = tensor_to_const (ctx, indices);
83
89
84
90
// IGatherLayer takes in input tensor, the indices, and the axis
85
91
// of input tensor to take indices from
86
- auto gather_layer = ctx->net ->addGather (*in, *const_out, axis );
92
+ auto gather_layer = ctx->net ->addGather (*in, *const_out, dim );
87
93
TRTORCH_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
88
94
auto out = gather_layer->getOutput (0 );
89
95
@@ -93,7 +99,7 @@ auto select_registrations TRTORCH_UNUSED =
93
99
// IShuffleLayer removes redundant dimensions
94
100
auto shuffle_layer = ctx->net ->addShuffle (*out);
95
101
TRTORCH_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
96
- shuffle_layer->setReshapeDimensions (util::squeezeDims (out->getDimensions (), axis ));
102
+ shuffle_layer->setReshapeDimensions (util::squeezeDims (out->getDimensions (), dim ));
97
103
shuffle_layer->setName (util::node_info (n).c_str ());
98
104
out = shuffle_layer->getOutput (0 );
99
105
}
0 commit comments