1717#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
1818#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1919#include " mlir/Dialect/Vector/IR/VectorOps.h"
20+ #include " mlir/Support/MathExtras.h"
2021#include " mlir/Transforms/DialectConversion.h"
2122#include " llvm/Support/FormatVariadic.h"
2223#include " llvm/Support/MathExtras.h"
@@ -209,6 +210,76 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
209210 return success ();
210211 }
211212};
213+
214+ // ===----------------------------------------------------------------------===//
215+ // ConvertMemRefSubview
216+ // ===----------------------------------------------------------------------===//
217+
218+ // / Emulating narrow ints on subview have limited support, supporting only
219+ // / static offset and size and stride of 1. Ideally, the subview should be
220+ // / folded away before running narrow type emulation, and this pattern would
221+ // / never run. This pattern is mostly used for testing pruposes.
222+ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
223+ using OpConversionPattern::OpConversionPattern;
224+
225+ LogicalResult
226+ matchAndRewrite (memref::SubViewOp op, OpAdaptor adaptor,
227+ ConversionPatternRewriter &rewriter) const override {
228+ MemRefType newTy =
229+ dyn_cast<MemRefType>(getTypeConverter ()->convertType (op.getType ()));
230+ if (!newTy) {
231+ return rewriter.notifyMatchFailure (
232+ op->getLoc (),
233+ llvm::formatv (" failed to convert memref type: {0}" , op.getType ()));
234+ }
235+
236+ auto convertedElementType = newTy.getElementType ();
237+ auto oldElementType = op.getType ().getElementType ();
238+ int srcBits = oldElementType.getIntOrFloatBitWidth ();
239+ int dstBits = convertedElementType.getIntOrFloatBitWidth ();
240+ if (dstBits % srcBits != 0 ) {
241+ return rewriter.notifyMatchFailure (
242+ op, " only dstBits % srcBits == 0 supported" );
243+ }
244+
245+ // Only support offset for 1-D subview.
246+ if (op.getType ().getRank () != 1 ) {
247+ return rewriter.notifyMatchFailure (
248+ op->getLoc (), " subview with rank > 1 is not supported" );
249+ }
250+
251+ // Only support stride of 1.
252+ if (op.getStaticStride (0 ) != 1 ) {
253+ return rewriter.notifyMatchFailure (
254+ op->getLoc (), " subview with stride != 1 is not supported" );
255+ }
256+
257+ int64_t size = op.getStaticSize (0 );
258+ int64_t offset = op.getStaticOffset (0 );
259+ // Only support static sizes and offsets.
260+ if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic ) {
261+ return rewriter.notifyMatchFailure (
262+ op->getLoc (), " subview with dynamic size or offset is not supported" );
263+ }
264+
265+ int elementsPerByte = dstBits / srcBits;
266+ if (offset % elementsPerByte != 0 ) {
267+ return rewriter.notifyMatchFailure (
268+ op->getLoc (),
269+ " subview with offset not multiple of elementsPerByte is not "
270+ " supported" );
271+ }
272+
273+ size = ceilDiv (size, elementsPerByte);
274+ offset = offset / elementsPerByte;
275+
276+ rewriter.replaceOpWithNewOp <memref::SubViewOp>(
277+ op, newTy, *adaptor.getODSOperands (0 ).begin (), offset, size,
278+ op.getStaticStrides ());
279+ return success ();
280+ }
281+ };
282+
212283} // end anonymous namespace
213284
214285// ===----------------------------------------------------------------------===//
@@ -220,9 +291,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
220291 RewritePatternSet &patterns) {
221292
222293 // Populate `memref.*` conversion patterns.
223- patterns
224- . add <ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
225- typeConverter, patterns.getContext ());
294+ patterns. add <ConvertMemRefAlloc, ConvertMemRefLoad,
295+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview >(
296+ typeConverter, patterns.getContext ());
226297 memref::populateResolveExtractStridedMetadataPatterns (patterns);
227298}
228299
@@ -271,9 +342,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
271342 return std::nullopt ;
272343
273344 StridedLayoutAttr layoutAttr;
345+ // If the offset is 0, we do not need a strided layout as the stride is
346+ // 1, so we only use the strided layout if the offset is not 0.
274347 if (offset != 0 ) {
275- layoutAttr = StridedLayoutAttr::get (ty.getContext (), offset,
276- ArrayRef<int64_t >{1 });
348+ if (offset == ShapedType::kDynamic ) {
349+ layoutAttr = StridedLayoutAttr::get (ty.getContext (), offset,
350+ ArrayRef<int64_t >{1 });
351+ } else {
352+ // Check if the number of bytes are a multiple of the loadStoreWidth
353+ // and if so, divide it by the loadStoreWidth to get the offset.
354+ if ((offset * width) % loadStoreWidth != 0 )
355+ return std::nullopt ;
356+ offset = (offset * width) / loadStoreWidth;
357+
358+ layoutAttr = StridedLayoutAttr::get (ty.getContext (), offset,
359+ ArrayRef<int64_t >{1 });
360+ }
277361 }
278362
279363 return MemRefType::get (getLinearizedShape (ty, width, loadStoreWidth),
0 commit comments