diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index ee9ba4564af8..0d9ef6e4cf99 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -29,7 +29,6 @@ from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker from .builder import BuilderInput, BuilderResult, PyBuilder - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -236,11 +235,9 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA """ # pylint: disable=import-outside-toplevel from tvm.driver import build as tvm_build - from tvm.ir.transform import PassContext # pylint: enable=import-outside-toplevel - with PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]): - return tvm_build(mod, target=target) + return tvm_build(mod, target=target) @register_func("meta_schedule.builder.default_export") diff --git a/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py b/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py index b52f88aaa876..4649a8b9bbe0 100644 --- a/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py +++ b/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py @@ -20,7 +20,6 @@ import tvm from tvm import auto_scheduler -from tvm.meta_schedule.runner import RPCConfig from tvm.meta_schedule.testing.te_workload import CONFIGS @@ -56,6 +55,11 @@ def _parse_args(): type=str, required=True, ) + args.add_argument( + "--rpc-workers", + type=int, + required=True, + ) args.add_argument( "--log-dir", type=str, @@ -63,12 +67,6 @@ def _parse_args(): ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) - parsed.rpc_workers = RPCConfig( - tracker_host=parsed.rpc_host, - tracker_port=parsed.rpc_port, - tracker_key=parsed.rpc_key, - session_timeout_sec=30, - ).count_num_servers(allow_missing=True) return parsed @@ -93,6 +91,7 @@ def main(): cache_line_bytes=64, max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]), max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]), + max_local_memory_per_block=12345678, max_vthread_extent=8, warp_size=32, ) diff --git a/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py b/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py index d4166b10f502..50ab5b93937d 100644 --- a/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py @@ -63,19 +63,19 @@ def _parse_args(): type=str, required=True, ) + args.add_argument( + "--rpc-workers", + type=int, + required=True, + ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) - if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": - parsed.alloc_repeat = 3 - else: - parsed.alloc_repeat = 1 parsed.rpc_config = ms.runner.RPCConfig( tracker_host=parsed.rpc_host, tracker_port=parsed.rpc_port, tracker_key=parsed.rpc_key, - session_timeout_sec=30, + session_timeout_sec=60, ) - parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False) return parsed @@ -85,6 +85,7 @@ def _parse_args(): def main(): + alloc_repeat = 1 runner = ms.runner.RPCRunner( rpc_config=ARGS.rpc_config, evaluator_config=ms.runner.EvaluatorConfig( @@ -93,7 +94,7 @@ def main(): min_repeat_ms=100, enable_cpu_cache_flush=False, ), - alloc_repeat=ARGS.alloc_repeat, + alloc_repeat=alloc_repeat, max_workers=ARGS.rpc_workers, ) sch: Optional[tir.Schedule] = ms.tune_tir( diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index b29405333d79..20581f4630a6 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -36,12 +36,15 @@ class AddToDatabaseNode : public MeasureCallbackNode { for (int i = 0; i < n; ++i) { RunnerResult result = runner_results[i]; MeasureCandidate candidate = measure_candidates[i]; - if (result->error_msg.defined()) { - continue; + Array run_secs{nullptr}; + if (result->run_secs.defined()) { + run_secs = result->run_secs.value(); + } else { + run_secs = Array{FloatImm(DataType::Float(32), 1e10)}; } database->CommitTuningRecord(TuningRecord( /*trace=*/candidate->sch->trace().value(), - /*run_secs=*/result->run_secs.value(), + /*run_secs=*/run_secs, /*workload=*/workload, /*target=*/target, /*args_info=*/candidate->args_info)); diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 6e034886bdb5..00967aef7acd 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -51,7 +51,7 @@ int64_t Product(const std::vector& array) { return result; } -/*! \brief A mutator that mutates the decision of instruction Sample-Perfect-Tile */ +/*! \brief A mutator that mutates the tile size */ class MutateTileSizeNode : public MutatorNode { public: void VisitAttrs(tvm::AttrVisitor* v) {} @@ -66,10 +66,12 @@ class MutateTileSizeNode : public MutatorNode { }; /*! - * \brief Find the Sample-Perfect-Tile instructions and their decisions in the trace + * \brief Find a sample-perfect-tile decision in the trace * \param trace The trace - * \param inst The instructions found - * \param decision The decisions of the instructions found + * \param rand_state The random state + * \param inst The instruction selected + * \param decision The decision selected + * \return Whether a decision is found */ void FindSamplePerfectTile(const Trace& trace, std::vector* inst, std::vector>* decision) { @@ -92,13 +94,6 @@ void FindSamplePerfectTile(const Trace& trace, std::vector* inst, } } -/*! - * \brief Find all Sample-Categorical instructions (and their decisions) whose outputs are used for - * cooperative fetch annotation - * \param trace The trace - * \param inst The instructions found - * \param decision The decisions of the instructions found - */ void FindSampleVectorize(const Trace& trace, std::vector* inst, std::vector* decision) { static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); @@ -137,17 +132,12 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, } struct FactorMemo { - /*! - * \brief Find all factors of the input integer - * \param n The integer to be factorized - * \return The factors of the input integer - */ static std::vector Factorize(int n) { if (const std::vector* result = Global()->Query(n)) { return *result; } std::vector result; - for (int64_t i = 1; i * i < n; ++i) { + for (int64_t i = 1; i * i <= n; ++i) { if (n % i == 0) { result.push_back(i); if (i * i != n) { @@ -162,17 +152,17 @@ struct FactorMemo { private: const std::vector* Query(int n) { - std::unique_lock lock(mutex); - auto it = memo.find(n); - if (it != memo.end()) { + std::unique_lock lock(mutex_); + auto it = memo_.find(n); + if (it != memo_.end()) { return &it->second; } return nullptr; } void Add(int n, std::vector result) { - std::unique_lock lock(mutex); - memo.emplace(n, std::move(result)); + std::unique_lock lock(mutex_); + memo_.emplace(n, std::move(result)); } static FactorMemo* Global() { @@ -180,8 +170,8 @@ struct FactorMemo { return &singleton; } - std::unordered_map> memo; - std::mutex mutex; + std::unordered_map> memo_; + std::mutex mutex_; }; Optional MutateSampleTileSize(const Trace& trace, Instruction inst, diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index ad8ee9854265..798f00423f7b 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -49,7 +49,8 @@ Optional ParseThreadBinding(const Schedule& sch, const Instruction& ins * \param vector_lane The number of vector lane in vectorized cooperative fetching * \return NullOpt if parsing fails; Otherwise, the annotated block */ -Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, int* vector_lane) { +Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, + int64_t* vector_lane) { static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); if (!inst->kind.same_as(inst_kind_annotate)) { return NullOpt; @@ -87,55 +88,66 @@ class RewriteCooperativeFetchNode : public PostprocNode { bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { tir::Trace trace = sch->trace().value(); - int thread_extent_x = -1; - int thread_extent_y = -1; - int vector_lane = -1; + int64_t thread_extent_x = -1; + int64_t thread_extent_y = -1; + int64_t vector_lane = 1; std::vector> tasks; for (const tir::Instruction& inst : trace->insts) { if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { thread_extent_x = new_thread_extent.value()->value; - } else if (Optional new_thread_extent = - tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { + continue; + } + if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { thread_extent_y = new_thread_extent.value()->value; - } else if (Optional block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) { - ICHECK_NE(thread_extent_x, -1); - if (vector_lane > 1) { - tasks.push_back([thread_extent_x, thread_extent_y, vector_lane, sch, - block = block_rv.value()]() -> void { - tir::LoopRV fused = sch->GetLoops(block).back(); - if (thread_extent_y == -1) { - Array split = sch->Split(fused, {NullOpt, // - Integer(thread_extent_x), // - Integer(vector_lane)}); - sch->Vectorize(split[2]); - sch->Bind(split[1], "threadIdx.x"); - } else { - Array split = sch->Split(fused, {NullOpt, // - Integer(thread_extent_y), // - Integer(thread_extent_x), // - Integer(vector_lane)}); - sch->Vectorize(split[3]); - sch->Bind(split[2], "threadIdx.x"); - sch->Bind(split[1], "threadIdx.y"); - } - }); + continue; + } + Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); + if (!opt_block_rv.defined()) { + continue; + } + auto task = [thread_extent_x, thread_extent_y, vector_lane, sch, + block = opt_block_rv.value()]() mutable -> void { + sch->Unannotate(block, tir::attr::meta_schedule_cooperative_fetch); + tir::LoopRV fused = sch->GetLoops(block).back(); + int64_t fused_extent = -1; + if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(fused).get())) { + fused_extent = *extent; } else { - tasks.push_back( - [thread_extent_x, thread_extent_y, sch, block = block_rv.value()]() -> void { - tir::LoopRV fused = sch->GetLoops(block).back(); - if (thread_extent_y == -1) { - Array split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)}); - sch->Bind(split[1], "threadIdx.x"); - } else { - Array split = sch->Split(fused, {NullOpt, // - Integer(thread_extent_y), // - Integer(thread_extent_x)}); - sch->Bind(split[2], "threadIdx.x"); - sch->Bind(split[1], "threadIdx.y"); - } - }); + return; } - } + if (fused_extent % vector_lane != 0) { + vector_lane = 1; + } + if (thread_extent_y != -1) { + if (vector_lane > 1) { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_y), // + Integer(thread_extent_x), // + Integer(vector_lane)}); + sch->Vectorize(split[3]); + sch->Bind(split[2], "threadIdx.x"); + sch->Bind(split[1], "threadIdx.y"); + } else { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_y), // + Integer(thread_extent_x)}); + sch->Bind(split[2], "threadIdx.x"); + sch->Bind(split[1], "threadIdx.y"); + } + } else { + if (vector_lane > 1) { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_x), // + Integer(vector_lane)}); + sch->Vectorize(split[2]); + sch->Bind(split[1], "threadIdx.x"); + } else { + Array split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)}); + sch->Bind(split[1], "threadIdx.x"); + } + } + }; + tasks.push_back(task); } for (auto&& task : tasks) { task(); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index e2c71b7ec164..7d4a716b2e0c 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -22,6 +22,7 @@ namespace tvm { namespace tir { + class ThreadExtentChecker : private StmtVisitor { public: static bool Check(const Stmt& stmt) { @@ -35,24 +36,24 @@ class ThreadExtentChecker : private StmtVisitor { private: void VisitStmt_(const ForNode* loop) { - if (IsThreadIdx(GetThreadScope(loop))) { - const std::string& thread_tag = loop->thread_binding.value()->thread_tag; + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsThreadIdx(thread_scope)) { if (const int64_t* p_ext = GetLoopIntExtent(loop)) { - auto it = thread_tag2extent_.find(thread_tag); - bool new_thread = it == thread_tag2extent_.end(); - if (new_thread) { - thread_extent_product *= *p_ext; - thread_tag2extent_[thread_tag] = *p_ext; + int64_t ext = *p_ext; + if (thread_scope.dim_index == 0) { + std::swap(thread_idx_x, ext); + StmtVisitor::VisitStmt_(loop); + std::swap(thread_idx_x, ext); + } else if (thread_scope.dim_index == 1) { + std::swap(thread_idx_y, ext); + StmtVisitor::VisitStmt_(loop); + std::swap(thread_idx_y, ext); + } else if (thread_scope.dim_index == 2) { + std::swap(thread_idx_z, ext); + StmtVisitor::VisitStmt_(loop); + std::swap(thread_idx_z, ext); } else { - CHECK_EQ(it->second, *p_ext) - << "ValueError: All loops that are bound to `" << thread_tag - << "` should have the same extent. However, there are two loops with extent " - << it->second << " and " << p_ext << ", which are not equal"; - } - StmtVisitor::VisitStmt_(loop); - if (new_thread) { - thread_extent_product /= *p_ext; - thread_tag2extent_.erase(thread_tag); + StmtVisitor::VisitStmt_(loop); } return; } else { @@ -69,6 +70,7 @@ class ThreadExtentChecker : private StmtVisitor { GetAnn(block, attr::meta_schedule_thread_extent_high_inclusive)) { int64_t low = low_inclusive.value()->value; int64_t high = high_inclusive.value()->value; + int64_t thread_extent_product = thread_idx_x * thread_idx_y * thread_idx_z; if (!(low <= thread_extent_product && thread_extent_product <= high)) { throw dmlc::Error("Thread extent"); } @@ -77,12 +79,15 @@ class ThreadExtentChecker : private StmtVisitor { StmtVisitor::VisitStmt_(block); } - int64_t thread_extent_product = 1; - - /*! \brief A mapping from a thread tag to its thread extent */ - std::unordered_map thread_tag2extent_; + int64_t thread_idx_x = 1; + int64_t thread_idx_y = 1; + int64_t thread_idx_z = 1; }; + } // namespace tir +} // namespace tvm + +namespace tvm { namespace meta_schedule { /*! \brief Extract attribute from a target. */ @@ -105,9 +110,9 @@ class VerifyGPUCodeNode : public PostprocNode { Target target = context->target.value(); this->target_constraints_ = Map{ {"max_shared_memory_per_block", Extract(target, "max_shared_memory_per_block")}, - {"max_threads_per_block", Extract(target, "max_threads_per_block")}, {"max_vthread", Integer(8)}, - {"max_vector_bytes", Integer(16)}}; + {"max_vector_bytes", Integer(16)}, + }; } bool Verify(const IRModule& mod) const { @@ -150,14 +155,12 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); - // Phase 2 pass_list.push_back(tir::transform::VectorizeLoop(true)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - // Convert Function to IRModule transform::PassContext pass_ctx = transform::PassContext::Current(); tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index efe8407d6150..24d15b149e70 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -30,6 +30,30 @@ using tir::Schedule; /**************** Data Structure ****************/ +/*! \brief An auxiliary data structure to help deduplicate IRModules */ +class IRModuleSet { + public: + /*! \brief Add an IRModule to the set */ + void Add(const IRModule& mod, size_t shash) { tab_.insert(Item{mod, shash}); } + /*! \brief Check if the IRModule is in the set */ + bool Has(const IRModule& mod, size_t shash) const { return tab_.count(Item{mod, shash}); } + + private: + struct Item { + IRModule mod; + size_t shash; + }; + struct ItemHash { + size_t operator()(const Item& hash) const { return hash.shash; } + }; + struct ItemEqual { + bool operator()(const Item& lhs, const Item& rhs) const { + return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); + } + }; + std::unordered_set tab_; +}; + /*! * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. * \note It maintains a min heap in terms of `Item::score`. Therefore, when @@ -40,21 +64,10 @@ class SizedHeap { public: struct Item { Schedule sch; - IRModule mod; - size_t shash; double score; bool operator<(const Item& other) const { return score > other.score; } }; - struct ItemHash { - size_t operator()(const Item& hash) const { return hash.shash; } - }; - - struct ItemEqual { - bool operator()(const Item& lhs, const Item& rhs) const { - return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); - } - }; /*! * \brief Constructor * \param size_limit The up-limit of the heap size @@ -65,20 +78,16 @@ class SizedHeap { * \brief Push the specific item to the heap if its key did not appears in the heap * \param item The item to be pushed */ - void Push(Schedule sch, IRModule mod, double score) { - Item item{sch, mod, StructuralHash()(mod), score}; - if (!in_heap.insert(item).second) { - return; - } + void Push(Schedule sch, double score) { int size = heap.size(); if (size < size_limit) { // Heap is not full, just push - heap.emplace_back(item); + heap.emplace_back(Item{sch, score}); std::push_heap(heap.begin(), heap.end()); - } else if (item.score > heap.front().score) { + } else if (score > heap.front().score) { // if the item is better than the worst one in the heap, we can safely kick it out std::pop_heap(heap.begin(), heap.end()); - heap.back() = item; + heap.back() = {sch, score}; std::push_heap(heap.begin(), heap.end()); } // Otherwise, the item is worse than any other element in the heap @@ -88,8 +97,6 @@ class SizedHeap { int size_limit; /*! \brief The heap, the worse the topper */ std::vector heap; - /*! \brief The traces that are in the heap */ - std::unordered_set in_heap; }; struct PerThreadData { @@ -237,9 +244,15 @@ class EvolutionarySearchNode : public SearchStrategyNode { int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; + /*! \brief The counter of returning empty results. */ + int num_empty_iters; explicit State(EvolutionarySearchNode* self, Array design_spaces) - : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + : self(self), + design_spaces(design_spaces), + st(0), + ed(self->num_trials_per_iter), + num_empty_iters(0) {} /*! * \brief Pick up best candidates from database. @@ -302,6 +315,11 @@ class EvolutionarySearchNode : public SearchStrategyNode { std::unique_ptr state_ = nullptr; /*! \brief The token registered for the given workload in database. */ Workload token_{nullptr}; + /*! + * \brief The workloads that are already measured. + * TODO(junrushao1994): add records from the database to avoid re-measuring. + * */ + IRModuleSet measured_workloads_; /*** Configuration: global ***/ /*! \brief The number of trials per iteration. */ @@ -310,6 +328,11 @@ class EvolutionarySearchNode : public SearchStrategyNode { int num_trials_total; /*! \brief The population size in the evolutionary search. */ int population_size; + /*! + * \brief The maximum number of iterations before early stopping to confirm the search space is + * exhausted + */ + int num_empty_iters_before_early_stop; /*** Configuration: the initial population ***/ /*! \brief The ratio of measured states used in the initial population */ double init_measured_ratio; @@ -343,6 +366,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { v->Visit("num_trials_total", &num_trials_total); v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("population_size", &population_size); + v->Visit("num_empty_iters_before_early_stop", &num_empty_iters_before_early_stop); /*** Configuration: the initial population ***/ v->Visit("init_measured_ratio", &init_measured_ratio); v->Visit("init_min_unmeasured", &init_min_unmeasured); @@ -368,6 +392,8 @@ class EvolutionarySearchNode : public SearchStrategyNode { this->postprocs_ = context->postprocs; this->num_threads_ = context->num_threads; this->rand_state_ = ForkSeed(&context->rand_state); + CHECK(context->task_scheduler != nullptr) + << "ValueError: TaskScheduler is not defined in TuneContext"; this->cost_model_ = context->task_scheduler->cost_model.value(); this->database_ = context->task_scheduler->database; this->token_ = this->database_->CommitWorkload(context->mod.value()); @@ -474,7 +500,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( std::vector population, int num) { ICHECK_GT(num, 0); // The heap to record best schedule, we do not consider schedules that are already measured - // Also we use `in_heap` to make sure items in the heap are de-duplicated + IRModuleSet exists = self->measured_workloads_; SizedHeap heap(num); for (int iter = 0;; ++iter) { // Predict normalized score with the cost model, @@ -486,9 +512,11 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( for (int i = 0, n = population.size(); i < n; ++i) { Schedule sch = population.at(i); IRModule mod = sch->mod(); + size_t shash = StructuralHash()(mod); double score = scores.at(i); - if (!self->database_->HasWorkload(mod)) { - heap.Push(sch, mod, score); + if (!exists.Has(mod, shash)) { + exists.Add(mod, shash); + heap.Push(sch, score); } } // Discontinue once it reaches end of search @@ -576,6 +604,7 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); std::vector results; results.reserve(num); + IRModuleSet& measured_workloads = self->measured_workloads_; for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { bool has_best = i_bests < static_cast(bests.size()); bool has_rand = i_rands < static_cast(rands.size()); @@ -600,7 +629,12 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( break; } } - results.push_back(sch); + IRModule mod = sch->mod(); + size_t shash = StructuralHash()(mod); + if (!measured_workloads.Has(mod, shash)) { + measured_workloads.Add(mod, shash); + results.push_back(sch); + } } return results; } @@ -630,6 +664,12 @@ Optional> EvolutionarySearchNode::State::GenerateMeasure LOG(INFO) << "Got " << bests.size() << " candidate(s) with evolutionary search"; std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); LOG(INFO) << "Sending " << picks.size() << " candidates(s) for measurement"; + if (picks.empty()) { + ++this->num_empty_iters; + if (this->num_empty_iters >= self->num_empty_iters_before_early_stop) { + return NullOpt; + } + } return AssembleCandidates(picks, self->args_info_); } @@ -656,6 +696,7 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, / n->num_trials_per_iter = num_trials_per_iter; n->num_trials_total = num_trials_total; n->population_size = population_size; + n->num_empty_iters_before_early_stop = 5; n->init_measured_ratio = init_measured_ratio; n->init_min_unmeasured = init_min_unmeasured; n->genetic_num_iters = genetic_num_iters; diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 0e767825573f..b7ea3f539bce 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -299,22 +299,12 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS return SamplePerfectTile(rand_state, extent, n_splits); } CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; - std::vector innermost_candidates; - innermost_candidates.reserve(max_innermost_factor); - for (int32_t i = 1; i <= max_innermost_factor; ++i) { - if (extent % i == 0) { - innermost_candidates.push_back(i); + while (true) { + std::vector result = SamplePerfectTile(rand_state, extent, n_splits); + if (result.back() <= max_innermost_factor) { + return result; } } - // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. - // We should do multiple factorization to weight the choices. However, it would lead to slower - // sampling speed. On the other hand, considering potential tricks we might do on the innermost - // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add - // more heuristics in the future - int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; - std::vector result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1); - result.push_back(innermost); - return result; } std::vector SamplePerfectTile( diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index 31e92e09e50e..38847b6dba4c 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -72,7 +72,6 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) - T.block_attr({"meta_schedule.cooperative_fetch":1}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): @@ -82,7 +81,6 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): with T.block("C"): diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 80d645a5ce93..663614371eeb 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -17,16 +17,13 @@ """ Test Meta Schedule SearchStrategy """ # pylint: disable=missing-function-docstring import sys -from typing import List, Optional, Tuple, Union +from typing import List -import numpy as np import pytest import tvm -from tvm.ir import IRModule +from tvm import meta_schedule as ms from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.builder import LocalBuilder -from tvm.meta_schedule.cost_model import RandomModel -from tvm.meta_schedule.runner import LocalRunner, RunnerResult +from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.search_strategy import ( EvolutionarySearch, ReplayFunc, @@ -35,12 +32,11 @@ ) from tvm.meta_schedule.space_generator import ScheduleFn from tvm.meta_schedule.task_scheduler import RoundRobin -from tvm.meta_schedule.utils import derived_object -from tvm.meta_schedule.testing import DummyDatabase, DummyMutator +from tvm.meta_schedule.testing import DummyMutator +from tvm.meta_schedule.testing.utils import DummyDatabase from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace - MATMUL_M = 32 # pylint: disable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking @@ -49,7 +45,7 @@ @tvm.script.ir_module class Matmul: @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, (32, 32), "float32") B = T.match_buffer(b, (32, 32), "float32") @@ -58,7 +54,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): - C[vi, vj] = 0.0 + C[vi, vj] = 0.0 # type: ignore C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on @@ -116,8 +112,14 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name] + def _schedule_matmul_small(sch: Schedule): + block = sch.get_block("matmul") + _, j, k = sch.get_loops(block=block) + _, _ = sch.split(j, sch.sample_perfect_tile(j, n=2)) + _, _ = sch.split(k, sch.sample_perfect_tile(k, n=2)) + num_trials_per_iter = 10 - num_trials_total = 100 + num_trials_total = 2000 strategy = EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, @@ -132,7 +134,7 @@ def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name] ) context = TuneContext( mod=Matmul, - space_generator=ScheduleFn(sch_fn=_schedule_matmul), + space_generator=ScheduleFn(sch_fn=_schedule_matmul_small), mutator_probs={ DummyMutator(): 1.0, }, @@ -141,10 +143,10 @@ def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name] ) _scheduler = RoundRobin( tasks=[context], - builder=LocalBuilder(), - runner=LocalRunner(), + builder=ms.builder.LocalBuilder(), + runner=ms.runner.LocalRunner(), database=DummyDatabase(), - cost_model=RandomModel(), + cost_model=ms.cost_model.RandomModel(), measure_callbacks=[], ) context.space_generator.initialize_with_tune_context(context) @@ -168,11 +170,68 @@ def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name] strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() - print(num_trials_each_iter) - correct_count = 10 # For each iteration except the last one - assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + ( - [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else [] + assert sum(num_trials_each_iter) == 25 + assert num_trials_each_iter.count(0) < 5 + del _scheduler + + +def test_meta_schedule_evolutionary_search_early_stop(): # pylint: disable = invalid-name] + def _schedule_matmul_empty(sch: Schedule): + return sch + + num_trials_per_iter = 10 + num_trials_total = 100 + + strategy = EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + num_trials_total=num_trials_total, + population_size=5, + init_measured_ratio=0.1, + init_min_unmeasured=50, + genetic_num_iters=3, + genetic_mutate_prob=0.5, + genetic_max_fail_count=10, + eps_greedy=0.9, + ) + context = TuneContext( + mod=Matmul, + space_generator=ScheduleFn(sch_fn=_schedule_matmul_empty), + mutator_probs={ + DummyMutator(): 1.0, + }, + target=tvm.target.Target("llvm"), + num_threads=1, # because we are using a mutator from the python side ) + _scheduler = RoundRobin( + tasks=[context], + builder=ms.builder.LocalBuilder(), + runner=ms.runner.LocalRunner(), + database=DummyDatabase(), + cost_model=ms.cost_model.RandomModel(), + measure_callbacks=[], + ) + context.space_generator.initialize_with_tune_context(context) + spaces = context.space_generator.generate_design_space(context.mod) + + strategy.initialize_with_tune_context(context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(context, candidates, runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + assert num_trials_each_iter == [1, 0, 0, 0, 0] del _scheduler