Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 67 additions & 50 deletions src/relax/transform/convert_dataflow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <tvm/relax/transform.h>
#include <tvm/relax/utils.h>

#include <optional>

namespace tvm {
namespace relax {

Expand All @@ -39,85 +41,100 @@ class DataflowBlockExtractor : public ExprMutator {
Array<BindingBlock> new_blocks;
Expr new_body = VisitExpr(seq->body);
bool changed = !new_body.same_as(seq->body);
bool dataflow_streak = false;
Array<Binding> dataflow_bindings;

// Accumulated bindings that are not going to be added to a
// DataflowBlock, either because they would be illegal within a
// DataflowBlock, or because there were insufficient bindings to
// make a dataflowblock. Because these bindings occur prior to
// `dataflow_bindings`, this array may only be accumulated into
// when `dataflow_bindings` is empty.
Array<Binding> non_dataflow_bindings;

// Current bindings that may legally be added to a DataflowBlock.
Array<Binding> dataflow_bindings;

// If present, a DataflowBlock whose bindings are currently in
// `dataflow_bindings`. Used to propagate DataflowBlock to the
// output, even if it doesn't meet the minimum size.
Optional<DataflowBlock> input_dataflow_block;

// Handle any bindings currently in `dataflow_bindings`. These
// are either pushed to their own block, or to the end of
// `non_dataflow_bindings`, depending on whether the bindings meet
// the minimum size requirement.
auto push_dataflow_bindings = [&]() {
if (dataflow_bindings.empty()) {
// No Dataflow bindings, so no action required.
return;
}
if (dataflow_bindings.size() < min_size_ && !input_dataflow_block) {
// The df block is below the minimum length, and no input
// DataflowBlock needs to be preserved. Combine the blocks
// and reset the dataflow collection.

non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
dataflow_bindings.end());

} else {
// A new DataflowBlock can be generated, with bindings that
// occur after the non-dataflow bindings.
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
new_blocks.push_back(DataflowBlock(dataflow_bindings));
non_dataflow_bindings = {};

// Making a dataflow block doesn't imply that the function was
// changed. A change requires that this either be a new
// dataflow block, or have additional dataflow bindings in the
// current block.
changed = changed || !input_dataflow_block.defined() ||
input_dataflow_block.value()->bindings.size() != dataflow_bindings.size();
}

dataflow_bindings = {};
input_dataflow_block = NullOpt;
};

for (auto block : seq->blocks) {
BindingBlock new_block = this->VisitBindingBlock(block);
changed = changed || !new_block.same_as(block);

// For an existing dataflow block, we add to the current streak
// or start a new streak in case there will be more dataflow operations
// coming up
if (new_block.as<DataflowBlock>()) {
if (!dataflow_streak) {
dataflow_streak = true;
}
if (auto dataflow_block = new_block.as<DataflowBlock>()) {
dataflow_bindings.insert(dataflow_bindings.end(), new_block->bindings.begin(),
new_block->bindings.end());
input_dataflow_block = dataflow_block;
continue;
}

// for a binding block, attempt to extract dataflow blocks inside
auto binding_block = Downcast<BindingBlock>(new_block);
for (size_t i = 0; i < binding_block->bindings.size(); i++) {
auto binding = binding_block->bindings[i];
for (const auto& binding : binding_block->bindings) {
Expr value = GetBoundValue(binding);
// dataflow values: not an if node and not an impure call
bool is_dataflow = (!value.as<IfNode>()) &&
(!(value.as<CallNode>() && IsImpureCall(Downcast<Call>(value))));
if (!dataflow_streak) {
// we can start a dataflow streak
if (is_dataflow) {
dataflow_streak = true;
dataflow_bindings = {binding};
} else {
non_dataflow_bindings.push_back(binding);
}
if (is_dataflow) {
// extend the streak
dataflow_bindings.push_back(binding);
} else {
if (is_dataflow) {
// extend the streak
dataflow_bindings.push_back(binding);
} else {
// this is the end of the streak
dataflow_streak = false;

// if the df block is below the minimum length, combine the blocks
// and reset the dataflow collection
if (dataflow_bindings.size() < min_size_) {
non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
dataflow_bindings.end());
dataflow_bindings = {};
} else {
// otherwise insert both collections
changed = true;
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
new_blocks.push_back(DataflowBlock(dataflow_bindings));
non_dataflow_bindings = {};
dataflow_bindings = {};
}
non_dataflow_bindings.push_back(binding);
}
// End the streak, if one currently exists.
push_dataflow_bindings();
non_dataflow_bindings.push_back(binding);
}
}
}

// handle any remaining bindings
if (dataflow_bindings.size() < min_size_) {
non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
dataflow_bindings.end());
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
} else {
changed = true;
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
new_blocks.push_back(DataflowBlock(dataflow_bindings));
}
push_dataflow_bindings();
new_blocks.push_back(BindingBlock(non_dataflow_bindings));

if (!changed) {
if (changed) {
return SeqExpr(new_blocks, new_body);
} else {
return GetRef<SeqExpr>(seq);
}
return SeqExpr(new_blocks, new_body);
}

private:
Expand Down
106 changes: 106 additions & 0 deletions tests/python/relax/test_transform_convert_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,5 +489,111 @@ def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
return v


class TestPreserveExistingDataflowBlocksAtBeginning(ExtractCompare):
"""Preserve existing DataflowBlocks

This is a regression test. In previous implementations, a
DataflowBlock in the input, without enough bindings to become a
new dataflow block, could be accidentally ommitted.

This test is identical to
`TestPreserveExistingDataflowBlocksAtEnd`, except that the
existing dataflow block is at the beginning of the function.

"""

@I.ir_module
class Before:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This DataflowBlock is below the minimum size for a new
# block, but already exists in the input IRModule.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)

R.print(format="impure_function")

# This sequence is large enough that it may be converted
# to a DataflowBlock.
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)

return (A1, B3)

@I.ir_module
class Expected:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This dataflow block should be preserved in the output.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)

R.print(format="impure_function")

with R.dataflow():
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)
R.output(B3)

return (A1, B3)


class TestPreserveExistingDataflowBlocksAtEnd(ExtractCompare):
"""Preserve existing DataflowBlocks

This is a regression test. In previous implementations, a
DataflowBlock in the input, without enough bindings to become a
new dataflow block, could be accidentally ommitted.

This test is identical to
`TestPreserveExistingDataflowBlocksAtBeginning`, except that the
existing dataflow block is at the end of the function.

"""

@I.ir_module
class Before:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This sequence is large enough that it may be converted
# to a DataflowBlock.
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)

R.print(format="impure_function")

# This DataflowBlock is below the minimum size for a new
# block, but already exists in the input IRModule.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)

return (A1, B3)

@I.ir_module
class Expected:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
with R.dataflow():
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)
R.output(B3)

R.print(format="impure_function")

# This dataflow block should be preserved in the output.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)

return (A1, B3)


if __name__ == "__main__":
tvm.testing.main()