From 3e32daf4b29a60eb62665d86fec7324e5a08a6e9 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 16 Oct 2023 13:17:27 -0700 Subject: [PATCH 1/2] [mlir][sparse] remove sparse2sparse path in library This cleans up all external entry points that will have to deal with non-permutations, making any subsequent refactoring much more local to the lib files. --- .../mlir/Dialect/SparseTensor/IR/Enums.h | 1 - .../ExecutionEngine/SparseTensor/Storage.h | 187 +----------------- .../ExecutionEngine/SparseTensorRuntime.h | 1 - .../SparseTensor/CMakeLists.txt | 1 - mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp | 79 -------- .../ExecutionEngine/SparseTensor/Storage.cpp | 13 +- .../ExecutionEngine/SparseTensorRuntime.cpp | 7 - .../llvm-project-overlay/mlir/BUILD.bazel | 1 - 8 files changed, 3 insertions(+), 287 deletions(-) delete mode 100644 mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index 0caf83a63b531..08887abcd0f10 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -145,7 +145,6 @@ enum class Action : uint32_t { kEmpty = 0, kEmptyForward = 1, kFromCOO = 2, - kSparseToSparse = 3, kFromReader = 4, kToCOO = 5, kPack = 7, diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h index c5be3d1acc337..beff393b94033 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -12,7 +12,6 @@ // * `SparseTensorStorage` // * `SparseTensorEnumeratorBase` // * `SparseTensorEnumerator` -// * `SparseTensorNNZ` // //===----------------------------------------------------------------------===// @@ -26,14 +25,6 @@ #include "mlir/ExecutionEngine/SparseTensor/ErrorHandling.h" #include "mlir/ExecutionEngine/SparseTensor/MapRef.h" -#define ASSERT_COMPRESSED_OR_SINGLETON_LVL(l) \ - do { \ - const DimLevelType dlt = getLvlType(l); \ - (void)dlt; \ - assert((isCompressedDLT(dlt) || isSingletonDLT(dlt)) && \ - "Level is neither compressed nor singleton"); \ - } while (false) - namespace mlir { namespace sparse_tensor { @@ -152,18 +143,6 @@ class SparseTensorStorageBase { // TODO: REMOVE THIS const std::vector &getLvl2Dim() const { return lvl2dimVec; } - /// Allocates a new enumerator. Callers must make sure to delete - /// the enumerator when they're done with it. The first argument - /// is the out-parameter for storing the newly allocated enumerator; - /// all other arguments are passed along to the `SparseTensorEnumerator` - /// ctor and must satisfy the preconditions/assertions thereof. -#define DECL_NEWENUMERATOR(VNAME, V) \ - virtual void newEnumerator(SparseTensorEnumeratorBase **, uint64_t, \ - const uint64_t *, uint64_t, const uint64_t *) \ - const; - MLIR_SPARSETENSOR_FOREVERY_V(DECL_NEWENUMERATOR) -#undef DECL_NEWENUMERATOR - /// Gets positions-overhead storage for the given level. #define DECL_GETPOSITIONS(PNAME, P) \ virtual void getPositions(std::vector

**, uint64_t); @@ -312,27 +291,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase { const DimLevelType *lvlTypes, const uint64_t *dim2lvl, const uint64_t *lvl2dim, SparseTensorCOO &lvlCOO); - /// Allocates a new sparse tensor and initializes it with the contents - /// of another sparse tensor. - // - // TODO: The `dimRank` and `dimShape` arguments are only used for - // verifying that the source tensor has the expected shape. So if we - // wanted to skip that verification, then we could remove those arguments. - // Alternatively, if we required the `dimShape` to be "sizes" instead, - // then that would remove any constraints on `source.getDimSizes()` - // (other than compatibility with `src2lvl`) as well as removing the - // requirement that `src2lvl` be the inverse of `lvl2dim`. Which would - // enable this factory to be used for performing a much larger class of - // transformations (which can already be handled by the `SparseTensorNNZ` - // implementation). - static SparseTensorStorage * - newFromSparseTensor(uint64_t dimRank, const uint64_t *dimShape, - uint64_t lvlRank, const uint64_t *lvlSizes, - const DimLevelType *lvlTypes, - const uint64_t *src2lvl, // FIXME: dim2lvl, - const uint64_t *lvl2dim, uint64_t srcRank, - const SparseTensorStorageBase &source); - /// Allocates a new sparse tensor and initialize it with the data stored level /// buffers directly. static SparseTensorStorage *packFromLvlBuffers( @@ -361,7 +319,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { /// Returns coordinate at given position. uint64_t getCrd(uint64_t lvl, uint64_t pos) const final { - ASSERT_COMPRESSED_OR_SINGLETON_LVL(lvl); + assert(isCompressedDLT(getLvlType(lvl)) || isSingletonDLT(getLvlType(lvl))); assert(pos < coordinates[lvl].size()); return coordinates[lvl][pos]; // Converts the stored `C` into `uint64_t`. } @@ -453,17 +411,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase { endPath(0); } - /// Allocates a new enumerator for this class's `` types and - /// erase the `` parts from the type. Callers must make sure to - /// delete the enumerator when they're done with it. - void newEnumerator(SparseTensorEnumeratorBase **out, uint64_t trgRank, - const uint64_t *trgSizes, uint64_t srcRank, - const uint64_t *src2trg) const final { - assert(out && "Received nullptr for out parameter"); - *out = new SparseTensorEnumerator(*this, trgRank, trgSizes, - srcRank, src2trg); - } - /// Allocates a new COO object and initializes it with the contents /// of this tensor under the given mapping from the `getDimSizes()` /// coordinate-space to the `trgSizes` coordinate-space. Callers must @@ -472,7 +419,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase { uint64_t srcRank, const uint64_t *src2trg, // FIXME: dim2lvl const uint64_t *lvl2dim) const { - // We inline `newEnumerator` to avoid virtual dispatch and allocation. // TODO: use MapRef here too for the translation SparseTensorEnumerator enumerator(*this, trgRank, trgSizes, srcRank, src2trg); @@ -584,7 +530,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { /// does not check that `crd` is semantically valid (i.e., in bounds /// for `dimSizes[lvl]` and not elsewhere occurring in the same segment). void writeCrd(uint64_t lvl, uint64_t pos, uint64_t crd) { - ASSERT_COMPRESSED_OR_SINGLETON_LVL(lvl); + assert(isCompressedDLT(getLvlType(lvl)) || isSingletonDLT(getLvlType(lvl))); // Subscript assignment to `std::vector` requires that the `pos`-th // entry has been initialized; thus we must be sure to check `size()` // here, instead of `capacity()` as would be ideal. @@ -735,8 +681,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase { SparseTensorCOO *lvlCOO; // COO used during forwarding }; -#undef ASSERT_COMPRESSED_OR_SINGLETON_LVL - //===----------------------------------------------------------------------===// // // SparseTensorEnumerator @@ -1025,33 +969,6 @@ SparseTensorStorage *SparseTensorStorage::newFromCOO( lvlTypes, dim2lvl, lvl2dim, lvlCOO); } -template -SparseTensorStorage *SparseTensorStorage::newFromSparseTensor( - uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank, - const uint64_t *lvlSizes, const DimLevelType *lvlTypes, - const uint64_t *src2lvl, // dim2lvl - const uint64_t *lvl2dim, uint64_t srcRank, - const SparseTensorStorageBase &source) { - // Verify that the `source` dimensions match the expected `dimShape`. - assert(dimShape && "Got nullptr for dimension shape"); - assert(dimRank == source.getDimRank() && "Dimension-rank mismatch"); - const auto &dimSizes = source.getDimSizes(); -#ifndef NDEBUG - for (uint64_t d = 0; d < dimRank; ++d) { - const uint64_t sz = dimShape[d]; - assert((sz == 0 || sz == dimSizes[d]) && - "Dimension-sizes do not match expected shape"); - } -#endif - SparseTensorEnumeratorBase *lvlEnumerator; - source.newEnumerator(&lvlEnumerator, lvlRank, lvlSizes, srcRank, src2lvl); - auto *tensor = new SparseTensorStorage(dimRank, dimSizes.data(), - lvlRank, lvlTypes, src2lvl, - lvl2dim, *lvlEnumerator); - delete lvlEnumerator; - return tensor; -} - template SparseTensorStorage *SparseTensorStorage::packFromLvlBuffers( uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank, @@ -1128,106 +1045,6 @@ SparseTensorStorage::SparseTensorStorage( // NOLINT fromCOO(elements, 0, nse, 0); } -template -SparseTensorStorage::SparseTensorStorage( - uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, - const DimLevelType *lvlTypes, const uint64_t *dim2lvl, - const uint64_t *lvl2dim, SparseTensorEnumeratorBase &lvlEnumerator) - : SparseTensorStorage(dimRank, dimSizes, lvlRank, - lvlEnumerator.getTrgSizes().data(), lvlTypes, dim2lvl, - lvl2dim) { - assert(lvlRank == lvlEnumerator.getTrgRank() && "Level-rank mismatch"); - { - // Initialize the statistics structure. - SparseTensorNNZ nnz(getLvlSizes(), getLvlTypes()); - nnz.initialize(lvlEnumerator); - // Initialize "positions" overhead (and allocate "coordinates", "values"). - uint64_t parentSz = 1; // assembled-size of the `(l - 1)`-level. - for (uint64_t l = 0; l < lvlRank; ++l) { - const auto dlt = lvlTypes[l]; // Avoid redundant bounds checking. - if (isCompressedDLT(dlt)) { - positions[l].reserve(parentSz + 1); - positions[l].push_back(0); - uint64_t currentPos = 0; - nnz.forallCoords(l, [this, ¤tPos, l](uint64_t n) { - currentPos += n; - appendPos(l, currentPos); - }); - assert(positions[l].size() == parentSz + 1 && - "Final positions size doesn't match allocated size"); - // That assertion entails `assembledSize(parentSz, l)` - // is now in a valid state. That is, `positions[l][parentSz]` - // equals the present value of `currentPos`, which is the - // correct assembled-size for `coordinates[l]`. - } - // Update assembled-size for the next iteration. - parentSz = assembledSize(parentSz, l); - // Ideally we need only `coordinates[l].reserve(parentSz)`, however - // the `std::vector` implementation forces us to initialize it too. - // That is, in the yieldPos loop we need random-access assignment - // to `coordinates[l]`; however, `std::vector`'s subscript-assignment - // only allows assigning to already-initialized positions. - if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) - coordinates[l].resize(parentSz, 0); - else - assert(isDenseDLT(dlt)); - } - values.resize(parentSz, 0); // Both allocate and zero-initialize. - } - // The yieldPos loop - lvlEnumerator.forallElements([this](const auto &lvlCoords, V val) { - uint64_t parentSz = 1, parentPos = 0; - for (uint64_t lvlRank = getLvlRank(), l = 0; l < lvlRank; ++l) { - const auto dlt = getLvlTypes()[l]; // Avoid redundant bounds checking. - if (isCompressedDLT(dlt)) { - // If `parentPos == parentSz` then it's valid as an array-lookup; - // however, it's semantically invalid here since that entry - // does not represent a segment of `coordinates[l]`. Moreover, that - // entry must be immutable for `assembledSize` to remain valid. - assert(parentPos < parentSz); - const uint64_t currentPos = positions[l][parentPos]; - // This increment won't overflow the `P` type, since it can't - // exceed the original value of `positions[l][parentPos+1]` - // which was already verified to be within bounds for `P` - // when it was written to the array. - positions[l][parentPos]++; - writeCrd(l, currentPos, lvlCoords[l]); - parentPos = currentPos; - } else if (isSingletonDLT(dlt)) { - writeCrd(l, parentPos, lvlCoords[l]); - // the new parentPos equals the old parentPos. - } else { // Dense level. - assert(isDenseDLT(dlt)); - parentPos = parentPos * getLvlSizes()[l] + lvlCoords[l]; - } - parentSz = assembledSize(parentSz, l); - } - assert(parentPos < values.size()); - values[parentPos] = val; - }); - // The finalizeYieldPos loop - for (uint64_t parentSz = 1, l = 0; l < lvlRank; ++l) { - const auto dlt = lvlTypes[l]; // Avoid redundant bounds checking. - if (isCompressedDLT(dlt)) { - assert(parentSz == positions[l].size() - 1 && - "Actual positions size doesn't match the expected size"); - // Can't check all of them, but at least we can check the last one. - assert(positions[l][parentSz - 1] == positions[l][parentSz] && - "Positions got corrupted"); - for (uint64_t n = 0; n < parentSz; ++n) { - const uint64_t parentPos = parentSz - n; - positions[l][parentPos] = positions[l][parentPos - 1]; - } - positions[l][0] = 0; - } else { - // Both dense and singleton are no-ops for the finalizeYieldPos loop. - // This assertion is for future-proofing. - assert((isDenseDLT(dlt) || isSingletonDLT(dlt))); - } - parentSz = assembledSize(parentSz, l); - } -} - template SparseTensorStorage::SparseTensorStorage( uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h index a470afc2f0c8c..8955b79f09197 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h @@ -47,7 +47,6 @@ extern "C" { /// kEmpty - STS, empty /// kEmptyForward - STS, empty, with forwarding COO /// kFromCOO COO STS, copied from the COO source -/// kSparseToSparse STS STS, copied from the STS source /// kToCOO STS COO, copied from the STS source /// kPack buffers STS, from level buffers /// kSortCOOInPlace STS STS, sorted in place diff --git a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt index c48af17b2d94b..15024b2475b91 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt @@ -8,7 +8,6 @@ add_mlir_library(MLIRSparseTensorRuntime File.cpp MapRef.cpp - NNZ.cpp Storage.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp b/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp deleted file mode 100644 index d3c3951c15468..0000000000000 --- a/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp +++ /dev/null @@ -1,79 +0,0 @@ -//===- NNZ.cpp - NNZ-statistics for direct sparse2sparse conversion -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains method definitions for `SparseTensorNNZ`. -// -//===----------------------------------------------------------------------===// - -#include "mlir/ExecutionEngine/SparseTensor/Storage.h" - -using namespace mlir::sparse_tensor; - -SparseTensorNNZ::SparseTensorNNZ(const std::vector &lvlSizes, - const std::vector &lvlTypes) - : lvlSizes(lvlSizes), lvlTypes(lvlTypes), nnz(getLvlRank()) { - assert(lvlSizes.size() == lvlTypes.size() && "Rank mismatch"); - bool alreadyCompressed = false; - (void)alreadyCompressed; - uint64_t sz = 1; // the product of all `lvlSizes` strictly less than `l`. - for (uint64_t l = 0, lvlrank = getLvlRank(); l < lvlrank; ++l) { - const DimLevelType dlt = lvlTypes[l]; - if (isCompressedDLT(dlt)) { - if (alreadyCompressed) - MLIR_SPARSETENSOR_FATAL( - "Multiple compressed levels not currently supported"); - alreadyCompressed = true; - nnz[l].resize(sz, 0); // Both allocate and zero-initialize. - } else if (isDenseDLT(dlt)) { - if (alreadyCompressed) - MLIR_SPARSETENSOR_FATAL( - "Dense after compressed not currently supported"); - } else if (isSingletonDLT(dlt)) { - // Singleton after Compressed causes no problems for allocating - // `nnz` nor for the yieldPos loop. This remains true even - // when adding support for multiple compressed dimensions or - // for dense-after-compressed. - } else { - MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n", - static_cast(dlt)); - } - sz = detail::checkedMul(sz, lvlSizes[l]); - } -} - -void SparseTensorNNZ::forallCoords(uint64_t stopLvl, - SparseTensorNNZ::NNZConsumer yield) const { - assert(stopLvl < getLvlRank() && "Level out of bounds"); - assert(isCompressedDLT(lvlTypes[stopLvl]) && - "Cannot look up non-compressed levels"); - forallCoords(yield, stopLvl, 0, 0); -} - -void SparseTensorNNZ::add(const std::vector &lvlCoords) { - uint64_t parentPos = 0; - for (uint64_t l = 0, lvlrank = getLvlRank(); l < lvlrank; ++l) { - if (isCompressedDLT(lvlTypes[l])) - nnz[l][parentPos]++; - parentPos = parentPos * lvlSizes[l] + lvlCoords[l]; - } -} - -void SparseTensorNNZ::forallCoords(SparseTensorNNZ::NNZConsumer yield, - uint64_t stopLvl, uint64_t parentPos, - uint64_t l) const { - assert(l <= stopLvl); - if (l == stopLvl) { - assert(parentPos < nnz[l].size() && "Cursor is out of range"); - yield(nnz[l][parentPos]); - } else { - const uint64_t sz = lvlSizes[l]; - const uint64_t pstart = parentPos * sz; - for (uint64_t i = 0; i < sz; ++i) - forallCoords(yield, stopLvl, pstart + i, l + 1); - } -} diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp index 050dff2da1fa4..f5890ebb6f3ff 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp @@ -44,21 +44,10 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT } } -// Helper macro for generating error messages when some -// `SparseTensorStorage` is cast to `SparseTensorStorageBase` -// and then the wrong "partial method specialization" is called. +// Helper macro for wrong "partial method specialization" errors. #define FATAL_PIV(NAME) \ MLIR_SPARSETENSOR_FATAL(" type mismatch for: " #NAME); -#define IMPL_NEWENUMERATOR(VNAME, V) \ - void SparseTensorStorageBase::newEnumerator( \ - SparseTensorEnumeratorBase **, uint64_t, const uint64_t *, uint64_t, \ - const uint64_t *) const { \ - FATAL_PIV("newEnumerator" #VNAME); \ - } -MLIR_SPARSETENSOR_FOREVERY_V(IMPL_NEWENUMERATOR) -#undef IMPL_NEWENUMERATOR - #define IMPL_GETPOSITIONS(PNAME, P) \ void SparseTensorStorageBase::getPositions(std::vector

