Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
from .builder import BuilderInput, BuilderResult, PyBuilder


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -236,11 +235,9 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA
"""
# pylint: disable=import-outside-toplevel
from tvm.driver import build as tvm_build
from tvm.ir.transform import PassContext

# pylint: enable=import-outside-toplevel
with PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]):
return tvm_build(mod, target=target)
return tvm_build(mod, target=target)


@register_func("meta_schedule.builder.default_export")
Expand Down
13 changes: 6 additions & 7 deletions python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import tvm
from tvm import auto_scheduler
from tvm.meta_schedule.runner import RPCConfig
from tvm.meta_schedule.testing.te_workload import CONFIGS


Expand Down Expand Up @@ -56,19 +55,18 @@ def _parse_args():
type=str,
required=True,
)
args.add_argument(
"--rpc-workers",
type=int,
required=True,
)
args.add_argument(
"--log-dir",
type=str,
required=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
parsed.rpc_workers = RPCConfig(
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=30,
).count_num_servers(allow_missing=True)
return parsed


Expand All @@ -93,6 +91,7 @@ def main():
cache_line_bytes=64,
max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]),
max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]),
max_local_memory_per_block=12345678,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Have a comment saying this is useless but just a placeholder.

max_vthread_extent=8,
warp_size=32,
)
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ def _parse_args():
type=str,
required=True,
)
args.add_argument(
"--rpc-workers",
type=int,
required=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu":
parsed.alloc_repeat = 3
else:
parsed.alloc_repeat = 1
parsed.rpc_config = ms.runner.RPCConfig(
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=30,
session_timeout_sec=60,
)
parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False)
return parsed


Expand All @@ -85,6 +85,7 @@ def _parse_args():


def main():
alloc_repeat = 1
runner = ms.runner.RPCRunner(
rpc_config=ARGS.rpc_config,
evaluator_config=ms.runner.EvaluatorConfig(
Expand All @@ -93,7 +94,7 @@ def main():
min_repeat_ms=100,
enable_cpu_cache_flush=False,
),
alloc_repeat=ARGS.alloc_repeat,
alloc_repeat=alloc_repeat,
max_workers=ARGS.rpc_workers,
)
sch: Optional[tir.Schedule] = ms.tune_tir(
Expand Down
9 changes: 6 additions & 3 deletions src/meta_schedule/measure_callback/add_to_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ class AddToDatabaseNode : public MeasureCallbackNode {
for (int i = 0; i < n; ++i) {
RunnerResult result = runner_results[i];
MeasureCandidate candidate = measure_candidates[i];
if (result->error_msg.defined()) {
continue;
Array<FloatImm> run_secs{nullptr};
if (result->run_secs.defined()) {
run_secs = result->run_secs.value();
} else {
run_secs = Array<FloatImm>{FloatImm(DataType::Float(32), 1e10)};
}
database->CommitTuningRecord(TuningRecord(
/*trace=*/candidate->sch->trace().value(),
/*run_secs=*/result->run_secs.value(),
/*run_secs=*/run_secs,
/*workload=*/workload,
/*target=*/target,
/*args_info=*/candidate->args_info));
Expand Down
38 changes: 14 additions & 24 deletions src/meta_schedule/mutator/mutate_tile_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ int64_t Product(const std::vector<int64_t>& array) {
return result;
}

/*! \brief A mutator that mutates the decision of instruction Sample-Perfect-Tile */
/*! \brief A mutator that mutates the tile size */
class MutateTileSizeNode : public MutatorNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
Expand All @@ -66,10 +66,12 @@ class MutateTileSizeNode : public MutatorNode {
};

/*!
* \brief Find the Sample-Perfect-Tile instructions and their decisions in the trace
* \brief Find a sample-perfect-tile decision in the trace
* \param trace The trace
* \param inst The instructions found
* \param decision The decisions of the instructions found
* \param rand_state The random state
* \param inst The instruction selected
* \param decision The decision selected
* \return Whether a decision is found
*/
void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
std::vector<std::vector<int64_t>>* decision) {
Expand All @@ -92,13 +94,6 @@ void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
}
}

/*!
* \brief Find all Sample-Categorical instructions (and their decisions) whose outputs are used for
* cooperative fetch annotation
* \param trace The trace
* \param inst The instructions found
* \param decision The decisions of the instructions found
*/
void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
std::vector<int64_t>* decision) {
static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
Expand Down Expand Up @@ -137,17 +132,12 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
}

struct FactorMemo {
/*!
* \brief Find all factors of the input integer
* \param n The integer to be factorized
* \return The factors of the input integer
*/
static std::vector<int> Factorize(int n) {
if (const std::vector<int>* result = Global()->Query(n)) {
return *result;
}
std::vector<int> result;
for (int64_t i = 1; i * i < n; ++i) {
for (int64_t i = 1; i * i <= n; ++i) {
if (n % i == 0) {
result.push_back(i);
if (i * i != n) {
Expand All @@ -162,26 +152,26 @@ struct FactorMemo {

private:
const std::vector<int>* Query(int n) {
std::unique_lock<std::mutex> lock(mutex);
auto it = memo.find(n);
if (it != memo.end()) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = memo_.find(n);
if (it != memo_.end()) {
return &it->second;
}
return nullptr;
}

void Add(int n, std::vector<int> result) {
std::unique_lock<std::mutex> lock(mutex);
memo.emplace(n, std::move(result));
std::unique_lock<std::mutex> lock(mutex_);
memo_.emplace(n, std::move(result));
}

static FactorMemo* Global() {
static FactorMemo singleton;
return &singleton;
}

std::unordered_map<int, std::vector<int>> memo;
std::mutex mutex;
std::unordered_map<int, std::vector<int>> memo_;
std::mutex mutex_;
};

Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,
Expand Down
98 changes: 55 additions & 43 deletions src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& ins
* \param vector_lane The number of vector lane in vectorized cooperative fetching
* \return NullOpt if parsing fails; Otherwise, the annotated block
*/
Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst, int* vector_lane) {
Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst,
int64_t* vector_lane) {
static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_kind_annotate)) {
return NullOpt;
Expand Down Expand Up @@ -87,55 +88,66 @@ class RewriteCooperativeFetchNode : public PostprocNode {

bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
tir::Trace trace = sch->trace().value();
int thread_extent_x = -1;
int thread_extent_y = -1;
int vector_lane = -1;
int64_t thread_extent_x = -1;
int64_t thread_extent_y = -1;
int64_t vector_lane = 1;
std::vector<std::function<void()>> tasks;
for (const tir::Instruction& inst : trace->insts) {
if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) {
thread_extent_x = new_thread_extent.value()->value;
} else if (Optional<Integer> new_thread_extent =
tir::ParseThreadBinding(sch, inst, "threadIdx.y")) {
continue;
}
if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y")) {
thread_extent_y = new_thread_extent.value()->value;
} else if (Optional<tir::BlockRV> block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) {
ICHECK_NE(thread_extent_x, -1);
if (vector_lane > 1) {
tasks.push_back([thread_extent_x, thread_extent_y, vector_lane, sch,
block = block_rv.value()]() -> void {
tir::LoopRV fused = sch->GetLoops(block).back();
if (thread_extent_y == -1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_x), //
Integer(vector_lane)});
sch->Vectorize(split[2]);
sch->Bind(split[1], "threadIdx.x");
} else {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_y), //
Integer(thread_extent_x), //
Integer(vector_lane)});
sch->Vectorize(split[3]);
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
}
});
continue;
}
Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane);
if (!opt_block_rv.defined()) {
continue;
}
auto task = [thread_extent_x, thread_extent_y, vector_lane, sch,
block = opt_block_rv.value()]() mutable -> void {
sch->Unannotate(block, tir::attr::meta_schedule_cooperative_fetch);
tir::LoopRV fused = sch->GetLoops(block).back();
int64_t fused_extent = -1;
if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(fused).get())) {
fused_extent = *extent;
} else {
tasks.push_back(
[thread_extent_x, thread_extent_y, sch, block = block_rv.value()]() -> void {
tir::LoopRV fused = sch->GetLoops(block).back();
if (thread_extent_y == -1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)});
sch->Bind(split[1], "threadIdx.x");
} else {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_y), //
Integer(thread_extent_x)});
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
}
});
return;
}
}
if (fused_extent % vector_lane != 0) {
vector_lane = 1;
}
if (thread_extent_y != -1) {
if (vector_lane > 1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_y), //
Integer(thread_extent_x), //
Integer(vector_lane)});
sch->Vectorize(split[3]);
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
} else {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_y), //
Integer(thread_extent_x)});
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
}
} else {
if (vector_lane > 1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_x), //
Integer(vector_lane)});
sch->Vectorize(split[2]);
sch->Bind(split[1], "threadIdx.x");
} else {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)});
sch->Bind(split[1], "threadIdx.x");
}
}
};
tasks.push_back(task);
}
for (auto&& task : tasks) {
task();
Expand Down
Loading