1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // A pass that converts loops generated by the sparsifier into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparsifier can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
17 //===----------------------------------------------------------------------===//
19 #include "Utils/CodegenUtils.h"
20 #include "Utils/LoopEmitter.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Dialect/Complex/IR/Complex.h"
25 #include "mlir/Dialect/Math/IR/Math.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
29 #include "mlir/Dialect/Vector/IR/VectorOps.h"
30 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
31 #include "mlir/IR/Matchers.h"
34 using namespace mlir::sparse_tensor
;
38 /// Target SIMD properties:
39 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
40 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
41 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
43 unsigned vectorLength
;
44 bool enableVLAVectorization
;
45 bool enableSIMDIndex32
;
48 /// Helper test for invariant value (defined outside given block).
49 static bool isInvariantValue(Value val
, Block
*block
) {
50 return val
.getDefiningOp() && val
.getDefiningOp()->getBlock() != block
;
53 /// Helper test for invariant argument (defined outside given block).
54 static bool isInvariantArg(BlockArgument arg
, Block
*block
) {
55 return arg
.getOwner() != block
;
58 /// Constructs vector type for element type.
59 static VectorType
vectorType(VL vl
, Type etp
) {
60 return VectorType::get(vl
.vectorLength
, etp
, vl
.enableVLAVectorization
);
63 /// Constructs vector type from a memref value.
64 static VectorType
vectorType(VL vl
, Value mem
) {
65 return vectorType(vl
, getMemRefType(mem
).getElementType());
68 /// Constructs vector iteration mask.
69 static Value
genVectorMask(PatternRewriter
&rewriter
, Location loc
, VL vl
,
70 Value iv
, Value lo
, Value hi
, Value step
) {
71 VectorType mtp
= vectorType(vl
, rewriter
.getI1Type());
72 // Special case if the vector length evenly divides the trip count (for
73 // example, "for i = 0, 128, 16"). A constant all-true mask is generated
74 // so that all subsequent masked memory operations are immediately folded
75 // into unconditional memory operations.
76 IntegerAttr loInt
, hiInt
, stepInt
;
77 if (matchPattern(lo
, m_Constant(&loInt
)) &&
78 matchPattern(hi
, m_Constant(&hiInt
)) &&
79 matchPattern(step
, m_Constant(&stepInt
))) {
80 if (((hiInt
.getInt() - loInt
.getInt()) % stepInt
.getInt()) == 0) {
81 Value trueVal
= constantI1(rewriter
, loc
, true);
82 return rewriter
.create
<vector::BroadcastOp
>(loc
, mtp
, trueVal
);
85 // Otherwise, generate a vector mask that avoids overrunning the upperbound
86 // during vector execution. Here we rely on subsequent loop optimizations to
87 // avoid executing the mask in all iterations, for example, by splitting the
88 // loop into an unconditional vector loop and a scalar cleanup loop.
89 auto min
= AffineMap::get(
90 /*dimCount=*/2, /*symbolCount=*/1,
91 {rewriter
.getAffineSymbolExpr(0),
92 rewriter
.getAffineDimExpr(0) - rewriter
.getAffineDimExpr(1)},
93 rewriter
.getContext());
94 Value end
= rewriter
.createOrFold
<affine::AffineMinOp
>(
95 loc
, min
, ValueRange
{hi
, iv
, step
});
96 return rewriter
.create
<vector::CreateMaskOp
>(loc
, mtp
, end
);
99 /// Generates a vectorized invariant. Here we rely on subsequent loop
100 /// optimizations to hoist the invariant broadcast out of the vector loop.
101 static Value
genVectorInvariantValue(PatternRewriter
&rewriter
, VL vl
,
103 VectorType vtp
= vectorType(vl
, val
.getType());
104 return rewriter
.create
<vector::BroadcastOp
>(val
.getLoc(), vtp
, val
);
107 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
108 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
109 /// that the sparsifier can only generate indirect loads in
110 /// the last index, i.e. back().
111 static Value
genVectorLoad(PatternRewriter
&rewriter
, Location loc
, VL vl
,
112 Value mem
, ArrayRef
<Value
> idxs
, Value vmask
) {
113 VectorType vtp
= vectorType(vl
, mem
);
114 Value pass
= constantZero(rewriter
, loc
, vtp
);
115 if (llvm::isa
<VectorType
>(idxs
.back().getType())) {
116 SmallVector
<Value
> scalarArgs(idxs
);
117 Value indexVec
= idxs
.back();
118 scalarArgs
.back() = constantIndex(rewriter
, loc
, 0);
119 return rewriter
.create
<vector::GatherOp
>(loc
, vtp
, mem
, scalarArgs
,
120 indexVec
, vmask
, pass
);
122 return rewriter
.create
<vector::MaskedLoadOp
>(loc
, vtp
, mem
, idxs
, vmask
,
126 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
127 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
128 /// that the sparsifier can only generate indirect stores in
129 /// the last index, i.e. back().
130 static void genVectorStore(PatternRewriter
&rewriter
, Location loc
, Value mem
,
131 ArrayRef
<Value
> idxs
, Value vmask
, Value rhs
) {
132 if (llvm::isa
<VectorType
>(idxs
.back().getType())) {
133 SmallVector
<Value
> scalarArgs(idxs
);
134 Value indexVec
= idxs
.back();
135 scalarArgs
.back() = constantIndex(rewriter
, loc
, 0);
136 rewriter
.create
<vector::ScatterOp
>(loc
, mem
, scalarArgs
, indexVec
, vmask
,
140 rewriter
.create
<vector::MaskedStoreOp
>(loc
, mem
, idxs
, vmask
, rhs
);
143 /// Detects a vectorizable reduction operations and returns the
144 /// combining kind of reduction on success in `kind`.
145 static bool isVectorizableReduction(Value red
, Value iter
,
146 vector::CombiningKind
&kind
) {
147 if (auto addf
= red
.getDefiningOp
<arith::AddFOp
>()) {
148 kind
= vector::CombiningKind::ADD
;
149 return addf
->getOperand(0) == iter
|| addf
->getOperand(1) == iter
;
151 if (auto addi
= red
.getDefiningOp
<arith::AddIOp
>()) {
152 kind
= vector::CombiningKind::ADD
;
153 return addi
->getOperand(0) == iter
|| addi
->getOperand(1) == iter
;
155 if (auto subf
= red
.getDefiningOp
<arith::SubFOp
>()) {
156 kind
= vector::CombiningKind::ADD
;
157 return subf
->getOperand(0) == iter
;
159 if (auto subi
= red
.getDefiningOp
<arith::SubIOp
>()) {
160 kind
= vector::CombiningKind::ADD
;
161 return subi
->getOperand(0) == iter
;
163 if (auto mulf
= red
.getDefiningOp
<arith::MulFOp
>()) {
164 kind
= vector::CombiningKind::MUL
;
165 return mulf
->getOperand(0) == iter
|| mulf
->getOperand(1) == iter
;
167 if (auto muli
= red
.getDefiningOp
<arith::MulIOp
>()) {
168 kind
= vector::CombiningKind::MUL
;
169 return muli
->getOperand(0) == iter
|| muli
->getOperand(1) == iter
;
171 if (auto andi
= red
.getDefiningOp
<arith::AndIOp
>()) {
172 kind
= vector::CombiningKind::AND
;
173 return andi
->getOperand(0) == iter
|| andi
->getOperand(1) == iter
;
175 if (auto ori
= red
.getDefiningOp
<arith::OrIOp
>()) {
176 kind
= vector::CombiningKind::OR
;
177 return ori
->getOperand(0) == iter
|| ori
->getOperand(1) == iter
;
179 if (auto xori
= red
.getDefiningOp
<arith::XOrIOp
>()) {
180 kind
= vector::CombiningKind::XOR
;
181 return xori
->getOperand(0) == iter
|| xori
->getOperand(1) == iter
;
186 /// Generates an initial value for a vector reduction, following the scheme
187 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
188 /// initial scalar value is correctly embedded in the vector reduction value,
189 /// and a straightforward horizontal reduction will complete the operation.
190 /// Value 'r' denotes the initial value of the reduction outside the loop.
191 static Value
genVectorReducInit(PatternRewriter
&rewriter
, Location loc
,
192 Value red
, Value iter
, Value r
,
194 vector::CombiningKind kind
;
195 if (!isVectorizableReduction(red
, iter
, kind
))
196 llvm_unreachable("unknown reduction");
198 case vector::CombiningKind::ADD
:
199 case vector::CombiningKind::XOR
:
200 // Initialize reduction vector to: | 0 | .. | 0 | r |
201 return rewriter
.create
<vector::InsertElementOp
>(
202 loc
, r
, constantZero(rewriter
, loc
, vtp
),
203 constantIndex(rewriter
, loc
, 0));
204 case vector::CombiningKind::MUL
:
205 // Initialize reduction vector to: | 1 | .. | 1 | r |
206 return rewriter
.create
<vector::InsertElementOp
>(
207 loc
, r
, constantOne(rewriter
, loc
, vtp
),
208 constantIndex(rewriter
, loc
, 0));
209 case vector::CombiningKind::AND
:
210 case vector::CombiningKind::OR
:
211 // Initialize reduction vector to: | r | .. | r | r |
212 return rewriter
.create
<vector::BroadcastOp
>(loc
, vtp
, r
);
216 llvm_unreachable("unknown reduction kind");
219 /// This method is called twice to analyze and rewrite the given subscripts.
220 /// The first call (!codegen) does the analysis. Then, on success, the second
221 /// call (codegen) yields the proper vector form in the output parameter
222 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
223 /// stay in sync. Note that the analyis part is simple because the sparsifier
224 /// only generates relatively simple subscript expressions.
226 /// See https://llvm.org/docs/GetElementPtr.html for some background on
227 /// the complications described below.
229 /// We need to generate a position/coordinate load from the sparse storage
230 /// scheme. Narrower data types need to be zero extended before casting
231 /// the value into the `index` type used for looping and indexing.
233 /// For the scalar case, subscripts simply zero extend narrower indices
234 /// into 64-bit values before casting to an index type without a performance
235 /// penalty. Indices that already are 64-bit, in theory, cannot express the
236 /// full range since the LLVM backend defines addressing in terms of an
237 /// unsigned pointer/signed index pair.
238 static bool vectorizeSubscripts(PatternRewriter
&rewriter
, scf::ForOp forOp
,
239 VL vl
, ValueRange subs
, bool codegen
,
240 Value vmask
, SmallVectorImpl
<Value
> &idxs
) {
242 unsigned dim
= subs
.size();
243 Block
*block
= &forOp
.getRegion().front();
244 for (auto sub
: subs
) {
245 bool innermost
= ++d
== dim
;
246 // Invariant subscripts in outer dimensions simply pass through.
247 // Note that we rely on LICM to hoist loads where all subscripts
248 // are invariant in the innermost loop.
251 if (isInvariantValue(sub
, block
)) {
256 continue; // success so far
258 // Invariant block arguments (including outer loop indices) in outer
259 // dimensions simply pass through. Direct loop indices in the
260 // innermost loop simply pass through as well.
262 // a[i][j] for both i and j
263 if (auto arg
= llvm::dyn_cast
<BlockArgument
>(sub
)) {
264 if (isInvariantArg(arg
, block
) == innermost
)
268 continue; // success so far
270 // Look under the hood of casting.
273 if (auto icast
= cast
.getDefiningOp
<arith::IndexCastOp
>())
274 cast
= icast
->getOperand(0);
275 else if (auto ecast
= cast
.getDefiningOp
<arith::ExtUIOp
>())
276 cast
= ecast
->getOperand(0);
280 // Since the index vector is used in a subsequent gather/scatter
281 // operations, which effectively defines an unsigned pointer + signed
282 // index, we must zero extend the vector to an index width. For 8-bit
283 // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
284 // zero extending the elements into 64-bit loses some performance since
285 // the 32-bit indexed gather/scatter is more efficient than the 64-bit
286 // index variant (if the negative 32-bit index space is unused, the
287 // enableSIMDIndex32 flag can preserve this performance). For 64-bit
288 // values, there is no good way to state that the indices are unsigned,
289 // which creates the potential of incorrect address calculations in the
290 // unlikely case we need such extremely large offsets.
293 if (auto load
= cast
.getDefiningOp
<memref::LoadOp
>()) {
297 SmallVector
<Value
> idxs2(load
.getIndices()); // no need to analyze
298 Location loc
= forOp
.getLoc();
300 genVectorLoad(rewriter
, loc
, vl
, load
.getMemRef(), idxs2
, vmask
);
301 Type etp
= llvm::cast
<VectorType
>(vload
.getType()).getElementType();
302 if (!llvm::isa
<IndexType
>(etp
)) {
303 if (etp
.getIntOrFloatBitWidth() < 32)
304 vload
= rewriter
.create
<arith::ExtUIOp
>(
305 loc
, vectorType(vl
, rewriter
.getI32Type()), vload
);
306 else if (etp
.getIntOrFloatBitWidth() < 64 && !vl
.enableSIMDIndex32
)
307 vload
= rewriter
.create
<arith::ExtUIOp
>(
308 loc
, vectorType(vl
, rewriter
.getI64Type()), vload
);
310 idxs
.push_back(vload
);
312 continue; // success so far
314 // Address calculation 'i = add inv, idx' (after LICM).
317 if (auto load
= cast
.getDefiningOp
<arith::AddIOp
>()) {
318 Value inv
= load
.getOperand(0);
319 Value idx
= load
.getOperand(1);
320 // Swap non-invariant.
321 if (!isInvariantValue(inv
, block
)) {
323 idx
= load
.getOperand(0);
326 if (isInvariantValue(inv
, block
)) {
327 if (auto arg
= llvm::dyn_cast
<BlockArgument
>(idx
)) {
328 if (isInvariantArg(arg
, block
) || !innermost
)
332 rewriter
.create
<arith::AddIOp
>(forOp
.getLoc(), inv
, idx
));
333 continue; // success so far
343 if (isa<xxx>(def)) { \
345 vexp = rewriter.create<xxx>(loc, vx); \
349 #define TYPEDUNAOP(xxx) \
350 if (auto x = dyn_cast<xxx>(def)) { \
352 VectorType vtp = vectorType(vl, x.getType()); \
353 vexp = rewriter.create<xxx>(loc, vtp, vx); \
359 if (isa<xxx>(def)) { \
361 vexp = rewriter.create<xxx>(loc, vx, vy); \
365 /// This method is called twice to analyze and rewrite the given expression.
366 /// The first call (!codegen) does the analysis. Then, on success, the second
367 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
368 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
369 /// that the analyis part is simple because the sparsifier only generates
370 /// relatively simple expressions inside the for-loops.
371 static bool vectorizeExpr(PatternRewriter
&rewriter
, scf::ForOp forOp
, VL vl
,
372 Value exp
, bool codegen
, Value vmask
, Value
&vexp
) {
373 Location loc
= forOp
.getLoc();
374 // Reject unsupported types.
375 if (!VectorType::isValidElementType(exp
.getType()))
377 // A block argument is invariant/reduction/index.
378 if (auto arg
= llvm::dyn_cast
<BlockArgument
>(exp
)) {
379 if (arg
== forOp
.getInductionVar()) {
380 // We encountered a single, innermost index inside the computation,
381 // such as a[i] = i, which must convert to [i, i+1, ...].
383 VectorType vtp
= vectorType(vl
, arg
.getType());
384 Value veci
= rewriter
.create
<vector::BroadcastOp
>(loc
, vtp
, arg
);
385 Value incr
= rewriter
.create
<vector::StepOp
>(loc
, vtp
);
386 vexp
= rewriter
.create
<arith::AddIOp
>(loc
, veci
, incr
);
390 // An invariant or reduction. In both cases, we treat this as an
391 // invariant value, and rely on later replacing and folding to
392 // construct a proper reduction chain for the latter case.
394 vexp
= genVectorInvariantValue(rewriter
, vl
, exp
);
397 // Something defined outside the loop-body is invariant.
398 Operation
*def
= exp
.getDefiningOp();
399 Block
*block
= &forOp
.getRegion().front();
400 if (def
->getBlock() != block
) {
402 vexp
= genVectorInvariantValue(rewriter
, vl
, exp
);
405 // Proper load operations. These are either values involved in the
406 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
407 // or coordinate values inside the computation that are now fetched from
408 // the sparse storage coordinates arrays, such as a[i] = i becomes
409 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
410 // and 'hi = lo + vl - 1'.
411 if (auto load
= dyn_cast
<memref::LoadOp
>(def
)) {
412 auto subs
= load
.getIndices();
413 SmallVector
<Value
> idxs
;
414 if (vectorizeSubscripts(rewriter
, forOp
, vl
, subs
, codegen
, vmask
, idxs
)) {
416 vexp
= genVectorLoad(rewriter
, loc
, vl
, load
.getMemRef(), idxs
, vmask
);
421 // Inside loop-body unary and binary operations. Note that it would be
422 // nicer if we could somehow test and build the operations in a more
423 // concise manner than just listing them all (although this way we know
424 // for certain that they can vectorize).
426 // TODO: avoid visiting CSEs multiple times
428 if (def
->getNumOperands() == 1) {
430 if (vectorizeExpr(rewriter
, forOp
, vl
, def
->getOperand(0), codegen
, vmask
,
442 TYPEDUNAOP(arith::TruncFOp
)
443 TYPEDUNAOP(arith::ExtFOp
)
444 TYPEDUNAOP(arith::FPToSIOp
)
445 TYPEDUNAOP(arith::FPToUIOp
)
446 TYPEDUNAOP(arith::SIToFPOp
)
447 TYPEDUNAOP(arith::UIToFPOp
)
448 TYPEDUNAOP(arith::ExtSIOp
)
449 TYPEDUNAOP(arith::ExtUIOp
)
450 TYPEDUNAOP(arith::IndexCastOp
)
451 TYPEDUNAOP(arith::TruncIOp
)
452 TYPEDUNAOP(arith::BitcastOp
)
455 } else if (def
->getNumOperands() == 2) {
457 if (vectorizeExpr(rewriter
, forOp
, vl
, def
->getOperand(0), codegen
, vmask
,
459 vectorizeExpr(rewriter
, forOp
, vl
, def
->getOperand(1), codegen
, vmask
,
461 // We only accept shift-by-invariant (where the same shift factor applies
462 // to all packed elements). In the vector dialect, this is still
463 // represented with an expanded vector at the right-hand-side, however,
464 // so that we do not have to special case the code generation.
465 if (isa
<arith::ShLIOp
>(def
) || isa
<arith::ShRUIOp
>(def
) ||
466 isa
<arith::ShRSIOp
>(def
)) {
467 Value shiftFactor
= def
->getOperand(1);
468 if (!isInvariantValue(shiftFactor
, block
))
475 BINOP(arith::DivSIOp
)
476 BINOP(arith::DivUIOp
)
485 BINOP(arith::ShRUIOp
)
486 BINOP(arith::ShRSIOp
)
497 /// This method is called twice to analyze and rewrite the given for-loop.
498 /// The first call (!codegen) does the analysis. Then, on success, the second
499 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
500 /// that analysis and rewriting code stay in sync.
501 static bool vectorizeStmt(PatternRewriter
&rewriter
, scf::ForOp forOp
, VL vl
,
503 Block
&block
= forOp
.getRegion().front();
504 // For loops with single yield statement (as below) could be generated
505 // when custom reduce is used with unary operation.
508 if (block
.getOperations().size() <= 1)
511 Location loc
= forOp
.getLoc();
512 scf::YieldOp yield
= cast
<scf::YieldOp
>(block
.getTerminator());
513 auto &last
= *++block
.rbegin();
516 // Perform initial set up during codegen (we know that the first analysis
517 // pass was successful). For reductions, we need to construct a completely
518 // new for-loop, since the incoming and outgoing reduction type
519 // changes into SIMD form. For stores, we can simply adjust the stride
520 // and insert in the existing for-loop. In both cases, we set up a vector
521 // mask for all operations which takes care of confining vectors to
522 // the original iteration space (later cleanup loops or other
523 // optimizations can take care of those).
526 Value step
= constantIndex(rewriter
, loc
, vl
.vectorLength
);
527 if (vl
.enableVLAVectorization
) {
529 rewriter
.create
<vector::VectorScaleOp
>(loc
, rewriter
.getIndexType());
530 step
= rewriter
.create
<arith::MulIOp
>(loc
, vscale
, step
);
532 if (!yield
.getResults().empty()) {
533 Value init
= forOp
.getInitArgs()[0];
534 VectorType vtp
= vectorType(vl
, init
.getType());
535 Value vinit
= genVectorReducInit(rewriter
, loc
, yield
->getOperand(0),
536 forOp
.getRegionIterArg(0), init
, vtp
);
537 forOpNew
= rewriter
.create
<scf::ForOp
>(
538 loc
, forOp
.getLowerBound(), forOp
.getUpperBound(), step
, vinit
);
540 LoopEmitter::getLoopEmitterLoopAttrName(),
541 forOp
->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
542 rewriter
.setInsertionPointToStart(forOpNew
.getBody());
544 rewriter
.modifyOpInPlace(forOp
, [&]() { forOp
.setStep(step
); });
545 rewriter
.setInsertionPoint(yield
);
547 vmask
= genVectorMask(rewriter
, loc
, vl
, forOp
.getInductionVar(),
548 forOp
.getLowerBound(), forOp
.getUpperBound(), step
);
551 // Sparse for-loops either are terminated by a non-empty yield operation
552 // (reduction loop) or otherwise by a store operation (pararallel loop).
553 if (!yield
.getResults().empty()) {
554 // Analyze/vectorize reduction.
555 if (yield
->getNumOperands() != 1)
557 Value red
= yield
->getOperand(0);
558 Value iter
= forOp
.getRegionIterArg(0);
559 vector::CombiningKind kind
;
561 if (isVectorizableReduction(red
, iter
, kind
) &&
562 vectorizeExpr(rewriter
, forOp
, vl
, red
, codegen
, vmask
, vrhs
)) {
564 Value partial
= forOpNew
.getResult(0);
565 Value vpass
= genVectorInvariantValue(rewriter
, vl
, iter
);
566 Value vred
= rewriter
.create
<arith::SelectOp
>(loc
, vmask
, vrhs
, vpass
);
567 rewriter
.create
<scf::YieldOp
>(loc
, vred
);
568 rewriter
.setInsertionPointAfter(forOpNew
);
569 Value vres
= rewriter
.create
<vector::ReductionOp
>(loc
, kind
, partial
);
570 // Now do some relinking (last one is not completely type safe
571 // but all bad ones are removed right away). This also folds away
572 // nop broadcast operations.
573 rewriter
.replaceAllUsesWith(forOp
.getResult(0), vres
);
574 rewriter
.replaceAllUsesWith(forOp
.getInductionVar(),
575 forOpNew
.getInductionVar());
576 rewriter
.replaceAllUsesWith(forOp
.getRegionIterArg(0),
577 forOpNew
.getRegionIterArg(0));
578 rewriter
.eraseOp(forOp
);
582 } else if (auto store
= dyn_cast
<memref::StoreOp
>(last
)) {
583 // Analyze/vectorize store operation.
584 auto subs
= store
.getIndices();
585 SmallVector
<Value
> idxs
;
586 Value rhs
= store
.getValue();
588 if (vectorizeSubscripts(rewriter
, forOp
, vl
, subs
, codegen
, vmask
, idxs
) &&
589 vectorizeExpr(rewriter
, forOp
, vl
, rhs
, codegen
, vmask
, vrhs
)) {
591 genVectorStore(rewriter
, loc
, store
.getMemRef(), idxs
, vmask
, vrhs
);
592 rewriter
.eraseOp(store
);
598 assert(!codegen
&& "cannot call codegen when analysis failed");
602 /// Basic for-loop vectorizer.
603 struct ForOpRewriter
: public OpRewritePattern
<scf::ForOp
> {
605 using OpRewritePattern
<scf::ForOp
>::OpRewritePattern
;
607 ForOpRewriter(MLIRContext
*context
, unsigned vectorLength
,
608 bool enableVLAVectorization
, bool enableSIMDIndex32
)
609 : OpRewritePattern(context
), vl
{vectorLength
, enableVLAVectorization
,
610 enableSIMDIndex32
} {}
612 LogicalResult
matchAndRewrite(scf::ForOp op
,
613 PatternRewriter
&rewriter
) const override
{
614 // Check for single block, unit-stride for-loop that is generated by
615 // sparsifier, which means no data dependence analysis is required,
616 // and its loop-body is very restricted in form.
617 if (!op
.getRegion().hasOneBlock() || !isConstantIntValue(op
.getStep(), 1) ||
618 !op
->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
620 // Analyze (!codegen) and rewrite (codegen) loop-body.
621 if (vectorizeStmt(rewriter
, op
, vl
, /*codegen=*/false) &&
622 vectorizeStmt(rewriter
, op
, vl
, /*codegen=*/true))
631 /// Reduction chain cleanup.
633 /// s = vsum(v) v = for { }
634 /// u = expand(s) -> for (v) { }
636 template <typename VectorOp
>
637 struct ReducChainRewriter
: public OpRewritePattern
<VectorOp
> {
639 using OpRewritePattern
<VectorOp
>::OpRewritePattern
;
641 LogicalResult
matchAndRewrite(VectorOp op
,
642 PatternRewriter
&rewriter
) const override
{
643 Value inp
= op
.getSource();
644 if (auto redOp
= inp
.getDefiningOp
<vector::ReductionOp
>()) {
645 if (auto forOp
= redOp
.getVector().getDefiningOp
<scf::ForOp
>()) {
646 if (forOp
->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
647 rewriter
.replaceOp(op
, redOp
.getVector());
658 //===----------------------------------------------------------------------===//
659 // Public method for populating vectorization rules.
660 //===----------------------------------------------------------------------===//
662 /// Populates the given patterns list with vectorization rules.
663 void mlir::populateSparseVectorizationPatterns(RewritePatternSet
&patterns
,
664 unsigned vectorLength
,
665 bool enableVLAVectorization
,
666 bool enableSIMDIndex32
) {
667 assert(vectorLength
> 0);
668 vector::populateVectorStepLoweringPatterns(patterns
);
669 patterns
.add
<ForOpRewriter
>(patterns
.getContext(), vectorLength
,
670 enableVLAVectorization
, enableSIMDIndex32
);
671 patterns
.add
<ReducChainRewriter
<vector::InsertElementOp
>,
672 ReducChainRewriter
<vector::BroadcastOp
>>(patterns
.getContext());