Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 106 additions & 155 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,58 +385,47 @@ class FusedTIRConstructor : public ExprVisitor {
: mod_(mod), func_name_(func_name) {}

void VisitExpr_(const FunctionNode* func) final {
// Step 1. Create buffers for function params

// Record which fields in a tuple passed as a parameter are actually accessed by the function.
std::unordered_set<const Object*> tuple_param;
for (auto param : func->params) {
if (GetStructInfo(param)->IsInstance<TupleStructInfoNode>()) {
tuple_param.insert(param.get());
}
}

PostOrderVisit(func->body, [=, &tuple_param](Expr e) {
if (auto tup_get = e.as<TupleGetItemNode>();
tup_get && tuple_param.count(tup_get->tuple.get())) {
func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index);
}
});

std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
for (const Var& relax_param : func->params) {
auto sinfo = GetStructInfo(relax_param);
if (sinfo->IsInstance<ShapeStructInfoNode>()) {
// It's a symbolic shape var, no need to alloc Buffers.
continue;
}

auto [params, buffers] = [=]() {
if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
// Add only those tuple fields which are actually used by the function body into the
// function parameters.
int index = 0;
Array<tir::Var> params;
Array<tir::Buffer> buffers;
for (auto i : func_info_.used_tuple_field_indices[relax_param.get()]) {
auto [ret_params, ret_buffers] =
CreateParamsAndBuffers(tuple->fields[i], relax_param->name_hint(), index);
ICHECK_EQ(ret_params.size(), ret_buffers.size());
// Adding tuple field results to the end of params and buffers.
params.insert(params.end(), ret_params.begin(), ret_params.end());
buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
index += ret_params.size();
size_t size_before = prim_func_params.size();
CollectPrimFuncParams(relax_param, &prim_func_params);

auto param_buffers = [&]() -> Array<tir::Buffer> {
Array<tir::Buffer> out;
for (size_t i = size_before; i < prim_func_params.size(); i++) {
if (auto buf = prim_func_params[i].as<tir::Buffer>()) {
out.push_back(buf.value());
}
return std::make_pair(params, buffers);
} else {
return CreateParamsAndBuffers(sinfo, relax_param->name_hint());
}
return out;
}();

ICHECK_EQ(params.size(), buffers.size());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.buffer_map.Set(params[i], buffers[i]);
func_info_.params.push_back(params[i]);
func_info_.expr2buffers.Set(relax_param, param_buffers);
}

// Move all scalar params after buffer params. To ensure that the
// order is deterministic and predictable for testing purposes,
// std::stable_sort is used instead of std::sort.
std::stable_sort(prim_func_params.begin(), prim_func_params.end(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume it's important for the relative ordering to be preserved (hence the stable sort), might be good to call that out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, and updated the comment. It's mainly to make sure that the order is consistent and predictable for unit tests, not for any correctness of the output.

[](const auto& a, const auto& b) {
bool a_is_var = a.template as<tir::VarNode>();
bool b_is_var = b.template as<tir::VarNode>();
return a_is_var < b_is_var;
});

for (const auto& param : prim_func_params) {
if (auto opt = param.as<tir::Buffer>()) {
auto buffer = opt.value();
// Differentiate buffer name and param name by adding prefix
// `p_` to the buffer name. Every symbol should be unique in
// TVMScript, and while they can be de-deplicated when
// printed, it's more readable when done explicitly. Since
// Buffer is used more than param it gets the name with better
// readability.
tir::Var param = tir::Var("p_" + buffer->name, PrimType(DataType::Handle()));
func_info_.params.push_back(param);
func_info_.buffer_map.Set(param, buffer);
}
func_info_.expr2buffers.Set(relax_param, buffers);
}

// Step 2. Visit Function body and create intermediate buffers
Expand All @@ -458,13 +447,9 @@ class FusedTIRConstructor : public ExprVisitor {
}

// Step 4. Append symbolic vars
const relax::Var& last_relax_param = func->params.back();
if (GetStructInfo(last_relax_param)->IsInstance<ShapeStructInfoNode>()) {
auto [params, buffers] =
CreateParamsAndBuffers(GetStructInfo(last_relax_param), last_relax_param->name_hint());
ICHECK(buffers.empty());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.params.push_back(params[i]);
for (const auto& param : prim_func_params) {
if (auto var = param.as<tir::Var>()) {
func_info_.params.push_back(var.value());
}
}

Expand Down Expand Up @@ -548,12 +533,7 @@ class FusedTIRConstructor : public ExprVisitor {
int end_buf_idx = 0;
const TupleType& tuple_type = Downcast<TupleType>(tuple_get_item->tuple->checked_type());
for (int i = 0; i < tuple_get_item->index; ++i) {
auto it = func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get());
// If this tuple is not passed as a parameter, or if the field at the index i is actually
// used, the corresponding buffer needs to be taken into account by this function.
if (it == func_info_.used_tuple_field_indices.end() || it->second.count(i)) {
begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
}
begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
}
end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]);
func_info_.expr2buffers.Set(
Expand Down Expand Up @@ -719,64 +699,47 @@ class FusedTIRConstructor : public ExprVisitor {
}

/*!
* \brief Create an TIR func params and buffers with specified relax type and shape
* \brief Collect TIR func params and buffers with specified relax type and shape
* \param struct_info The struct info
* \param name_hint The name hint for params and buffers
* \param index The index used for unique name_hint if type is Tuple.
* -1 means no need to add postfix since the relax param is not a Tuple.
* \return The created TIR func params and buffers
* \param out The vector into which to collect the params/buffers
*/
static std::pair<Array<tir::Var>, Array<tir::Buffer>> CreateParamsAndBuffers(
StructInfo struct_info, const String& name_hint, int index = -1) {
Array<tir::Var> params;
Array<tir::Buffer> buffers;
// The symbolic shape params must be defined at the end of the param list.
bool symbolic_shape_param_started = false;
static void CollectPrimFuncParams(const Var& relax_param,
std::vector<Variant<tir::Var, tir::Buffer>>* out) {
auto struct_info = GetStructInfo(relax_param);

CHECK(!struct_info.as<TupleStructInfoNode>())
<< "InternalError: "
<< "All tuple parameters should be expanded before this point in FuseTIR. "
<< "However, parameter " << relax_param << " has struct info " << struct_info;

auto name_hint = relax_param->name_hint();

if (const auto* tensor = struct_info.as<TensorStructInfoNode>()) {
// Case 1. the relax param is a Tensor, we directly create a tir var and buffer
// Case 1. The relax param is a Tensor, we directly create a tir var and buffer
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape.";
CHECK(!symbolic_shape_param_started)
<< "The symbolic shape params must be defined at the end of the param "
"list.";
String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index);
ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape.";
DataType dtype = tensor->dtype;
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name);
// Differentiate buffer name and param name by adding prefix `v_` to param
// Every symbol should be unique in TVMScript, and Buffer is used more than param
// So we decide to make sure buffer names have better readability.
tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle()));
params.push_back(std::move(param));
buffers.push_back(std::move(buffer));
} else if (const auto* tuple = struct_info.as<TupleStructInfoNode>()) {
// Case 2. the relax param is a Tuple, we recursively visit each field until it's a Tensor
// Enable postfix
CHECK(!symbolic_shape_param_started)
<< "The symbolic shape params must be defined at the end of the param "
"list.";
if (index == -1) index = 0;
for (size_t i = 0; i < tuple->fields.size(); ++i) {
auto [ret_params, ret_buffers] = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
ICHECK_EQ(ret_params.size(), ret_buffers.size());
// Adding tuple field results to the end of params and buffers.
params.insert(params.end(), ret_params.begin(), ret_params.end());
buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
index += ret_params.size();
}
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
out->push_back(std::move(buffer));

} else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {
// Case 2. The relax param is a scalar, we directly create a tir var
ICHECK(prim_value->value->IsInstance<tir::VarNode>());
out->push_back(Downcast<tir::Var>(prim_value->value));

} else if (const auto* shape_expr = struct_info.as<ShapeStructInfoNode>()) {
// Case 3. the relax param is a scalar, we directly create a tir var
symbolic_shape_param_started = true;
ICHECK(index == -1) << "TypeError: The ShapeExprNode should not be in a Tuple field.";
// Case 3. The relax param is a tuple of scalars, each represented as a tir var
for (const auto& var : shape_expr->values.value()) {
ICHECK(var->IsInstance<tir::VarNode>());
params.push_back(Downcast<tir::Var>(var));
out->push_back(Downcast<tir::Var>(var));
}
} else {
ICHECK(false) << "TypeError: The param type of PrimFunc is expected to be Tensor, Tuple or "
"ShapeExpr, but got "
<< struct_info->GetTypeKey();
LOG(FATAL) << "TypeError: "
<< "The param type of PrimFunc is expected to be "
<< "Tensor, PrimValue, or ShapeExpr, "
<< "but got " << struct_info->GetTypeKey();
}
return std::make_pair(params, buffers);
}

/*!
Expand Down Expand Up @@ -870,9 +833,6 @@ class FusedTIRConstructor : public ExprVisitor {
/*! \brief The map from symbolic var to its corresponding var in the fused function */
tir::SymbolicMatcher symbolic_var_matcher =
tir::SymbolicMatcher(&analyzer, &symbolic_var_remap);

/*! \brief Record indices of tuple fields that are actually accessed. */
std::unordered_map<const Object*, std::unordered_set<size_t>> used_tuple_field_indices;
};

