1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements lowering of vector transfer operations to SCF.
11 //===----------------------------------------------------------------------===//
15 #include <type_traits>
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
26 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/ImplicitLocOpBuilder.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "mlir/Transforms/Passes.h"
35 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
36 #include "mlir/Conversion/Passes.h.inc"
40 using vector::TransferReadOp
;
41 using vector::TransferWriteOp
;
45 /// Attribute name used for labeling transfer ops during progressive lowering.
46 static const char kPassLabel
[] = "__vector_to_scf_lowering__";
48 /// Return true if this transfer op operates on a source tensor.
49 static bool isTensorOp(VectorTransferOpInterface xferOp
) {
50 if (isa
<RankedTensorType
>(xferOp
.getShapedType())) {
51 if (isa
<vector::TransferWriteOp
>(xferOp
)) {
52 // TransferWriteOps on tensors have a result.
53 assert(xferOp
->getNumResults() > 0);
60 /// Patterns that inherit from this struct have access to
61 /// VectorTransferToSCFOptions.
62 template <typename OpTy
>
63 struct VectorToSCFPattern
: public OpRewritePattern
<OpTy
> {
64 explicit VectorToSCFPattern(MLIRContext
*context
,
65 VectorTransferToSCFOptions opt
)
66 : OpRewritePattern
<OpTy
>(context
), options(opt
) {}
68 LogicalResult
checkLowerTensors(VectorTransferOpInterface xferOp
,
69 PatternRewriter
&rewriter
) const {
70 if (isTensorOp(xferOp
) && !options
.lowerTensors
) {
71 return rewriter
.notifyMatchFailure(
72 xferOp
, "lowering tensor transfers is disabled");
77 VectorTransferToSCFOptions options
;
80 /// Given a vector transfer op, calculate which dimension of the `source`
81 /// memref should be unpacked in the next application of TransferOpConversion.
82 /// A return value of std::nullopt indicates a broadcast.
83 template <typename OpTy
>
84 static std::optional
<int64_t> unpackedDim(OpTy xferOp
) {
85 // TODO: support 0-d corner case.
86 assert(xferOp
.getTransferRank() > 0 && "unexpected 0-d transfer");
87 auto map
= xferOp
.getPermutationMap();
88 if (auto expr
= dyn_cast
<AffineDimExpr
>(map
.getResult(0))) {
89 return expr
.getPosition();
91 assert(xferOp
.isBroadcastDim(0) &&
92 "Expected AffineDimExpr or AffineConstantExpr");
96 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
97 /// map is identical to the current permutation map, but the first result is
99 template <typename OpTy
>
100 static AffineMap
unpackedPermutationMap(OpBuilder
&b
, OpTy xferOp
) {
101 // TODO: support 0-d corner case.
102 assert(xferOp
.getTransferRank() > 0 && "unexpected 0-d transfer");
103 auto map
= xferOp
.getPermutationMap();
104 return AffineMap::get(map
.getNumDims(), 0, map
.getResults().drop_front(),
108 /// Calculate the indices for the new vector transfer op.
110 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
111 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
113 /// `iv` is the iteration variable of the (new) surrounding loop.
114 template <typename OpTy
>
115 static void getXferIndices(OpBuilder
&b
, OpTy xferOp
, Value iv
,
116 SmallVector
<Value
, 8> &indices
) {
117 typename
OpTy::Adaptor
adaptor(xferOp
);
118 // Corresponding memref dim of the vector dim that is unpacked.
119 auto dim
= unpackedDim(xferOp
);
120 auto prevIndices
= adaptor
.getIndices();
121 indices
.append(prevIndices
.begin(), prevIndices
.end());
123 Location loc
= xferOp
.getLoc();
124 bool isBroadcast
= !dim
.has_value();
127 bindDims(xferOp
.getContext(), d0
, d1
);
128 Value offset
= adaptor
.getIndices()[*dim
];
130 affine::makeComposedAffineApply(b
, loc
, d0
+ d1
, {offset
, iv
});
134 static void maybeYieldValue(OpBuilder
&b
, Location loc
, bool hasRetVal
,
137 assert(value
&& "Expected non-empty value");
138 b
.create
<scf::YieldOp
>(loc
, value
);
140 b
.create
<scf::YieldOp
>(loc
);
144 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
145 /// is set to true. No such check is generated under following circumstances:
146 /// * xferOp does not have a mask.
147 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
148 /// computed and attached to the new transfer op in the pattern.)
149 /// * The to-be-unpacked dim of xferOp is a broadcast.
150 template <typename OpTy
>
151 static Value
generateMaskCheck(OpBuilder
&b
, OpTy xferOp
, Value iv
) {
152 if (!xferOp
.getMask())
154 if (xferOp
.getMaskType().getRank() != 1)
156 if (xferOp
.isBroadcastDim(0))
159 Location loc
= xferOp
.getLoc();
160 return b
.create
<vector::ExtractElementOp
>(loc
, xferOp
.getMask(), iv
);
163 /// Helper function TransferOpConversion and TransferOp1dConversion.
164 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
165 /// specified dimension `dim` with the loop iteration variable `iv`.
166 /// E.g., when unpacking dimension 0 from:
168 /// %vec = vector.transfer_read %A[%a, %b] %cst
169 /// : vector<5x4xf32>, memref<?x?xf32>
171 /// An if check similar to this will be generated inside the loop:
173 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
174 /// if (%a + iv < %d) {
177 /// (out-of-bounds case)
181 /// If the transfer is 1D and has a mask, this function generates a more complex
182 /// check also accounts for potentially masked out elements.
184 /// This function variant returns the value returned by `inBoundsCase` or
185 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
187 template <typename OpTy
>
188 static Value
generateInBoundsCheck(
189 OpBuilder
&b
, OpTy xferOp
, Value iv
, std::optional
<int64_t> dim
,
190 TypeRange resultTypes
,
191 function_ref
<Value(OpBuilder
&, Location
)> inBoundsCase
,
192 function_ref
<Value(OpBuilder
&, Location
)> outOfBoundsCase
= nullptr) {
193 bool hasRetVal
= !resultTypes
.empty();
194 Value cond
; // Condition to be built...
196 // Condition check 1: Access in-bounds?
197 bool isBroadcast
= !dim
; // No in-bounds check for broadcasts.
198 Location loc
= xferOp
.getLoc();
199 ImplicitLocOpBuilder
lb(xferOp
.getLoc(), b
);
200 if (!xferOp
.isDimInBounds(0) && !isBroadcast
) {
202 vector::createOrFoldDimOp(b
, loc
, xferOp
.getSource(), *dim
);
204 bindDims(xferOp
.getContext(), d0
, d1
);
205 Value base
= xferOp
.getIndices()[*dim
];
207 affine::makeComposedAffineApply(b
, loc
, d0
+ d1
, {base
, iv
});
208 cond
= lb
.create
<arith::CmpIOp
>(arith::CmpIPredicate::sgt
, memrefDim
,
212 // Condition check 2: Masked in?
213 if (auto maskCond
= generateMaskCheck(b
, xferOp
, iv
)) {
215 cond
= lb
.create
<arith::AndIOp
>(cond
, maskCond
);
220 // If the condition is non-empty, generate an SCF::IfOp.
222 auto check
= lb
.create
<scf::IfOp
>(
225 [&](OpBuilder
&b
, Location loc
) {
226 maybeYieldValue(b
, loc
, hasRetVal
, inBoundsCase(b
, loc
));
229 [&](OpBuilder
&b
, Location loc
) {
230 if (outOfBoundsCase
) {
231 maybeYieldValue(b
, loc
, hasRetVal
, outOfBoundsCase(b
, loc
));
233 b
.create
<scf::YieldOp
>(loc
);
237 return hasRetVal
? check
.getResult(0) : Value();
240 // Condition is empty, no need for an SCF::IfOp.
241 return inBoundsCase(b
, loc
);
244 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
245 /// a return value. Consequently, this function does not have a return value.
246 template <typename OpTy
>
247 static void generateInBoundsCheck(
248 OpBuilder
&b
, OpTy xferOp
, Value iv
, std::optional
<int64_t> dim
,
249 function_ref
<void(OpBuilder
&, Location
)> inBoundsCase
,
250 function_ref
<void(OpBuilder
&, Location
)> outOfBoundsCase
= nullptr) {
251 generateInBoundsCheck(
252 b
, xferOp
, iv
, dim
, /*resultTypes=*/TypeRange(),
254 [&](OpBuilder
&b
, Location loc
) {
255 inBoundsCase(b
, loc
);
259 [&](OpBuilder
&b
, Location loc
) {
261 outOfBoundsCase(b
, loc
);
266 /// Given an ArrayAttr, return a copy where the first element is dropped.
267 static ArrayAttr
dropFirstElem(OpBuilder
&b
, ArrayAttr attr
) {
270 return ArrayAttr::get(b
.getContext(), attr
.getValue().drop_front());
273 /// Add the pass label to a vector transfer op if its rank is not the target
275 template <typename OpTy
>
276 static void maybeApplyPassLabel(OpBuilder
&b
, OpTy newXferOp
,
277 unsigned targetRank
) {
278 if (newXferOp
.getVectorType().getRank() > targetRank
)
279 newXferOp
->setAttr(kPassLabel
, b
.getUnitAttr());
282 namespace lowering_n_d
{
284 /// Helper data structure for data and mask buffers.
285 struct BufferAllocs
{
290 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
291 static Operation
*getAutomaticAllocationScope(Operation
*op
) {
293 op
->getParentWithTrait
<OpTrait::AutomaticAllocationScope
>();
294 assert(scope
&& "Expected op to be inside automatic allocation scope");
298 /// Allocate temporary buffers for data (vector) and mask (if present).
299 template <typename OpTy
>
300 static BufferAllocs
allocBuffers(OpBuilder
&b
, OpTy xferOp
) {
301 Location loc
= xferOp
.getLoc();
302 OpBuilder::InsertionGuard
guard(b
);
303 Operation
*scope
= getAutomaticAllocationScope(xferOp
);
304 assert(scope
->getNumRegions() == 1 &&
305 "AutomaticAllocationScope with >1 regions");
306 b
.setInsertionPointToStart(&scope
->getRegion(0).front());
309 auto bufferType
= MemRefType::get({}, xferOp
.getVectorType());
310 result
.dataBuffer
= b
.create
<memref::AllocaOp
>(loc
, bufferType
);
312 if (xferOp
.getMask()) {
313 auto maskType
= MemRefType::get({}, xferOp
.getMask().getType());
314 auto maskBuffer
= b
.create
<memref::AllocaOp
>(loc
, maskType
);
315 b
.setInsertionPoint(xferOp
);
316 b
.create
<memref::StoreOp
>(loc
, xferOp
.getMask(), maskBuffer
);
317 result
.maskBuffer
= b
.create
<memref::LoadOp
>(loc
, maskBuffer
, ValueRange());
323 /// Given a MemRefType with VectorType element type, unpack one dimension from
324 /// the VectorType into the MemRefType.
326 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
327 static FailureOr
<MemRefType
> unpackOneDim(MemRefType type
) {
328 auto vectorType
= dyn_cast
<VectorType
>(type
.getElementType());
329 // Vectors with leading scalable dims are not supported.
330 // It may be possible to support these in future by using dynamic memref dims.
331 if (vectorType
.getScalableDims().front())
333 auto memrefShape
= type
.getShape();
334 SmallVector
<int64_t, 8> newMemrefShape
;
335 newMemrefShape
.append(memrefShape
.begin(), memrefShape
.end());
336 newMemrefShape
.push_back(vectorType
.getDimSize(0));
337 return MemRefType::get(newMemrefShape
,
338 VectorType::Builder(vectorType
).dropDim(0));
341 /// Given a transfer op, find the memref from which the mask is loaded. This
342 /// is similar to Strategy<TransferWriteOp>::getBuffer.
343 template <typename OpTy
>
344 static Value
getMaskBuffer(OpTy xferOp
) {
345 assert(xferOp
.getMask() && "Expected that transfer op has mask");
346 auto loadOp
= xferOp
.getMask().template getDefiningOp
<memref::LoadOp
>();
347 assert(loadOp
&& "Expected transfer op mask produced by LoadOp");
348 return loadOp
.getMemRef();
351 /// Codegen strategy, depending on the operation.
352 template <typename OpTy
>
355 /// Code strategy for vector TransferReadOp.
357 struct Strategy
<TransferReadOp
> {
358 /// Find the StoreOp that is used for writing the current TransferReadOp's
359 /// result to the temporary buffer allocation.
360 static memref::StoreOp
getStoreOp(TransferReadOp xferOp
) {
361 assert(xferOp
->hasOneUse() && "Expected exactly one use of TransferReadOp");
362 auto storeOp
= dyn_cast
<memref::StoreOp
>((*xferOp
->use_begin()).getOwner());
363 assert(storeOp
&& "Expected TransferReadOp result used by StoreOp");
367 /// Find the temporary buffer allocation. All labeled TransferReadOps are
368 /// used like this, where %buf is either the buffer allocation or a type cast
369 /// of the buffer allocation:
371 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
372 /// memref.store %vec, %buf[...] ...
374 static Value
getBuffer(TransferReadOp xferOp
) {
375 return getStoreOp(xferOp
).getMemRef();
378 /// Retrieve the indices of the current StoreOp that stores into the buffer.
379 static void getBufferIndices(TransferReadOp xferOp
,
380 SmallVector
<Value
, 8> &indices
) {
381 auto storeOp
= getStoreOp(xferOp
);
382 auto prevIndices
= memref::StoreOpAdaptor(storeOp
).getIndices();
383 indices
.append(prevIndices
.begin(), prevIndices
.end());
386 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
387 /// accesses on the to-be-unpacked dimension.
389 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
391 /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
395 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
396 /// : memref<?x?x?xf32>, vector<4x3xf32>
397 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
401 /// %casted = vector.type_cast %buf
402 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
403 /// for %j = 0 to 4 {
404 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
405 /// : memref<?x?x?xf32>, vector<3xf32>
406 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
410 /// Note: The loop and type cast are generated in TransferOpConversion.
411 /// The original TransferReadOp and store op are deleted in `cleanup`.
412 /// Note: The `mask` operand is set in TransferOpConversion.
413 static TransferReadOp
rewriteOp(OpBuilder
&b
,
414 VectorTransferToSCFOptions options
,
415 TransferReadOp xferOp
, Value buffer
, Value iv
,
416 ValueRange
/*loopState*/) {
417 SmallVector
<Value
, 8> storeIndices
;
418 getBufferIndices(xferOp
, storeIndices
);
419 storeIndices
.push_back(iv
);
421 SmallVector
<Value
, 8> xferIndices
;
422 getXferIndices(b
, xferOp
, iv
, xferIndices
);
424 Location loc
= xferOp
.getLoc();
425 auto bufferType
= dyn_cast
<ShapedType
>(buffer
.getType());
426 auto vecType
= dyn_cast
<VectorType
>(bufferType
.getElementType());
427 auto inBoundsAttr
= dropFirstElem(b
, xferOp
.getInBoundsAttr());
428 auto newXferOp
= b
.create
<vector::TransferReadOp
>(
429 loc
, vecType
, xferOp
.getSource(), xferIndices
,
430 AffineMapAttr::get(unpackedPermutationMap(b
, xferOp
)),
431 xferOp
.getPadding(), Value(), inBoundsAttr
);
433 maybeApplyPassLabel(b
, newXferOp
, options
.targetRank
);
435 b
.create
<memref::StoreOp
>(loc
, newXferOp
.getVector(), buffer
, storeIndices
);
439 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
440 /// padding value to the temporary buffer.
441 static Value
handleOutOfBoundsDim(OpBuilder
&b
, TransferReadOp xferOp
,
442 Value buffer
, Value iv
,
443 ValueRange
/*loopState*/) {
444 SmallVector
<Value
, 8> storeIndices
;
445 getBufferIndices(xferOp
, storeIndices
);
446 storeIndices
.push_back(iv
);
448 Location loc
= xferOp
.getLoc();
449 auto bufferType
= dyn_cast
<ShapedType
>(buffer
.getType());
450 auto vecType
= dyn_cast
<VectorType
>(bufferType
.getElementType());
451 auto vec
= b
.create
<vector::SplatOp
>(loc
, vecType
, xferOp
.getPadding());
452 b
.create
<memref::StoreOp
>(loc
, vec
, buffer
, storeIndices
);
457 /// Cleanup after rewriting the op.
458 static void cleanup(PatternRewriter
&rewriter
, TransferReadOp xferOp
,
459 scf::ForOp
/*forOp*/) {
460 rewriter
.eraseOp(getStoreOp(xferOp
));
461 rewriter
.eraseOp(xferOp
);
464 /// Return the initial loop state for the generated scf.for loop.
465 static Value
initialLoopState(TransferReadOp xferOp
) { return Value(); }
468 /// Codegen strategy for vector TransferWriteOp.
470 struct Strategy
<TransferWriteOp
> {
471 /// Find the temporary buffer allocation. All labeled TransferWriteOps are
472 /// used like this, where %buf is either the buffer allocation or a type cast
473 /// of the buffer allocation:
475 /// %vec = memref.load %buf[...] ...
476 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
478 static Value
getBuffer(TransferWriteOp xferOp
) {
479 auto loadOp
= xferOp
.getVector().getDefiningOp
<memref::LoadOp
>();
480 assert(loadOp
&& "Expected transfer op vector produced by LoadOp");
481 return loadOp
.getMemRef();
484 /// Retrieve the indices of the current LoadOp that loads from the buffer.
485 static void getBufferIndices(TransferWriteOp xferOp
,
486 SmallVector
<Value
, 8> &indices
) {
487 auto loadOp
= xferOp
.getVector().getDefiningOp
<memref::LoadOp
>();
488 auto prevIndices
= memref::LoadOpAdaptor(loadOp
).getIndices();
489 indices
.append(prevIndices
.begin(), prevIndices
.end());
492 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
493 /// accesses on the to-be-unpacked dimension.
495 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
496 /// using the loop iteration variable `iv`.
497 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
500 /// Note: For more details, see comments on Strategy<TransferReadOp>.
501 static TransferWriteOp
rewriteOp(OpBuilder
&b
,
502 VectorTransferToSCFOptions options
,
503 TransferWriteOp xferOp
, Value buffer
,
504 Value iv
, ValueRange loopState
) {
505 SmallVector
<Value
, 8> loadIndices
;
506 getBufferIndices(xferOp
, loadIndices
);
507 loadIndices
.push_back(iv
);
509 SmallVector
<Value
, 8> xferIndices
;
510 getXferIndices(b
, xferOp
, iv
, xferIndices
);
512 Location loc
= xferOp
.getLoc();
513 auto vec
= b
.create
<memref::LoadOp
>(loc
, buffer
, loadIndices
);
514 auto inBoundsAttr
= dropFirstElem(b
, xferOp
.getInBoundsAttr());
515 auto source
= loopState
.empty() ? xferOp
.getSource() : loopState
[0];
516 Type type
= isTensorOp(xferOp
) ? xferOp
.getShapedType() : Type();
517 auto newXferOp
= b
.create
<vector::TransferWriteOp
>(
518 loc
, type
, vec
, source
, xferIndices
,
519 AffineMapAttr::get(unpackedPermutationMap(b
, xferOp
)), Value(),
522 maybeApplyPassLabel(b
, newXferOp
, options
.targetRank
);
527 /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
528 static Value
handleOutOfBoundsDim(OpBuilder
&b
, TransferWriteOp xferOp
,
529 Value buffer
, Value iv
,
530 ValueRange loopState
) {
531 return isTensorOp(xferOp
) ? loopState
[0] : Value();
534 /// Cleanup after rewriting the op.
535 static void cleanup(PatternRewriter
&rewriter
, TransferWriteOp xferOp
,
537 if (isTensorOp(xferOp
)) {
538 assert(forOp
->getNumResults() == 1 && "Expected one for loop result");
539 rewriter
.replaceOp(xferOp
, forOp
->getResult(0));
541 rewriter
.eraseOp(xferOp
);
545 /// Return the initial loop state for the generated scf.for loop.
546 static Value
initialLoopState(TransferWriteOp xferOp
) {
547 return isTensorOp(xferOp
) ? xferOp
.getSource() : Value();
551 template <typename OpTy
>
552 LogicalResult
checkPrepareXferOp(OpTy xferOp
,
553 VectorTransferToSCFOptions options
) {
554 if (xferOp
->hasAttr(kPassLabel
))
556 if (xferOp
.getVectorType().getRank() <= options
.targetRank
)
558 // Currently the unpacking of the leading dimension into the memref is not
559 // supported for scalable dimensions.
560 if (xferOp
.getVectorType().getScalableDims().front())
562 if (isTensorOp(xferOp
) && !options
.lowerTensors
)
564 // Transfer ops that modify the element type are not supported atm.
565 if (xferOp
.getVectorType().getElementType() !=
566 xferOp
.getShapedType().getElementType())
571 /// Prepare a TransferReadOp for progressive lowering.
573 /// 1. Allocate a temporary buffer.
574 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
575 /// 3. Store the result of the TransferReadOp into the temporary buffer.
576 /// 4. Load the result from the temporary buffer and replace all uses of the
577 /// original TransferReadOp with this load.
581 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
582 /// : vector<5x4xf32>, memref<?x?x?xf32>
586 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
587 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
588 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
589 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
590 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
593 /// Note: A second temporary buffer may be allocated for the `mask` operand.
594 struct PrepareTransferReadConversion
595 : public VectorToSCFPattern
<TransferReadOp
> {
596 using VectorToSCFPattern
<TransferReadOp
>::VectorToSCFPattern
;
598 LogicalResult
matchAndRewrite(TransferReadOp xferOp
,
599 PatternRewriter
&rewriter
) const override
{
600 if (checkPrepareXferOp(xferOp
, options
).failed())
603 auto buffers
= allocBuffers(rewriter
, xferOp
);
604 auto *newXfer
= rewriter
.clone(*xferOp
.getOperation());
605 newXfer
->setAttr(kPassLabel
, rewriter
.getUnitAttr());
606 if (xferOp
.getMask()) {
607 dyn_cast
<TransferReadOp
>(newXfer
).getMaskMutable().assign(
611 Location loc
= xferOp
.getLoc();
612 rewriter
.create
<memref::StoreOp
>(loc
, newXfer
->getResult(0),
614 rewriter
.replaceOpWithNewOp
<memref::LoadOp
>(xferOp
, buffers
.dataBuffer
);
620 /// Prepare a TransferWriteOp for progressive lowering.
622 /// 1. Allocate a temporary buffer.
623 /// 2. Store the vector into the buffer.
624 /// 3. Load the vector from the buffer again.
625 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
626 /// marking it eligible for progressive lowering via TransferOpConversion.
630 /// vector.transfer_write %vec, %A[%a, %b, %c]
631 /// : vector<5x4xf32>, memref<?x?x?xf32>
635 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
636 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
637 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
638 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
639 /// : vector<5x4xf32>, memref<?x?x?xf32>
642 /// Note: A second temporary buffer may be allocated for the `mask` operand.
643 struct PrepareTransferWriteConversion
644 : public VectorToSCFPattern
<TransferWriteOp
> {
645 using VectorToSCFPattern
<TransferWriteOp
>::VectorToSCFPattern
;
647 LogicalResult
matchAndRewrite(TransferWriteOp xferOp
,
648 PatternRewriter
&rewriter
) const override
{
649 if (checkPrepareXferOp(xferOp
, options
).failed())
652 Location loc
= xferOp
.getLoc();
653 auto buffers
= allocBuffers(rewriter
, xferOp
);
654 rewriter
.create
<memref::StoreOp
>(loc
, xferOp
.getVector(),
656 auto loadedVec
= rewriter
.create
<memref::LoadOp
>(loc
, buffers
.dataBuffer
);
657 rewriter
.modifyOpInPlace(xferOp
, [&]() {
658 xferOp
.getVectorMutable().assign(loadedVec
);
659 xferOp
->setAttr(kPassLabel
, rewriter
.getUnitAttr());
662 if (xferOp
.getMask()) {
663 rewriter
.modifyOpInPlace(xferOp
, [&]() {
664 xferOp
.getMaskMutable().assign(buffers
.maskBuffer
);
672 /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
673 /// printing both 1D scalable vectors and n-D fixed size vectors.
677 /// vector.print %v : vector<[4]xi32>
681 /// %c0 = arith.constant 0 : index
682 /// %c4 = arith.constant 4 : index
683 /// %c1 = arith.constant 1 : index
684 /// %vscale = vector.vscale
685 /// %length = arith.muli %vscale, %c4 : index
686 /// %lastIndex = arith.subi %length, %c1 : index
687 /// vector.print punctuation <open>
688 /// scf.for %i = %c0 to %length step %c1 {
689 /// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
690 /// vector.print %el : i32 punctuation <no_punctuation>
691 /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
692 /// scf.if %notLastIndex {
693 /// vector.print punctuation <comma>
696 /// vector.print punctuation <close>
699 struct DecomposePrintOpConversion
: public VectorToSCFPattern
<vector::PrintOp
> {
700 using VectorToSCFPattern
<vector::PrintOp
>::VectorToSCFPattern
;
701 LogicalResult
matchAndRewrite(vector::PrintOp printOp
,
702 PatternRewriter
&rewriter
) const override
{
703 if (!printOp
.getSource())
706 VectorType vectorType
= dyn_cast
<VectorType
>(printOp
.getPrintType());
710 // Currently >= 2D scalable vectors are not supported.
711 // These can't be lowered to LLVM (as LLVM does not support scalable vectors
712 // of scalable vectors), and due to limitations of current ops can't be
713 // indexed with SSA values or flattened. This may change after
714 // https://reviews.llvm.org/D155034, though there still needs to be a path
715 // for lowering to LLVM.
716 if (vectorType
.getRank() > 1 && vectorType
.isScalable())
719 auto loc
= printOp
.getLoc();
720 auto value
= printOp
.getSource();
722 if (auto intTy
= dyn_cast
<IntegerType
>(vectorType
.getElementType())) {
723 // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
724 // avoid issues extend them to a more standard size.
725 // https://github.com/llvm/llvm-project/issues/30613
726 auto width
= intTy
.getWidth();
727 auto legalWidth
= llvm::NextPowerOf2(std::max(8u, width
) - 1);
728 auto legalIntTy
= IntegerType::get(rewriter
.getContext(), legalWidth
,
729 intTy
.getSignedness());
730 // arith can only take signless integers, so we must cast back and forth.
731 auto signlessSourceVectorType
=
732 vectorType
.cloneWith({}, getIntTypeWithSignlessSemantics(intTy
));
733 auto signlessTargetVectorType
=
734 vectorType
.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy
));
735 auto targetVectorType
= vectorType
.cloneWith({}, legalIntTy
);
736 value
= rewriter
.create
<vector::BitCastOp
>(loc
, signlessSourceVectorType
,
738 if (value
.getType() != signlessTargetVectorType
) {
739 if (width
== 1 || intTy
.isUnsigned())
740 value
= rewriter
.create
<arith::ExtUIOp
>(loc
, signlessTargetVectorType
,
743 value
= rewriter
.create
<arith::ExtSIOp
>(loc
, signlessTargetVectorType
,
746 value
= rewriter
.create
<vector::BitCastOp
>(loc
, targetVectorType
, value
);
747 vectorType
= targetVectorType
;
750 auto scalableDimensions
= vectorType
.getScalableDims();
751 auto shape
= vectorType
.getShape();
752 constexpr int64_t singletonShape
[] = {1};
753 if (vectorType
.getRank() == 0)
754 shape
= singletonShape
;
756 if (vectorType
.getRank() != 1) {
757 // Flatten n-D vectors to 1D. This is done to allow indexing with a
758 // non-constant value (which can currently only be done via
759 // vector.extractelement for 1D vectors).
760 auto flatLength
= std::accumulate(shape
.begin(), shape
.end(), 1,
761 std::multiplies
<int64_t>());
762 auto flatVectorType
=
763 VectorType::get({flatLength
}, vectorType
.getElementType());
764 value
= rewriter
.create
<vector::ShapeCastOp
>(loc
, flatVectorType
, value
);
767 vector::PrintOp firstClose
;
768 SmallVector
<Value
, 8> loopIndices
;
769 for (unsigned d
= 0; d
< shape
.size(); d
++) {
770 // Setup loop bounds and step.
771 Value lowerBound
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
772 Value upperBound
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, shape
[d
]);
773 Value step
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
774 if (!scalableDimensions
.empty() && scalableDimensions
[d
]) {
775 auto vscale
= rewriter
.create
<vector::VectorScaleOp
>(
776 loc
, rewriter
.getIndexType());
777 upperBound
= rewriter
.create
<arith::MulIOp
>(loc
, upperBound
, vscale
);
779 auto lastIndex
= rewriter
.create
<arith::SubIOp
>(loc
, upperBound
, step
);
781 // Create a loop to print the elements surrounded by parentheses.
782 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::Open
);
784 rewriter
.create
<scf::ForOp
>(loc
, lowerBound
, upperBound
, step
);
785 auto printClose
= rewriter
.create
<vector::PrintOp
>(
786 loc
, vector::PrintPunctuation::Close
);
788 firstClose
= printClose
;
790 auto loopIdx
= loop
.getInductionVar();
791 loopIndices
.push_back(loopIdx
);
793 // Print a comma after all but the last element.
794 rewriter
.setInsertionPointToStart(loop
.getBody());
795 auto notLastIndex
= rewriter
.create
<arith::CmpIOp
>(
796 loc
, arith::CmpIPredicate::ult
, loopIdx
, lastIndex
);
797 rewriter
.create
<scf::IfOp
>(loc
, notLastIndex
,
798 [&](OpBuilder
&builder
, Location loc
) {
799 builder
.create
<vector::PrintOp
>(
800 loc
, vector::PrintPunctuation::Comma
);
801 builder
.create
<scf::YieldOp
>(loc
);
804 rewriter
.setInsertionPointToStart(loop
.getBody());
807 // Compute the flattened index.
808 // Note: For the > rank 1 vectors this assumes non-scalable.
810 auto currentStride
= 1;
811 for (int d
= shape
.size() - 1; d
>= 0; d
--) {
812 auto stride
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, currentStride
);
813 auto index
= rewriter
.create
<arith::MulIOp
>(loc
, stride
, loopIndices
[d
]);
815 flatIndex
= rewriter
.create
<arith::AddIOp
>(loc
, flatIndex
, index
);
818 currentStride
*= shape
[d
];
821 // Print the scalar elements in the inner most loop.
823 rewriter
.create
<vector::ExtractElementOp
>(loc
, value
, flatIndex
);
824 rewriter
.create
<vector::PrintOp
>(loc
, element
,
825 vector::PrintPunctuation::NoPunctuation
);
827 rewriter
.setInsertionPointAfter(firstClose
);
828 rewriter
.create
<vector::PrintOp
>(loc
, printOp
.getPunctuation());
829 rewriter
.eraseOp(printOp
);
833 static IntegerType
getIntTypeWithSignlessSemantics(IntegerType intTy
) {
834 return IntegerType::get(intTy
.getContext(), intTy
.getWidth(),
835 IntegerType::Signless
);
839 /// Progressive lowering of vector transfer ops: Unpack one dimension.
841 /// 1. Unpack one dimension from the current buffer type and cast the buffer
842 /// to that new type. E.g.:
844 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
845 /// vector.transfer_write %vec ...
847 /// The following cast is generated:
849 /// %casted = vector.type_cast %0
850 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
852 /// 2. Generate a for loop and rewrite the transfer op according to the
853 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
854 /// out-of-bounds, generate an if-check and handle both cases separately.
855 /// 3. Clean up according to the corresponding Strategy<OpTy>.
857 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
858 /// source (as opposed to a memref source), then each iteration of the generated
859 /// scf.for loop yields the new tensor value. E.g.:
861 /// %result = scf.for i = 0 to 5 {
862 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
863 /// %1 = vector.transfer_write %0, %source[...]
864 /// : vector<4x3xf32>, tensor<5x4x3xf32>
865 /// scf.yield %1 : tensor<5x4x3xf32>
868 template <typename OpTy
>
869 struct TransferOpConversion
: public VectorToSCFPattern
<OpTy
> {
870 using VectorToSCFPattern
<OpTy
>::VectorToSCFPattern
;
873 // This pattern recursively unpacks one dimension at a time. The recursion
874 // bounded as the rank is strictly decreasing.
875 this->setHasBoundedRewriteRecursion();
878 static void getMaskBufferLoadIndices(OpTy xferOp
, Value castedMaskBuffer
,
879 SmallVectorImpl
<Value
> &loadIndices
,
881 assert(xferOp
.getMask() && "Expected transfer op to have mask");
883 // Add load indices from the previous iteration.
884 // The mask buffer depends on the permutation map, which makes determining
885 // the indices quite complex, so this is why we need to "look back" to the
886 // previous iteration to find the right indices.
887 Value maskBuffer
= getMaskBuffer(xferOp
);
888 for (Operation
*user
: maskBuffer
.getUsers()) {
889 // If there is no previous load op, then the indices are empty.
890 if (auto loadOp
= dyn_cast
<memref::LoadOp
>(user
)) {
891 Operation::operand_range prevIndices
= loadOp
.getIndices();
892 loadIndices
.append(prevIndices
.begin(), prevIndices
.end());
897 // In case of broadcast: Use same indices to load from memref
899 if (!xferOp
.isBroadcastDim(0))
900 loadIndices
.push_back(iv
);
903 LogicalResult
matchAndRewrite(OpTy xferOp
,
904 PatternRewriter
&rewriter
) const override
{
905 if (!xferOp
->hasAttr(kPassLabel
))
908 // Find and cast data buffer. How the buffer can be found depends on OpTy.
909 ImplicitLocOpBuilder
locB(xferOp
.getLoc(), rewriter
);
910 Value dataBuffer
= Strategy
<OpTy
>::getBuffer(xferOp
);
911 auto dataBufferType
= dyn_cast
<MemRefType
>(dataBuffer
.getType());
912 FailureOr
<MemRefType
> castedDataType
= unpackOneDim(dataBufferType
);
913 if (failed(castedDataType
))
916 auto castedDataBuffer
=
917 locB
.create
<vector::TypeCastOp
>(*castedDataType
, dataBuffer
);
919 // If the xferOp has a mask: Find and cast mask buffer.
920 Value castedMaskBuffer
;
921 if (xferOp
.getMask()) {
922 Value maskBuffer
= getMaskBuffer(xferOp
);
923 if (xferOp
.isBroadcastDim(0) || xferOp
.getMaskType().getRank() == 1) {
924 // Do not unpack a dimension of the mask, if:
925 // * To-be-unpacked transfer op dimension is a broadcast.
926 // * Mask is 1D, i.e., the mask cannot be further unpacked.
927 // (That means that all remaining dimensions of the transfer op must
929 castedMaskBuffer
= maskBuffer
;
931 // It's safe to assume the mask buffer can be unpacked if the data
932 // buffer was unpacked.
933 auto maskBufferType
= cast
<MemRefType
>(maskBuffer
.getType());
934 MemRefType castedMaskType
= *unpackOneDim(maskBufferType
);
936 locB
.create
<vector::TypeCastOp
>(castedMaskType
, maskBuffer
);
940 // Loop bounds and step.
941 auto lb
= locB
.create
<arith::ConstantIndexOp
>(0);
942 auto ub
= locB
.create
<arith::ConstantIndexOp
>(
943 castedDataType
->getDimSize(castedDataType
->getRank() - 1));
944 auto step
= locB
.create
<arith::ConstantIndexOp
>(1);
945 // TransferWriteOps that operate on tensors return the modified tensor and
946 // require a loop state.
947 auto loopState
= Strategy
<OpTy
>::initialLoopState(xferOp
);
949 // Generate for loop.
950 auto result
= locB
.create
<scf::ForOp
>(
951 lb
, ub
, step
, loopState
? ValueRange(loopState
) : ValueRange(),
952 [&](OpBuilder
&b
, Location loc
, Value iv
, ValueRange loopState
) {
953 Type stateType
= loopState
.empty() ? Type() : loopState
[0].getType();
955 auto result
= generateInBoundsCheck(
956 b
, xferOp
, iv
, unpackedDim(xferOp
),
957 stateType
? TypeRange(stateType
) : TypeRange(),
959 [&](OpBuilder
&b
, Location loc
) {
960 // Create new transfer op.
961 OpTy newXfer
= Strategy
<OpTy
>::rewriteOp(
962 b
, this->options
, xferOp
, castedDataBuffer
, iv
, loopState
);
964 // If old transfer op has a mask: Set mask on new transfer op.
965 // Special case: If the mask of the old transfer op is 1D and
966 // the unpacked dim is not a broadcast, no mask is needed on
967 // the new transfer op.
968 if (xferOp
.getMask() && (xferOp
.isBroadcastDim(0) ||
969 xferOp
.getMaskType().getRank() > 1)) {
970 OpBuilder::InsertionGuard
guard(b
);
971 b
.setInsertionPoint(newXfer
); // Insert load before newXfer.
973 SmallVector
<Value
, 8> loadIndices
;
974 getMaskBufferLoadIndices(xferOp
, castedMaskBuffer
,
976 auto mask
= b
.create
<memref::LoadOp
>(loc
, castedMaskBuffer
,
978 rewriter
.modifyOpInPlace(newXfer
, [&]() {
979 newXfer
.getMaskMutable().assign(mask
);
983 return loopState
.empty() ? Value() : newXfer
->getResult(0);
986 [&](OpBuilder
&b
, Location
/*loc*/) {
987 return Strategy
<OpTy
>::handleOutOfBoundsDim(
988 b
, xferOp
, castedDataBuffer
, iv
, loopState
);
991 maybeYieldValue(b
, loc
, !loopState
.empty(), result
);
994 Strategy
<OpTy
>::cleanup(rewriter
, xferOp
, result
);
999 /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1000 /// and ConstantMaskOp.
1001 template <typename VscaleConstantBuilder
>
1002 static FailureOr
<SmallVector
<OpFoldResult
>>
1003 getMaskDimSizes(Value mask
, VscaleConstantBuilder
&createVscaleMultiple
) {
1005 return SmallVector
<OpFoldResult
>{};
1006 if (auto createMaskOp
= mask
.getDefiningOp
<vector::CreateMaskOp
>()) {
1007 return llvm::map_to_vector(createMaskOp
.getOperands(), [](Value dimSize
) {
1008 return OpFoldResult(dimSize
);
1011 if (auto constantMask
= mask
.getDefiningOp
<vector::ConstantMaskOp
>()) {
1013 VectorType maskType
= constantMask
.getVectorType();
1014 auto indexType
= IndexType::get(mask
.getContext());
1015 return llvm::map_to_vector(
1016 constantMask
.getMaskDimSizes(), [&](int64_t dimSize
) {
1017 // A scalable dim in a constant_mask means vscale x dimSize.
1018 if (maskType
.getScalableDims()[dimIdx
++])
1019 return OpFoldResult(createVscaleMultiple(dimSize
));
1020 return OpFoldResult(IntegerAttr::get(indexType
, dimSize
));
1026 /// Scalable vector lowering of transfer_write(transpose). This lowering only
1027 /// supports rank 2 (scalable) vectors, but can be used in conjunction with
1028 /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1029 /// unrolls until the first scalable dimension.
1035 /// %transpose = vector.transpose %vec, [1, 0]
1036 /// : vector<4x[4]xf32> to vector<[4]x4xf32>
1037 /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1038 /// : vector<[4]x4xf32>, memref<?x?xf32>
1043 /// %c1 = arith.constant 1 : index
1044 /// %c4 = arith.constant 4 : index
1045 /// %c0 = arith.constant 0 : index
1046 /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1047 /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1048 /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1049 /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1050 /// %vscale = vector.vscale
1051 /// %c4_vscale = arith.muli %vscale, %c4 : index
1052 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
1053 /// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1054 /// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1055 /// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1056 /// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1057 /// %slice_i = affine.apply #map(%idx)[%i]
1058 /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1059 /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1060 /// : vector<4xf32>, memref<?x?xf32>
1063 struct ScalableTransposeTransferWriteConversion
1064 : VectorToSCFPattern
<vector::TransferWriteOp
> {
1065 using VectorToSCFPattern::VectorToSCFPattern
;
1067 LogicalResult
matchAndRewrite(TransferWriteOp writeOp
,
1068 PatternRewriter
&rewriter
) const override
{
1069 if (failed(checkLowerTensors(writeOp
, rewriter
)))
1072 VectorType vectorType
= writeOp
.getVectorType();
1074 // Note: By comparing the scalable dims to an ArrayRef of length two this
1075 // implicitly checks the rank (is also two).
1076 ArrayRef
<bool> scalableFlags
= vectorType
.getScalableDims();
1077 if (scalableFlags
!= ArrayRef
<bool>{true, false}) {
1078 return rewriter
.notifyMatchFailure(
1079 writeOp
, "expected vector of the form vector<[N]xMxty>");
1082 auto permutationMap
= writeOp
.getPermutationMap();
1083 if (!permutationMap
.isIdentity()) {
1084 return rewriter
.notifyMatchFailure(
1085 writeOp
, "non-identity permutations are unsupported (lower first)");
1088 // Note: This pattern is only lowering the leading dimension (to a loop),
1089 // so we only check if the leading dimension is in bounds. The in-bounds
1090 // attribute for the trailing dimension will be propagated.
1091 if (!writeOp
.isDimInBounds(0)) {
1092 return rewriter
.notifyMatchFailure(
1093 writeOp
, "out-of-bounds dims are unsupported (use masking)");
1096 Value vector
= writeOp
.getVector();
1097 auto transposeOp
= vector
.getDefiningOp
<vector::TransposeOp
>();
1099 transposeOp
.getPermutation() != ArrayRef
<int64_t>{1, 0}) {
1100 return rewriter
.notifyMatchFailure(writeOp
, "source not transpose");
1103 auto loc
= writeOp
.getLoc();
1104 auto createVscaleMultiple
=
1105 vector::makeVscaleConstantBuilder(rewriter
, loc
);
1107 auto maskDims
= getMaskDimSizes(writeOp
.getMask(), createVscaleMultiple
);
1108 if (failed(maskDims
)) {
1109 return rewriter
.notifyMatchFailure(writeOp
,
1110 "failed to resolve mask dims");
1113 int64_t fixedDimSize
= vectorType
.getDimSize(1);
1114 auto fixedDimOffsets
= llvm::seq(fixedDimSize
);
1116 // Extract all slices from the source of the transpose.
1117 auto transposeSource
= transposeOp
.getVector();
1118 SmallVector
<Value
> transposeSourceSlices
=
1119 llvm::map_to_vector(fixedDimOffsets
, [&](int64_t idx
) -> Value
{
1120 return rewriter
.create
<vector::ExtractOp
>(loc
, transposeSource
, idx
);
1123 // Loop bounds and step.
1124 auto lb
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
1127 ? Value(createVscaleMultiple(vectorType
.getDimSize(0)))
1128 : vector::getAsValues(rewriter
, loc
, maskDims
->front()).front();
1129 auto step
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
1131 // Generate a new mask for the slice.
1132 VectorType sliceType
= VectorType::Builder(vectorType
).dropDim(0);
1133 Value sliceMask
= nullptr;
1134 if (!maskDims
->empty()) {
1135 sliceMask
= rewriter
.create
<vector::CreateMaskOp
>(
1136 loc
, sliceType
.clone(rewriter
.getI1Type()),
1137 ArrayRef
<OpFoldResult
>(*maskDims
).drop_front());
1140 Value initDest
= isTensorOp(writeOp
) ? writeOp
.getSource() : Value
{};
1141 ValueRange initLoopArgs
= initDest
? initDest
: ValueRange
{};
1142 auto result
= rewriter
.create
<scf::ForOp
>(
1143 loc
, lb
, ub
, step
, initLoopArgs
,
1144 [&](OpBuilder
&b
, Location loc
, Value iv
, ValueRange loopIterArgs
) {
1145 // Indices for the new transfer op.
1146 SmallVector
<Value
, 8> xferIndices
;
1147 getXferIndices(b
, writeOp
, iv
, xferIndices
);
1149 // Extract a transposed slice from the source vector.
1150 SmallVector
<Value
> transposeElements
=
1151 llvm::map_to_vector(fixedDimOffsets
, [&](int64_t idx
) -> Value
{
1152 return b
.create
<vector::ExtractOp
>(
1153 loc
, transposeSourceSlices
[idx
], iv
);
1155 auto sliceVec
= b
.create
<vector::FromElementsOp
>(loc
, sliceType
,
1158 // Create the transfer_write for the slice.
1160 loopIterArgs
.empty() ? writeOp
.getSource() : loopIterArgs
.front();
1161 auto newWriteOp
= b
.create
<vector::TransferWriteOp
>(
1162 loc
, sliceVec
, dest
, xferIndices
,
1163 ArrayRef
<bool>(writeOp
.getInBoundsValues()).drop_front());
1165 newWriteOp
.getMaskMutable().assign(sliceMask
);
1167 // Yield from the loop.
1168 b
.create
<scf::YieldOp
>(loc
, loopIterArgs
.empty()
1170 : newWriteOp
.getResult());
1173 if (isTensorOp(writeOp
))
1174 rewriter
.replaceOp(writeOp
, result
);
1176 rewriter
.eraseOp(writeOp
);
1182 } // namespace lowering_n_d
1184 namespace lowering_n_d_unrolled
{
1186 /// If the original transfer op has a mask, compute the mask of the new transfer
1187 /// op (for the current iteration `i`) and assign it.
1188 template <typename OpTy
>
1189 static void maybeAssignMask(OpBuilder
&b
, OpTy xferOp
, OpTy newXferOp
,
1191 if (!xferOp
.getMask())
1194 if (xferOp
.isBroadcastDim(0)) {
1195 // To-be-unpacked dimension is a broadcast, which does not have a
1196 // corresponding mask dimension. Mask attribute remains unchanged.
1197 newXferOp
.getMaskMutable().assign(xferOp
.getMask());
1201 if (xferOp
.getMaskType().getRank() > 1) {
1202 // Unpack one dimension of the mask.
1203 OpBuilder::InsertionGuard
guard(b
);
1204 b
.setInsertionPoint(newXferOp
); // Insert load before newXfer.
1206 llvm::SmallVector
<int64_t, 1> indices({i
});
1207 Location loc
= xferOp
.getLoc();
1208 auto newMask
= b
.create
<vector::ExtractOp
>(loc
, xferOp
.getMask(), indices
);
1209 newXferOp
.getMaskMutable().assign(newMask
);
1212 // If we end up here: The mask of the old transfer op is 1D and the unpacked
1213 // dim is not a broadcast, so no mask is needed on the new transfer op.
1214 // `generateInBoundsCheck` will have evaluated the mask already.
1217 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1218 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1219 /// memref buffer is allocated and the SCF loop is fully unrolled.
1224 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1225 /// : memref<?x?x?xf32>, vector<5x4xf32>
1227 /// is rewritten to IR such as (simplified):
1229 /// %v_init = splat %padding : vector<5x4xf32>
1230 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1231 /// : memref<?x?x?xf32>, vector<4xf32>
1232 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1233 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1234 /// : memref<?x?x?xf32>, vector<4xf32>
1235 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1237 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1238 /// : memref<?x?x?xf32>, vector<4xf32>
1239 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1242 /// Note: As an optimization, if the result of the original TransferReadOp
1243 /// was directly inserted into another vector, no new %v_init vector is created.
1244 /// Instead, the new TransferReadOp results are inserted into that vector.
1245 struct UnrollTransferReadConversion
1246 : public VectorToSCFPattern
<TransferReadOp
> {
1247 using VectorToSCFPattern
<TransferReadOp
>::VectorToSCFPattern
;
1250 // This pattern recursively unpacks one dimension at a time. The recursion
1251 // bounded as the rank is strictly decreasing.
1252 setHasBoundedRewriteRecursion();
1255 /// Get or build the vector into which the newly created TransferReadOp
1256 /// results are inserted.
1257 Value
buildResultVector(PatternRewriter
&rewriter
,
1258 TransferReadOp xferOp
) const {
1259 if (auto insertOp
= getInsertOp(xferOp
))
1260 return insertOp
.getDest();
1261 Location loc
= xferOp
.getLoc();
1262 return rewriter
.create
<vector::SplatOp
>(loc
, xferOp
.getVectorType(),
1263 xferOp
.getPadding());
1266 /// If the result of the TransferReadOp has exactly one user, which is a
1267 /// vector::InsertOp, return that operation.
1268 vector::InsertOp
getInsertOp(TransferReadOp xferOp
) const {
1269 if (xferOp
->hasOneUse()) {
1270 Operation
*xferOpUser
= *xferOp
->getUsers().begin();
1271 if (auto insertOp
= dyn_cast
<vector::InsertOp
>(xferOpUser
))
1275 return vector::InsertOp();
1278 /// If the result of the TransferReadOp has exactly one user, which is a
1279 /// vector::InsertOp, return that operation's indices.
1280 void getInsertionIndices(TransferReadOp xferOp
,
1281 SmallVectorImpl
<OpFoldResult
> &indices
) const {
1282 if (auto insertOp
= getInsertOp(xferOp
)) {
1283 auto pos
= insertOp
.getMixedPosition();
1284 indices
.append(pos
.begin(), pos
.end());
1288 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1289 /// accesses, and broadcasts and transposes in permutation maps.
1290 LogicalResult
matchAndRewrite(TransferReadOp xferOp
,
1291 PatternRewriter
&rewriter
) const override
{
1292 if (xferOp
.getVectorType().getRank() <= options
.targetRank
)
1293 return rewriter
.notifyMatchFailure(
1294 xferOp
, "vector rank is less or equal to target rank");
1295 if (failed(checkLowerTensors(xferOp
, rewriter
)))
1297 // Transfer ops that modify the element type are not supported atm.
1298 if (xferOp
.getVectorType().getElementType() !=
1299 xferOp
.getShapedType().getElementType())
1300 return rewriter
.notifyMatchFailure(
1301 xferOp
, "not yet supported: element type mismatch");
1302 auto xferVecType
= xferOp
.getVectorType();
1303 if (xferVecType
.getScalableDims()[0]) {
1304 // Cannot unroll a scalable dimension at compile time.
1305 return rewriter
.notifyMatchFailure(
1306 xferOp
, "scalable dimensions cannot be unrolled");
1309 auto insertOp
= getInsertOp(xferOp
);
1310 auto vec
= buildResultVector(rewriter
, xferOp
);
1311 auto vecType
= dyn_cast
<VectorType
>(vec
.getType());
1313 VectorType newXferVecType
= VectorType::Builder(xferVecType
).dropDim(0);
1315 int64_t dimSize
= xferVecType
.getShape()[0];
1317 // Generate fully unrolled loop of transfer ops.
1318 Location loc
= xferOp
.getLoc();
1319 for (int64_t i
= 0; i
< dimSize
; ++i
) {
1320 Value iv
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, i
);
1322 vec
= generateInBoundsCheck(
1323 rewriter
, xferOp
, iv
, unpackedDim(xferOp
), TypeRange(vecType
),
1325 [&](OpBuilder
&b
, Location loc
) {
1326 // Indices for the new transfer op.
1327 SmallVector
<Value
, 8> xferIndices
;
1328 getXferIndices(b
, xferOp
, iv
, xferIndices
);
1330 // Indices for the new vector.insert op.
1331 SmallVector
<OpFoldResult
, 8> insertionIndices
;
1332 getInsertionIndices(xferOp
, insertionIndices
);
1333 insertionIndices
.push_back(rewriter
.getIndexAttr(i
));
1335 auto inBoundsAttr
= dropFirstElem(b
, xferOp
.getInBoundsAttr());
1336 auto newXferOp
= b
.create
<vector::TransferReadOp
>(
1337 loc
, newXferVecType
, xferOp
.getSource(), xferIndices
,
1338 AffineMapAttr::get(unpackedPermutationMap(b
, xferOp
)),
1339 xferOp
.getPadding(), Value(), inBoundsAttr
);
1340 maybeAssignMask(b
, xferOp
, newXferOp
, i
);
1341 return b
.create
<vector::InsertOp
>(loc
, newXferOp
, vec
,
1344 /*outOfBoundsCase=*/
1345 [&](OpBuilder
&b
, Location loc
) {
1346 // Loop through original (unmodified) vector.
1352 // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1353 rewriter
.replaceOp(insertOp
, vec
);
1354 rewriter
.eraseOp(xferOp
);
1356 rewriter
.replaceOp(xferOp
, vec
);
1363 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1364 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1365 /// memref buffer is allocated and the SCF loop is fully unrolled.
1370 /// vector.transfer_write %vec, %A[%a, %b, %c]
1371 /// : vector<5x4xf32>, memref<?x?x?xf32>
1373 /// is rewritten to IR such as (simplified):
1375 /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1376 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1377 /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1378 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1380 /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1381 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1384 /// Note: As an optimization, if the vector of the original TransferWriteOp
1385 /// was directly extracted from another vector via an ExtractOp `a`, extract
1386 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1387 /// doing so, `a` may become dead, and the number of ExtractOps generated during
1388 /// recursive application of this pattern will be minimal.
1389 struct UnrollTransferWriteConversion
1390 : public VectorToSCFPattern
<TransferWriteOp
> {
1391 using VectorToSCFPattern
<TransferWriteOp
>::VectorToSCFPattern
;
1394 // This pattern recursively unpacks one dimension at a time. The recursion
1395 // bounded as the rank is strictly decreasing.
1396 setHasBoundedRewriteRecursion();
1399 /// Return the vector from which newly generated ExtracOps will extract.
1400 Value
getDataVector(TransferWriteOp xferOp
) const {
1401 if (auto extractOp
= getExtractOp(xferOp
))
1402 return extractOp
.getVector();
1403 return xferOp
.getVector();
1406 /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1407 vector::ExtractOp
getExtractOp(TransferWriteOp xferOp
) const {
1408 if (auto *op
= xferOp
.getVector().getDefiningOp())
1409 return dyn_cast
<vector::ExtractOp
>(op
);
1410 return vector::ExtractOp();
1413 /// If the input of the given TransferWriteOp is an ExtractOp, return its
1415 void getExtractionIndices(TransferWriteOp xferOp
,
1416 SmallVectorImpl
<OpFoldResult
> &indices
) const {
1417 if (auto extractOp
= getExtractOp(xferOp
)) {
1418 auto pos
= extractOp
.getMixedPosition();
1419 indices
.append(pos
.begin(), pos
.end());
1423 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1424 /// accesses, and broadcasts and transposes in permutation maps.
1425 LogicalResult
matchAndRewrite(TransferWriteOp xferOp
,
1426 PatternRewriter
&rewriter
) const override
{
1427 VectorType inputVectorTy
= xferOp
.getVectorType();
1429 if (inputVectorTy
.getRank() <= options
.targetRank
)
1432 if (failed(checkLowerTensors(xferOp
, rewriter
)))
1434 // Transfer ops that modify the element type are not supported atm.
1435 if (inputVectorTy
.getElementType() !=
1436 xferOp
.getShapedType().getElementType())
1439 auto vec
= getDataVector(xferOp
);
1440 if (inputVectorTy
.getScalableDims()[0]) {
1441 // Cannot unroll a scalable dimension at compile time.
1445 int64_t dimSize
= inputVectorTy
.getShape()[0];
1446 Value source
= xferOp
.getSource(); // memref or tensor to be written to.
1447 auto sourceType
= isTensorOp(xferOp
) ? xferOp
.getShapedType() : Type();
1449 // Generate fully unrolled loop of transfer ops.
1450 Location loc
= xferOp
.getLoc();
1451 for (int64_t i
= 0; i
< dimSize
; ++i
) {
1452 Value iv
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, i
);
1454 auto updatedSource
= generateInBoundsCheck(
1455 rewriter
, xferOp
, iv
, unpackedDim(xferOp
),
1456 isTensorOp(xferOp
) ? TypeRange(sourceType
) : TypeRange(),
1458 [&](OpBuilder
&b
, Location loc
) {
1459 // Indices for the new transfer op.
1460 SmallVector
<Value
, 8> xferIndices
;
1461 getXferIndices(b
, xferOp
, iv
, xferIndices
);
1463 // Indices for the new vector.extract op.
1464 SmallVector
<OpFoldResult
, 8> extractionIndices
;
1465 getExtractionIndices(xferOp
, extractionIndices
);
1466 extractionIndices
.push_back(b
.getI64IntegerAttr(i
));
1469 b
.create
<vector::ExtractOp
>(loc
, vec
, extractionIndices
);
1470 auto inBoundsAttr
= dropFirstElem(b
, xferOp
.getInBoundsAttr());
1472 if (inputVectorTy
.getRank() == 1) {
1473 // When target-rank=0, unrolling would causes the vector input
1474 // argument into `transfer_write` to become a scalar. We solve
1475 // this by broadcasting the scalar to a 0D vector.
1476 xferVec
= b
.create
<vector::BroadcastOp
>(
1477 loc
, VectorType::get({}, extracted
.getType()), extracted
);
1479 xferVec
= extracted
;
1481 auto newXferOp
= b
.create
<vector::TransferWriteOp
>(
1482 loc
, sourceType
, xferVec
, source
, xferIndices
,
1483 AffineMapAttr::get(unpackedPermutationMap(b
, xferOp
)), Value(),
1486 maybeAssignMask(b
, xferOp
, newXferOp
, i
);
1488 return isTensorOp(xferOp
) ? newXferOp
->getResult(0) : Value();
1490 /*outOfBoundsCase=*/
1491 [&](OpBuilder
&b
, Location loc
) {
1492 return isTensorOp(xferOp
) ? source
: Value();
1495 if (isTensorOp(xferOp
))
1496 source
= updatedSource
;
1499 if (isTensorOp(xferOp
))
1500 rewriter
.replaceOp(xferOp
, source
);
1502 rewriter
.eraseOp(xferOp
);
1508 } // namespace lowering_n_d_unrolled
1510 namespace lowering_1_d
{
1512 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1513 /// part of TransferOp1dConversion. Return the memref dimension on which
1514 /// the transfer is operating. A return value of std::nullopt indicates a
1516 template <typename OpTy
>
1517 static std::optional
<int64_t>
1518 get1dMemrefIndices(OpBuilder
&b
, OpTy xferOp
, Value iv
,
1519 SmallVector
<Value
, 8> &memrefIndices
) {
1520 auto indices
= xferOp
.getIndices();
1521 auto map
= xferOp
.getPermutationMap();
1522 assert(xferOp
.getTransferRank() > 0 && "unexpected 0-d transfer");
1524 memrefIndices
.append(indices
.begin(), indices
.end());
1525 assert(map
.getNumResults() == 1 &&
1526 "Expected 1 permutation map result for 1D transfer");
1527 if (auto expr
= dyn_cast
<AffineDimExpr
>(map
.getResult(0))) {
1528 Location loc
= xferOp
.getLoc();
1529 auto dim
= expr
.getPosition();
1531 bindDims(xferOp
.getContext(), d0
, d1
);
1532 Value offset
= memrefIndices
[dim
];
1533 memrefIndices
[dim
] =
1534 affine::makeComposedAffineApply(b
, loc
, d0
+ d1
, {offset
, iv
});
1538 assert(xferOp
.isBroadcastDim(0) &&
1539 "Expected AffineDimExpr or AffineConstantExpr");
1540 return std::nullopt
;
1543 /// Codegen strategy for TransferOp1dConversion, depending on the
1545 template <typename OpTy
>
1548 /// Codegen strategy for TransferReadOp.
1550 struct Strategy1d
<TransferReadOp
> {
1551 static void generateForLoopBody(OpBuilder
&b
, Location loc
,
1552 TransferReadOp xferOp
, Value iv
,
1553 ValueRange loopState
) {
1554 SmallVector
<Value
, 8> indices
;
1555 auto dim
= get1dMemrefIndices(b
, xferOp
, iv
, indices
);
1556 auto vec
= loopState
[0];
1558 // In case of out-of-bounds access, leave `vec` as is (was initialized with
1560 auto nextVec
= generateInBoundsCheck(
1561 b
, xferOp
, iv
, dim
, TypeRange(xferOp
.getVectorType()),
1563 [&](OpBuilder
&b
, Location loc
) {
1565 b
.create
<memref::LoadOp
>(loc
, xferOp
.getSource(), indices
);
1566 return b
.create
<vector::InsertElementOp
>(loc
, val
, vec
, iv
);
1568 /*outOfBoundsCase=*/
1569 [&](OpBuilder
& /*b*/, Location loc
) { return vec
; });
1570 b
.create
<scf::YieldOp
>(loc
, nextVec
);
1573 static Value
initialLoopState(OpBuilder
&b
, TransferReadOp xferOp
) {
1574 // Inititalize vector with padding value.
1575 Location loc
= xferOp
.getLoc();
1576 return b
.create
<vector::SplatOp
>(loc
, xferOp
.getVectorType(),
1577 xferOp
.getPadding());
1581 /// Codegen strategy for TransferWriteOp.
1583 struct Strategy1d
<TransferWriteOp
> {
1584 static void generateForLoopBody(OpBuilder
&b
, Location loc
,
1585 TransferWriteOp xferOp
, Value iv
,
1586 ValueRange
/*loopState*/) {
1587 SmallVector
<Value
, 8> indices
;
1588 auto dim
= get1dMemrefIndices(b
, xferOp
, iv
, indices
);
1590 // Nothing to do in case of out-of-bounds access.
1591 generateInBoundsCheck(
1593 /*inBoundsCase=*/[&](OpBuilder
&b
, Location loc
) {
1595 b
.create
<vector::ExtractElementOp
>(loc
, xferOp
.getVector(), iv
);
1596 b
.create
<memref::StoreOp
>(loc
, val
, xferOp
.getSource(), indices
);
1598 b
.create
<scf::YieldOp
>(loc
);
1601 static Value
initialLoopState(OpBuilder
&b
, TransferWriteOp xferOp
) {
1606 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1607 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1608 /// vector load/stores due to non-unit strides or broadcasts:
1610 /// * Transfer dimension is not the last memref dimension
1611 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1612 /// * Memref has a layout map with non-unit stride on the last dimension
1614 /// This pattern generates IR as follows:
1616 /// 1. Generate a for loop iterating over each vector element.
1617 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1618 /// depending on OpTy.
1620 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1621 /// can be generated instead of TransferOp1dConversion. Add such a pattern
1622 /// to ConvertVectorToLLVM.
1626 /// vector.transfer_write %vec, %A[%a, %b]
1627 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1628 /// : vector<9xf32>, memref<?x?xf32>
1630 /// Is rewritten to approximately the following pseudo-IR:
1632 /// for i = 0 to 9 {
1633 /// %t = vector.extractelement %vec[i] : vector<9xf32>
1634 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1637 template <typename OpTy
>
1638 struct TransferOp1dConversion
: public VectorToSCFPattern
<OpTy
> {
1639 using VectorToSCFPattern
<OpTy
>::VectorToSCFPattern
;
1641 LogicalResult
matchAndRewrite(OpTy xferOp
,
1642 PatternRewriter
&rewriter
) const override
{
1643 // TODO: support 0-d corner case.
1644 if (xferOp
.getTransferRank() == 0)
1646 auto map
= xferOp
.getPermutationMap();
1647 auto memRefType
= dyn_cast
<MemRefType
>(xferOp
.getShapedType());
1651 if (xferOp
.getVectorType().getRank() != 1)
1653 if (map
.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType
))
1654 return failure(); // Handled by ConvertVectorToLLVM
1656 // Loop bounds, step, state...
1657 Location loc
= xferOp
.getLoc();
1658 auto vecType
= xferOp
.getVectorType();
1659 auto lb
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
1661 rewriter
.create
<arith::ConstantIndexOp
>(loc
, vecType
.getDimSize(0));
1662 if (vecType
.isScalable()) {
1664 rewriter
.create
<vector::VectorScaleOp
>(loc
, rewriter
.getIndexType());
1665 ub
= rewriter
.create
<arith::MulIOp
>(loc
, ub
, vscale
);
1667 auto step
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
1668 auto loopState
= Strategy1d
<OpTy
>::initialLoopState(rewriter
, xferOp
);
1670 // Generate for loop.
1671 rewriter
.replaceOpWithNewOp
<scf::ForOp
>(
1672 xferOp
, lb
, ub
, step
, loopState
? ValueRange(loopState
) : ValueRange(),
1673 [&](OpBuilder
&b
, Location loc
, Value iv
, ValueRange loopState
) {
1674 Strategy1d
<OpTy
>::generateForLoopBody(b
, loc
, xferOp
, iv
, loopState
);
1681 } // namespace lowering_1_d
1684 void mlir::populateVectorToSCFConversionPatterns(
1685 RewritePatternSet
&patterns
, const VectorTransferToSCFOptions
&options
) {
1686 if (options
.unroll
) {
1687 patterns
.add
<lowering_n_d_unrolled::UnrollTransferReadConversion
,
1688 lowering_n_d_unrolled::UnrollTransferWriteConversion
>(
1689 patterns
.getContext(), options
);
1691 patterns
.add
<lowering_n_d::PrepareTransferReadConversion
,
1692 lowering_n_d::PrepareTransferWriteConversion
,
1693 lowering_n_d::TransferOpConversion
<TransferReadOp
>,
1694 lowering_n_d::TransferOpConversion
<TransferWriteOp
>>(
1695 patterns
.getContext(), options
);
1697 if (options
.lowerScalable
) {
1698 patterns
.add
<lowering_n_d::ScalableTransposeTransferWriteConversion
>(
1699 patterns
.getContext(), options
);
1701 if (options
.targetRank
== 1) {
1702 patterns
.add
<lowering_1_d::TransferOp1dConversion
<TransferReadOp
>,
1703 lowering_1_d::TransferOp1dConversion
<TransferWriteOp
>>(
1704 patterns
.getContext(), options
);
1706 patterns
.add
<lowering_n_d::DecomposePrintOpConversion
>(patterns
.getContext(),
1712 struct ConvertVectorToSCFPass
1713 : public impl::ConvertVectorToSCFBase
<ConvertVectorToSCFPass
> {
1714 ConvertVectorToSCFPass() = default;
1715 ConvertVectorToSCFPass(const VectorTransferToSCFOptions
&options
) {
1716 this->fullUnroll
= options
.unroll
;
1717 this->targetRank
= options
.targetRank
;
1718 this->lowerTensors
= options
.lowerTensors
;
1719 this->lowerScalable
= options
.lowerScalable
;
1722 void runOnOperation() override
{
1723 VectorTransferToSCFOptions options
;
1724 options
.unroll
= fullUnroll
;
1725 options
.targetRank
= targetRank
;
1726 options
.lowerTensors
= lowerTensors
;
1727 options
.lowerScalable
= lowerScalable
;
1729 // Lower permutation maps first.
1730 RewritePatternSet
lowerTransferPatterns(&getContext());
1731 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1732 lowerTransferPatterns
);
1733 (void)applyPatternsAndFoldGreedily(getOperation(),
1734 std::move(lowerTransferPatterns
));
1736 RewritePatternSet
patterns(&getContext());
1737 populateVectorToSCFConversionPatterns(patterns
, options
);
1738 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
1744 std::unique_ptr
<Pass
>
1745 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions
&options
) {
1746 return std::make_unique
<ConvertVectorToSCFPass
>(options
);