1+ #include " torch/torch.h"
12#include " core/util/prelude.h"
23#include " core/conversion/converters/converters.h"
34
@@ -8,93 +9,59 @@ namespace converters {
89namespace impl {
910namespace {
1011
11- bool ConvertConvBatchNorm (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
12- auto input = args[0 ].ITensor ();
13- auto shape = util::toVec (input->getDimensions ());
14- LOG_WARNING (" Assuming channel dimension is 3 because input is from a conv layer, please verify" );
15- auto gamma = args[1 ].unwrapToTensor (at::full ({shape[shape.size () - 3 ]}, 1 ));
16- auto beta = args[2 ].unwrapToTensor (at::full ({shape[shape.size () - 3 ]}, 1 ));
17- auto mean = args[3 ].unwrapToTensor (at::full ({shape[shape.size () - 3 ]}, 0 ));
18- auto var = args[4 ].unwrapToTensor (at::full ({shape[shape.size () - 3 ]}, 0 ));
19- LOG_WARNING (" Momentum argument is disregarded" );
20- // auto momentum = args[6].unwrapToDouble(0);
21- auto eps = args[7 ].unwrapToDouble (0 );
22-
23- auto w = at::diag (gamma / at::sqrt (var + eps));
24- auto w_shape = w.sizes ().vec ();
25- w_shape.push_back (1 );
26- w_shape.push_back (1 );
27- w = w.reshape (w_shape);
28- auto b = beta - gamma * (mean / at::sqrt (var + eps));
29-
30- auto weights = Weights (ctx, w);
31- auto bias = Weights (ctx, b);
32-
33- auto bn_as_conv = ctx->net ->addConvolutionNd (*input, weights.num_output_maps , weights.kernel_shape , weights.data , bias.data );
34- TRTORCH_CHECK (bn_as_conv, " Unable to create fused batch norm from node: " << *n);
35-
36- bn_as_conv->setName (util::node_info (n).c_str ());
37-
38- auto bn_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], bn_as_conv->getOutput (0 ));
39- LOG_DEBUG (" Output tensor shape: " << bn_out->getDimensions ());
40- return true ;
41- }
42-
43- bool ConvertLinearBatchNorm (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
44- auto input = args[0 ].ITensor ();
45- auto shape = util::toVec (input->getDimensions ());
46- auto gamma = args[1 ].unwrapToTensor (at::full ({shape},1 ));
47- auto beta = args[2 ].unwrapToTensor (at::full ({shape},1 ));
48- auto mean = args[3 ].unwrapToTensor (at::full ({shape},0 ));
49- auto var = args[4 ].unwrapToTensor (at::full ({shape},0 ));
50- LOG_WARNING (" Momentum argument is disregarded" );
51- // auto momentum = args[6].unwrapToDouble(0);
52- auto eps = args[7 ].unwrapToDouble (0 );
53-
54- auto mean_ = tensor_to_const (ctx, mean);
55- auto bot_half = at::sqrt (var + eps);
56- auto bot_half_ = tensor_to_const (ctx, bot_half);
57- auto gamma_ = tensor_to_const (ctx, gamma);
58- auto beta_ = tensor_to_const (ctx, beta);
59-
60- auto top_half = ctx->net ->addElementWise (*input, *mean_, nvinfer1::ElementWiseOperation::kSUB );
61- auto top_half_out = top_half->getOutput (0 );
62- auto x_hat = ctx->net ->addElementWise (*top_half_out, *bot_half_, nvinfer1::ElementWiseOperation::kDIV );
63- auto x_hat_out = x_hat->getOutput (0 );
64- auto bn_scaled = ctx->net ->addElementWise (*gamma_, *x_hat_out, nvinfer1::ElementWiseOperation::kPROD );
65- auto bn_scaled_out = bn_scaled->getOutput (0 );
66- auto bn_biased = ctx->net ->addElementWise (*beta_, *bn_scaled_out, nvinfer1::ElementWiseOperation::kSUM );
67- auto bn_biased_out = bn_biased->getOutput (0 );
68-
69- bn_biased->setName (util::node_info (n).c_str ());
70- ctx->AssociateValueAndTensor (n->outputs ()[0 ], bn_biased_out);
71-
72- return true ;
73- }
74-
75- volatile auto batch_norm_registrations = RegisterNodeConversionPatterns()
76- .pattern({
77- R"SIG( aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
78- Tensor? mean, Tensor? var,
79- bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG" ,
80- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
81- auto input = args[0 ].ITensor ();
82- auto shape = input->getDimensions ();
83- auto gamma = args[1 ].unwrapToTensor ();
84-
85- if (/* training*/ args[5 ].unwrapToBool ()) {
86- LOG_WARNING (R"WARN( TRTorch only converts forward pass of graphs, but saw training = True, may see
87- unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN" );
88- }
89-
90- // If gamma is None this fails
91- if (util::volume (shape) == gamma.numel ()) {
92- return ConvertLinearBatchNorm (ctx, n, args);
93- } else {
94- return ConvertConvBatchNorm (ctx, n, args);
95- }
96- }
97- });
12+ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
13+ .pattern({
14+ R"SIG( aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
15+ Tensor? mean, Tensor? var,
16+ bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG" ,
17+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18+ auto input = args[0 ].ITensor ();
19+ auto orig_shape = input->getDimensions ();
20+ auto shape = util::toVec (orig_shape);
21+ auto options = torch::TensorOptions ().dtype (torch::kFloat32 );
22+ auto gamma = args[1 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
23+ auto beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , {options}));
24+ auto mean = args[3 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
25+ auto var = args[4 ].unwrapToTensor (at::full ({shape}, 0 , {options}));
26+ auto eps = args[7 ].unwrapToDouble (1e-5f );
27+
28+ LOG_DEBUG (" momentum disregarded" );
29+ LOG_DEBUG (" training disregarded" );
30+ LOG_DEBUG (" cudnn disregarded" );
31+
32+ auto should_unpack = util::toVec (orig_shape).size () < 4 ;
33+ if (should_unpack) {
34+ // expand spatial dims from 1D to 2D
35+ auto new_shape = util::toDimsPad (util::toVec (orig_shape), 4 );
36+ LOG_DEBUG (" Input shape is less than 4D got: " << orig_shape << " , inserting shuffle layer to reshape to 4D tensor shape: " << new_shape);
37+ auto in_shuffle = ctx->net ->addShuffle (*input);
38+ in_shuffle->setReshapeDimensions (new_shape);
39+ in_shuffle->setName (std::string (" [Reshape input to " + util::toStr (new_shape) + ' ]' ).c_str ());
40+ input = in_shuffle->getOutput (0 );
41+ }
42+
43+ auto scale = gamma / torch::sqrt (var + eps);
44+ auto bias = beta - mean * scale;
45+
46+ auto scale_weights = Weights (ctx, scale);
47+ auto bias_weights = Weights (ctx, bias);
48+
49+ auto bn = ctx->net ->addScaleNd (*input, nvinfer1::ScaleMode::kCHANNEL , bias_weights.data , scale_weights.data , {}, 1 );
50+ bn->setName (util::node_info (n).c_str ());
51+ auto out_tensor = bn->getOutput (0 );
52+
53+ if (should_unpack) {
54+ LOG_DEBUG (" Inserting shuffle layer to reshape to back to original shape: " << orig_shape);
55+ auto out_shuffle = ctx->net ->addShuffle (*out_tensor);
56+ out_shuffle->setReshapeDimensions (orig_shape);
57+ out_shuffle->setName (std::string (" [Reshape output to " + util::toStr (orig_shape) + ' ]' ).c_str ());
58+ out_tensor = out_shuffle->getOutput (0 );
59+ }
60+
61+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
62+ return true ;
63+ }
64+ });
9865
9966
10067} // namespace
0 commit comments