@@ -49,7 +49,8 @@ Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& ins
4949 * \param vector_lane The number of vector lane in vectorized cooperative fetching
5050 * \return NullOpt if parsing fails; Otherwise, the annotated block
5151 */
52- Optional<BlockRV> ParseAnnotate (const Schedule& sch, const Instruction& inst, int * vector_lane) {
52+ Optional<BlockRV> ParseAnnotate (const Schedule& sch, const Instruction& inst,
53+ int64_t * vector_lane) {
5354 static InstructionKind inst_kind_annotate = InstructionKind::Get (" Annotate" );
5455 if (!inst->kind .same_as (inst_kind_annotate)) {
5556 return NullOpt;
@@ -87,55 +88,66 @@ class RewriteCooperativeFetchNode : public PostprocNode {
8788
8889bool RewriteCooperativeFetchNode::Apply (const tir::Schedule& sch) {
8990 tir::Trace trace = sch->trace ().value ();
90- int thread_extent_x = -1 ;
91- int thread_extent_y = -1 ;
92- int vector_lane = - 1 ;
91+ int64_t thread_extent_x = -1 ;
92+ int64_t thread_extent_y = -1 ;
93+ int64_t vector_lane = 1 ;
9394 std::vector<std::function<void ()>> tasks;
9495 for (const tir::Instruction& inst : trace->insts ) {
9596 if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding (sch, inst, " threadIdx.x" )) {
9697 thread_extent_x = new_thread_extent.value ()->value ;
97- } else if (Optional<Integer> new_thread_extent =
98- tir::ParseThreadBinding (sch, inst, " threadIdx.y" )) {
98+ continue ;
99+ }
100+ if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding (sch, inst, " threadIdx.y" )) {
99101 thread_extent_y = new_thread_extent.value ()->value ;
100- } else if (Optional<tir::BlockRV> block_rv = tir::ParseAnnotate (sch, inst, &vector_lane)) {
101- ICHECK_NE (thread_extent_x, -1 );
102- if (vector_lane > 1 ) {
103- tasks.push_back ([thread_extent_x, thread_extent_y, vector_lane, sch,
104- block = block_rv.value ()]() -> void {
105- tir::LoopRV fused = sch->GetLoops (block).back ();
106- if (thread_extent_y == -1 ) {
107- Array<tir::LoopRV> split = sch->Split (fused, {NullOpt, //
108- Integer (thread_extent_x), //
109- Integer (vector_lane)});
110- sch->Vectorize (split[2 ]);
111- sch->Bind (split[1 ], " threadIdx.x" );
112- } else {
113- Array<tir::LoopRV> split = sch->Split (fused, {NullOpt, //
114- Integer (thread_extent_y), //
115- Integer (thread_extent_x), //
116- Integer (vector_lane)});
117- sch->Vectorize (split[3 ]);
118- sch->Bind (split[2 ], " threadIdx.x" );
119- sch->Bind (split[1 ], " threadIdx.y" );
120- }
121- });
102+ continue ;
103+ }
104+ Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate (sch, inst, &vector_lane);
105+ if (!opt_block_rv.defined ()) {
106+ continue ;
107+ }
108+ auto task = [thread_extent_x, thread_extent_y, vector_lane, sch,
109+ block = opt_block_rv.value ()]() mutable -> void {
110+ sch->Unannotate (block, tir::attr::meta_schedule_cooperative_fetch);
111+ tir::LoopRV fused = sch->GetLoops (block).back ();
112+ int64_t fused_extent = -1 ;
113+ if (const int64_t * extent = tir::GetLoopIntExtent (sch->Get (fused).get ())) {
114+ fused_extent = *extent;
122115 } else {
123- tasks.push_back (
124- [thread_extent_x, thread_extent_y, sch, block = block_rv.value ()]() -> void {
125- tir::LoopRV fused = sch->GetLoops (block).back ();
126- if (thread_extent_y == -1 ) {
127- Array<tir::LoopRV> split = sch->Split (fused, {NullOpt, Integer (thread_extent_x)});
128- sch->Bind (split[1 ], " threadIdx.x" );
129- } else {
130- Array<tir::LoopRV> split = sch->Split (fused, {NullOpt, //
131- Integer (thread_extent_y), //
132- Integer (thread_extent_x)});
133- sch->Bind (split[2 ], " threadIdx.x" );
134- sch->Bind (split[1 ], " threadIdx.y" );
135- }
136- });
116+ return ;
137117 }
138- }
118+ if (fused_extent % vector_lane != 0 ) {
119+ vector_lane = 1 ;
120+ }
121+ if (thread_extent_y != -1 ) {
122+ if (vector_lane > 1 ) {
123+ Array<tir::LoopRV> split = sch->Split (fused, {NullOpt, //
124+ Integer (thread_extent_y), //
125+ Integer (thread_extent_x), //
126+ Integer (vector_lane)});
127+ sch->Vectorize (split[3 ]);
128+ sch->Bind (split[2 ], " threadIdx.x" );
129+ sch->Bind (split[1 ], " threadIdx.y" );
130+ } else {
131+ Array<tir::LoopRV> split = sch->Split (fused, {NullOpt, //
132+ Integer (thread_extent_y), //
133+ Integer (thread_extent_x)});
134+ sch->Bind (split[2 ], " threadIdx.x" );
135+ sch->Bind (split[1 ], " threadIdx.y" );
136+ }
137+ } else {
138+ if (vector_lane > 1 ) {
139+ Array<tir::LoopRV> split = sch->Split (fused, {NullOpt, //
140+ Integer (thread_extent_x), //
141+ Integer (vector_lane)});
142+ sch->Vectorize (split[2 ]);
143+ sch->Bind (split[1 ], " threadIdx.x" );
144+ } else {
145+ Array<tir::LoopRV> split = sch->Split (fused, {NullOpt, Integer (thread_extent_x)});
146+ sch->Bind (split[1 ], " threadIdx.x" );
147+ }
148+ }
149+ };
150+ tasks.push_back (task);
139151 }
140152 for (auto && task : tasks) {
141153 task ();
0 commit comments