Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion include/tvm/tir/index_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ class IndexMapNode : public Object {

/*!
* \brief Convert to string representation in Python.
* \param f_name_map Optional function to specify the stringified name of the variables.
* \return The stringified lambda expression in Python.
*/
String ToPythonString() const;
String ToPythonString(
const std::function<Optional<String>(const Var& var)>& f_name_map = nullptr) const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("initial_indices", &initial_indices);
Expand Down Expand Up @@ -203,6 +205,17 @@ class IndexMap : public ObjectRef {
*/
IndexMap Inverse(Array<Range> initial_ranges) const;

/*! \brief Rename the variables in the index map and ensure the names are unique.
*
* Construct a new index map with the same transformation, but with name_hint of variables to be
* guaranteed unique. The optional f_name_map can be provided to rename the variables.
*
* \param f_name_map The optional name map to rename the variables.
* \return The renamed index map.
*/
IndexMap RenameVariables(
const std::function<Optional<String>(const Var& var)>& f_name_map = nullptr) const;

/*! \brief Generate the inverse mapping.
*
* Determine the inverse, where the output range may contain
Expand All @@ -217,6 +230,14 @@ class IndexMap : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
};

/*! \brief Substitute variables in an index map.
*
* \param index_map The index_map
* \param f_subst The substitution function
*/
IndexMap Substitute(const IndexMap& index_map,
std::function<Optional<PrimExpr>(const Var& var)> f_subst);

} // namespace tir
} // namespace tvm

Expand Down
5 changes: 4 additions & 1 deletion python/tvm/tir/schedule/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from tvm._ffi import register_object as _register_object
from tvm.runtime import Object

from ...ir import Array, Map
from ...ir import Array, Map, save_json
from ...runtime import String
from ..expr import FloatImm, IntImm
from ..function import IndexMap
from . import _ffi_api
from .instruction import ATTR_TYPE, INPUT_RV_TYPE, Instruction

Expand All @@ -45,6 +46,8 @@ def _json_from_tvm(obj):
return str(obj)
if isinstance(obj, (IntImm, FloatImm)):
return obj.value
if isinstance(obj, IndexMap):
return save_json(obj)
raise TypeError("Not supported type: " + str(type(obj)))


Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/database/database_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void JSONDumps(ObjectRef json_obj, std::ostringstream& os) {
}
os << "}";
} else if (json_obj->IsInstance<tir::IndexMapNode>()) {
// Do nothing for index maps to start
JSONDumps(String(SaveJSON(json_obj)), os);
} else {
LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj->GetTypeKey();
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ class ScheduleBuilder : public ExprVisitor {
TuningRecord record = opt_record.value();
for (const Instruction& inst : record->trace->insts) {
if (inst->kind.same_as(kind_transform_layout)) {
ICHECK_EQ(inst->attrs.size(), 4);
auto index_map = Downcast<IndexMap>(inst->attrs[2]);
ICHECK_EQ(inst->inputs.size(), 2);
auto index_map = Downcast<IndexMap>(inst->inputs[1]);

if (!const_collector.constants.empty()) {
// In this case, RewriteLayout is acting on an AllocateConst node.
Expand Down
94 changes: 73 additions & 21 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ir/name_supply.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
Expand Down Expand Up @@ -310,8 +311,59 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const {
return arr_dst;
}

IndexMap IndexMap::RenameVariables(
const std::function<Optional<String>(const Var& var)>& f_name_map) const {
std::unordered_set<std::string> used_names;
Map<Var, PrimExpr> var_remap;
NameSupply name_supply{""};
const IndexMapNode* n = this->get();
if (f_name_map != nullptr) {
// Collect variables with pre-defined names provided by f_name_map.
std::unordered_set<const Object*> visited;
std::for_each(n->final_indices.begin(), n->final_indices.end(), [&](const PrimExpr& expr) {
PostOrderVisit(expr, [&](const ObjectRef& obj) {
if (!obj->IsInstance<VarNode>()) {
return;
}
if (visited.count(obj.get())) {
return;
}
visited.emplace(obj.get());
Var var = Downcast<Var>(obj);
if (Optional<String> opt_name = f_name_map(var); opt_name.defined()) {
String name = opt_name.value();
ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false));
name_supply->ReserveName(name, /*add_prefix=*/false);
var_remap.Set(var, Var(name, var->dtype));
}
});
});
}

for (const Var& initial_index : n->initial_indices) {
if (var_remap.count(initial_index)) {
// The name of the variable is pre-defined.
continue;
}
String unique_name = name_supply->FreshName(initial_index->name_hint, /*add_prefix=*/false);
if (unique_name != initial_index->name_hint) {
var_remap.Set(initial_index, Var(unique_name));
}
}

auto new_initial_indices = n->initial_indices.Map(
[&](const Var& var) { return Downcast<Var>(Substitute(var, var_remap)); });
auto new_final_indices =
n->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, var_remap); });
Optional<IndexMap> new_inverse_index_map = NullOpt;
if (n->inverse_index_map.defined()) {
new_inverse_index_map = Downcast<IndexMap>(n->inverse_index_map).RenameVariables(f_name_map);
}
return IndexMap(new_initial_indices, new_final_indices, new_inverse_index_map);
}

