[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / VectorToSCF / VectorToSCF.cpp
blob3a4dc806efe976a50765c217ce597c1aeae78881
1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
11 //===----------------------------------------------------------------------===//
13 #include <numeric>
14 #include <optional>
15 #include <type_traits>
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
26 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/ImplicitLocOpBuilder.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "mlir/Transforms/Passes.h"
34 namespace mlir {
35 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
36 #include "mlir/Conversion/Passes.h.inc"
37 } // namespace mlir
39 using namespace mlir;
40 using vector::TransferReadOp;
41 using vector::TransferWriteOp;
43 namespace {
45 /// Attribute name used for labeling transfer ops during progressive lowering.
46 static const char kPassLabel[] = "__vector_to_scf_lowering__";
48 /// Return true if this transfer op operates on a source tensor.
49 static bool isTensorOp(VectorTransferOpInterface xferOp) {
50 if (isa<RankedTensorType>(xferOp.getShapedType())) {
51 if (isa<vector::TransferWriteOp>(xferOp)) {
52 // TransferWriteOps on tensors have a result.
53 assert(xferOp->getNumResults() > 0);
55 return true;
57 return false;
60 /// Patterns that inherit from this struct have access to
61 /// VectorTransferToSCFOptions.
62 template <typename OpTy>
63 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
64 explicit VectorToSCFPattern(MLIRContext *context,
65 VectorTransferToSCFOptions opt)
66 : OpRewritePattern<OpTy>(context), options(opt) {}
68 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
69 PatternRewriter &rewriter) const {
70 if (isTensorOp(xferOp) && !options.lowerTensors) {
71 return rewriter.notifyMatchFailure(
72 xferOp, "lowering tensor transfers is disabled");
74 return success();
77 VectorTransferToSCFOptions options;
80 /// Given a vector transfer op, calculate which dimension of the `source`
81 /// memref should be unpacked in the next application of TransferOpConversion.
82 /// A return value of std::nullopt indicates a broadcast.
83 template <typename OpTy>
84 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
85 // TODO: support 0-d corner case.
86 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
87 auto map = xferOp.getPermutationMap();
88 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
89 return expr.getPosition();
91 assert(xferOp.isBroadcastDim(0) &&
92 "Expected AffineDimExpr or AffineConstantExpr");
93 return std::nullopt;
96 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
97 /// map is identical to the current permutation map, but the first result is
98 /// omitted.
99 template <typename OpTy>
100 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
101 // TODO: support 0-d corner case.
102 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
103 auto map = xferOp.getPermutationMap();
104 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
105 b.getContext());
108 /// Calculate the indices for the new vector transfer op.
110 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
111 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
112 /// ^^^^^^
113 /// `iv` is the iteration variable of the (new) surrounding loop.
114 template <typename OpTy>
115 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
116 SmallVector<Value, 8> &indices) {
117 typename OpTy::Adaptor adaptor(xferOp);
118 // Corresponding memref dim of the vector dim that is unpacked.
119 auto dim = unpackedDim(xferOp);
120 auto prevIndices = adaptor.getIndices();
121 indices.append(prevIndices.begin(), prevIndices.end());
123 Location loc = xferOp.getLoc();
124 bool isBroadcast = !dim.has_value();
125 if (!isBroadcast) {
126 AffineExpr d0, d1;
127 bindDims(xferOp.getContext(), d0, d1);
128 Value offset = adaptor.getIndices()[*dim];
129 indices[*dim] =
130 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
134 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
135 Value value) {
136 if (hasRetVal) {
137 assert(value && "Expected non-empty value");
138 b.create<scf::YieldOp>(loc, value);
139 } else {
140 b.create<scf::YieldOp>(loc);
144 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
145 /// is set to true. No such check is generated under following circumstances:
146 /// * xferOp does not have a mask.
147 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
148 /// computed and attached to the new transfer op in the pattern.)
149 /// * The to-be-unpacked dim of xferOp is a broadcast.
150 template <typename OpTy>
151 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
152 if (!xferOp.getMask())
153 return Value();
154 if (xferOp.getMaskType().getRank() != 1)
155 return Value();
156 if (xferOp.isBroadcastDim(0))
157 return Value();
159 Location loc = xferOp.getLoc();
160 return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
163 /// Helper function TransferOpConversion and TransferOp1dConversion.
164 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
165 /// specified dimension `dim` with the loop iteration variable `iv`.
166 /// E.g., when unpacking dimension 0 from:
167 /// ```
168 /// %vec = vector.transfer_read %A[%a, %b] %cst
169 /// : vector<5x4xf32>, memref<?x?xf32>
170 /// ```
171 /// An if check similar to this will be generated inside the loop:
172 /// ```
173 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
174 /// if (%a + iv < %d) {
175 /// (in-bounds case)
176 /// } else {
177 /// (out-of-bounds case)
178 /// }
179 /// ```
181 /// If the transfer is 1D and has a mask, this function generates a more complex
182 /// check also accounts for potentially masked out elements.
184 /// This function variant returns the value returned by `inBoundsCase` or
185 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
186 /// `resultTypes`.
187 template <typename OpTy>
188 static Value generateInBoundsCheck(
189 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
190 TypeRange resultTypes,
191 function_ref<Value(OpBuilder &, Location)> inBoundsCase,
192 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
193 bool hasRetVal = !resultTypes.empty();
194 Value cond; // Condition to be built...
196 // Condition check 1: Access in-bounds?
197 bool isBroadcast = !dim; // No in-bounds check for broadcasts.
198 Location loc = xferOp.getLoc();
199 ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
200 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
201 Value memrefDim =
202 vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
203 AffineExpr d0, d1;
204 bindDims(xferOp.getContext(), d0, d1);
205 Value base = xferOp.getIndices()[*dim];
206 Value memrefIdx =
207 affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
208 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
209 memrefIdx);
212 // Condition check 2: Masked in?
213 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
214 if (cond)
215 cond = lb.create<arith::AndIOp>(cond, maskCond);
216 else
217 cond = maskCond;
220 // If the condition is non-empty, generate an SCF::IfOp.
221 if (cond) {
222 auto check = lb.create<scf::IfOp>(
223 cond,
224 /*thenBuilder=*/
225 [&](OpBuilder &b, Location loc) {
226 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
228 /*elseBuilder=*/
229 [&](OpBuilder &b, Location loc) {
230 if (outOfBoundsCase) {
231 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
232 } else {
233 b.create<scf::YieldOp>(loc);
237 return hasRetVal ? check.getResult(0) : Value();
240 // Condition is empty, no need for an SCF::IfOp.
241 return inBoundsCase(b, loc);
244 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
245 /// a return value. Consequently, this function does not have a return value.
246 template <typename OpTy>
247 static void generateInBoundsCheck(
248 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
249 function_ref<void(OpBuilder &, Location)> inBoundsCase,
250 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
251 generateInBoundsCheck(
252 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
253 /*inBoundsCase=*/
254 [&](OpBuilder &b, Location loc) {
255 inBoundsCase(b, loc);
256 return Value();
258 /*outOfBoundsCase=*/
259 [&](OpBuilder &b, Location loc) {
260 if (outOfBoundsCase)
261 outOfBoundsCase(b, loc);
262 return Value();
266 /// Given an ArrayAttr, return a copy where the first element is dropped.
267 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
268 if (!attr)
269 return attr;
270 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
273 /// Add the pass label to a vector transfer op if its rank is not the target
274 /// rank.
275 template <typename OpTy>
276 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
277 unsigned targetRank) {
278 if (newXferOp.getVectorType().getRank() > targetRank)
279 newXferOp->setAttr(kPassLabel, b.getUnitAttr());
282 namespace lowering_n_d {
284 /// Helper data structure for data and mask buffers.
285 struct BufferAllocs {
286 Value dataBuffer;
287 Value maskBuffer;
290 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
291 static Operation *getAutomaticAllocationScope(Operation *op) {
292 Operation *scope =
293 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
294 assert(scope && "Expected op to be inside automatic allocation scope");
295 return scope;
298 /// Allocate temporary buffers for data (vector) and mask (if present).
299 template <typename OpTy>
300 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
301 Location loc = xferOp.getLoc();
302 OpBuilder::InsertionGuard guard(b);
303 Operation *scope = getAutomaticAllocationScope(xferOp);
304 assert(scope->getNumRegions() == 1 &&
305 "AutomaticAllocationScope with >1 regions");
306 b.setInsertionPointToStart(&scope->getRegion(0).front());
308 BufferAllocs result;
309 auto bufferType = MemRefType::get({}, xferOp.getVectorType());
310 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
312 if (xferOp.getMask()) {
313 auto maskType = MemRefType::get({}, xferOp.getMask().getType());
314 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
315 b.setInsertionPoint(xferOp);
316 b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
317 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
320 return result;
323 /// Given a MemRefType with VectorType element type, unpack one dimension from
324 /// the VectorType into the MemRefType.
326 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
327 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
328 auto vectorType = dyn_cast<VectorType>(type.getElementType());
329 // Vectors with leading scalable dims are not supported.
330 // It may be possible to support these in future by using dynamic memref dims.
331 if (vectorType.getScalableDims().front())
332 return failure();
333 auto memrefShape = type.getShape();
334 SmallVector<int64_t, 8> newMemrefShape;
335 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
336 newMemrefShape.push_back(vectorType.getDimSize(0));
337 return MemRefType::get(newMemrefShape,
338 VectorType::Builder(vectorType).dropDim(0));
341 /// Given a transfer op, find the memref from which the mask is loaded. This
342 /// is similar to Strategy<TransferWriteOp>::getBuffer.
343 template <typename OpTy>
344 static Value getMaskBuffer(OpTy xferOp) {
345 assert(xferOp.getMask() && "Expected that transfer op has mask");
346 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
347 assert(loadOp && "Expected transfer op mask produced by LoadOp");
348 return loadOp.getMemRef();
351 /// Codegen strategy, depending on the operation.
352 template <typename OpTy>
353 struct Strategy;
355 /// Code strategy for vector TransferReadOp.
356 template <>
357 struct Strategy<TransferReadOp> {
358 /// Find the StoreOp that is used for writing the current TransferReadOp's
359 /// result to the temporary buffer allocation.
360 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
361 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
362 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
363 assert(storeOp && "Expected TransferReadOp result used by StoreOp");
364 return storeOp;
367 /// Find the temporary buffer allocation. All labeled TransferReadOps are
368 /// used like this, where %buf is either the buffer allocation or a type cast
369 /// of the buffer allocation:
370 /// ```
371 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
372 /// memref.store %vec, %buf[...] ...
373 /// ```
374 static Value getBuffer(TransferReadOp xferOp) {
375 return getStoreOp(xferOp).getMemRef();
378 /// Retrieve the indices of the current StoreOp that stores into the buffer.
379 static void getBufferIndices(TransferReadOp xferOp,
380 SmallVector<Value, 8> &indices) {
381 auto storeOp = getStoreOp(xferOp);
382 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
383 indices.append(prevIndices.begin(), prevIndices.end());
386 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
387 /// accesses on the to-be-unpacked dimension.
389 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
390 /// variable `iv`.
391 /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
393 /// E.g.:
394 /// ```
395 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
396 /// : memref<?x?x?xf32>, vector<4x3xf32>
397 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
398 /// ```
399 /// Is rewritten to:
400 /// ```
401 /// %casted = vector.type_cast %buf
402 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
403 /// for %j = 0 to 4 {
404 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
405 /// : memref<?x?x?xf32>, vector<3xf32>
406 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
407 /// }
408 /// ```
410 /// Note: The loop and type cast are generated in TransferOpConversion.
411 /// The original TransferReadOp and store op are deleted in `cleanup`.
412 /// Note: The `mask` operand is set in TransferOpConversion.
413 static TransferReadOp rewriteOp(OpBuilder &b,
414 VectorTransferToSCFOptions options,
415 TransferReadOp xferOp, Value buffer, Value iv,
416 ValueRange /*loopState*/) {
417 SmallVector<Value, 8> storeIndices;
418 getBufferIndices(xferOp, storeIndices);
419 storeIndices.push_back(iv);
421 SmallVector<Value, 8> xferIndices;
422 getXferIndices(b, xferOp, iv, xferIndices);
424 Location loc = xferOp.getLoc();
425 auto bufferType = dyn_cast<ShapedType>(buffer.getType());
426 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
427 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
428 auto newXferOp = b.create<vector::TransferReadOp>(
429 loc, vecType, xferOp.getSource(), xferIndices,
430 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
431 xferOp.getPadding(), Value(), inBoundsAttr);
433 maybeApplyPassLabel(b, newXferOp, options.targetRank);
435 b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
436 return newXferOp;
439 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
440 /// padding value to the temporary buffer.
441 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
442 Value buffer, Value iv,
443 ValueRange /*loopState*/) {
444 SmallVector<Value, 8> storeIndices;
445 getBufferIndices(xferOp, storeIndices);
446 storeIndices.push_back(iv);
448 Location loc = xferOp.getLoc();
449 auto bufferType = dyn_cast<ShapedType>(buffer.getType());
450 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
451 auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
452 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
454 return Value();
457 /// Cleanup after rewriting the op.
458 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
459 scf::ForOp /*forOp*/) {
460 rewriter.eraseOp(getStoreOp(xferOp));
461 rewriter.eraseOp(xferOp);
464 /// Return the initial loop state for the generated scf.for loop.
465 static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
468 /// Codegen strategy for vector TransferWriteOp.
469 template <>
470 struct Strategy<TransferWriteOp> {
471 /// Find the temporary buffer allocation. All labeled TransferWriteOps are
472 /// used like this, where %buf is either the buffer allocation or a type cast
473 /// of the buffer allocation:
474 /// ```
475 /// %vec = memref.load %buf[...] ...
476 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
477 /// ```
478 static Value getBuffer(TransferWriteOp xferOp) {
479 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
480 assert(loadOp && "Expected transfer op vector produced by LoadOp");
481 return loadOp.getMemRef();
484 /// Retrieve the indices of the current LoadOp that loads from the buffer.
485 static void getBufferIndices(TransferWriteOp xferOp,
486 SmallVector<Value, 8> &indices) {
487 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
488 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
489 indices.append(prevIndices.begin(), prevIndices.end());
492 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
493 /// accesses on the to-be-unpacked dimension.
495 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
496 /// using the loop iteration variable `iv`.
497 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
498 /// to memory.
500 /// Note: For more details, see comments on Strategy<TransferReadOp>.
501 static TransferWriteOp rewriteOp(OpBuilder &b,
502 VectorTransferToSCFOptions options,
503 TransferWriteOp xferOp, Value buffer,
504 Value iv, ValueRange loopState) {
505 SmallVector<Value, 8> loadIndices;
506 getBufferIndices(xferOp, loadIndices);
507 loadIndices.push_back(iv);
509 SmallVector<Value, 8> xferIndices;
510 getXferIndices(b, xferOp, iv, xferIndices);
512 Location loc = xferOp.getLoc();
513 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
514 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
515 auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
516 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
517 auto newXferOp = b.create<vector::TransferWriteOp>(
518 loc, type, vec, source, xferIndices,
519 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
520 inBoundsAttr);
522 maybeApplyPassLabel(b, newXferOp, options.targetRank);
524 return newXferOp;
527 /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
528 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
529 Value buffer, Value iv,
530 ValueRange loopState) {
531 return isTensorOp(xferOp) ? loopState[0] : Value();
534 /// Cleanup after rewriting the op.
535 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
536 scf::ForOp forOp) {
537 if (isTensorOp(xferOp)) {
538 assert(forOp->getNumResults() == 1 && "Expected one for loop result");
539 rewriter.replaceOp(xferOp, forOp->getResult(0));
540 } else {
541 rewriter.eraseOp(xferOp);
545 /// Return the initial loop state for the generated scf.for loop.
546 static Value initialLoopState(TransferWriteOp xferOp) {
547 return isTensorOp(xferOp) ? xferOp.getSource() : Value();
551 template <typename OpTy>
552 LogicalResult checkPrepareXferOp(OpTy xferOp,
553 VectorTransferToSCFOptions options) {
554 if (xferOp->hasAttr(kPassLabel))
555 return failure();
556 if (xferOp.getVectorType().getRank() <= options.targetRank)
557 return failure();
558 // Currently the unpacking of the leading dimension into the memref is not
559 // supported for scalable dimensions.
560 if (xferOp.getVectorType().getScalableDims().front())
561 return failure();
562 if (isTensorOp(xferOp) && !options.lowerTensors)
563 return failure();
564 // Transfer ops that modify the element type are not supported atm.
565 if (xferOp.getVectorType().getElementType() !=
566 xferOp.getShapedType().getElementType())
567 return failure();
568 return success();
571 /// Prepare a TransferReadOp for progressive lowering.
573 /// 1. Allocate a temporary buffer.
574 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
575 /// 3. Store the result of the TransferReadOp into the temporary buffer.
576 /// 4. Load the result from the temporary buffer and replace all uses of the
577 /// original TransferReadOp with this load.
579 /// E.g.:
580 /// ```
581 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
582 /// : vector<5x4xf32>, memref<?x?x?xf32>
583 /// ```
584 /// is rewritten to:
585 /// ```
586 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
587 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
588 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
589 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
590 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
591 /// ```
593 /// Note: A second temporary buffer may be allocated for the `mask` operand.
594 struct PrepareTransferReadConversion
595 : public VectorToSCFPattern<TransferReadOp> {
596 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
598 LogicalResult matchAndRewrite(TransferReadOp xferOp,
599 PatternRewriter &rewriter) const override {
600 if (checkPrepareXferOp(xferOp, options).failed())
601 return failure();
603 auto buffers = allocBuffers(rewriter, xferOp);
604 auto *newXfer = rewriter.clone(*xferOp.getOperation());
605 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
606 if (xferOp.getMask()) {
607 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
608 buffers.maskBuffer);
611 Location loc = xferOp.getLoc();
612 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
613 buffers.dataBuffer);
614 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
616 return success();
620 /// Prepare a TransferWriteOp for progressive lowering.
622 /// 1. Allocate a temporary buffer.
623 /// 2. Store the vector into the buffer.
624 /// 3. Load the vector from the buffer again.
625 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
626 /// marking it eligible for progressive lowering via TransferOpConversion.
628 /// E.g.:
629 /// ```
630 /// vector.transfer_write %vec, %A[%a, %b, %c]
631 /// : vector<5x4xf32>, memref<?x?x?xf32>
632 /// ```
633 /// is rewritten to:
634 /// ```
635 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
636 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
637 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
638 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
639 /// : vector<5x4xf32>, memref<?x?x?xf32>
640 /// ```
642 /// Note: A second temporary buffer may be allocated for the `mask` operand.
643 struct PrepareTransferWriteConversion
644 : public VectorToSCFPattern<TransferWriteOp> {
645 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
647 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
648 PatternRewriter &rewriter) const override {
649 if (checkPrepareXferOp(xferOp, options).failed())
650 return failure();
652 Location loc = xferOp.getLoc();
653 auto buffers = allocBuffers(rewriter, xferOp);
654 rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
655 buffers.dataBuffer);
656 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
657 rewriter.modifyOpInPlace(xferOp, [&]() {
658 xferOp.getVectorMutable().assign(loadedVec);
659 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
662 if (xferOp.getMask()) {
663 rewriter.modifyOpInPlace(xferOp, [&]() {
664 xferOp.getMaskMutable().assign(buffers.maskBuffer);
668 return success();
672 /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
673 /// printing both 1D scalable vectors and n-D fixed size vectors.
675 /// E.g.:
676 /// ```
677 /// vector.print %v : vector<[4]xi32>
678 /// ```
679 /// is rewritten to:
680 /// ```
681 /// %c0 = arith.constant 0 : index
682 /// %c4 = arith.constant 4 : index
683 /// %c1 = arith.constant 1 : index
684 /// %vscale = vector.vscale
685 /// %length = arith.muli %vscale, %c4 : index
686 /// %lastIndex = arith.subi %length, %c1 : index
687 /// vector.print punctuation <open>
688 /// scf.for %i = %c0 to %length step %c1 {
689 /// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
690 /// vector.print %el : i32 punctuation <no_punctuation>
691 /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
692 /// scf.if %notLastIndex {
693 /// vector.print punctuation <comma>
694 /// }
695 /// }
696 /// vector.print punctuation <close>
697 /// vector.print
698 /// ```
699 struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
700 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
701 LogicalResult matchAndRewrite(vector::PrintOp printOp,
702 PatternRewriter &rewriter) const override {
703 if (!printOp.getSource())
704 return failure();
706 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
707 if (!vectorType)
708 return failure();
710 // Currently >= 2D scalable vectors are not supported.
711 // These can't be lowered to LLVM (as LLVM does not support scalable vectors
712 // of scalable vectors), and due to limitations of current ops can't be
713 // indexed with SSA values or flattened. This may change after
714 // https://reviews.llvm.org/D155034, though there still needs to be a path
715 // for lowering to LLVM.
716 if (vectorType.getRank() > 1 && vectorType.isScalable())
717 return failure();
719 auto loc = printOp.getLoc();
720 auto value = printOp.getSource();
722 if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
723 // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
724 // avoid issues extend them to a more standard size.
725 // https://github.com/llvm/llvm-project/issues/30613
726 auto width = intTy.getWidth();
727 auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
728 auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
729 intTy.getSignedness());
730 // arith can only take signless integers, so we must cast back and forth.
731 auto signlessSourceVectorType =
732 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
733 auto signlessTargetVectorType =
734 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
735 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
736 value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
737 value);
738 if (value.getType() != signlessTargetVectorType) {
739 if (width == 1 || intTy.isUnsigned())
740 value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
741 value);
742 else
743 value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
744 value);
746 value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
747 vectorType = targetVectorType;
750 auto scalableDimensions = vectorType.getScalableDims();
751 auto shape = vectorType.getShape();
752 constexpr int64_t singletonShape[] = {1};
753 if (vectorType.getRank() == 0)
754 shape = singletonShape;
756 if (vectorType.getRank() != 1) {
757 // Flatten n-D vectors to 1D. This is done to allow indexing with a
758 // non-constant value (which can currently only be done via
759 // vector.extractelement for 1D vectors).
760 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
761 std::multiplies<int64_t>());
762 auto flatVectorType =
763 VectorType::get({flatLength}, vectorType.getElementType());
764 value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
767 vector::PrintOp firstClose;
768 SmallVector<Value, 8> loopIndices;
769 for (unsigned d = 0; d < shape.size(); d++) {
770 // Setup loop bounds and step.
771 Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
772 Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
773 Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
774 if (!scalableDimensions.empty() && scalableDimensions[d]) {
775 auto vscale = rewriter.create<vector::VectorScaleOp>(
776 loc, rewriter.getIndexType());
777 upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
779 auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
781 // Create a loop to print the elements surrounded by parentheses.
782 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
783 auto loop =
784 rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
785 auto printClose = rewriter.create<vector::PrintOp>(
786 loc, vector::PrintPunctuation::Close);
787 if (!firstClose)
788 firstClose = printClose;
790 auto loopIdx = loop.getInductionVar();
791 loopIndices.push_back(loopIdx);
793 // Print a comma after all but the last element.
794 rewriter.setInsertionPointToStart(loop.getBody());
795 auto notLastIndex = rewriter.create<arith::CmpIOp>(
796 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
797 rewriter.create<scf::IfOp>(loc, notLastIndex,
798 [&](OpBuilder &builder, Location loc) {
799 builder.create<vector::PrintOp>(
800 loc, vector::PrintPunctuation::Comma);
801 builder.create<scf::YieldOp>(loc);
804 rewriter.setInsertionPointToStart(loop.getBody());
807 // Compute the flattened index.
808 // Note: For the > rank 1 vectors this assumes non-scalable.
809 Value flatIndex;
810 auto currentStride = 1;
811 for (int d = shape.size() - 1; d >= 0; d--) {
812 auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
813 auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
814 if (flatIndex)
815 flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
816 else
817 flatIndex = index;
818 currentStride *= shape[d];
821 // Print the scalar elements in the inner most loop.
822 auto element =
823 rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
824 rewriter.create<vector::PrintOp>(loc, element,
825 vector::PrintPunctuation::NoPunctuation);
827 rewriter.setInsertionPointAfter(firstClose);
828 rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
829 rewriter.eraseOp(printOp);
830 return success();
833 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
834 return IntegerType::get(intTy.getContext(), intTy.getWidth(),
835 IntegerType::Signless);
839 /// Progressive lowering of vector transfer ops: Unpack one dimension.
841 /// 1. Unpack one dimension from the current buffer type and cast the buffer
842 /// to that new type. E.g.:
843 /// ```
844 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
845 /// vector.transfer_write %vec ...
846 /// ```
847 /// The following cast is generated:
848 /// ```
849 /// %casted = vector.type_cast %0
850 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
851 /// ```
852 /// 2. Generate a for loop and rewrite the transfer op according to the
853 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
854 /// out-of-bounds, generate an if-check and handle both cases separately.
855 /// 3. Clean up according to the corresponding Strategy<OpTy>.
857 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
858 /// source (as opposed to a memref source), then each iteration of the generated
859 /// scf.for loop yields the new tensor value. E.g.:
860 /// ```
861 /// %result = scf.for i = 0 to 5 {
862 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
863 /// %1 = vector.transfer_write %0, %source[...]
864 /// : vector<4x3xf32>, tensor<5x4x3xf32>
865 /// scf.yield %1 : tensor<5x4x3xf32>
866 /// }
867 /// ```
868 template <typename OpTy>
869 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
870 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
872 void initialize() {
873 // This pattern recursively unpacks one dimension at a time. The recursion
874 // bounded as the rank is strictly decreasing.
875 this->setHasBoundedRewriteRecursion();
878 static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
879 SmallVectorImpl<Value> &loadIndices,
880 Value iv) {
881 assert(xferOp.getMask() && "Expected transfer op to have mask");
883 // Add load indices from the previous iteration.
884 // The mask buffer depends on the permutation map, which makes determining
885 // the indices quite complex, so this is why we need to "look back" to the
886 // previous iteration to find the right indices.
887 Value maskBuffer = getMaskBuffer(xferOp);
888 for (Operation *user : maskBuffer.getUsers()) {
889 // If there is no previous load op, then the indices are empty.
890 if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
891 Operation::operand_range prevIndices = loadOp.getIndices();
892 loadIndices.append(prevIndices.begin(), prevIndices.end());
893 break;
897 // In case of broadcast: Use same indices to load from memref
898 // as before.
899 if (!xferOp.isBroadcastDim(0))
900 loadIndices.push_back(iv);
903 LogicalResult matchAndRewrite(OpTy xferOp,
904 PatternRewriter &rewriter) const override {
905 if (!xferOp->hasAttr(kPassLabel))
906 return failure();
908 // Find and cast data buffer. How the buffer can be found depends on OpTy.
909 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
910 Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
911 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
912 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
913 if (failed(castedDataType))
914 return failure();
916 auto castedDataBuffer =
917 locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
919 // If the xferOp has a mask: Find and cast mask buffer.
920 Value castedMaskBuffer;
921 if (xferOp.getMask()) {
922 Value maskBuffer = getMaskBuffer(xferOp);
923 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
924 // Do not unpack a dimension of the mask, if:
925 // * To-be-unpacked transfer op dimension is a broadcast.
926 // * Mask is 1D, i.e., the mask cannot be further unpacked.
927 // (That means that all remaining dimensions of the transfer op must
928 // be broadcasted.)
929 castedMaskBuffer = maskBuffer;
930 } else {
931 // It's safe to assume the mask buffer can be unpacked if the data
932 // buffer was unpacked.
933 auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
934 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
935 castedMaskBuffer =
936 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
940 // Loop bounds and step.
941 auto lb = locB.create<arith::ConstantIndexOp>(0);
942 auto ub = locB.create<arith::ConstantIndexOp>(
943 castedDataType->getDimSize(castedDataType->getRank() - 1));
944 auto step = locB.create<arith::ConstantIndexOp>(1);
945 // TransferWriteOps that operate on tensors return the modified tensor and
946 // require a loop state.
947 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
949 // Generate for loop.
950 auto result = locB.create<scf::ForOp>(
951 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
952 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
953 Type stateType = loopState.empty() ? Type() : loopState[0].getType();
955 auto result = generateInBoundsCheck(
956 b, xferOp, iv, unpackedDim(xferOp),
957 stateType ? TypeRange(stateType) : TypeRange(),
958 /*inBoundsCase=*/
959 [&](OpBuilder &b, Location loc) {
960 // Create new transfer op.
961 OpTy newXfer = Strategy<OpTy>::rewriteOp(
962 b, this->options, xferOp, castedDataBuffer, iv, loopState);
964 // If old transfer op has a mask: Set mask on new transfer op.
965 // Special case: If the mask of the old transfer op is 1D and
966 // the unpacked dim is not a broadcast, no mask is needed on
967 // the new transfer op.
968 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
969 xferOp.getMaskType().getRank() > 1)) {
970 OpBuilder::InsertionGuard guard(b);
971 b.setInsertionPoint(newXfer); // Insert load before newXfer.
973 SmallVector<Value, 8> loadIndices;
974 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
975 loadIndices, iv);
976 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
977 loadIndices);
978 rewriter.modifyOpInPlace(newXfer, [&]() {
979 newXfer.getMaskMutable().assign(mask);
983 return loopState.empty() ? Value() : newXfer->getResult(0);
985 /*outOfBoundsCase=*/
986 [&](OpBuilder &b, Location /*loc*/) {
987 return Strategy<OpTy>::handleOutOfBoundsDim(
988 b, xferOp, castedDataBuffer, iv, loopState);
991 maybeYieldValue(b, loc, !loopState.empty(), result);
994 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
995 return success();
999 /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1000 /// and ConstantMaskOp.
1001 template <typename VscaleConstantBuilder>
1002 static FailureOr<SmallVector<OpFoldResult>>
1003 getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1004 if (!mask)
1005 return SmallVector<OpFoldResult>{};
1006 if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
1007 return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
1008 return OpFoldResult(dimSize);
1011 if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
1012 int dimIdx = 0;
1013 VectorType maskType = constantMask.getVectorType();
1014 auto indexType = IndexType::get(mask.getContext());
1015 return llvm::map_to_vector(
1016 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1017 // A scalable dim in a constant_mask means vscale x dimSize.
1018 if (maskType.getScalableDims()[dimIdx++])
1019 return OpFoldResult(createVscaleMultiple(dimSize));
1020 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1023 return failure();
1026 /// Scalable vector lowering of transfer_write(transpose). This lowering only
1027 /// supports rank 2 (scalable) vectors, but can be used in conjunction with
1028 /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1029 /// unrolls until the first scalable dimension.
1031 /// Example:
1033 /// BEFORE:
1034 /// ```mlir
1035 /// %transpose = vector.transpose %vec, [1, 0]
1036 /// : vector<4x[4]xf32> to vector<[4]x4xf32>
1037 /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1038 /// : vector<[4]x4xf32>, memref<?x?xf32>
1039 /// ```
1041 /// AFTER:
1042 /// ```mlir
1043 /// %c1 = arith.constant 1 : index
1044 /// %c4 = arith.constant 4 : index
1045 /// %c0 = arith.constant 0 : index
1046 /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1047 /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1048 /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1049 /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1050 /// %vscale = vector.vscale
1051 /// %c4_vscale = arith.muli %vscale, %c4 : index
1052 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
1053 /// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1054 /// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1055 /// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1056 /// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1057 /// %slice_i = affine.apply #map(%idx)[%i]
1058 /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1059 /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1060 /// : vector<4xf32>, memref<?x?xf32>
1061 /// }
1062 /// ```
1063 struct ScalableTransposeTransferWriteConversion
1064 : VectorToSCFPattern<vector::TransferWriteOp> {
1065 using VectorToSCFPattern::VectorToSCFPattern;
1067 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1068 PatternRewriter &rewriter) const override {
1069 if (failed(checkLowerTensors(writeOp, rewriter)))
1070 return failure();
1072 VectorType vectorType = writeOp.getVectorType();
1074 // Note: By comparing the scalable dims to an ArrayRef of length two this
1075 // implicitly checks the rank (is also two).
1076 ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
1077 if (scalableFlags != ArrayRef<bool>{true, false}) {
1078 return rewriter.notifyMatchFailure(
1079 writeOp, "expected vector of the form vector<[N]xMxty>");
1082 auto permutationMap = writeOp.getPermutationMap();
1083 if (!permutationMap.isIdentity()) {
1084 return rewriter.notifyMatchFailure(
1085 writeOp, "non-identity permutations are unsupported (lower first)");
1088 // Note: This pattern is only lowering the leading dimension (to a loop),
1089 // so we only check if the leading dimension is in bounds. The in-bounds
1090 // attribute for the trailing dimension will be propagated.
1091 if (!writeOp.isDimInBounds(0)) {
1092 return rewriter.notifyMatchFailure(
1093 writeOp, "out-of-bounds dims are unsupported (use masking)");
1096 Value vector = writeOp.getVector();
1097 auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
1098 if (!transposeOp ||
1099 transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
1100 return rewriter.notifyMatchFailure(writeOp, "source not transpose");
1103 auto loc = writeOp.getLoc();
1104 auto createVscaleMultiple =
1105 vector::makeVscaleConstantBuilder(rewriter, loc);
1107 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1108 if (failed(maskDims)) {
1109 return rewriter.notifyMatchFailure(writeOp,
1110 "failed to resolve mask dims");
1113 int64_t fixedDimSize = vectorType.getDimSize(1);
1114 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1116 // Extract all slices from the source of the transpose.
1117 auto transposeSource = transposeOp.getVector();
1118 SmallVector<Value> transposeSourceSlices =
1119 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1120 return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
1123 // Loop bounds and step.
1124 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1125 auto ub =
1126 maskDims->empty()
1127 ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
1128 : vector::getAsValues(rewriter, loc, maskDims->front()).front();
1129 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1131 // Generate a new mask for the slice.
1132 VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
1133 Value sliceMask = nullptr;
1134 if (!maskDims->empty()) {
1135 sliceMask = rewriter.create<vector::CreateMaskOp>(
1136 loc, sliceType.clone(rewriter.getI1Type()),
1137 ArrayRef<OpFoldResult>(*maskDims).drop_front());
1140 Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{};
1141 ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1142 auto result = rewriter.create<scf::ForOp>(
1143 loc, lb, ub, step, initLoopArgs,
1144 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1145 // Indices for the new transfer op.
1146 SmallVector<Value, 8> xferIndices;
1147 getXferIndices(b, writeOp, iv, xferIndices);
1149 // Extract a transposed slice from the source vector.
1150 SmallVector<Value> transposeElements =
1151 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1152 return b.create<vector::ExtractOp>(
1153 loc, transposeSourceSlices[idx], iv);
1155 auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
1156 transposeElements);
1158 // Create the transfer_write for the slice.
1159 Value dest =
1160 loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
1161 auto newWriteOp = b.create<vector::TransferWriteOp>(
1162 loc, sliceVec, dest, xferIndices,
1163 ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
1164 if (sliceMask)
1165 newWriteOp.getMaskMutable().assign(sliceMask);
1167 // Yield from the loop.
1168 b.create<scf::YieldOp>(loc, loopIterArgs.empty()
1169 ? ValueRange{}
1170 : newWriteOp.getResult());
1173 if (isTensorOp(writeOp))
1174 rewriter.replaceOp(writeOp, result);
1175 else
1176 rewriter.eraseOp(writeOp);
1178 return success();
1182 } // namespace lowering_n_d
1184 namespace lowering_n_d_unrolled {
1186 /// If the original transfer op has a mask, compute the mask of the new transfer
1187 /// op (for the current iteration `i`) and assign it.
1188 template <typename OpTy>
1189 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1190 int64_t i) {
1191 if (!xferOp.getMask())
1192 return;
1194 if (xferOp.isBroadcastDim(0)) {
1195 // To-be-unpacked dimension is a broadcast, which does not have a
1196 // corresponding mask dimension. Mask attribute remains unchanged.
1197 newXferOp.getMaskMutable().assign(xferOp.getMask());
1198 return;
1201 if (xferOp.getMaskType().getRank() > 1) {
1202 // Unpack one dimension of the mask.
1203 OpBuilder::InsertionGuard guard(b);
1204 b.setInsertionPoint(newXferOp); // Insert load before newXfer.
1206 llvm::SmallVector<int64_t, 1> indices({i});
1207 Location loc = xferOp.getLoc();
1208 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1209 newXferOp.getMaskMutable().assign(newMask);
1212 // If we end up here: The mask of the old transfer op is 1D and the unpacked
1213 // dim is not a broadcast, so no mask is needed on the new transfer op.
1214 // `generateInBoundsCheck` will have evaluated the mask already.
1217 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1218 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1219 /// memref buffer is allocated and the SCF loop is fully unrolled.
1221 /// ```
1222 /// E.g.:
1223 /// ```
1224 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1225 /// : memref<?x?x?xf32>, vector<5x4xf32>
1226 /// ```
1227 /// is rewritten to IR such as (simplified):
1228 /// ```
1229 /// %v_init = splat %padding : vector<5x4xf32>
1230 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1231 /// : memref<?x?x?xf32>, vector<4xf32>
1232 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1233 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1234 /// : memref<?x?x?xf32>, vector<4xf32>
1235 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1236 /// ...
1237 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1238 /// : memref<?x?x?xf32>, vector<4xf32>
1239 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1240 /// ```
1242 /// Note: As an optimization, if the result of the original TransferReadOp
1243 /// was directly inserted into another vector, no new %v_init vector is created.
1244 /// Instead, the new TransferReadOp results are inserted into that vector.
1245 struct UnrollTransferReadConversion
1246 : public VectorToSCFPattern<TransferReadOp> {
1247 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1249 void initialize() {
1250 // This pattern recursively unpacks one dimension at a time. The recursion
1251 // bounded as the rank is strictly decreasing.
1252 setHasBoundedRewriteRecursion();
1255 /// Get or build the vector into which the newly created TransferReadOp
1256 /// results are inserted.
1257 Value buildResultVector(PatternRewriter &rewriter,
1258 TransferReadOp xferOp) const {
1259 if (auto insertOp = getInsertOp(xferOp))
1260 return insertOp.getDest();
1261 Location loc = xferOp.getLoc();
1262 return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1263 xferOp.getPadding());
1266 /// If the result of the TransferReadOp has exactly one user, which is a
1267 /// vector::InsertOp, return that operation.
1268 vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1269 if (xferOp->hasOneUse()) {
1270 Operation *xferOpUser = *xferOp->getUsers().begin();
1271 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1272 return insertOp;
1275 return vector::InsertOp();
1278 /// If the result of the TransferReadOp has exactly one user, which is a
1279 /// vector::InsertOp, return that operation's indices.
1280 void getInsertionIndices(TransferReadOp xferOp,
1281 SmallVectorImpl<OpFoldResult> &indices) const {
1282 if (auto insertOp = getInsertOp(xferOp)) {
1283 auto pos = insertOp.getMixedPosition();
1284 indices.append(pos.begin(), pos.end());
1288 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1289 /// accesses, and broadcasts and transposes in permutation maps.
1290 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1291 PatternRewriter &rewriter) const override {
1292 if (xferOp.getVectorType().getRank() <= options.targetRank)
1293 return rewriter.notifyMatchFailure(
1294 xferOp, "vector rank is less or equal to target rank");
1295 if (failed(checkLowerTensors(xferOp, rewriter)))
1296 return failure();
1297 // Transfer ops that modify the element type are not supported atm.
1298 if (xferOp.getVectorType().getElementType() !=
1299 xferOp.getShapedType().getElementType())
1300 return rewriter.notifyMatchFailure(
1301 xferOp, "not yet supported: element type mismatch");
1302 auto xferVecType = xferOp.getVectorType();
1303 if (xferVecType.getScalableDims()[0]) {
1304 // Cannot unroll a scalable dimension at compile time.
1305 return rewriter.notifyMatchFailure(
1306 xferOp, "scalable dimensions cannot be unrolled");
1309 auto insertOp = getInsertOp(xferOp);
1310 auto vec = buildResultVector(rewriter, xferOp);
1311 auto vecType = dyn_cast<VectorType>(vec.getType());
1313 VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
1315 int64_t dimSize = xferVecType.getShape()[0];
1317 // Generate fully unrolled loop of transfer ops.
1318 Location loc = xferOp.getLoc();
1319 for (int64_t i = 0; i < dimSize; ++i) {
1320 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1322 vec = generateInBoundsCheck(
1323 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
1324 /*inBoundsCase=*/
1325 [&](OpBuilder &b, Location loc) {
1326 // Indices for the new transfer op.
1327 SmallVector<Value, 8> xferIndices;
1328 getXferIndices(b, xferOp, iv, xferIndices);
1330 // Indices for the new vector.insert op.
1331 SmallVector<OpFoldResult, 8> insertionIndices;
1332 getInsertionIndices(xferOp, insertionIndices);
1333 insertionIndices.push_back(rewriter.getIndexAttr(i));
1335 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1336 auto newXferOp = b.create<vector::TransferReadOp>(
1337 loc, newXferVecType, xferOp.getSource(), xferIndices,
1338 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1339 xferOp.getPadding(), Value(), inBoundsAttr);
1340 maybeAssignMask(b, xferOp, newXferOp, i);
1341 return b.create<vector::InsertOp>(loc, newXferOp, vec,
1342 insertionIndices);
1344 /*outOfBoundsCase=*/
1345 [&](OpBuilder &b, Location loc) {
1346 // Loop through original (unmodified) vector.
1347 return vec;
1351 if (insertOp) {
1352 // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1353 rewriter.replaceOp(insertOp, vec);
1354 rewriter.eraseOp(xferOp);
1355 } else {
1356 rewriter.replaceOp(xferOp, vec);
1359 return success();
1363 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1364 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1365 /// memref buffer is allocated and the SCF loop is fully unrolled.
1367 /// ```
1368 /// E.g.:
1369 /// ```
1370 /// vector.transfer_write %vec, %A[%a, %b, %c]
1371 /// : vector<5x4xf32>, memref<?x?x?xf32>
1372 /// ```
1373 /// is rewritten to IR such as (simplified):
1374 /// ```
1375 /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1376 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1377 /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1378 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1379 /// ...
1380 /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1381 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1382 /// ```
1384 /// Note: As an optimization, if the vector of the original TransferWriteOp
1385 /// was directly extracted from another vector via an ExtractOp `a`, extract
1386 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1387 /// doing so, `a` may become dead, and the number of ExtractOps generated during
1388 /// recursive application of this pattern will be minimal.
1389 struct UnrollTransferWriteConversion
1390 : public VectorToSCFPattern<TransferWriteOp> {
1391 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1393 void initialize() {
1394 // This pattern recursively unpacks one dimension at a time. The recursion
1395 // bounded as the rank is strictly decreasing.
1396 setHasBoundedRewriteRecursion();
1399 /// Return the vector from which newly generated ExtracOps will extract.
1400 Value getDataVector(TransferWriteOp xferOp) const {
1401 if (auto extractOp = getExtractOp(xferOp))
1402 return extractOp.getVector();
1403 return xferOp.getVector();
1406 /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1407 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1408 if (auto *op = xferOp.getVector().getDefiningOp())
1409 return dyn_cast<vector::ExtractOp>(op);
1410 return vector::ExtractOp();
1413 /// If the input of the given TransferWriteOp is an ExtractOp, return its
1414 /// indices.
1415 void getExtractionIndices(TransferWriteOp xferOp,
1416 SmallVectorImpl<OpFoldResult> &indices) const {
1417 if (auto extractOp = getExtractOp(xferOp)) {
1418 auto pos = extractOp.getMixedPosition();
1419 indices.append(pos.begin(), pos.end());
1423 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1424 /// accesses, and broadcasts and transposes in permutation maps.
1425 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1426 PatternRewriter &rewriter) const override {
1427 VectorType inputVectorTy = xferOp.getVectorType();
1429 if (inputVectorTy.getRank() <= options.targetRank)
1430 return failure();
1432 if (failed(checkLowerTensors(xferOp, rewriter)))
1433 return failure();
1434 // Transfer ops that modify the element type are not supported atm.
1435 if (inputVectorTy.getElementType() !=
1436 xferOp.getShapedType().getElementType())
1437 return failure();
1439 auto vec = getDataVector(xferOp);
1440 if (inputVectorTy.getScalableDims()[0]) {
1441 // Cannot unroll a scalable dimension at compile time.
1442 return failure();
1445 int64_t dimSize = inputVectorTy.getShape()[0];
1446 Value source = xferOp.getSource(); // memref or tensor to be written to.
1447 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1449 // Generate fully unrolled loop of transfer ops.
1450 Location loc = xferOp.getLoc();
1451 for (int64_t i = 0; i < dimSize; ++i) {
1452 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1454 auto updatedSource = generateInBoundsCheck(
1455 rewriter, xferOp, iv, unpackedDim(xferOp),
1456 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1457 /*inBoundsCase=*/
1458 [&](OpBuilder &b, Location loc) {
1459 // Indices for the new transfer op.
1460 SmallVector<Value, 8> xferIndices;
1461 getXferIndices(b, xferOp, iv, xferIndices);
1463 // Indices for the new vector.extract op.
1464 SmallVector<OpFoldResult, 8> extractionIndices;
1465 getExtractionIndices(xferOp, extractionIndices);
1466 extractionIndices.push_back(b.getI64IntegerAttr(i));
1468 auto extracted =
1469 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1470 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1471 Value xferVec;
1472 if (inputVectorTy.getRank() == 1) {
1473 // When target-rank=0, unrolling would causes the vector input
1474 // argument into `transfer_write` to become a scalar. We solve
1475 // this by broadcasting the scalar to a 0D vector.
1476 xferVec = b.create<vector::BroadcastOp>(
1477 loc, VectorType::get({}, extracted.getType()), extracted);
1478 } else {
1479 xferVec = extracted;
1481 auto newXferOp = b.create<vector::TransferWriteOp>(
1482 loc, sourceType, xferVec, source, xferIndices,
1483 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1484 inBoundsAttr);
1486 maybeAssignMask(b, xferOp, newXferOp, i);
1488 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1490 /*outOfBoundsCase=*/
1491 [&](OpBuilder &b, Location loc) {
1492 return isTensorOp(xferOp) ? source : Value();
1495 if (isTensorOp(xferOp))
1496 source = updatedSource;
1499 if (isTensorOp(xferOp))
1500 rewriter.replaceOp(xferOp, source);
1501 else
1502 rewriter.eraseOp(xferOp);
1504 return success();
1508 } // namespace lowering_n_d_unrolled
1510 namespace lowering_1_d {
1512 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1513 /// part of TransferOp1dConversion. Return the memref dimension on which
1514 /// the transfer is operating. A return value of std::nullopt indicates a
1515 /// broadcast.
1516 template <typename OpTy>
1517 static std::optional<int64_t>
1518 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1519 SmallVector<Value, 8> &memrefIndices) {
1520 auto indices = xferOp.getIndices();
1521 auto map = xferOp.getPermutationMap();
1522 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1524 memrefIndices.append(indices.begin(), indices.end());
1525 assert(map.getNumResults() == 1 &&
1526 "Expected 1 permutation map result for 1D transfer");
1527 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1528 Location loc = xferOp.getLoc();
1529 auto dim = expr.getPosition();
1530 AffineExpr d0, d1;
1531 bindDims(xferOp.getContext(), d0, d1);
1532 Value offset = memrefIndices[dim];
1533 memrefIndices[dim] =
1534 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1535 return dim;
1538 assert(xferOp.isBroadcastDim(0) &&
1539 "Expected AffineDimExpr or AffineConstantExpr");
1540 return std::nullopt;
1543 /// Codegen strategy for TransferOp1dConversion, depending on the
1544 /// operation.
1545 template <typename OpTy>
1546 struct Strategy1d;
1548 /// Codegen strategy for TransferReadOp.
1549 template <>
1550 struct Strategy1d<TransferReadOp> {
1551 static void generateForLoopBody(OpBuilder &b, Location loc,
1552 TransferReadOp xferOp, Value iv,
1553 ValueRange loopState) {
1554 SmallVector<Value, 8> indices;
1555 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1556 auto vec = loopState[0];
1558 // In case of out-of-bounds access, leave `vec` as is (was initialized with
1559 // padding value).
1560 auto nextVec = generateInBoundsCheck(
1561 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1562 /*inBoundsCase=*/
1563 [&](OpBuilder &b, Location loc) {
1564 Value val =
1565 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1566 return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1568 /*outOfBoundsCase=*/
1569 [&](OpBuilder & /*b*/, Location loc) { return vec; });
1570 b.create<scf::YieldOp>(loc, nextVec);
1573 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1574 // Inititalize vector with padding value.
1575 Location loc = xferOp.getLoc();
1576 return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1577 xferOp.getPadding());
1581 /// Codegen strategy for TransferWriteOp.
1582 template <>
1583 struct Strategy1d<TransferWriteOp> {
1584 static void generateForLoopBody(OpBuilder &b, Location loc,
1585 TransferWriteOp xferOp, Value iv,
1586 ValueRange /*loopState*/) {
1587 SmallVector<Value, 8> indices;
1588 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1590 // Nothing to do in case of out-of-bounds access.
1591 generateInBoundsCheck(
1592 b, xferOp, iv, dim,
1593 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1594 auto val =
1595 b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1596 b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1598 b.create<scf::YieldOp>(loc);
1601 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1602 return Value();
1606 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1607 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1608 /// vector load/stores due to non-unit strides or broadcasts:
1610 /// * Transfer dimension is not the last memref dimension
1611 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1612 /// * Memref has a layout map with non-unit stride on the last dimension
1614 /// This pattern generates IR as follows:
1616 /// 1. Generate a for loop iterating over each vector element.
1617 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1618 /// depending on OpTy.
1620 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1621 /// can be generated instead of TransferOp1dConversion. Add such a pattern
1622 /// to ConvertVectorToLLVM.
1624 /// E.g.:
1625 /// ```
1626 /// vector.transfer_write %vec, %A[%a, %b]
1627 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1628 /// : vector<9xf32>, memref<?x?xf32>
1629 /// ```
1630 /// Is rewritten to approximately the following pseudo-IR:
1631 /// ```
1632 /// for i = 0 to 9 {
1633 /// %t = vector.extractelement %vec[i] : vector<9xf32>
1634 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1635 /// }
1636 /// ```
1637 template <typename OpTy>
1638 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1639 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1641 LogicalResult matchAndRewrite(OpTy xferOp,
1642 PatternRewriter &rewriter) const override {
1643 // TODO: support 0-d corner case.
1644 if (xferOp.getTransferRank() == 0)
1645 return failure();
1646 auto map = xferOp.getPermutationMap();
1647 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1649 if (!memRefType)
1650 return failure();
1651 if (xferOp.getVectorType().getRank() != 1)
1652 return failure();
1653 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1654 return failure(); // Handled by ConvertVectorToLLVM
1656 // Loop bounds, step, state...
1657 Location loc = xferOp.getLoc();
1658 auto vecType = xferOp.getVectorType();
1659 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1660 Value ub =
1661 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1662 if (vecType.isScalable()) {
1663 Value vscale =
1664 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
1665 ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
1667 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1668 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1670 // Generate for loop.
1671 rewriter.replaceOpWithNewOp<scf::ForOp>(
1672 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1673 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1674 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1677 return success();
1681 } // namespace lowering_1_d
1682 } // namespace
1684 void mlir::populateVectorToSCFConversionPatterns(
1685 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1686 if (options.unroll) {
1687 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1688 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1689 patterns.getContext(), options);
1690 } else {
1691 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1692 lowering_n_d::PrepareTransferWriteConversion,
1693 lowering_n_d::TransferOpConversion<TransferReadOp>,
1694 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1695 patterns.getContext(), options);
1697 if (options.lowerScalable) {
1698 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1699 patterns.getContext(), options);
1701 if (options.targetRank == 1) {
1702 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1703 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1704 patterns.getContext(), options);
1706 patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(),
1707 options);
1710 namespace {
1712 struct ConvertVectorToSCFPass
1713 : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1714 ConvertVectorToSCFPass() = default;
1715 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1716 this->fullUnroll = options.unroll;
1717 this->targetRank = options.targetRank;
1718 this->lowerTensors = options.lowerTensors;
1719 this->lowerScalable = options.lowerScalable;
1722 void runOnOperation() override {
1723 VectorTransferToSCFOptions options;
1724 options.unroll = fullUnroll;
1725 options.targetRank = targetRank;
1726 options.lowerTensors = lowerTensors;
1727 options.lowerScalable = lowerScalable;
1729 // Lower permutation maps first.
1730 RewritePatternSet lowerTransferPatterns(&getContext());
1731 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1732 lowerTransferPatterns);
1733 (void)applyPatternsAndFoldGreedily(getOperation(),
1734 std::move(lowerTransferPatterns));
1736 RewritePatternSet patterns(&getContext());
1737 populateVectorToSCFConversionPatterns(patterns, options);
1738 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1742 } // namespace
1744 std::unique_ptr<Pass>
1745 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1746 return std::make_unique<ConvertVectorToSCFPass>(options);