diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 47003c6faa25..89c5dc3c8c21 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -20,6 +20,7 @@ from tvm._ffi import register_object from tvm.ir import IRModule, transform +from tvm.meta_schedule.database.database import Database from tvm.relay import Any, Function as RelayFunc, vm from tvm.runtime import NDArray, Object from tvm.target import Target @@ -174,10 +175,16 @@ def __init__(self) -> None: @register_object("meta_schedule.ApplyHistoryBest") class ApplyHistoryBest(MetaScheduleContext): - pass + """An integration context that allows application of historically best records from a database""" + database: Database + """ The database to be queried from""" -def extract_task( + def __init__(self, database) -> None: + self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member + + +def extract_task_from_relay( mod: Union[IRModule, RelayFunc], target: Target, params: Optional[Dict[str, NDArray]] = None, diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 21d7a2614261..bcfa08cdfc7d 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -19,10 +19,11 @@ import logging import os.path from typing import Callable, Dict, List, Optional, Union +from tvm.ir.base import structural_equal, structural_hash from tvm.ir.module import IRModule from tvm.runtime import NDArray -from tvm.meta_schedule.integration import extract_task +from tvm.meta_schedule.integration import extract_task_from_relay from tvm.target.target import Target from tvm.te import Tensor, create_prim_func from tvm.tir import PrimFunc, Schedule @@ -650,11 +651,12 @@ def tune_relay( """ logger.info("Working directory: %s", work_dir) - extracted_tasks = extract_task(mod, target, params) + extracted_tasks = extract_task_from_relay(mod, target, params) # pylint: disable=protected-access tune_contexts = [] target = Parse._target(target) database = Parse._database(database, task_name, work_dir) + # parse the tuning contexts for task in extracted_tasks: assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" mod = Parse._mod(task.dispatched[0]) @@ -664,7 +666,7 @@ def tune_relay( mod=mod, target=target, config=config, - task_name=task_name, + task_name=task.task_name, space_generator=space, sch_rules=sch_rules, postprocs=postprocs, @@ -672,9 +674,27 @@ def tune_relay( num_threads=num_threads, ) ) + # deduplication + logger.info(f"Before task deduplication: {len(tune_contexts)} tasks") + tasks: List[TuneContext] = [] + hashs: List[int] = [] + for i, task in enumerate(tune_contexts): + struct_hash: int = structural_hash(task.mod) + flag: bool = False + if struct_hash in hashs: + for other_task in tune_contexts[i + 1 :]: + if structural_equal(task.mod, other_task.mod): + flag = True + break + if not flag: + tasks.append(task) + hashs.append(struct_hash) + logger.info(f"After task deduplication: {len(tasks)} tasks") + + # parse the task scheduler task_scheduler = Parse._task_scheduler( task_scheduler, - tune_contexts, + tasks, builder=Parse._builder(builder), runner=Parse._runner(runner), database=database, @@ -684,7 +704,7 @@ def tune_relay( # pylint: enable=protected-access task_scheduler.tune() schs: List[Schedule] = [] - for task in tune_contexts: + for task in tasks: mod = task.mod workload = database.commit_workload(mod) bests: List[TuningRecord] = database.get_top_k(workload, top_k=1) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index cf4262814947..130b3a534b70 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -112,7 +112,17 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) { Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Optional> dispatched) { - throw; + ICHECK(dispatched.defined()); + ICHECK_EQ(dispatched.value().size(), 1); + IRModule prim_mod = dispatched.value()[0]; + ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + ICHECK(HasOnlyOneFunction(mod)) << mod; + if (database->HasWorkload(prim_mod)) { + Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); + ICHECK(records.size() == 1) << "No records was found for given workload" << prim_mod; + return records[0]->workload->mod; + } else + return NullOpt; } /**************** FFI ****************/ @@ -146,6 +156,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { return TaskExtraction(); }); +TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") + .set_body_typed([](Database database) -> ApplyHistoryBest { + return ApplyHistoryBest(database); + }); } // namespace meta_schedule } // namespace tvm diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 6ca748e28573..545290a3de23 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -28,15 +28,6 @@ namespace tvm { namespace tir { -/*! - * \brief Create a sampling function that does multinomial sampling. - * \param rand_state The random state. - * \param weights The weights for multinomial sampling. - * \return The multinomial sampling function. - */ -TVM_DLL std::function MakeMultinomialSampler( - support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); - /******** Schedule: Sampling ********/ /*! * \brief Sample a random integer from a given range. diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index f508c7d252e1..0ace4d2bd02c 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -112,7 +112,7 @@ def test_meta_schedule_integration_extract_from_resnet(): layout="NHWC", dtype="float32", ) - extracted_tasks = ms.integration.extract_task(mod, target="llvm", params=params) + extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) assert len(extracted_tasks) == 30 diff --git a/tests/python/unittest/test_meta_schedule_task_extraction.py b/tests/python/unittest/test_meta_schedule_task_extraction.py index 8d1eca51432e..8523275f5186 100644 --- a/tests/python/unittest/test_meta_schedule_task_extraction.py +++ b/tests/python/unittest/test_meta_schedule_task_extraction.py @@ -91,7 +91,7 @@ def test_meta_schedule_extract_from_torch_model(model_name: str, batch_size: int dtype="float32", ) target = tvm.target.Target(target) - ms.integration.extract_task(mod, params=params, target=target) + ms.integration.extract_task_from_relay(mod, params=params, target=target) if __name__ == "__main__":