/*!
* \brief Auxilarry function to comvert an index map to lambda expression in Python.
* \brief Auxilarry function to convert an index map to lambda expression in Python.
* \param initial_indices The initial indices in the index map.
* \param final_indices The final indices in the index map.
* \return The lambda expression string.
Expand All @@ -320,47 +372,36 @@ std::string IndexMap2PythonLambdaExpr(const Array<Var>& initial_indices,
const Array<PrimExpr>& final_indices) {
std::unordered_set<std::string> used_names;
Map<Var, PrimExpr> var_remap;
for (const Var& initial_index : initial_indices) {
if (used_names.count(initial_index->name_hint)) {
std::string new_name = initial_index->name_hint + std::to_string(used_names.size());
used_names.insert(new_name);
var_remap.Set(initial_index, Var(new_name));
} else {
used_names.insert(initial_index->name_hint);
}
}
std::ostringstream oss;
oss << "lambda ";
for (size_t i = 0; i < initial_indices.size(); ++i) {
if (i != 0) {
oss << ", ";
}
auto it = var_remap.find(initial_indices[i]);
if (it != var_remap.end()) {
oss << (*it).second;
} else {
oss << initial_indices[i];
}
oss << initial_indices[i];
}
oss << ": (";
for (size_t i = 0; i < final_indices.size(); ++i) {
if (i != 0) {
oss << " ";
}
oss << Substitute(final_indices[i], var_remap);
oss << final_indices[i];
oss << ",";
}
oss << ")";
return oss.str();
}

String IndexMapNode::ToPythonString() const {
std::string lambda_expr = IndexMap2PythonLambdaExpr(initial_indices, final_indices);
if (!inverse_index_map.defined()) {
String IndexMapNode::ToPythonString(
const std::function<Optional<String>(const Var& var)>& f_name_map) const {
auto index_map = GetRef<IndexMap>(this).RenameVariables(f_name_map);
std::string lambda_expr =
IndexMap2PythonLambdaExpr(index_map->initial_indices, index_map->final_indices);
if (!index_map->inverse_index_map.defined()) {
return String(lambda_expr);
}
// Also convert the inverse index map.
IndexMap inverse = Downcast<IndexMap>(inverse_index_map.value());
IndexMap inverse = Downcast<IndexMap>(index_map->inverse_index_map.value());
std::string inverse_lambda_expr =
IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices);
std::ostringstream oss;
Expand All @@ -369,6 +410,17 @@ String IndexMapNode::ToPythonString() const {
return String(oss.str());
}

IndexMap Substitute(const IndexMap& index_map,
std::function<Optional<PrimExpr>(const Var& var)> f_subst) {
Array<PrimExpr> new_output =
index_map->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, f_subst); });
Optional<IndexMap> new_inverse_map = NullOpt;
if (index_map->inverse_index_map.defined()) {
new_inverse_map = Substitute(Downcast<IndexMap>(index_map->inverse_index_map.value()), f_subst);
}
return IndexMap{index_map->initial_indices, new_output, new_inverse_map};
}

TVM_REGISTER_NODE_TYPE(IndexMapNode);

TVM_REGISTER_GLOBAL("tir.IndexMap")
Expand Down
8 changes: 6 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,12 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_i
const IndexMap& index_map,
const Optional<IndexMap>& pad_value) {
TVM_TIR_SCHEDULE_BEGIN();
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map,
pad_value);
auto f_subst = [&](const Var& var) -> Optional<PrimExpr> {
return Downcast<Optional<PrimExpr>>(symbol_table_.Get(var));
};
auto new_index_map = Substitute(index_map, f_subst);
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
new_index_map, pad_value);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
std::ostringstream os;
os << new_expr;
inputs.push_back(String(os.str()));
} else if (obj.as<IndexMapNode>()) {
inputs.push_back(obj);
} else {
LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << obj->GetTypeKey();
throw;
Expand Down
27 changes: 16 additions & 11 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1483,20 +1483,20 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 4;
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 3;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
Integer buffer_index_type, IndexMap index_map,
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map,
Integer buffer_index, Integer buffer_index_type,
Optional<IndexMap> pad_value) {
return sch->TransformLayout(block_rv, buffer_index.IntValue(),
static_cast<BufferIndexType>(buffer_index_type->value), index_map,
pad_value);
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
Integer buffer_index_type, IndexMap index_map,
static String UnpackedAsPython(Array<String> outputs, String block_rv, IndexMap index_map,
Integer buffer_index, Integer buffer_index_type,
Optional<IndexMap> pad_value) {
PythonAPICall py("transform_layout");
py.Input("block", block_rv);
Expand All @@ -1505,7 +1505,6 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
<< "\", " << buffer_index << ")";
py.Input("buffer", os.str());

py.Input("index_map", index_map->ToPythonString());
py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() : "None");

Expand All @@ -1518,8 +1517,11 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
attrs_record.reserve(kNumAttrs);
attrs_record.push_back(attrs[0]);
attrs_record.push_back(attrs[1]);
attrs_record.push_back(String(::tvm::SaveJSON(attrs[2])));
attrs_record.push_back(attrs[3]);
if (attrs[2].defined()) {
attrs_record.push_back(String(::tvm::SaveJSON(attrs[2])));
} else {
attrs_record.push_back(attrs[2]);
}
return std::move(attrs_record);
}

Expand All @@ -1528,8 +1530,11 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
Array<ObjectRef> attrs;
attrs.push_back(attrs_record[0]);
attrs.push_back(attrs_record[1]);
attrs.push_back(::tvm::LoadJSON(Downcast<String>(attrs_record[2])));
attrs.push_back(attrs_record[3]);
if (attrs_record[2].defined()) {
attrs.push_back(::tvm::LoadJSON(Downcast<String>(attrs_record[2])));
} else {
attrs.push_back(attrs_record[2]);
}
return attrs;
}

Expand Down
54 changes: 43 additions & 11 deletions src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
const std::unordered_map<const Object*, const Object*>& rv_map) {
Array<ObjectRef> result;
result.reserve(inputs.size());
auto f_subst_with_rv_map = [&rv_map](const Var& var) -> Optional<PrimExpr> {
auto it = rv_map.find(var.get());
if (it == rv_map.end()) {
return NullOpt;
}
const Object* dst = it->second;
ICHECK(dst->IsInstance<VarNode>())
<< "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey();
return GetRef<Var>(static_cast<const VarNode*>(dst));
};

for (const ObjectRef& input : inputs) {
if (!input.defined() || // constant: nullptr
input->IsInstance<StringObj>() || // constant: string
Expand All @@ -68,17 +79,9 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input;
result.push_back(GetRef<ObjectRef>(it->second));
} else if (const auto* expr = input.as<PrimExprNode>()) { // RV: Expr
result.push_back(
Substitute(GetRef<PrimExpr>(expr), [&rv_map](const Var& var) -> Optional<PrimExpr> {
auto it = rv_map.find(var.get());
if (it == rv_map.end()) {
return NullOpt;
}
const Object* dst = it->second;
ICHECK(dst->IsInstance<VarNode>())
<< "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey();
return GetRef<Var>(static_cast<const VarNode*>(dst));
}));
result.push_back(Substitute(GetRef<PrimExpr>(expr), f_subst_with_rv_map));
} else if (const auto* index_map = input.as<IndexMapNode>()) {
result.push_back(Substitute(GetRef<IndexMap>(index_map), f_subst_with_rv_map));
} else if (input->IsInstance<ArrayNode>()) {
// Recursively convert elements of the array into a new list of ObjectRefs.
result.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), rv_map));
Expand Down Expand Up @@ -118,6 +121,16 @@ Array<ObjectRef> TranslateInputRVs(
} else if (input->IsInstance<MapNode>()) {
// Case 5: dict
results.push_back(input);
} else if (input->IsInstance<IndexMapNode>()) {
// // Case 6: IndexMap
IndexMap index_map = Downcast<IndexMap>(input);
index_map = index_map.RenameVariables([&rv_names](const Var& var) -> Optional<String> {
if (auto it = rv_names.find(var); it != rv_names.end()) {
return it->second;
}
return NullOpt;
});
results.push_back(index_map);
} else if (input->IsInstance<BlockRVNode>() || inputs->IsInstance<LoopRVNode>() ||
inputs->IsInstance<VarNode>()) {
LOG(FATAL) << "IndexError: Random variable is not defined " << input;
Expand Down Expand Up @@ -155,6 +168,25 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names";
const char* name = str->data;
int64_t size = str->size;
if (name[0] == '{' && name[size - 1] == '}') {
ObjectRef obj = LoadJSON(name);
// Case 6. IndexMap
if (obj->IsInstance<IndexMapNode>()) {
IndexMap index_map = Downcast<IndexMap>(obj);
index_map = Substitute(index_map, [&named_rvs](const Var& var) -> Optional<PrimExpr> {
auto it = named_rvs.find(var->name_hint);
if (it != named_rvs.end()) {
return Downcast<Var>(it->second);
}
return NullOpt;
});
results.push_back(index_map);
continue;
} else {
LOG(FATAL) << "TypeError: Unexpected object: " << obj->GetTypeKey();
throw;
}
}
// Case 2. string
if (size >= 2 && name[0] == '"' && name[size - 1] == '"') {
results.push_back(String(std::string(name + 1, size - 2)));
Expand Down
Loading