/*! \brief The IRModule */
Expand Down Expand Up @@ -987,34 +947,35 @@ class TIRFuseMutator : public ExprMutator {
Array<PrimExpr> tir_vars;
for (size_t i = 0; i < call->args.size(); ++i) {
auto arg = call->args[i];
Array<Expr> flattened;
if (GetStructInfo(relax_func->params[i])->IsInstance<TupleStructInfoNode>()) {
// Add only those tuple fields which are actually used by the function body
auto tup_get_indices = GetTupleAccessedIndices(relax_func.get(), relax_func->params[i]);
for (size_t tup_get_ind : tup_get_indices) {
auto flattened_inner = FlattenArg(builder_->Emit(TupleGetItem(arg, tup_get_ind)));
flattened.insert(flattened.end(), flattened_inner.begin(), flattened_inner.end());
auto sinfo = GetStructInfo(arg);

ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
!sinfo.as<TupleStructInfoNode>())
<< "InternalError: "
<< "All tuple parameters should be expanded before this point in FuseTIR. "
<< "However, argument " << arg << " with struct info " << arg->struct_info_
<< " is passed as argument " << i << " to Primitive Relax function " << old_gv
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
<< relax_func->params[i]->struct_info_;

if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
CHECK(shape->values.defined())
<< "FuseTIR requires all shape input has struct_info value.";
for (const PrimExpr& prim_value : shape->values.value()) {
CHECK(prim_value->IsInstance<tir::VarNode>())
<< "All shape inputs are expected to be single tir var.";
tir_vars.push_back(prim_value);
}
} else {
flattened.push_back(arg);
}
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
CHECK(prim_value->value.defined())
<< "FuseTIR requires all R.Prim arguments to have a known value.";
PrimExpr expr = prim_value->value.value();
CHECK(expr->IsInstance<tir::VarNode>())
<< "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var.";
tir_vars.push_back(expr);

for (const Expr& e : flattened) {
StructInfo sinfo = GetStructInfo(e);
if (sinfo->IsInstance<TensorStructInfoNode>()) {
arg_list.push_back(e);
} else if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
CHECK(shape->values.defined())
<< "FuseTIR requires all shape input has struct_info value.";
for (const PrimExpr& prim_value : shape->values.value()) {
CHECK(prim_value->IsInstance<tir::VarNode>())
<< "All shape inputs are expected to be single tir var.";
tir_vars.push_back(prim_value);
}
} else {
LOG(FATAL) << "The flattened arg is expected to be either tensor or shape, but got "
<< sinfo->GetTypeKey();
}
} else {
arg_list.push_back(arg);
}
}
// Step b. Create call_tir
Expand Down Expand Up @@ -1042,23 +1003,6 @@ class TIRFuseMutator : public ExprMutator {
return call;
}

