Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "../../node/ndarray_hash_equal.h"
Expand Down Expand Up @@ -102,32 +103,49 @@ class BlockBuilderImpl : public BlockBuilderNode {

context_mod_->Add(gvar, func);

ctx_func_dedup_map_->emplace(func, gvar);
(*ctx_func_dedup_map_)[func].insert(gvar);
return gvar;
} else {
return it->second;
ICHECK(it->second.size()) << "Values contained in de-duplication map must be non-empty sets, "
<< "but found an empty set for function of type "
<< func->GetTypeKey();
// To provide deterministic results, return the GlobalVar that
// comes first in lexicographic order.
return *std::min_element(
it->second.begin(), it->second.end(),
[](const GlobalVar& a, const GlobalVar& b) { return a->name_hint < b->name_hint; });
}
}

void UpdateFunction(const GlobalVar& gv, BaseFunc function) final {
context_mod_.CopyOnWrite();

// invalidate old dedup map
// Remove function from the de-duplication map.
if (ctx_func_dedup_map_ != nullptr) {
auto it = context_mod_->functions.find(gv);
if (it != context_mod_->functions.end()) {
BaseFunc old_func = (*it).second;
auto ptr = ctx_func_dedup_map_->find(old_func);
ICHECK(ptr != ctx_func_dedup_map_->end());
ctx_func_dedup_map_->erase(ptr);
ICHECK(ptr != ctx_func_dedup_map_->end())
<< "BlockBuilder::UpdateFunction is updating " << gv
<< ", which appears in the BlockBuilder's context_mod_, "
<< "but does not appear in the de-duplication map";
ICHECK(ptr->second.count(gv))
<< "BlockBuilder::UpdateFunction is updating " << gv
<< ", but the de-duplication map for the previous value of this function "
<< "does not include " << gv;
ptr->second.erase(gv);
if (ptr->second.empty()) {
ctx_func_dedup_map_->erase(ptr);
}
}
}

context_mod_->Update(gv, function);

// add new dedup map item.
if (ctx_func_dedup_map_ != nullptr) {
ctx_func_dedup_map_->emplace(function, gv);
(*ctx_func_dedup_map_)[function].insert(gv);
}
}

Expand Down Expand Up @@ -399,7 +417,8 @@ class BlockBuilderImpl : public BlockBuilderNode {
* We use a custom hash to avoid hashing constants that may be bound to each BaseFunc.
*/
std::unique_ptr<
std::unordered_map<BaseFunc, GlobalVar, StructuralHashIgnoreNDarray, StructuralEqual>>
std::unordered_map<BaseFunc, std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual>,
StructuralHashIgnoreNDarray, StructuralEqual>>
ctx_func_dedup_map_ = nullptr;

/*!
Expand All @@ -408,11 +427,12 @@ class BlockBuilderImpl : public BlockBuilderNode {
void LazyInitCtxFuncDedupMap() {
if (ctx_func_dedup_map_ != nullptr) return;
ctx_func_dedup_map_ = std::make_unique<
std::unordered_map<BaseFunc, GlobalVar, StructuralHashIgnoreNDarray, StructuralEqual>>();
std::unordered_map<BaseFunc, std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual>,
StructuralHashIgnoreNDarray, StructuralEqual>>();
for (const auto& kv : context_mod_->functions) {
const GlobalVar gv = kv.first;
const BaseFunc func = kv.second;
ctx_func_dedup_map_->emplace(func, gv);
(*ctx_func_dedup_map_)[func].insert(gv);
}
}

Expand Down
67 changes: 66 additions & 1 deletion tests/python/relax/test_blockbuilder_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm import relax as rx, relay
from tvm.ir.base import assert_structural_equal
from tvm.relax import ExternFunc
from tvm.script import relax as R, tir as T
from tvm.script import ir as I, relax as R, tir as T
from tvm.tir.function import PrimFunc


Expand Down Expand Up @@ -925,5 +925,70 @@ def test_error_when_unwrapping_dataflowvar():
bb.emit_func_output(out)


def test_deduplication_when_input_contains_duplicates():
"""De-duplication of IRModules

A well-formed IRModule may contain duplicate function definitions.
This is rare, as most functions can be disambiguated by the the
function attribute `tvm::attr::kGlobalSymbol`. However, private
functions do not have this attribute, and a well-formed IRModule
may contain multiple copies of the same function.

This is a regression test. Previous implementation de-duplicated
using a `Dict[Function, GlobalVar]`, which has the failure mode
shown below. This was resolved by de-duplicating using a
`Dict[Function, Set[GlobalVar]]` instead.

"""

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor):
B = Module.subroutine_a(A)
C = Module.subroutine_b(B)
return C

@R.function(private=True)
def subroutine_a(arg: R.Tensor) -> R.Tensor:
return R.add(arg, arg)

@R.function(private=True)
def subroutine_b(arg: R.Tensor) -> R.Tensor:
return R.add(arg, arg)

@R.function(private=True)
def subroutine_c(arg: R.Tensor) -> R.Tensor:
return R.multiply(arg, arg)

# This test case is only valid when the two subroutines are
# structurally equal, and therefore allowed to be de-duplicated by
# the BlockBuilder.
tvm.ir.assert_structural_equal(Module["subroutine_a"], Module["subroutine_b"])

gvar_a = Module.get_global_var("subroutine_a")
gvar_b = Module.get_global_var("subroutine_b")
subroutine_c = Module["subroutine_c"]

bb = rx.BlockBuilder(Module)

# Add a function to the module. What we add doesn't matter, as
# this is only to initialize the de-duplication map.
bb.add_func(subroutine_c, "_unused")
# The deduplication table now maps `subroutine_ab` to either
# `gvar_a` or `gvar_b`.

# Update gvar_a.
bb.update_func(gvar_a, subroutine_c)
# The deduplication map no longer has an entry for
# `subroutine_ab`.

# Update gvar_b. The deduplication map is present (because we
# called `add_func`), but doesn't contain an entry for
# `subroutine_ab` (because it was just removed). This throws an
# error.
bb.update_func(gvar_b, subroutine_c)


if __name__ == "__main__":
tvm.testing.main()