@@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
356356
357357FailureOr<LowerUnPackOpResult> linalg::lowerUnPack (RewriterBase &rewriter,
358358 tensor::UnPackOp unPackOp) {
359- // 1. Filter out NYI cases.
360- if (!unPackOp.getOuterDimsPerm ().empty () &&
361- !isIdentityPermutation (unPackOp.getOuterDimsPerm ())) {
362- return rewriter.notifyMatchFailure (unPackOp,
363- " non-identity outer dims perm NYI" );
364- }
365-
366359 Location loc = unPackOp->getLoc ();
367360 OpBuilder::InsertionGuard g (rewriter);
368361 rewriter.setInsertionPoint (unPackOp);
@@ -391,45 +384,42 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
391384 return LowerUnPackOpResult{/* emptyOp=*/ nullptr , /* transposeOp=*/ nullptr ,
392385 /* reshapeOp=*/ nullptr , extractSliceOp};
393386 }
394- // 2. Compute the permutation vector to move the last `numPackedDims` into
395- // the `innerPosDims` of a shape of rank `packedRank`.
396- int64_t numPackedDims = unPackOp.getInnerDimsPos ().size ();
397- auto lastDims = llvm::to_vector (
398- llvm::seq<int64_t >(packedRank - numPackedDims, packedRank));
399- PackingMetadata packingMetadata =
400- computePackingMetadata (packedRank, unPackOp.getInnerDimsPos ());
401- SmallVector<int64_t > lastDimsToInsertPositionsPerm = computePermutationVector (
402- packedRank, lastDims, packingMetadata.insertPositions );
403-
404- // 3. Compute the stripMinedShape: this is the packed shape without outer and
387+
388+ // 1. Compute the permutation vector to shuffle packed shape into the shape
389+ // before any outer or inner permutations have been applied.
390+ PackingMetadata packingMetadata;
391+ SmallVector<int64_t > packedToStripMinedShapePerm =
392+ tensor::getUnPackInverseSrcPerm (unPackOp, packingMetadata);
393+
394+ // 2. Compute the stripMinedShape: this is the packed shape without outer and
405395 // inner permutations.
406396 SmallVector<int64_t > stripMinedShape (packedTensorType.getShape ());
407- applyPermutationToVector (stripMinedShape, lastDimsToInsertPositionsPerm );
397+ applyPermutationToVector (stripMinedShape, packedToStripMinedShapePerm );
408398
409- // 4 . Transpose packedShape to stripMinedShape.
399+ // 3 . Transpose packedShape to stripMinedShape.
410400 RankedTensorType stripMinedTensorType =
411401 RankedTensorType::Builder (packedTensorType).setShape (stripMinedShape);
412402 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
413403 stripMinedTensorType, packingMetadata.reassociations );
414404
415- // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
405+ // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
416406 // permutation.
417407 SmallVector<OpFoldResult, 4 > dims =
418408 tensor::getMixedSizes (rewriter, loc, unPackOp.getSource ());
419- applyPermutationToVector (dims, lastDimsToInsertPositionsPerm );
409+ applyPermutationToVector (dims, packedToStripMinedShapePerm );
420410 auto emptyOp = rewriter.create <tensor::EmptyOp>(
421411 loc, dims, stripMinedTensorType.getElementType ());
422412 auto transposeOp = rewriter.create <linalg::TransposeOp>(
423- loc, unPackOp.getSource (), emptyOp, lastDimsToInsertPositionsPerm );
413+ loc, unPackOp.getSource (), emptyOp, packedToStripMinedShapePerm );
424414
425415 LLVM_DEBUG (
426416 DBGSNL (); DBGSNL (); llvm::interleaveComma (packingMetadata.insertPositions ,
427417 DBGS () << " insertPositions: " );
428418 DBGSNL (); llvm::interleaveComma (packedTensorType.getShape (),
429419 DBGS () << " packedShape: " );
430420 DBGSNL ();
431- llvm::interleaveComma (lastDimsToInsertPositionsPerm ,
432- DBGS () << " lastDimsToInsertPositionsPerm : " );
421+ llvm::interleaveComma (packedToStripMinedShapePerm ,
422+ DBGS () << " packedToStripMinedShapePerm : " );
433423 DBGSNL (); llvm::interleaveComma (
434424 packingMetadata.reassociations , DBGS () << " reassociations: " ,
435425 [&](ReassociationIndices ri) {
@@ -439,24 +429,24 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
439429 llvm::interleaveComma (stripMinedShape, DBGS () << " stripMinedShape: " );
440430 DBGSNL (); DBGS () << " collapsed type: " << collapsedType; DBGSNL (););
441431
442- // 5 . Collapse from the stripMinedShape to the padded result.
432+ // 4 . Collapse from the stripMinedShape to the padded result.
443433 auto reshapeOp = rewriter.create <tensor::CollapseShapeOp>(
444434 loc, collapsedType, transposeOp->getResult (0 ),
445435 packingMetadata.reassociations );
446436
447- // 6 . ExtractSlice.
437+ // 5 . ExtractSlice.
448438 int64_t destRank = destTensorType.getRank ();
449439 auto extractSliceOp = rewriter.create <tensor::ExtractSliceOp>(
450440 loc, destTensorType, reshapeOp->getResult (0 ),
451441 SmallVector<OpFoldResult>(destRank, zero),
452442 tensor::getMixedSizes (rewriter, loc, unPackOp.getDest ()),
453443 SmallVector<OpFoldResult>(destRank, one));
454444
455- // 7 . Inject a copy to preserve DPS.
445+ // 6 . Inject a copy to preserve DPS.
456446 auto copyOp = rewriter.create <linalg::CopyOp>(
457447 loc, extractSliceOp->getResult (0 ), unPackOp.getDest ());
458448
459- // 8 . Replace unPackOp by extractSliceOp .
449+ // 7 . Replace unPackOp by copyOp .
460450 rewriter.replaceOp (unPackOp, copyOp->getResults ());
461451
462452 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
0 commit comments