Skip to content

Commit c1b0952

Browse files
committed
add prompt if cudnn isn't installed
1 parent 8ab139d commit c1b0952

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
123123
// track_running_stats=True
124124
LOG_DEBUG("Args[3] running_mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
125125
LOG_DEBUG("Args[4] running_var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
126-
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
126+
LOG_DEBUG("use_input_stats, momemtum are disregarded");
127127
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);
128128

129129
// Expand spatial dims from 1D to 2D if needed
@@ -154,6 +154,17 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
154154
return true;
155155
}
156156

157+
auto cudnn_enabled = static_cast<bool>(args[8].unwrapToBool(false));
158+
if (!cudnn_enabled) {
159+
LOG_DEBUG(
160+
"cuDNN is not enabled, skipping instance_norm conversion. \
161+
Since TRT 10.0, cuDNN is loaded as a dynamic dependency, \
162+
so for some functionalities, users need to install correct \
163+
cuDNN version by themselves. Please see our support matrix \
164+
here: https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html.");
165+
return false;
166+
}
167+
157168
const int relu = 0;
158169
const float alpha = 0;
159170
LOG_DEBUG("Set parameter `relu` and `alpha` to 0");

tests/core/conversion/converters/test_instance_norm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ constexpr auto graph = R"IR(
1818
%running_mean.1 : Tensor?,
1919
%running_var.1 : Tensor?,
2020
%use_input_stats.1 : bool):
21-
%cudnn_enabled.1 : bool = prim::Constant[value=1]()
21+
%cudnn_enabled.1 : bool = prim::Constant[value=0]()
2222
%momentum.1 : float = prim::Constant[value=0.10000000000000001]()
2323
%eps.1 : float = prim::Constant[value=1.0000000000000001e-05]()
2424
%4 : Tensor = aten::instance_norm(%input.1,

0 commit comments

Comments
 (0)