From 802bc1222e7a189332c1ca460c0018d169393b7f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 1 Jul 2024 13:47:37 -0500 Subject: [PATCH] [Bugfix][Relax] Preserve existing DataflowBlock in ConvertToDataflow The `relax.transform.ConvertToDataflow` identifies portions of a Relax function that satisfy the requirements of a `relax::DataflowBlock`, and converts those portions to a new `DataflowBlock`, provided they are at least some minimum number of operations. Prior to this commit, if a function contained a region that would be converted to a `DataflowBlock`, but also contains existing `DataflowBlock`s that were smaller than the size required for creating a `DataflowBlock`, those existing blocks would be erroneously converted to non-dataflow. This commit updates the `ConvertToDataflow` pass to preserve all existing `DataflowBlock` present in the input. --- src/relax/transform/convert_dataflow.cc | 117 ++++++++++-------- .../relax/test_transform_convert_dataflow.py | 106 ++++++++++++++++ 2 files changed, 173 insertions(+), 50 deletions(-) diff --git a/src/relax/transform/convert_dataflow.cc b/src/relax/transform/convert_dataflow.cc index b927307c2e0e..528a466a9bb3 100644 --- a/src/relax/transform/convert_dataflow.cc +++ b/src/relax/transform/convert_dataflow.cc @@ -28,6 +28,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -39,10 +41,59 @@ class DataflowBlockExtractor : public ExprMutator { Array new_blocks; Expr new_body = VisitExpr(seq->body); bool changed = !new_body.same_as(seq->body); - bool dataflow_streak = false; - Array dataflow_bindings; + + // Accumulated bindings that are not going to be added to a + // DataflowBlock, either because they would be illegal within a + // DataflowBlock, or because there were insufficient bindings to + // make a dataflowblock. Because these bindings occur prior to + // `dataflow_bindings`, this array may only be accumulated into + // when `dataflow_bindings` is empty. Array non_dataflow_bindings; + // Current bindings that may legally be added to a DataflowBlock. + Array dataflow_bindings; + + // If present, a DataflowBlock whose bindings are currently in + // `dataflow_bindings`. Used to propagate DataflowBlock to the + // output, even if it doesn't meet the minimum size. + Optional input_dataflow_block; + + // Handle any bindings currently in `dataflow_bindings`. These + // are either pushed to their own block, or to the end of + // `non_dataflow_bindings`, depending on whether the bindings meet + // the minimum size requirement. + auto push_dataflow_bindings = [&]() { + if (dataflow_bindings.empty()) { + // No Dataflow bindings, so no action required. + return; + } + if (dataflow_bindings.size() < min_size_ && !input_dataflow_block) { + // The df block is below the minimum length, and no input + // DataflowBlock needs to be preserved. Combine the blocks + // and reset the dataflow collection. + + non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(), + dataflow_bindings.end()); + + } else { + // A new DataflowBlock can be generated, with bindings that + // occur after the non-dataflow bindings. + new_blocks.push_back(BindingBlock(non_dataflow_bindings)); + new_blocks.push_back(DataflowBlock(dataflow_bindings)); + non_dataflow_bindings = {}; + + // Making a dataflow block doesn't imply that the function was + // changed. A change requires that this either be a new + // dataflow block, or have additional dataflow bindings in the + // current block. + changed = changed || !input_dataflow_block.defined() || + input_dataflow_block.value()->bindings.size() != dataflow_bindings.size(); + } + + dataflow_bindings = {}; + input_dataflow_block = NullOpt; + }; + for (auto block : seq->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); changed = changed || !new_block.same_as(block); @@ -50,74 +101,40 @@ class DataflowBlockExtractor : public ExprMutator { // For an existing dataflow block, we add to the current streak // or start a new streak in case there will be more dataflow operations // coming up - if (new_block.as()) { - if (!dataflow_streak) { - dataflow_streak = true; - } + if (auto dataflow_block = new_block.as()) { dataflow_bindings.insert(dataflow_bindings.end(), new_block->bindings.begin(), new_block->bindings.end()); + input_dataflow_block = dataflow_block; continue; } // for a binding block, attempt to extract dataflow blocks inside auto binding_block = Downcast(new_block); - for (size_t i = 0; i < binding_block->bindings.size(); i++) { - auto binding = binding_block->bindings[i]; + for (const auto& binding : binding_block->bindings) { Expr value = GetBoundValue(binding); // dataflow values: not an if node and not an impure call bool is_dataflow = (!value.as()) && (!(value.as() && IsImpureCall(Downcast(value)))); - if (!dataflow_streak) { - // we can start a dataflow streak - if (is_dataflow) { - dataflow_streak = true; - dataflow_bindings = {binding}; - } else { - non_dataflow_bindings.push_back(binding); - } + if (is_dataflow) { + // extend the streak + dataflow_bindings.push_back(binding); } else { - if (is_dataflow) { - // extend the streak - dataflow_bindings.push_back(binding); - } else { - // this is the end of the streak - dataflow_streak = false; - - // if the df block is below the minimum length, combine the blocks - // and reset the dataflow collection - if (dataflow_bindings.size() < min_size_) { - non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(), - dataflow_bindings.end()); - dataflow_bindings = {}; - } else { - // otherwise insert both collections - changed = true; - new_blocks.push_back(BindingBlock(non_dataflow_bindings)); - new_blocks.push_back(DataflowBlock(dataflow_bindings)); - non_dataflow_bindings = {}; - dataflow_bindings = {}; - } - non_dataflow_bindings.push_back(binding); - } + // End the streak, if one currently exists. + push_dataflow_bindings(); + non_dataflow_bindings.push_back(binding); } } } // handle any remaining bindings - if (dataflow_bindings.size() < min_size_) { - non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(), - dataflow_bindings.end()); - new_blocks.push_back(BindingBlock(non_dataflow_bindings)); - } else { - changed = true; - new_blocks.push_back(BindingBlock(non_dataflow_bindings)); - new_blocks.push_back(DataflowBlock(dataflow_bindings)); - } + push_dataflow_bindings(); + new_blocks.push_back(BindingBlock(non_dataflow_bindings)); - if (!changed) { + if (changed) { + return SeqExpr(new_blocks, new_body); + } else { return GetRef(seq); } - return SeqExpr(new_blocks, new_body); } private: diff --git a/tests/python/relax/test_transform_convert_dataflow.py b/tests/python/relax/test_transform_convert_dataflow.py index 8a926cd4aedc..ab78ec0b3bc7 100644 --- a/tests/python/relax/test_transform_convert_dataflow.py +++ b/tests/python/relax/test_transform_convert_dataflow.py @@ -489,5 +489,111 @@ def main(x: R.Tensor, y: R.Tensor) -> R.Tensor: return v +class TestPreserveExistingDataflowBlocksAtBeginning(ExtractCompare): + """Preserve existing DataflowBlocks + + This is a regression test. In previous implementations, a + DataflowBlock in the input, without enough bindings to become a + new dataflow block, could be accidentally ommitted. + + This test is identical to + `TestPreserveExistingDataflowBlocksAtEnd`, except that the + existing dataflow block is at the beginning of the function. + + """ + + @I.ir_module + class Before: + @R.function(pure=False) + def main(A0: R.Tensor, B0: R.Tensor): + # This DataflowBlock is below the minimum size for a new + # block, but already exists in the input IRModule. + with R.dataflow(): + A1 = R.add(A0, A0) + R.output(A1) + + R.print(format="impure_function") + + # This sequence is large enough that it may be converted + # to a DataflowBlock. + B1 = R.add(B0, B0) + B2 = R.add(B1, B1) + B3 = R.add(B2, B2) + + return (A1, B3) + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(A0: R.Tensor, B0: R.Tensor): + # This dataflow block should be preserved in the output. + with R.dataflow(): + A1 = R.add(A0, A0) + R.output(A1) + + R.print(format="impure_function") + + with R.dataflow(): + B1 = R.add(B0, B0) + B2 = R.add(B1, B1) + B3 = R.add(B2, B2) + R.output(B3) + + return (A1, B3) + + +class TestPreserveExistingDataflowBlocksAtEnd(ExtractCompare): + """Preserve existing DataflowBlocks + + This is a regression test. In previous implementations, a + DataflowBlock in the input, without enough bindings to become a + new dataflow block, could be accidentally ommitted. + + This test is identical to + `TestPreserveExistingDataflowBlocksAtBeginning`, except that the + existing dataflow block is at the end of the function. + + """ + + @I.ir_module + class Before: + @R.function(pure=False) + def main(A0: R.Tensor, B0: R.Tensor): + # This sequence is large enough that it may be converted + # to a DataflowBlock. + B1 = R.add(B0, B0) + B2 = R.add(B1, B1) + B3 = R.add(B2, B2) + + R.print(format="impure_function") + + # This DataflowBlock is below the minimum size for a new + # block, but already exists in the input IRModule. + with R.dataflow(): + A1 = R.add(A0, A0) + R.output(A1) + + return (A1, B3) + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(A0: R.Tensor, B0: R.Tensor): + with R.dataflow(): + B1 = R.add(B0, B0) + B2 = R.add(B1, B1) + B3 = R.add(B2, B2) + R.output(B3) + + R.print(format="impure_function") + + # This dataflow block should be preserved in the output. + with R.dataflow(): + A1 = R.add(A0, A0) + R.output(A1) + + return (A1, B3) + + if __name__ == "__main__": tvm.testing.main()