@@ -79,6 +79,26 @@ auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
7979 return true ;
8080 }});
8181
82+ auto isfinite_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
83+ {" aten::isfinite(Tensor self) -> Tensor" , [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
84+ auto in = args[0 ].ITensorOrFreeze (ctx);
85+ // assuming x-x = 0 for all values other than nan/inf/-inf where x-x = nan
86+ // x==x for all non-nan values
87+ auto inf_test_layer = ctx->net ->addElementWise (*in, *in, nvinfer1::ElementWiseOperation::kSUB );
88+ TORCHTRT_CHECK (inf_test_layer, " Unable to create sub layer from node: " << *n);
89+ inf_test_layer->setName ((util::node_info (n) + " _inf_test" ).c_str ());
90+ auto inf_test_tensor = inf_test_layer->getOutput (0 );
91+
92+ auto nan_test_layer =
93+ ctx->net ->addElementWise (*inf_test_tensor, *inf_test_tensor, nvinfer1::ElementWiseOperation::kEQUAL );
94+ TORCHTRT_CHECK (nan_test_layer, " Unable to create eq layer from node: " << *n);
95+ nan_test_layer->setName ((util::node_info (n) + " _nan_test" ).c_str ());
96+
97+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], nan_test_layer->getOutput (0 ));
98+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
99+ return true ;
100+ }});
101+
82102#define convert (unary, trt_type ) \
83103 auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
84104 {" aten::" #unary " (Tensor self) -> Tensor" , \
0 commit comments