Skip to content

Commit d8b2aa3

Browse files
committed
clean up using namespace
1 parent eb05d25 commit d8b2aa3

File tree

3 files changed

+21
-24
lines changed

3 files changed

+21
-24
lines changed

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ std::vector<int> GetReadBufferNDims(const StmtSRef& block_sref) {
5252
namespace tvm {
5353
namespace meta_schedule {
5454

55+
using tir::BlockRV;
56+
using tir::IterVarType;
57+
using tir::LoopRV;
58+
using tir::Schedule;
59+
5560
// Do nothing; Inherited from ScheduleRuleNode
5661
void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) {
5762
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block")) {
@@ -163,12 +168,12 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
163168
}
164169
// Do the split
165170
int n_tiles = idx->size();
166-
Array<ExprRV> factors = sch->SamplePerfectTile(
171+
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
167172
/*loop=*/loop,
168173
/*n=*/n_tiles,
169174
/*max_innermost_factor=*/max_innermost_factor);
170-
Array<LoopRV> splits = sch->Split(/*loop=*/loop,
171-
/*factors=*/{factors.begin(), factors.end()});
175+
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
176+
/*factors=*/{factors.begin(), factors.end()});
172177
// Put every tile to its slot
173178
for (int j = 0; j < n_tiles; ++j) {
174179
tiles[idx->at(j)].push_back(splits[j]);
@@ -230,7 +235,7 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
230235
if (!vector_load_lens.empty()) {
231236
int n = vector_load_lens.size();
232237
double prob = 1.0 / n;
233-
ExprRV vector_load_len =
238+
tir::ExprRV vector_load_len =
234239
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
235240
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
236241
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@
2323
namespace tvm {
2424
namespace meta_schedule {
2525

26-
using tir::BlockRV;
27-
using tir::ExprRV;
28-
using tir::IterVarType;
29-
using tir::LoopRV;
30-
using tir::Schedule;
31-
3226
/*!
3327
* \brief Configuration of data reuse type:
3428
* 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed.
@@ -83,15 +77,16 @@ struct ReuseConfig {
8377
/*! \brief The state of auto scheduling for the multi-level tiling rule */
8478
struct State {
8579
/*! \brief The schedule to date */
86-
Schedule sch;
80+
tir::Schedule sch;
8781
/*! \brief The block to be tiled */
88-
BlockRV block_rv;
82+
tir::BlockRV block_rv;
8983
/*! \brief The loop tiles */
90-
Array<Array<LoopRV>> tiles;
84+
Array<Array<tir::LoopRV>> tiles;
9185

9286
/*! \brief Default constructor */
93-
explicit State(Schedule sch, BlockRV block_rv, Optional<BlockRV> write_cache = NullOpt,
94-
bool write_cache_is_added = false, Array<Array<LoopRV>> tiles = {})
87+
explicit State(tir::Schedule sch, tir::BlockRV block_rv,
88+
Optional<tir::BlockRV> write_cache = NullOpt, bool write_cache_is_added = false,
89+
Array<Array<tir::LoopRV>> tiles = {})
9590
: sch(sch), block_rv(block_rv), tiles(tiles) {}
9691
};
9792

@@ -131,7 +126,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
131126
void InitializeWithTuneContext(const TuneContext& context) final;
132127

133128
// Entry of the mega rule; Inherited from ScheduleRuleNode
134-
Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final;
129+
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final;
135130

136131
protected:
137132
virtual std::vector<State> ApplySubRules(std::vector<State> states);

src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
namespace tvm {
2525
namespace meta_schedule {
2626

27+
using tir::LoopRV;
28+
2729
/*! \brief Necessary information used for tensorization */
2830
class TensorizeInfoNode : public Object {
2931
public:
@@ -182,7 +184,7 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
182184
return TensorizeInfo(ret);
183185
}
184186

185-
Optional<LoopRV> TilingwithTensorIntrin(const Schedule& sch, const BlockRV& block_rv,
187+
Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
186188
const String& intrin_name) {
187189
Optional<TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
188190
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
@@ -244,15 +246,12 @@ Optional<LoopRV> TilingwithTensorIntrin(const Schedule& sch, const BlockRV& bloc
244246
}
245247

246248
std::vector<State> TileForVNNI(State state) {
247-
std::vector<State> result;
248-
BlockRV block_rv = state.block_rv;
249249
const std::string intrin_name = "dot_16x4_vnni";
250-
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, intrin_name);
250+
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, state.block_rv, intrin_name);
251251
ICHECK(tiled_loop_rv.defined());
252252
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
253253
state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
254-
result.push_back(state);
255-
return result;
254+
return {state};
256255
}
257256

258257
class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
@@ -267,8 +266,6 @@ class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
267266
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingVNNINode, MultiLevelTilingNode);
268267
};
269268

270-
// Constructor
271-
272269
ScheduleRule ScheduleRule::MultiLevelTilingVNNI(String structure,
273270
Optional<Array<String>> tile_binds,
274271
Optional<Integer> max_innermost_factor,

0 commit comments

Comments
 (0)