|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | +#include "../utils.h" |
| 20 | + |
| 21 | +namespace tvm { |
| 22 | +namespace meta_schedule { |
| 23 | +class RandomComputeLocationNode : public ScheduleRuleNode { |
| 24 | + public: |
| 25 | + // Inherited from ScheduleRuleNode |
| 26 | + void InitializeWithTuneContext(const TuneContext& context) final {} |
| 27 | + |
| 28 | + // Inherited from ScheduleRuleNode |
| 29 | + Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { |
| 30 | + if (!CheckConditions(sch, block_rv)) { |
| 31 | + return {sch}; |
| 32 | + } |
| 33 | + |
| 34 | + // Step 1. If the producer of the input block needs a random compute-at location (specified by |
| 35 | + // the annotation), we colect the producer first, and transform the producer block later. |
| 36 | + // - The reason we collect the producer before transforming the input block is that, if the |
| 37 | + // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer |
| 38 | + // access the input block. Hence we collect its producer ahead of time. |
| 39 | + // - Note that only single producer is allowed in this case. |
| 40 | + Array<tir::BlockRV> producers{nullptr}; |
| 41 | + if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, |
| 42 | + true)) { |
| 43 | + producers = sch->GetProducers(block_rv); |
| 44 | + sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer); |
| 45 | + ICHECK_EQ(producers.size(), 1); |
| 46 | + } |
| 47 | + |
| 48 | + // Step 2. Transform the input block. |
| 49 | + tir::Schedule res = RandomlyComputeAt(sch, block_rv); |
| 50 | + |
| 51 | + // Step 3. Transform the producer block if compute-location sampling is needed. |
| 52 | + if (producers.defined()) { |
| 53 | + res = RandomlyComputeAt(res, producers[0]); |
| 54 | + } |
| 55 | + |
| 56 | + return {res}; |
| 57 | + } |
| 58 | + |
| 59 | + private: |
| 60 | + bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { |
| 61 | + const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); |
| 62 | + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); |
| 63 | + |
| 64 | + // Cond 1. The block is not the root block. |
| 65 | + if (block_sref->parent == nullptr) { |
| 66 | + return false; |
| 67 | + } |
| 68 | + // Cond 2. The block should be the direct child block of the root block. |
| 69 | + if (GetScopeRoot(sch->state(), block_sref, // |
| 70 | + /*require_stage_pipeline=*/false, // |
| 71 | + /*require_subtree_compact_dataflow=*/false) |
| 72 | + ->parent != nullptr) { |
| 73 | + return false; |
| 74 | + } |
| 75 | + // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child |
| 76 | + // block. |
| 77 | + Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref); |
| 78 | + if (loop_srefs.empty()) { |
| 79 | + return false; |
| 80 | + } |
| 81 | + if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) { |
| 82 | + return false; |
| 83 | + } |
| 84 | + // Cond 5. The block is not tiled. We check this condition by examine the block's annotation. |
| 85 | + if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) { |
| 86 | + return false; |
| 87 | + } |
| 88 | + // Cond 6. The block has at lease one consumer. |
| 89 | + if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) { |
| 90 | + return false; |
| 91 | + } |
| 92 | + return true; |
| 93 | + } |
| 94 | + |
| 95 | + /*! |
| 96 | + * \brief Keep sampling a compute-at location for the input block until success. |
| 97 | + * \param sch The TIR schedule |
| 98 | + * \param block_rv The block whose compute-at location is to be sampled |
| 99 | + * \return The TIR schedule after transformation |
| 100 | + */ |
| 101 | + tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) { |
| 102 | + for (;;) { |
| 103 | + tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv); |
| 104 | + try { |
| 105 | + sch->ComputeAt(block_rv, compute_at_loc, true); |
| 106 | + } catch (const dmlc::Error& e) { |
| 107 | + // ComputeAt fails, cleanup the following before re-try: |
| 108 | + // 1) trace: instruction & decisions |
| 109 | + // 2) sym_tab |
| 110 | + sch->trace().value()->Pop(); |
| 111 | + sch->RemoveRV(compute_at_loc); |
| 112 | + continue; |
| 113 | + } |
| 114 | + break; |
| 115 | + } |
| 116 | + return sch; |
| 117 | + } |
| 118 | + |
| 119 | + public: |
| 120 | + void VisitAttrs(tvm::AttrVisitor* v) {} |
| 121 | + |
| 122 | + static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation"; |
| 123 | + TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode); |
| 124 | +}; |
| 125 | + |
| 126 | +ScheduleRule ScheduleRule::RandomComputeLocation() { |
| 127 | + ObjectPtr<RandomComputeLocationNode> n = make_object<RandomComputeLocationNode>(); |
| 128 | + return ScheduleRule(n); |
| 129 | +} |
| 130 | + |
| 131 | +TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); |
| 132 | +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") |
| 133 | + .set_body_typed(ScheduleRule::RandomComputeLocation); |
| 134 | +} // namespace meta_schedule |
| 135 | +} // namespace tvm |
0 commit comments