Skip to content

Commit ef2a913

Browse files
authored
[Unity] Improved error message for matmul shape mismatch (#16308)
If a matrix multiplication cannot be performed due to incompatible shapes, the error message now specifies the arguments, the shape of each argument, and which dimension of the shape has a mismatch. Previously, this error message only provided the dimension of the mismatch.
1 parent 4c77f0f commit ef2a913

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

src/relax/op/tensor/linear_algebra.cc

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul);
4848

4949
StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
5050
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
51+
Expr lhs = call->args[0];
52+
Expr rhs = call->args[1];
5153
TensorStructInfo x1_sinfo = input_sinfo[0];
5254
TensorStructInfo x2_sinfo = input_sinfo[1];
5355

@@ -75,10 +77,19 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
7577
}
7678
int x1_ndim = x1_sinfo->ndim;
7779
int x2_ndim = x2_sinfo->ndim;
78-
if (x1_ndim == 0 || x2_ndim == 0) {
80+
if (x1_ndim == 0) {
7981
ctx->ReportFatal(Diagnostic::Error(call)
80-
<< "Matmul requires both inputs to have at least 1 dimension. However, "
81-
<< (x1_ndim == 0 ? "x1" : "x2") << " is a 0-rank tensor.");
82+
<< "Matmul operands must not be scalar. "
83+
<< "However, the expression " << call << " has a LHS of " << lhs
84+
<< " with struct info " << x1_sinfo
85+
<< ", which is scalar (zero-dimensional) tensor.");
86+
}
87+
if (x2_ndim == 0) {
88+
ctx->ReportFatal(Diagnostic::Error(call)
89+
<< "Matmul operands must not be scalar. "
90+
<< "However, the expression " << call << " has a RHS of " << rhs
91+
<< " with struct info " << x2_sinfo
92+
<< ", which is scalar (zero-dimensional) tensor.");
8293
}
8394

8495
int x1_prepended = 0;
@@ -120,9 +131,11 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
120131
PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2];
121132
if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) {
122133
ctx->ReportFatal(Diagnostic::Error(call)
123-
<< "Matmul requires the reduction length of x1 and x2 to be equal. However, "
124-
"the reduction lengths of x1 and x2 are "
125-
<< x1_reduction_length << " and " << x2_reduction_length << " respectively.");
134+
<< "Matmul requires the reduction length of the operands to be equal. "
135+
<< "However, the LHS " << lhs << " has shape " << x1_sinfo->shape
136+
<< ", while the RHS " << rhs << " has shape " << x2_sinfo->shape
137+
<< ". The reduction dimensions of " << x1_reduction_length << " and "
138+
<< x2_reduction_length << " are not equal.");
126139
}
127140

128141
Array<PrimExpr> output_shape = output_shape_prefix.value();

0 commit comments

Comments
 (0)