@@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
23712371 return res;
23722372}
23732373
2374- BroadcastableToResult
2375- mlir::vector::isBroadcastableTo ( Type srcType, VectorType dstVectorType,
2376- std::pair<int , int > *mismatchingDims) {
2374+ BroadcastableToResult mlir::vector::isBroadcastableTo (
2375+ Type srcType, VectorType dstVectorType,
2376+ std::pair<VectorDim, VectorDim > *mismatchingDims) {
23772377 // Broadcast scalar to vector of the same element type.
23782378 if (srcType.isIntOrIndexOrFloat () && dstVectorType &&
23792379 getElementTypeOrSelf (srcType) == getElementTypeOrSelf (dstVectorType))
@@ -2390,13 +2390,31 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
23902390 // Source has an exact match or singleton value for all trailing dimensions
23912391 // (all leading dimensions are simply duplicated).
23922392 int64_t lead = dstRank - srcRank;
2393- for (int64_t r = 0 ; r < srcRank; ++r) {
2394- int64_t srcDim = srcVectorType.getDimSize (r);
2395- int64_t dstDim = dstVectorType.getDimSize (lead + r);
2396- if (srcDim != 1 && srcDim != dstDim) {
2397- if (mismatchingDims) {
2398- mismatchingDims->first = srcDim;
2399- mismatchingDims->second = dstDim;
2393+ for (int64_t dimIdx = 0 ; dimIdx < srcRank; ++dimIdx) {
2394+ // Have mismatching dims (in the sense of vector.broadcast semantics) been
2395+ // encountered?
2396+ bool foundMismatchingDims = false ;
2397+
2398+ // Check fixed-width dims.
2399+ int64_t srcDim = srcVectorType.getDimSize (dimIdx);
2400+ int64_t dstDim = dstVectorType.getDimSize (lead + dimIdx);
2401+ if (srcDim != 1 && srcDim != dstDim)
2402+ foundMismatchingDims = true ;
2403+
2404+ // Check scalable flags.
2405+ bool srcDimScalableFlag = srcVectorType.getScalableDims ()[dimIdx];
2406+ bool dstDimScalableFlag = dstVectorType.getScalableDims ()[lead + dimIdx];
2407+ if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1 ) ||
2408+ (srcDimScalableFlag != dstDimScalableFlag))
2409+ foundMismatchingDims = true ;
2410+
2411+ if (foundMismatchingDims) {
2412+ if (mismatchingDims != nullptr ) {
2413+ mismatchingDims->first .dim = srcDim;
2414+ mismatchingDims->first .isScalable = srcDimScalableFlag;
2415+
2416+ mismatchingDims->second .dim = dstDim;
2417+ mismatchingDims->second .isScalable = dstDimScalableFlag;
24002418 }
24012419 return BroadcastableToResult::DimensionMismatch;
24022420 }
@@ -2406,16 +2424,22 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
24062424}
24072425
24082426LogicalResult BroadcastOp::verify () {
2409- std::pair<int , int > mismatchingDims;
2427+ std::pair<VectorDim, VectorDim > mismatchingDims;
24102428 BroadcastableToResult res = isBroadcastableTo (
24112429 getSourceType (), getResultVectorType (), &mismatchingDims);
24122430 if (res == BroadcastableToResult::Success)
24132431 return success ();
24142432 if (res == BroadcastableToResult::SourceRankHigher)
24152433 return emitOpError (" source rank higher than destination rank" );
2416- if (res == BroadcastableToResult::DimensionMismatch)
2434+ if (res == BroadcastableToResult::DimensionMismatch) {
24172435 return emitOpError (" dimension mismatch (" )
2418- << mismatchingDims.first << " vs. " << mismatchingDims.second << " )" ;
2436+ << (mismatchingDims.first .isScalable ? " [" : " " )
2437+ << mismatchingDims.first .dim
2438+ << (mismatchingDims.first .isScalable ? " ]" : " " ) << " vs. "
2439+ << (mismatchingDims.second .isScalable ? " [" : " " )
2440+ << mismatchingDims.second .dim
2441+ << (mismatchingDims.second .isScalable ? " ]" : " " ) << " )" ;
2442+ }
24192443 if (res == BroadcastableToResult::SourceTypeNotAVector)
24202444 return emitOpError (" source type is not a vector" );
24212445 llvm_unreachable (" unexpected vector.broadcast op error" );
0 commit comments