@@ -2390,29 +2390,29 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
23902390 // Source has an exact match or singleton value for all trailing dimensions
23912391 // (all leading dimensions are simply duplicated).
23922392 int64_t lead = dstRank - srcRank;
2393- for (int64_t r = 0 ; r < srcRank; ++r ) {
2393+ for (int64_t dimIdx = 0 ; dimIdx < srcRank; ++dimIdx ) {
23942394 bool mismatch = false ;
23952395
2396- // Check fixed-width dims
2397- int64_t srcDim = srcVectorType.getDimSize (r );
2398- int64_t dstDim = dstVectorType.getDimSize (lead + r );
2399- if (( srcDim != 1 && srcDim != dstDim) )
2396+ // Check fixed-width dims.
2397+ int64_t srcDim = srcVectorType.getDimSize (dimIdx );
2398+ int64_t dstDim = dstVectorType.getDimSize (lead + dimIdx );
2399+ if (srcDim != 1 && srcDim != dstDim)
24002400 mismatch = true ;
24012401
2402- // Check scalable flags
2403- bool srcDimScalableFlag = srcVectorType.getScalableDims ()[r ];
2404- bool dstDimScalableFlag = dstVectorType.getScalableDims ()[lead + r ];
2402+ // Check scalable flags.
2403+ bool srcDimScalableFlag = srcVectorType.getScalableDims ()[dimIdx ];
2404+ bool dstDimScalableFlag = dstVectorType.getScalableDims ()[lead + dimIdx ];
24052405 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1 ) ||
24062406 (srcDimScalableFlag != dstDimScalableFlag))
24072407 mismatch = true ;
24082408
24092409 if (mismatch) {
2410- if (mismatchingDims) {
2410+ if (mismatchingDims != nullptr ) {
24112411 mismatchingDims->first .dim = srcDim;
2412- mismatchingDims->first .scalableFlag = srcDimScalableFlag;
2412+ mismatchingDims->first .isScalable = srcDimScalableFlag;
24132413
24142414 mismatchingDims->second .dim = dstDim;
2415- mismatchingDims->second .scalableFlag = dstDimScalableFlag;
2415+ mismatchingDims->second .isScalable = dstDimScalableFlag;
24162416 }
24172417 return BroadcastableToResult::DimensionMismatch;
24182418 }
@@ -2430,15 +2430,14 @@ LogicalResult BroadcastOp::verify() {
24302430 if (res == BroadcastableToResult::SourceRankHigher)
24312431 return emitOpError (" source rank higher than destination rank" );
24322432 if (res == BroadcastableToResult::DimensionMismatch) {
2433- std::string msg =
2434- (Twine (" dimension mismatch (" ) +
2435- (mismatchingDims.first .scalableFlag ? " [" : " " ) +
2436- std::to_string (mismatchingDims.first .dim ) +
2437- (mismatchingDims.first .scalableFlag ? " ]" : " " ) + " vs. " +
2438- (mismatchingDims.second .scalableFlag ? " [" : " " ) +
2439- std::to_string (mismatchingDims.second .dim ) +
2440- (mismatchingDims.second .scalableFlag ? " ]" : " " ) + " )" )
2441- .str ();
2433+ std::string msg = (Twine (" dimension mismatch (" ) +
2434+ (mismatchingDims.first .isScalable ? " [" : " " ) +
2435+ std::to_string (mismatchingDims.first .dim ) +
2436+ (mismatchingDims.first .isScalable ? " ]" : " " ) + " vs. " +
2437+ (mismatchingDims.second .isScalable ? " [" : " " ) +
2438+ std::to_string (mismatchingDims.second .dim ) +
2439+ (mismatchingDims.second .isScalable ? " ]" : " " ) + " )" )
2440+ .str ();
24422441 return emitOpError (msg);
24432442 }
24442443 if (res == BroadcastableToResult::SourceTypeNotAVector)
0 commit comments