Skip to content

Commit 491a0f6

Browse files
authored
[Relax] Require correct input/output shapes R.call_tir (#17285)
Prior to this commit, the Relax well-formed checker validated arguments provided to Relax functions, but did not validate arguments provided to `R.call_tir`. As a result, incorrect arguments from Relax to TIR would not be checked until runtime, if at all. This commit updates the well-formed checker to verify that `R.call_tir` has received the correct arguments, and has the correct output shape specified in the `out_sinfo` parameter. Initial implementation performed the validation as part of `FNormalize`, to maximize coverage of this check. This increased end-to-end compilation time by ~10%, and so the check was requested to be restricted to the well-formed checker. Expensive operator-specific validation is now performed in the new `FValidate` attribute.
1 parent f432ebd commit 491a0f6

19 files changed

+928
-106
lines changed

include/tvm/relax/op_attr_types.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,40 @@ using FCallPacked = String;
5656
* expressed in multiple syntactically valid and semantically
5757
* equivalent forms, to normalize to a single representation.
5858
*
59+
* Note: `FNormalize` is applied for each expression as part of the
60+
* `relax::BlockBuilder`. While operator-specific validation may
61+
* be performed within the `FNormalize` implementation, ensuring
62+
* that errors are caught as early as possible, this should only be
63+
* used when validation is fast to apply. If the validation logic
64+
* may be slow, it should instead be implemented in `FValidate`,
65+
* which is only run as part of the well-formed checker.
66+
*
5967
* \param bb The BlockBuilder context.
6068
*
6169
* \param call The call to be normalized. It is provided by-value, to
6270
* avoid copies for the common case where the call is already normalized.
6371
*/
6472
using FNormalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, Call call)>;
6573

74+
/*!
75+
* \brief The function type of a validation function.
76+
*
77+
* A validation function is used to define constraints that should be
78+
* verified for an operator as part of the well-formed checker.
79+
*
80+
* Note: `FValidate` is only applied as part of the well-formed
81+
* checker. While this minimizes overhead while compiling Relax,
82+
* this delay between generating an ill-formed `relax::Call` and
83+
* identifying the ill-formed call may complicate debugging. If
84+
* the validation logic is very fast to check, and doing so would
85+
* not introduce a signficant overhead, consider validating as part
86+
* of `FNormalize`, which is applied by the block builder for each
87+
* `relax::Call`.
88+
*
89+
* \param call The call to be validated.
90+
*/
91+
using FValidate = runtime::TypedPackedFunc<void(const Call& call)>;
92+
6693
/*! \brief The function type of a legalization function.
6794
*
6895
* A legalization function is used to replace a `relax::Call` with

src/relax/analysis/well_formed.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,16 @@ class WellFormedChecker : public relax::ExprVisitor,
352352
<< after_normalize);
353353
}
354354
}
355+
356+
if (auto func_validate = op_map_validate_.get(call->op, nullptr); func_validate != nullptr) {
357+
try {
358+
func_validate(GetRef<Call>(call));
359+
} catch (std::exception& err) {
360+
Malformed(Diagnostic::Error(call) << "Operator-specific validation (FValidate) for "
361+
<< call->op << " identified error: \n"
362+
<< err.what());
363+
}
364+
}
355365
}
356366

357367
void VisitExpr_(const IfNode* op) final {
@@ -574,6 +584,7 @@ class WellFormedChecker : public relax::ExprVisitor,
574584
std::unordered_map<tir::Var, const FunctionNode*> symbolic_var_func_map_;
575585

576586
tvm::OpAttrMap<FNormalize> op_map_normalize_ = Op::GetAttrMap<FNormalize>("FNormalize");
587+
tvm::OpAttrMap<FValidate> op_map_validate_ = Op::GetAttrMap<FValidate>("FValidate");
577588
};
578589

579590
bool WellFormed(Variant<IRModule, Function> obj, bool check_struct_info) {

0 commit comments

Comments
 (0)