diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index adec5d3af630..0d4d594222e2 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -319,7 +319,7 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& if (lhs_dtype.is_void() || rhs_dtype.is_void()) { return DataType::Void(); - } else if (lhs_dtype != rhs_dtype) { + } else if (lhs_dtype != rhs_dtype && !lhs_dtype.is_bool() && !rhs_dtype.is_bool()) { ctx->ReportFatal(Diagnostic::Error(call) << "TypeError: " << "Binary operators must have the same datatype for both operands. "