@@ -96,6 +96,87 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
9696 }
9797};
9898
99+ // / Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
100+ // / MemRef with updated indices that model the strided access.
101+ // /
102+ // / ```mlir
103+ // / %subview = memref.subview %M (...)
104+ // / : memref<100x3xf32> to memref<100xf32, strided<[3]>>
105+ // / %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
106+ // / ```
107+ // / ==>
108+ // / ```mlir
109+ // / %collapse_shape = memref.collapse_shape %M (...)
110+ // / : memref<100x3xf32> into memref<300xf32>
111+ // / %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
112+ // / %gather = vector.gather %collapse_shape[%new_idxs] (...)
113+ // / : memref<300xf32> (...)
114+ // / ```
115+ // /
116+ // / ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
117+ // / but should be fairly straightforward to extend beyond that.
118+ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
119+ using OpRewritePattern::OpRewritePattern;
120+
121+ LogicalResult matchAndRewrite (vector::GatherOp op,
122+ PatternRewriter &rewriter) const override {
123+ Value base = op.getBase ();
124+
125+ // TODO: Strided accesses might be coming from other ops as well
126+ auto subview = base.getDefiningOp <memref::SubViewOp>();
127+ if (!subview)
128+ return failure ();
129+
130+ auto sourceType = subview.getSource ().getType ();
131+
132+ // TODO: Allow ranks > 2.
133+ if (sourceType.getRank () != 2 )
134+ return failure ();
135+
136+ // Get strides
137+ auto layout = subview.getResult ().getType ().getLayout ();
138+ auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
139+ if (!stridedLayoutAttr)
140+ return failure ();
141+
142+ // TODO: Allow the access to be strided in multiple dimensions.
143+ if (stridedLayoutAttr.getStrides ().size () != 1 )
144+ return failure ();
145+
146+ int64_t srcTrailingDim = sourceType.getShape ().back ();
147+
148+ // Assume that the stride matches the trailing dimension of the source
149+ // memref.
150+ // TODO: Relax this assumption.
151+ if (stridedLayoutAttr.getStrides ()[0 ] != srcTrailingDim)
152+ return failure ();
153+
154+ // 1. Collapse the input memref so that it's "flat".
155+ SmallVector<ReassociationIndices> reassoc = {{0 , 1 }};
156+ Value collapsed = rewriter.create <memref::CollapseShapeOp>(
157+ op.getLoc (), subview.getSource (), reassoc);
158+
159+ // 2. Generate new gather indices that will model the
160+ // strided access.
161+ IntegerAttr stride = rewriter.getIndexAttr (srcTrailingDim);
162+ VectorType vType = op.getIndexVec ().getType ();
163+ Value mulCst = rewriter.create <arith::ConstantOp>(
164+ op.getLoc (), vType, DenseElementsAttr::get (vType, stride));
165+
166+ Value newIdxs =
167+ rewriter.create <arith::MulIOp>(op.getLoc (), op.getIndexVec (), mulCst);
168+
169+ // 3. Create an updated gather op with the collapsed input memref and the
170+ // updated indices.
171+ Value newGather = rewriter.create <vector::GatherOp>(
172+ op.getLoc (), op.getResult ().getType (), collapsed, op.getIndices (),
173+ newIdxs, op.getMask (), op.getPassThru ());
174+ rewriter.replaceOp (op, newGather);
175+
176+ return success ();
177+ }
178+ };
179+
99180// / Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
100181// / `tensor.extract`s. To avoid out-of-bounds memory accesses, these
101182// / loads/extracts are made conditional using `scf.if` ops.
@@ -115,6 +196,16 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
115196
116197 Value condMask = op.getMask ();
117198 Value base = op.getBase ();
199+
200+ // vector.load requires the most minor memref dim to have unit stride
201+ if (auto memType = dyn_cast<MemRefType>(base.getType ())) {
202+ if (auto stridesAttr =
203+ dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout ())) {
204+ if (stridesAttr.getStrides ().back () != 1 )
205+ return failure ();
206+ }
207+ }
208+
118209 Value indexVec = rewriter.createOrFold <arith::IndexCastOp>(
119210 loc, op.getIndexVectorType ().clone (rewriter.getIndexType ()),
120211 op.getIndexVec ());
@@ -168,6 +259,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
168259
169260void mlir::vector::populateVectorGatherLoweringPatterns (
170261 RewritePatternSet &patterns, PatternBenefit benefit) {
171- patterns.add <FlattenGather, Gather1DToConditionalLoads>(patterns. getContext () ,
172- benefit);
262+ patterns.add <FlattenGather, RemoveStrideFromGatherSource ,
263+ Gather1DToConditionalLoads>(patterns. getContext (), benefit);
173264}
0 commit comments