Skip to content

Conversation

@aartbik
Copy link
Contributor

@aartbik aartbik commented Oct 19, 2023

This makes sure

  • GEN MAP dim=2 lvl=4
    (d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2)
    --
    (d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3)

is indeed encoded as

MAP-REF (dim=2, lvl=4) isperm=0
d2l = [ d0/2 d1/2 d0%2 d1%2 ]
ld2 = [ l2+2l0 l3+2l1 ]

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Oct 19, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

This makes sure

  • GEN MAP dim=2 lvl=4
    (d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2)
    --
    (d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3)

is indeed encoded as

MAP-REF (dim=2, lvl=4) isperm=0
d2l = [ d0/2 d1/2 d0%2 d1%2 ]
ld2 = [ l2+2l0 l3+2l1 ]


Full diff: https://github.com/llvm/llvm-project/pull/69540.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (+42-20)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 98b412c8ec9eb5b..b1b1d67ac2d420d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
   // This code deals with permutations as well as non-permutations that
   // arise from rank changing blocking.
   const auto dimToLvl = stt.getDimToLvl();
+  const auto lvlToDim = stt.getLvlToDim();
   SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
   SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
   SmallVector<Value> lvlSizesValues(lvlRank);
@@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
     Dimension d = 0;
     uint64_t cf = 0, cm = 0;
     switch (exp.getKind()) {
-    case AffineExprKind::DimId:
+    case AffineExprKind::DimId: {
       d = exp.cast<AffineDimExpr>().getPosition();
       break;
-    case AffineExprKind::FloorDiv:
-      d = exp.cast<AffineBinaryOpExpr>()
-              .getLHS()
-              .cast<AffineDimExpr>()
-              .getPosition();
-      cf = exp.cast<AffineBinaryOpExpr>()
-               .getRHS()
-               .cast<AffineConstantExpr>()
-               .getValue();
+    }
+    case AffineExprKind::FloorDiv: {
+      auto floor = exp.cast<AffineBinaryOpExpr>();
+      d = floor.getLHS().cast<AffineDimExpr>().getPosition();
+      cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
       break;
-    case AffineExprKind::Mod:
-      d = exp.cast<AffineBinaryOpExpr>()
-              .getLHS()
-              .cast<AffineDimExpr>()
-              .getPosition();
-      cm = exp.cast<AffineBinaryOpExpr>()
-               .getRHS()
-               .cast<AffineConstantExpr>()
-               .getValue();
+    }
+    case AffineExprKind::Mod: {
+      auto mod = exp.cast<AffineBinaryOpExpr>();
+      d = mod.getLHS().cast<AffineDimExpr>().getPosition();
+      cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
       break;
+    }
     default:
       llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
     }
     dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
-    lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim
     // Compute the level sizes.
     //    (1) l = d        : size(d)
     //    (2) l = d / c    : size(d) / c
@@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
     }
     lvlSizesValues[l] = lvlSz;
   }
+  // Generate lvl2dim.
+  assert(dimRank == lvlToDim.getNumResults());
+  for (Dimension d = 0; d < dimRank; d++) {
+    AffineExpr exp = lvlToDim.getResult(d);
+    // We expect:
+    //    (1) d = l
+    //    (2) d = l' * c + l
+    Level l = 0, ll = 0;
+    uint64_t c = 0;
+    switch (exp.getKind()) {
+    case AffineExprKind::DimId: {
+      l = exp.cast<AffineDimExpr>().getPosition();
+      break;
+    }
+    case AffineExprKind::Add: {
+      // Always mul on lhs, symbol/constant on rhs.
+      auto add = exp.cast<AffineBinaryOpExpr>();
+      assert(add.getLHS().getKind() == AffineExprKind::Mul);
+      auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
+      ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
+      c = mul.getRHS().cast<AffineConstantExpr>().getValue();
+      l = add.getRHS().cast<AffineDimExpr>().getPosition();
+      break;
+    }
+    default:
+      llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
+    }
+    lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
+  }
   // Return buffers.
   dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
   lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:sparse Sparse compiler in MLIR mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants