@@ -14,35 +14,57 @@ namespace converters {
1414namespace impl {
1515namespace {
1616
17- auto squeeze_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
18- {" aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))" ,
19- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20- auto self = args[0 ].ITensorOrFreeze (ctx);
21- auto dim = args[1 ].unwrapToInt ();
17+ auto squeeze_registrations TORCHTRT_UNUSED =
18+ RegisterNodeConversionPatterns ()
19+ .pattern(
20+ {" aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))" ,
21+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
22+ auto self = args[0 ].ITensorOrFreeze (ctx);
23+ auto dim = args[1 ].unwrapToInt ();
2224
23- auto selfDim = util::toVec (self->getDimensions ());
24- if (dim < 0 ) {
25- dim = selfDim.size () + dim;
26- }
25+ auto selfDim = util::toVec (self->getDimensions ());
26+ if (dim < 0 ) {
27+ dim = selfDim.size () + dim;
28+ }
2729
28- if (selfDim[dim] != 1 ) {
29- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], self);
30+ if (selfDim[dim] != 1 ) {
31+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], self);
3032
31- LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
33+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
3234
33- return true ;
34- }
35+ return true ;
36+ }
3537
36- auto shuffle_layer = ctx->net ->addShuffle (*self);
37- TORCHTRT_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
38- shuffle_layer->setReshapeDimensions (util::squeezeDims (self->getDimensions (), dim));
38+ auto shuffle_layer = ctx->net ->addShuffle (*self);
39+ TORCHTRT_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
40+ shuffle_layer->setReshapeDimensions (util::squeezeDims (self->getDimensions (), dim));
3941
40- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_layer->getOutput (0 ));
42+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_layer->getOutput (0 ));
4143
42- LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
44+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
4345
44- return true ;
45- }});
46+ return true ;
47+ }})
48+ .pattern(
49+ {" aten::squeeze(Tensor(a) self) -> (Tensor(a))" ,
50+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
51+ auto self = args[0 ].ITensorOrFreeze (ctx);
52+ auto self_dims = self->getDimensions ();
53+ auto out = self;
54+ auto squeeze_dims = util::squeezeAllDims (self_dims);
55+ if (squeeze_dims != self_dims) {
56+ auto shuffle_layer = ctx->net ->addShuffle (*self);
57+ TORCHTRT_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
58+ shuffle_layer->setReshapeDimensions (squeeze_dims);
59+ out = shuffle_layer->getOutput (0 );
60+ }
61+
62+ auto trt_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out);
63+
64+ LOG_DEBUG (" Output tensor shape: " << trt_out->getDimensions ());
65+
66+ return true ;
67+ }});
4668
4769} // namespace
4870} // namespace impl
0 commit comments