-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[BugTIR] fix thread_sync occurs in letstmt #16454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @vinx13 @spectrometerHBH please spend a little time do a review if you see this |
src/tir/transforms/storage_access.cc
Outdated
|
|
||
| void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { | ||
| allow_append_ = true; | ||
| curr_stmt_.access.clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't call clear here, it can only be used after finishing handling a statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you check my new solution?
| A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] | ||
| in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, scope="local") | ||
| in_thread_A_temp_1[0] = T.float32(0) | ||
| with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in case here, before take change, StorageAccessVisitor will call VisitExpr_(const BufferLoadNode* op) directly (not called by father call like VisitStmt_(const BufferStoreNode* op) traverse child), so ICHECK(allow_append_) << op << " " << scope.to_string(); will return error.
if we take VisitStmt_(const LetStmtNode* op) and traverse child of LetStmtNode
(case here it will traverse BufferLoadNode A_shared which will increase curr_stmt_.access by 1, and then visit BufferStoreNode in_thread_A_temp , ICHECK_EQ(curr_stmt_.access.size(), 0U); will return error)
Do you have any insights on how to solve this problem? @vinx13
See original discuss
LayerNorm Error in thread_storage_sync when read x into shared memory
I try to read x into shared memory to accelerate layernorm, script here 1
but error occurs in pass thread_sorage_sync pass.I found it is because the error in lower LetStmt ,lowered script