-
Notifications
You must be signed in to change notification settings - Fork 639
[TorchToLinalg] Fix multi batch matmul conversion to Linalg #4319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TorchToLinalg] Fix multi batch matmul conversion to Linalg #4319
Conversation
Co-authored by: [email protected] Improve usage of static shape information to avoid unnecessary broadcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for submitting the PR!
Let me know if you have trouble finding an existing utility function, since it would be better to re-use existing util functions rather than filling the current file with one-off functions.
I'd also prefer if we added a lit test to cover this change.
static int64_t getDimFromValue(Value dimValue) { | ||
if (auto constOp = dimValue.getDefiningOp<arith::ConstantOp>()) { | ||
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) { | ||
return intAttr.getInt(); | ||
} | ||
} | ||
return ShapedType::kDynamic; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this exact function likely already exists in some Utils.cpp
file.
There may even be a computeBroadcastShape
util function at this point, so I'd take a look and see if we can re-use some of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! I will have a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting rid of the useless dynamic broadcast is great, thanks. When you clean up the lit test a bit, I'll stamp and merge.
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index | ||
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index | ||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index | ||
// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index | ||
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index | ||
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index | ||
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a bit flaky and adds a bunch of checks we don't actually care about in the logic of the conversion.
I'd just verify the relevant ops. E.g. check the collapse, expand, and linalg batch mm ops. The specific values for fill ops etc. aren't super important to check here.
Co-authored by: [email protected]
Improve usage of static shape information to avoid unnecessary broadcast.