@@ -817,5 +817,54 @@ are data in batch.
817817.add_type_rel(" BatchMatmul" , BatchMatmulRel);
818818
819819
820+ // relay.nn.cross_entropy
821+ bool CrossEntropyRel (const Array<Type>& types,
822+ int num_inputs,
823+ const Attrs& attrs,
824+ const TypeReporter& reporter) {
825+ CHECK_EQ (types.size (), 3 );
826+ const auto * x = types[0 ].as <TensorTypeNode>();
827+ const auto * y = types[1 ].as <TensorTypeNode>();
828+ if (x == nullptr || y == nullptr ) return false ;
829+ CHECK (x->shape .size () == 2 && y->shape .size () == 2 )
830+ << " CrossEntropy: shapes of x and y is inconsistent, "
831+ << " x shape = " << x->shape << " , "
832+ << " y shape = " << y->shape ;
833+ CHECK (reporter->AssertEQ (x->shape [0 ], y->shape [0 ]))
834+ << " CrossEntropy: shapes of x and y is inconsistent, "
835+ << " x shape = " << x->shape << " , "
836+ << " y shape = " << y->shape ;
837+ CHECK (reporter->AssertEQ (x->shape [1 ], y->shape [1 ]))
838+ << " CrossEntropy: shapes of x and y is inconsistent, "
839+ << " x shape = " << x->shape << " , "
840+ << " y shape = " << y->shape ;
841+ // assign output type
842+ reporter->Assign (types[2 ], TensorTypeNode::make ({}, x->dtype ));
843+ return true ;
844+ }
845+
846+ // Positional relay function to create batch_matmul operator used by frontend FFI.
847+ Expr MakeCrossEntropy (Expr predictions, Expr targets) {
848+ static const Op& op = Op::Get (" nn.cross_entropy" );
849+ return CallNode::make (op, {predictions, targets}, Attrs (), {});
850+ }
851+
852+
853+ TVM_REGISTER_API (" relay.op.nn._make.cross_entropy" )
854+ .set_body_typed(MakeCrossEntropy);
855+
856+
857+ RELAY_REGISTER_OP (" nn.cross_entropy" )
858+ .describe(R"code(
859+ Computes cross entropy given predictions and targets.
860+ Do log on the data - do not accept logits.
861+ )code" TVM_ADD_FILELINE)
862+ .set_num_inputs(2 )
863+ .add_argument(" x" , " 1D Tensor" , " Predictions." )
864+ .add_argument(" y" , " 1D Tensor" , " Targets." )
865+ .set_support_level(10 )
866+ .add_type_rel(" CrossEntropy" , CrossEntropyRel);
867+
868+
820869} // namespace relay
821870} // namespace tvm
0 commit comments