@@ -138,6 +138,12 @@ extern "C" {
138138 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
139139 dimRank, tensor); \
140140 } \
141+ case Action::kFromReader : { \
142+ assert (ptr && " Received nullptr for SparseTensorReader object" ); \
143+ SparseTensorReader &reader = *static_cast <SparseTensorReader *>(ptr); \
144+ return static_cast <void *>(reader.readSparseTensor <P, C, V>( \
145+ lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
146+ } \
141147 case Action::kToCOO : { \
142148 assert (ptr && " Received nullptr for SparseTensorStorage object" ); \
143149 auto &tensor = *static_cast <SparseTensorStorage<P, C, V> *>(ptr); \
@@ -442,113 +448,6 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
442448MLIR_SPARSETENSOR_FOREVERY_V_O (IMPL_GETNEXT)
443449#undef IMPL_GETNEXT
444450
445- void *_mlir_ciface_newSparseTensorFromReader (
446- void *p, StridedMemRefType<index_type, 1 > *lvlSizesRef,
447- StridedMemRefType<DimLevelType, 1 > *lvlTypesRef,
448- StridedMemRefType<index_type, 1 > *dim2lvlRef,
449- StridedMemRefType<index_type, 1 > *lvl2dimRef, OverheadType posTp,
450- OverheadType crdTp, PrimaryType valTp) {
451- assert (p);
452- SparseTensorReader &reader = *static_cast <SparseTensorReader *>(p);
453- ASSERT_NO_STRIDE (lvlSizesRef);
454- ASSERT_NO_STRIDE (lvlTypesRef);
455- ASSERT_NO_STRIDE (dim2lvlRef);
456- ASSERT_NO_STRIDE (lvl2dimRef);
457- const uint64_t dimRank = reader.getRank ();
458- const uint64_t lvlRank = MEMREF_GET_USIZE (lvlSizesRef);
459- ASSERT_USIZE_EQ (lvlTypesRef, lvlRank);
460- ASSERT_USIZE_EQ (dim2lvlRef, dimRank);
461- ASSERT_USIZE_EQ (lvl2dimRef, lvlRank);
462- (void )dimRank;
463- const index_type *lvlSizes = MEMREF_GET_PAYLOAD (lvlSizesRef);
464- const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD (lvlTypesRef);
465- const index_type *dim2lvl = MEMREF_GET_PAYLOAD (dim2lvlRef);
466- const index_type *lvl2dim = MEMREF_GET_PAYLOAD (lvl2dimRef);
467- #define CASE (p, c, v, P, C, V ) \
468- if (posTp == OverheadType::p && crdTp == OverheadType::c && \
469- valTp == PrimaryType::v) \
470- return static_cast <void *>(reader.readSparseTensor <P, C, V>( \
471- lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
472- #define CASE_SECSAME (p, v, P, V ) CASE(p, p, v, P, P, V)
473- // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
474- // This is safe because of the static_assert above.
475- if (posTp == OverheadType::kIndex )
476- posTp = OverheadType::kU64 ;
477- if (crdTp == OverheadType::kIndex )
478- crdTp = OverheadType::kU64 ;
479- // Double matrices with all combinations of overhead storage.
480- CASE (kU64 , kU64 , kF64 , uint64_t , uint64_t , double );
481- CASE (kU64 , kU32 , kF64 , uint64_t , uint32_t , double );
482- CASE (kU64 , kU16 , kF64 , uint64_t , uint16_t , double );
483- CASE (kU64 , kU8 , kF64 , uint64_t , uint8_t , double );
484- CASE (kU32 , kU64 , kF64 , uint32_t , uint64_t , double );
485- CASE (kU32 , kU32 , kF64 , uint32_t , uint32_t , double );
486- CASE (kU32 , kU16 , kF64 , uint32_t , uint16_t , double );
487- CASE (kU32 , kU8 , kF64 , uint32_t , uint8_t , double );
488- CASE (kU16 , kU64 , kF64 , uint16_t , uint64_t , double );
489- CASE (kU16 , kU32 , kF64 , uint16_t , uint32_t , double );
490- CASE (kU16 , kU16 , kF64 , uint16_t , uint16_t , double );
491- CASE (kU16 , kU8 , kF64 , uint16_t , uint8_t , double );
492- CASE (kU8 , kU64 , kF64 , uint8_t , uint64_t , double );
493- CASE (kU8 , kU32 , kF64 , uint8_t , uint32_t , double );
494- CASE (kU8 , kU16 , kF64 , uint8_t , uint16_t , double );
495- CASE (kU8 , kU8 , kF64 , uint8_t , uint8_t , double );
496- // Float matrices with all combinations of overhead storage.
497- CASE (kU64 , kU64 , kF32 , uint64_t , uint64_t , float );
498- CASE (kU64 , kU32 , kF32 , uint64_t , uint32_t , float );
499- CASE (kU64 , kU16 , kF32 , uint64_t , uint16_t , float );
500- CASE (kU64 , kU8 , kF32 , uint64_t , uint8_t , float );
501- CASE (kU32 , kU64 , kF32 , uint32_t , uint64_t , float );
502- CASE (kU32 , kU32 , kF32 , uint32_t , uint32_t , float );
503- CASE (kU32 , kU16 , kF32 , uint32_t , uint16_t , float );
504- CASE (kU32 , kU8 , kF32 , uint32_t , uint8_t , float );
505- CASE (kU16 , kU64 , kF32 , uint16_t , uint64_t , float );
506- CASE (kU16 , kU32 , kF32 , uint16_t , uint32_t , float );
507- CASE (kU16 , kU16 , kF32 , uint16_t , uint16_t , float );
508- CASE (kU16 , kU8 , kF32 , uint16_t , uint8_t , float );
509- CASE (kU8 , kU64 , kF32 , uint8_t , uint64_t , float );
510- CASE (kU8 , kU32 , kF32 , uint8_t , uint32_t , float );
511- CASE (kU8 , kU16 , kF32 , uint8_t , uint16_t , float );
512- CASE (kU8 , kU8 , kF32 , uint8_t , uint8_t , float );
513- // Two-byte floats with both overheads of the same type.
514- CASE_SECSAME (kU64 , kF16 , uint64_t , f16 );
515- CASE_SECSAME (kU64 , kBF16 , uint64_t , bf16 );
516- CASE_SECSAME (kU32 , kF16 , uint32_t , f16 );
517- CASE_SECSAME (kU32 , kBF16 , uint32_t , bf16 );
518- CASE_SECSAME (kU16 , kF16 , uint16_t , f16 );
519- CASE_SECSAME (kU16 , kBF16 , uint16_t , bf16 );
520- CASE_SECSAME (kU8 , kF16 , uint8_t , f16 );
521- CASE_SECSAME (kU8 , kBF16 , uint8_t , bf16 );
522- // Integral matrices with both overheads of the same type.
523- CASE_SECSAME (kU64 , kI64 , uint64_t , int64_t );
524- CASE_SECSAME (kU64 , kI32 , uint64_t , int32_t );
525- CASE_SECSAME (kU64 , kI16 , uint64_t , int16_t );
526- CASE_SECSAME (kU64 , kI8 , uint64_t , int8_t );
527- CASE_SECSAME (kU32 , kI64 , uint32_t , int64_t );
528- CASE_SECSAME (kU32 , kI32 , uint32_t , int32_t );
529- CASE_SECSAME (kU32 , kI16 , uint32_t , int16_t );
530- CASE_SECSAME (kU32 , kI8 , uint32_t , int8_t );
531- CASE_SECSAME (kU16 , kI64 , uint16_t , int64_t );
532- CASE_SECSAME (kU16 , kI32 , uint16_t , int32_t );
533- CASE_SECSAME (kU16 , kI16 , uint16_t , int16_t );
534- CASE_SECSAME (kU16 , kI8 , uint16_t , int8_t );
535- CASE_SECSAME (kU8 , kI64 , uint8_t , int64_t );
536- CASE_SECSAME (kU8 , kI32 , uint8_t , int32_t );
537- CASE_SECSAME (kU8 , kI16 , uint8_t , int16_t );
538- CASE_SECSAME (kU8 , kI8 , uint8_t , int8_t );
539- // Complex matrices with wide overhead.
540- CASE_SECSAME (kU64 , kC64 , uint64_t , complex64);
541- CASE_SECSAME (kU64 , kC32 , uint64_t , complex32);
542-
543- // Unsupported case (add above if needed).
544- MLIR_SPARSETENSOR_FATAL (
545- " unsupported combination of types: <P=%d, C=%d, V=%d>\n " ,
546- static_cast <int >(posTp), static_cast <int >(crdTp),
547- static_cast <int >(valTp));
548- #undef CASE_SECSAME
549- #undef CASE
550- }
551-
552451void _mlir_ciface_outSparseTensorWriterMetaData (
553452 void *p, index_type dimRank, index_type nse,
554453 StridedMemRefType<index_type, 1 > *dimSizesRef) {
@@ -635,34 +534,10 @@ char *getTensorFilename(index_type id) {
635534 return env;
636535}
637536
638- void readSparseTensorShape (char *filename, std::vector<uint64_t > *out) {
639- assert (out && " Received nullptr for out-parameter" );
640- SparseTensorReader reader (filename);
641- reader.openFile ();
642- reader.readHeader ();
643- reader.closeFile ();
644- const uint64_t dimRank = reader.getRank ();
645- const uint64_t *dimSizes = reader.getDimSizes ();
646- out->reserve (dimRank);
647- out->assign (dimSizes, dimSizes + dimRank);
648- }
649-
650- index_type getSparseTensorReaderRank (void *p) {
651- return static_cast <SparseTensorReader *>(p)->getRank ();
652- }
653-
654- bool getSparseTensorReaderIsSymmetric (void *p) {
655- return static_cast <SparseTensorReader *>(p)->isSymmetric ();
656- }
657-
658537index_type getSparseTensorReaderNSE (void *p) {
659538 return static_cast <SparseTensorReader *>(p)->getNSE ();
660539}
661540
662- index_type getSparseTensorReaderDimSize (void *p, index_type d) {
663- return static_cast <SparseTensorReader *>(p)->getDimSize (d);
664- }
665-
666541void delSparseTensorReader (void *p) {
667542 delete static_cast <SparseTensorReader *>(p);
668543}
0 commit comments