/********** Helper Functions **********/

/*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */
Array<Expr> FlattenArg(const Expr& arg) {
if (const auto* tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(arg)) {
Array<Expr> arg_list;
for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
Expr new_arg = builder_->Emit(TupleGetItem(arg, i));
Array<Expr> flattened = FlattenArg(new_arg);
arg_list.insert(arg_list.end(), flattened.begin(), flattened.end());
}
return arg_list;
} else {
return {arg};
}
}

private:
/*! \brief The IRModule */
const IRModule& mod_;
Expand All @@ -1076,10 +1020,17 @@ namespace transform {
Pass FuseTIR() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) { return relax::FuseTIR(m); };
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseTIR", //
/*required=*/{});
auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseTIRInner", //
/*required=*/{});
return tvm::transform::Sequential(
{
ExpandTupleArguments(),
RemoveUnusedParameters(),
inner_pass,
},
"FuseTIR");
}

TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR);
Expand Down
14 changes: 7 additions & 7 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def fused_exp_squeeze(x):
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_exp_squeeze, x)
lv2 = bb.emit_te(fused_exp_squeeze, lv)
lv2 = bb.call_te(fused_exp_squeeze, lv)
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
return bb.get()
Expand Down Expand Up @@ -245,7 +245,7 @@ def fused_exp_exp_squeeze(x):
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_exp_exp_squeeze, x)
lv = bb.call_te(fused_exp_exp_squeeze, x)
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.get()
Expand All @@ -257,7 +257,7 @@ def test_fuse_with_tuple_as_param():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")]))
with bb.function("fused_exp_add", [x], attrs={"Primitive": True}):
with bb.function("fused_exp_add", [x], attrs={"Primitive": True}, private=True):
with bb.dataflow():
lv0 = bb.emit(relax.TupleGetItem(x, 0))
lv1 = bb.emit(relax.TupleGetItem(x, 1))
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_fuse_with_nested_tuple_as_param():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", tuple_struct_info)
with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}):
with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}, private=True):
with bb.dataflow():
lv0 = bb.emit(relax.TupleGetItem(x, 0))
lv0_exp = bb.emit_te(topi.exp, lv0)
Expand Down Expand Up @@ -373,7 +373,7 @@ def fused_exp_squeeze(x):
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_exp_squeeze, x)
lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32"))
lv2 = bb.call_te(topi.add, lv, relax.const(1, "float32"))
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
return bb.get()
Expand Down Expand Up @@ -414,7 +414,7 @@ def fused_add_exp_squeeze(x, y):
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
lv = bb.call_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.get()
Expand Down Expand Up @@ -1268,7 +1268,7 @@ def reshape(
(v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
]

@R.function
@R.function(private=True)
def fused_reshape(
lv: R.Tuple(
R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32")
Expand Down