**, uint64_t) { \ FATAL_PIV("getPositions" #PNAME); \ diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp index 74ab65c143d63..6a4c0f292c5f8 100644 --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -131,13 +131,6 @@ extern "C" { return SparseTensorStorage::newFromCOO( \ dimRank, dimSizes, lvlRank, lvlTypes, dim2lvl, lvl2dim, coo); \ } \ - case Action::kSparseToSparse: { \ - assert(ptr && "Received nullptr for SparseTensorStorage object"); \ - auto &tensor = *static_cast(ptr); \ - return SparseTensorStorage::newFromSparseTensor( \ - dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \ - dimRank, tensor); \ - } \ case Action::kFromReader: { \ assert(ptr && "Received nullptr for SparseTensorReader object"); \ SparseTensorReader &reader = *static_cast(ptr); \ diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 63f9cdafce88b..09cf01e73ed8c 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8795,7 +8795,6 @@ cc_library( srcs = [ "lib/ExecutionEngine/SparseTensor/File.cpp", "lib/ExecutionEngine/SparseTensor/MapRef.cpp", - "lib/ExecutionEngine/SparseTensor/NNZ.cpp", "lib/ExecutionEngine/SparseTensor/Storage.cpp", ], hdrs = [ From d3c6e524a5929f3766d5ca79b295bbcbc91cb421 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Mon, 16 Oct 2023 13:26:19 -0700 Subject: [PATCH 2/2] Update Storage.h --- .../ExecutionEngine/SparseTensor/Storage.h | 77 ------------------- 1 file changed, 77 deletions(-) diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h index beff393b94033..bafc9baa7edde 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -849,83 +849,6 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase { } }; -//===----------------------------------------------------------------------===// -// -// SparseTensorNNZ -// -//===----------------------------------------------------------------------===// - -/// Statistics regarding the number of nonzero subtensors in -/// a source tensor, for direct sparse=>sparse conversion a la -/// . -/// -/// N.B., this class stores references to the parameters passed to -/// the constructor; thus, objects of this class must not outlive -/// those parameters. -/// -/// This class does not have the "dimension" vs "level" distinction, but -/// since it is used for initializing the levels of a `SparseTensorStorage` -/// object, we use the "level" name throughout for the sake of consistency. -class SparseTensorNNZ final { -public: - /// Allocates the statistics structure for the desired target-tensor - /// level structure (i.e., sizes and types). This constructor does not - /// actually populate the statistics, however; for that see `initialize`. - /// - /// Precondition: `lvlSizes` must not contain zeros. - /// Asserts: `lvlSizes.size() == lvlTypes.size()`. - SparseTensorNNZ(const std::vector &lvlSizes, - const std::vector &lvlTypes); - - // We disallow copying to help avoid leaking the stored references. - SparseTensorNNZ(const SparseTensorNNZ &) = delete; - SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete; - - /// Gets the target-tensor's level-rank. - uint64_t getLvlRank() const { return lvlSizes.size(); } - - /// Enumerates the source tensor to fill in the statistics. - /// The enumerator should already incorporate the mapping from - /// the source tensor-dimensions to the target storage-levels. - /// - /// Asserts: - /// * `enumerator.getTrgRank() == getLvlRank()`. - /// * `enumerator.getTrgSizes() == lvlSizes`. - template - void initialize(SparseTensorEnumeratorBase &enumerator) { - assert(enumerator.getTrgRank() == getLvlRank() && "Tensor rank mismatch"); - assert(enumerator.getTrgSizes() == lvlSizes && "Tensor size mismatch"); - enumerator.forallElements( - [this](const std::vector &lvlCoords, V) { add(lvlCoords); }); - } - - /// The type of callback functions which receive an nnz-statistic. - using NNZConsumer = const std::function &; - - /// Lexicographically enumerates all coordinates for levels strictly - /// less than `stopLvl`, and passes their nnz statistic to the callback. - /// Since our use-case only requires the statistic not the coordinates - /// themselves, we do not bother to construct those coordinates. - void forallCoords(uint64_t stopLvl, NNZConsumer yield) const; - -private: - /// Adds a new element (i.e., increment its statistics). We use - /// a method rather than inlining into the lambda in `initialize`, - /// to avoid spurious templating over `V`. And this method is private - /// to avoid needing to re-assert validity of `lvlCoords` (which is - /// guaranteed by `forallElements`). - void add(const std::vector &lvlCoords); - - /// Recursive component of the public `forallCoords`. - void forallCoords(NNZConsumer yield, uint64_t stopLvl, uint64_t parentPos, - uint64_t l) const; - - // All of these are in the target storage-order. - const std::vector &lvlSizes; - const std::vector &lvlTypes; - std::vector> nnz; -}; - //===----------------------------------------------------------------------===// // // SparseTensorStorage Factories