Skip to content

Commit edce105

Browse files
committed
[MetaSchedule] Upstream the leftover changes
1 parent 2fc7d16 commit edce105

File tree

12 files changed

+268
-174
lines changed

12 files changed

+268
-174
lines changed

python/tvm/meta_schedule/builder/local_builder.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
3030
from .builder import BuilderInput, BuilderResult, PyBuilder
3131

32-
3332
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
3433

3534

@@ -236,11 +235,9 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA
236235
"""
237236
# pylint: disable=import-outside-toplevel
238237
from tvm.driver import build as tvm_build
239-
from tvm.ir.transform import PassContext
240238

241239
# pylint: enable=import-outside-toplevel
242-
with PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]):
243-
return tvm_build(mod, target=target)
240+
return tvm_build(mod, target=target)
244241

245242

246243
@register_func("meta_schedule.builder.default_export")

python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import tvm
2222
from tvm import auto_scheduler
23-
from tvm.meta_schedule.runner import RPCConfig
2423
from tvm.meta_schedule.testing.te_workload import CONFIGS
2524

2625

@@ -56,19 +55,18 @@ def _parse_args():
5655
type=str,
5756
required=True,
5857
)
58+
args.add_argument(
59+
"--rpc-workers",
60+
type=int,
61+
required=True,
62+
)
5963
args.add_argument(
6064
"--log-dir",
6165
type=str,
6266
required=True,
6367
)
6468
parsed = args.parse_args()
6569
parsed.target = tvm.target.Target(parsed.target)
66-
parsed.rpc_workers = RPCConfig(
67-
tracker_host=parsed.rpc_host,
68-
tracker_port=parsed.rpc_port,
69-
tracker_key=parsed.rpc_key,
70-
session_timeout_sec=30,
71-
).count_num_servers(allow_missing=True)
7270
return parsed
7371

7472

@@ -93,6 +91,7 @@ def main():
9391
cache_line_bytes=64,
9492
max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]),
9593
max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]),
94+
max_local_memory_per_block=12345678,
9695
max_vthread_extent=8,
9796
warp_size=32,
9897
)

python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,19 @@ def _parse_args():
6363
type=str,
6464
required=True,
6565
)
66+
args.add_argument(
67+
"--rpc-workers",
68+
type=int,
69+
required=True,
70+
)
6671
parsed = args.parse_args()
6772
parsed.target = tvm.target.Target(parsed.target)
68-
if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu":
69-
parsed.alloc_repeat = 3
70-
else:
71-
parsed.alloc_repeat = 1
7273
parsed.rpc_config = ms.runner.RPCConfig(
7374
tracker_host=parsed.rpc_host,
7475
tracker_port=parsed.rpc_port,
7576
tracker_key=parsed.rpc_key,
76-
session_timeout_sec=30,
77+
session_timeout_sec=60,
7778
)
78-
parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False)
7979
return parsed
8080

8181

@@ -85,6 +85,7 @@ def _parse_args():
8585

8686

8787
def main():
88+
alloc_repeat = 1
8889
runner = ms.runner.RPCRunner(
8990
rpc_config=ARGS.rpc_config,
9091
evaluator_config=ms.runner.EvaluatorConfig(
@@ -93,7 +94,7 @@ def main():
9394
min_repeat_ms=100,
9495
enable_cpu_cache_flush=False,
9596
),
96-
alloc_repeat=ARGS.alloc_repeat,
97+
alloc_repeat=alloc_repeat,
9798
max_workers=ARGS.rpc_workers,
9899
)
99100
sch: Optional[tir.Schedule] = ms.tune_tir(

src/driver/driver_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
270270
if (!disable_storage_rewrite) {
271271
pass_list.push_back(tir::transform::StorageRewrite());
272272
}
273+
pass_list.push_back(tir::transform::Simplify());
273274
pass_list.push_back(tir::transform::UnrollLoop());
274275

275276
// Add user-defined phase-2 passes

src/meta_schedule/measure_callback/add_to_database.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@ class AddToDatabaseNode : public MeasureCallbackNode {
3636
for (int i = 0; i < n; ++i) {
3737
RunnerResult result = runner_results[i];
3838
MeasureCandidate candidate = measure_candidates[i];
39-
if (result->error_msg.defined()) {
40-
continue;
39+
Array<FloatImm> run_secs{nullptr};
40+
if (result->run_secs.defined()) {
41+
run_secs = result->run_secs.value();
42+
} else {
43+
run_secs = Array<FloatImm>{FloatImm(DataType::Float(32), 1e10)};
4144
}
4245
database->CommitTuningRecord(TuningRecord(
4346
/*trace=*/candidate->sch->trace().value(),
44-
/*run_secs=*/result->run_secs.value(),
47+
/*run_secs=*/run_secs,
4548
/*workload=*/workload,
4649
/*target=*/target,
4750
/*args_info=*/candidate->args_info));

src/meta_schedule/mutator/mutate_tile_size.cc

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ int64_t Product(const std::vector<int64_t>& array) {
5151
return result;
5252
}
5353

54-
/*! \brief A mutator that mutates the decision of instruction Sample-Perfect-Tile */
54+
/*! \brief A mutator that mutates the tile size */
5555
class MutateTileSizeNode : public MutatorNode {
5656
public:
5757
void VisitAttrs(tvm::AttrVisitor* v) {}
@@ -66,10 +66,12 @@ class MutateTileSizeNode : public MutatorNode {
6666
};
6767

6868
/*!
69-
* \brief Find the Sample-Perfect-Tile instructions and their decisions in the trace
69+
* \brief Find a sample-perfect-tile decision in the trace
7070
* \param trace The trace
71-
* \param inst The instructions found
72-
* \param decision The decisions of the instructions found
71+
* \param rand_state The random state
72+
* \param inst The instruction selected
73+
* \param decision The decision selected
74+
* \return Whether a decision is found
7375
*/
7476
void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
7577
std::vector<std::vector<int64_t>>* decision) {
@@ -92,13 +94,6 @@ void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
9294
}
9395
}
9496

95-
/*!
96-
* \brief Find all Sample-Categorical instructions (and their decisions) whose outputs are used for
97-
* cooperative fetch annotation
98-
* \param trace The trace
99-
* \param inst The instructions found
100-
* \param decision The decisions of the instructions found
101-
*/
10297
void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
10398
std::vector<int64_t>* decision) {
10499
static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
@@ -137,17 +132,12 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
137132
}
138133

139134
struct FactorMemo {
140-
/*!
141-
* \brief Find all factors of the input integer
142-
* \param n The integer to be factorized
143-
* \return The factors of the input integer
144-
*/
145135
static std::vector<int> Factorize(int n) {
146136
if (const std::vector<int>* result = Global()->Query(n)) {
147137
return *result;
148138
}
149139
std::vector<int> result;
150-
for (int64_t i = 1; i * i < n; ++i) {
140+
for (int64_t i = 1; i * i <= n; ++i) {
151141
if (n % i == 0) {
152142
result.push_back(i);
153143
if (i * i != n) {
@@ -162,26 +152,26 @@ struct FactorMemo {
162152

163153
private:
164154
const std::vector<int>* Query(int n) {
165-
std::unique_lock<std::mutex> lock(mutex);
166-
auto it = memo.find(n);
167-
if (it != memo.end()) {
155+
std::unique_lock<std::mutex> lock(mutex_);
156+
auto it = memo_.find(n);
157+
if (it != memo_.end()) {
168158
return &it->second;
169159
}
170160
return nullptr;
171161
}
172162

173163
void Add(int n, std::vector<int> result) {
174-
std::unique_lock<std::mutex> lock(mutex);
175-
memo.emplace(n, std::move(result));
164+
std::unique_lock<std::mutex> lock(mutex_);
165+
memo_.emplace(n, std::move(result));
176166
}
177167

178168
static FactorMemo* Global() {
179169
static FactorMemo singleton;
180170
return &singleton;
181171
}
182172

183-
std::unordered_map<int, std::vector<int>> memo;
184-
std::mutex mutex;
173+
std::unordered_map<int, std::vector<int>> memo_;
174+
std::mutex mutex_;
185175
};
186176

187177
Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,

src/meta_schedule/postproc/rewrite_cooperative_fetch.cc

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& ins
4949
* \param vector_lane The number of vector lane in vectorized cooperative fetching
5050
* \return NullOpt if parsing fails; Otherwise, the annotated block
5151
*/
52-
Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst, int* vector_lane) {
52+
Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst,
53+
int64_t* vector_lane) {
5354
static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
5455
if (!inst->kind.same_as(inst_kind_annotate)) {
5556
return NullOpt;
@@ -87,55 +88,66 @@ class RewriteCooperativeFetchNode : public PostprocNode {
8788

8889
bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
8990
tir::Trace trace = sch->trace().value();
90-
int thread_extent_x = -1;
91-
int thread_extent_y = -1;
92-
int vector_lane = -1;
91+
int64_t thread_extent_x = -1;
92+
int64_t thread_extent_y = -1;
93+
int64_t vector_lane = 1;
9394
std::vector<std::function<void()>> tasks;
9495
for (const tir::Instruction& inst : trace->insts) {
9596
if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) {
9697
thread_extent_x = new_thread_extent.value()->value;
97-
} else if (Optional<Integer> new_thread_extent =
98-
tir::ParseThreadBinding(sch, inst, "threadIdx.y")) {
98+
continue;
99+
}
100+
if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y")) {
99101
thread_extent_y = new_thread_extent.value()->value;
100-
} else if (Optional<tir::BlockRV> block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) {
101-
ICHECK_NE(thread_extent_x, -1);
102-
if (vector_lane > 1) {
103-
tasks.push_back([thread_extent_x, thread_extent_y, vector_lane, sch,
104-
block = block_rv.value()]() -> void {
105-
tir::LoopRV fused = sch->GetLoops(block).back();
106-
if (thread_extent_y == -1) {
107-
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
108-
Integer(thread_extent_x), //
109-
Integer(vector_lane)});
110-
sch->Vectorize(split[2]);
111-
sch->Bind(split[1], "threadIdx.x");
112-
} else {
113-
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
114-
Integer(thread_extent_y), //
115-
Integer(thread_extent_x), //
116-
Integer(vector_lane)});
117-
sch->Vectorize(split[3]);
118-
sch->Bind(split[2], "threadIdx.x");
119-
sch->Bind(split[1], "threadIdx.y");
120-
}
121-
});
102+
continue;
103+
}
104+
Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane);
105+
if (!opt_block_rv.defined()) {
106+
continue;
107+
}
108+
auto task = [thread_extent_x, thread_extent_y, vector_lane, sch,
109+
block = opt_block_rv.value()]() mutable -> void {
110+
sch->Unannotate(block, tir::attr::meta_schedule_cooperative_fetch);
111+
tir::LoopRV fused = sch->GetLoops(block).back();
112+
int64_t fused_extent = -1;
113+
if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(fused).get())) {
114+
fused_extent = *extent;
122115
} else {
123-
tasks.push_back(
124-
[thread_extent_x, thread_extent_y, sch, block = block_rv.value()]() -> void {
125-
tir::LoopRV fused = sch->GetLoops(block).back();
126-
if (thread_extent_y == -1) {
127-
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)});
128-
sch->Bind(split[1], "threadIdx.x");
129-
} else {
130-
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
131-
Integer(thread_extent_y), //
132-
Integer(thread_extent_x)});
133-
sch->Bind(split[2], "threadIdx.x");
134-
sch->Bind(split[1], "threadIdx.y");
135-
}
136-
});
116+
return;
137117
}
138-
}
118+
if (fused_extent % vector_lane != 0) {
119+
vector_lane = 1;
120+
}
121+
if (thread_extent_y != -1) {
122+
if (vector_lane > 1) {
123+
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
124+
Integer(thread_extent_y), //
125+
Integer(thread_extent_x), //
126+
Integer(vector_lane)});
127+
sch->Vectorize(split[3]);
128+
sch->Bind(split[2], "threadIdx.x");
129+
sch->Bind(split[1], "threadIdx.y");
130+
} else {
131+
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
132+
Integer(thread_extent_y), //
133+
Integer(thread_extent_x)});
134+
sch->Bind(split[2], "threadIdx.x");
135+
sch->Bind(split[1], "threadIdx.y");
136+
}
137+
} else {
138+
if (vector_lane > 1) {
139+
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
140+
Integer(thread_extent_x), //
141+
Integer(vector_lane)});
142+
sch->Vectorize(split[2]);
143+
sch->Bind(split[1], "threadIdx.x");
144+
} else {
145+
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)});
146+
sch->Bind(split[1], "threadIdx.x");
147+
}
148+
}
149+
};
150+
tasks.push_back(task);
139151
}
140152
for (auto&& task : tasks) {
141153
task();

0 commit comments

Comments
 (0)