[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Dialect / SparseTensor / Transforms / SparseVectorization.cpp
blobb2eca539194a87b2559b4f83f147c2c28ab3b117
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparsifier into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparsifier can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
17 //===----------------------------------------------------------------------===//
19 #include "Utils/CodegenUtils.h"
20 #include "Utils/LoopEmitter.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h"
23 #include "mlir/Dialect/Arith/IR/Arith.h"
24 #include "mlir/Dialect/Complex/IR/Complex.h"
25 #include "mlir/Dialect/Math/IR/Math.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
29 #include "mlir/Dialect/Vector/IR/VectorOps.h"
30 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
31 #include "mlir/IR/Matchers.h"
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
36 namespace {
38 /// Target SIMD properties:
39 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
40 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
41 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
42 struct VL {
43 unsigned vectorLength;
44 bool enableVLAVectorization;
45 bool enableSIMDIndex32;
48 /// Helper test for invariant value (defined outside given block).
49 static bool isInvariantValue(Value val, Block *block) {
50 return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
53 /// Helper test for invariant argument (defined outside given block).
54 static bool isInvariantArg(BlockArgument arg, Block *block) {
55 return arg.getOwner() != block;
58 /// Constructs vector type for element type.
59 static VectorType vectorType(VL vl, Type etp) {
60 return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
63 /// Constructs vector type from a memref value.
64 static VectorType vectorType(VL vl, Value mem) {
65 return vectorType(vl, getMemRefType(mem).getElementType());
68 /// Constructs vector iteration mask.
69 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
70 Value iv, Value lo, Value hi, Value step) {
71 VectorType mtp = vectorType(vl, rewriter.getI1Type());
72 // Special case if the vector length evenly divides the trip count (for
73 // example, "for i = 0, 128, 16"). A constant all-true mask is generated
74 // so that all subsequent masked memory operations are immediately folded
75 // into unconditional memory operations.
76 IntegerAttr loInt, hiInt, stepInt;
77 if (matchPattern(lo, m_Constant(&loInt)) &&
78 matchPattern(hi, m_Constant(&hiInt)) &&
79 matchPattern(step, m_Constant(&stepInt))) {
80 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
81 Value trueVal = constantI1(rewriter, loc, true);
82 return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
85 // Otherwise, generate a vector mask that avoids overrunning the upperbound
86 // during vector execution. Here we rely on subsequent loop optimizations to
87 // avoid executing the mask in all iterations, for example, by splitting the
88 // loop into an unconditional vector loop and a scalar cleanup loop.
89 auto min = AffineMap::get(
90 /*dimCount=*/2, /*symbolCount=*/1,
91 {rewriter.getAffineSymbolExpr(0),
92 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
93 rewriter.getContext());
94 Value end = rewriter.createOrFold<affine::AffineMinOp>(
95 loc, min, ValueRange{hi, iv, step});
96 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
99 /// Generates a vectorized invariant. Here we rely on subsequent loop
100 /// optimizations to hoist the invariant broadcast out of the vector loop.
101 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
102 Value val) {
103 VectorType vtp = vectorType(vl, val.getType());
104 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
107 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
108 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
109 /// that the sparsifier can only generate indirect loads in
110 /// the last index, i.e. back().
111 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
112 Value mem, ArrayRef<Value> idxs, Value vmask) {
113 VectorType vtp = vectorType(vl, mem);
114 Value pass = constantZero(rewriter, loc, vtp);
115 if (llvm::isa<VectorType>(idxs.back().getType())) {
116 SmallVector<Value> scalarArgs(idxs);
117 Value indexVec = idxs.back();
118 scalarArgs.back() = constantIndex(rewriter, loc, 0);
119 return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
120 indexVec, vmask, pass);
122 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
123 pass);
126 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
127 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
128 /// that the sparsifier can only generate indirect stores in
129 /// the last index, i.e. back().
130 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
131 ArrayRef<Value> idxs, Value vmask, Value rhs) {
132 if (llvm::isa<VectorType>(idxs.back().getType())) {
133 SmallVector<Value> scalarArgs(idxs);
134 Value indexVec = idxs.back();
135 scalarArgs.back() = constantIndex(rewriter, loc, 0);
136 rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
137 rhs);
138 return;
140 rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
143 /// Detects a vectorizable reduction operations and returns the
144 /// combining kind of reduction on success in `kind`.
145 static bool isVectorizableReduction(Value red, Value iter,
146 vector::CombiningKind &kind) {
147 if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
148 kind = vector::CombiningKind::ADD;
149 return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
151 if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
152 kind = vector::CombiningKind::ADD;
153 return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
155 if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
156 kind = vector::CombiningKind::ADD;
157 return subf->getOperand(0) == iter;
159 if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
160 kind = vector::CombiningKind::ADD;
161 return subi->getOperand(0) == iter;
163 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
164 kind = vector::CombiningKind::MUL;
165 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
167 if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
168 kind = vector::CombiningKind::MUL;
169 return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
171 if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
172 kind = vector::CombiningKind::AND;
173 return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
175 if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
176 kind = vector::CombiningKind::OR;
177 return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
179 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
180 kind = vector::CombiningKind::XOR;
181 return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
183 return false;
186 /// Generates an initial value for a vector reduction, following the scheme
187 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
188 /// initial scalar value is correctly embedded in the vector reduction value,
189 /// and a straightforward horizontal reduction will complete the operation.
190 /// Value 'r' denotes the initial value of the reduction outside the loop.
191 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
192 Value red, Value iter, Value r,
193 VectorType vtp) {
194 vector::CombiningKind kind;
195 if (!isVectorizableReduction(red, iter, kind))
196 llvm_unreachable("unknown reduction");
197 switch (kind) {
198 case vector::CombiningKind::ADD:
199 case vector::CombiningKind::XOR:
200 // Initialize reduction vector to: | 0 | .. | 0 | r |
201 return rewriter.create<vector::InsertElementOp>(
202 loc, r, constantZero(rewriter, loc, vtp),
203 constantIndex(rewriter, loc, 0));
204 case vector::CombiningKind::MUL:
205 // Initialize reduction vector to: | 1 | .. | 1 | r |
206 return rewriter.create<vector::InsertElementOp>(
207 loc, r, constantOne(rewriter, loc, vtp),
208 constantIndex(rewriter, loc, 0));
209 case vector::CombiningKind::AND:
210 case vector::CombiningKind::OR:
211 // Initialize reduction vector to: | r | .. | r | r |
212 return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
213 default:
214 break;
216 llvm_unreachable("unknown reduction kind");
219 /// This method is called twice to analyze and rewrite the given subscripts.
220 /// The first call (!codegen) does the analysis. Then, on success, the second
221 /// call (codegen) yields the proper vector form in the output parameter
222 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
223 /// stay in sync. Note that the analyis part is simple because the sparsifier
224 /// only generates relatively simple subscript expressions.
226 /// See https://llvm.org/docs/GetElementPtr.html for some background on
227 /// the complications described below.
229 /// We need to generate a position/coordinate load from the sparse storage
230 /// scheme. Narrower data types need to be zero extended before casting
231 /// the value into the `index` type used for looping and indexing.
233 /// For the scalar case, subscripts simply zero extend narrower indices
234 /// into 64-bit values before casting to an index type without a performance
235 /// penalty. Indices that already are 64-bit, in theory, cannot express the
236 /// full range since the LLVM backend defines addressing in terms of an
237 /// unsigned pointer/signed index pair.
238 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
239 VL vl, ValueRange subs, bool codegen,
240 Value vmask, SmallVectorImpl<Value> &idxs) {
241 unsigned d = 0;
242 unsigned dim = subs.size();
243 Block *block = &forOp.getRegion().front();
244 for (auto sub : subs) {
245 bool innermost = ++d == dim;
246 // Invariant subscripts in outer dimensions simply pass through.
247 // Note that we rely on LICM to hoist loads where all subscripts
248 // are invariant in the innermost loop.
249 // Example:
250 // a[inv][i] for inv
251 if (isInvariantValue(sub, block)) {
252 if (innermost)
253 return false;
254 if (codegen)
255 idxs.push_back(sub);
256 continue; // success so far
258 // Invariant block arguments (including outer loop indices) in outer
259 // dimensions simply pass through. Direct loop indices in the
260 // innermost loop simply pass through as well.
261 // Example:
262 // a[i][j] for both i and j
263 if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
264 if (isInvariantArg(arg, block) == innermost)
265 return false;
266 if (codegen)
267 idxs.push_back(sub);
268 continue; // success so far
270 // Look under the hood of casting.
271 auto cast = sub;
272 while (true) {
273 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
274 cast = icast->getOperand(0);
275 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
276 cast = ecast->getOperand(0);
277 else
278 break;
280 // Since the index vector is used in a subsequent gather/scatter
281 // operations, which effectively defines an unsigned pointer + signed
282 // index, we must zero extend the vector to an index width. For 8-bit
283 // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
284 // zero extending the elements into 64-bit loses some performance since
285 // the 32-bit indexed gather/scatter is more efficient than the 64-bit
286 // index variant (if the negative 32-bit index space is unused, the
287 // enableSIMDIndex32 flag can preserve this performance). For 64-bit
288 // values, there is no good way to state that the indices are unsigned,
289 // which creates the potential of incorrect address calculations in the
290 // unlikely case we need such extremely large offsets.
291 // Example:
292 // a[ ind[i] ]
293 if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
294 if (!innermost)
295 return false;
296 if (codegen) {
297 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
298 Location loc = forOp.getLoc();
299 Value vload =
300 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
301 Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
302 if (!llvm::isa<IndexType>(etp)) {
303 if (etp.getIntOrFloatBitWidth() < 32)
304 vload = rewriter.create<arith::ExtUIOp>(
305 loc, vectorType(vl, rewriter.getI32Type()), vload);
306 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
307 vload = rewriter.create<arith::ExtUIOp>(
308 loc, vectorType(vl, rewriter.getI64Type()), vload);
310 idxs.push_back(vload);
312 continue; // success so far
314 // Address calculation 'i = add inv, idx' (after LICM).
315 // Example:
316 // a[base + i]
317 if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
318 Value inv = load.getOperand(0);
319 Value idx = load.getOperand(1);
320 // Swap non-invariant.
321 if (!isInvariantValue(inv, block)) {
322 inv = idx;
323 idx = load.getOperand(0);
325 // Inspect.
326 if (isInvariantValue(inv, block)) {
327 if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
328 if (isInvariantArg(arg, block) || !innermost)
329 return false;
330 if (codegen)
331 idxs.push_back(
332 rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
333 continue; // success so far
337 return false;
339 return true;
342 #define UNAOP(xxx) \
343 if (isa<xxx>(def)) { \
344 if (codegen) \
345 vexp = rewriter.create<xxx>(loc, vx); \
346 return true; \
349 #define TYPEDUNAOP(xxx) \
350 if (auto x = dyn_cast<xxx>(def)) { \
351 if (codegen) { \
352 VectorType vtp = vectorType(vl, x.getType()); \
353 vexp = rewriter.create<xxx>(loc, vtp, vx); \
355 return true; \
358 #define BINOP(xxx) \
359 if (isa<xxx>(def)) { \
360 if (codegen) \
361 vexp = rewriter.create<xxx>(loc, vx, vy); \
362 return true; \
365 /// This method is called twice to analyze and rewrite the given expression.
366 /// The first call (!codegen) does the analysis. Then, on success, the second
367 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
368 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
369 /// that the analyis part is simple because the sparsifier only generates
370 /// relatively simple expressions inside the for-loops.
371 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
372 Value exp, bool codegen, Value vmask, Value &vexp) {
373 Location loc = forOp.getLoc();
374 // Reject unsupported types.
375 if (!VectorType::isValidElementType(exp.getType()))
376 return false;
377 // A block argument is invariant/reduction/index.
378 if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
379 if (arg == forOp.getInductionVar()) {
380 // We encountered a single, innermost index inside the computation,
381 // such as a[i] = i, which must convert to [i, i+1, ...].
382 if (codegen) {
383 VectorType vtp = vectorType(vl, arg.getType());
384 Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
385 Value incr = rewriter.create<vector::StepOp>(loc, vtp);
386 vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
388 return true;
390 // An invariant or reduction. In both cases, we treat this as an
391 // invariant value, and rely on later replacing and folding to
392 // construct a proper reduction chain for the latter case.
393 if (codegen)
394 vexp = genVectorInvariantValue(rewriter, vl, exp);
395 return true;
397 // Something defined outside the loop-body is invariant.
398 Operation *def = exp.getDefiningOp();
399 Block *block = &forOp.getRegion().front();
400 if (def->getBlock() != block) {
401 if (codegen)
402 vexp = genVectorInvariantValue(rewriter, vl, exp);
403 return true;
405 // Proper load operations. These are either values involved in the
406 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
407 // or coordinate values inside the computation that are now fetched from
408 // the sparse storage coordinates arrays, such as a[i] = i becomes
409 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
410 // and 'hi = lo + vl - 1'.
411 if (auto load = dyn_cast<memref::LoadOp>(def)) {
412 auto subs = load.getIndices();
413 SmallVector<Value> idxs;
414 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
415 if (codegen)
416 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
417 return true;
419 return false;
421 // Inside loop-body unary and binary operations. Note that it would be
422 // nicer if we could somehow test and build the operations in a more
423 // concise manner than just listing them all (although this way we know
424 // for certain that they can vectorize).
426 // TODO: avoid visiting CSEs multiple times
428 if (def->getNumOperands() == 1) {
429 Value vx;
430 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
431 vx)) {
432 UNAOP(math::AbsFOp)
433 UNAOP(math::AbsIOp)
434 UNAOP(math::CeilOp)
435 UNAOP(math::FloorOp)
436 UNAOP(math::SqrtOp)
437 UNAOP(math::ExpM1Op)
438 UNAOP(math::Log1pOp)
439 UNAOP(math::SinOp)
440 UNAOP(math::TanhOp)
441 UNAOP(arith::NegFOp)
442 TYPEDUNAOP(arith::TruncFOp)
443 TYPEDUNAOP(arith::ExtFOp)
444 TYPEDUNAOP(arith::FPToSIOp)
445 TYPEDUNAOP(arith::FPToUIOp)
446 TYPEDUNAOP(arith::SIToFPOp)
447 TYPEDUNAOP(arith::UIToFPOp)
448 TYPEDUNAOP(arith::ExtSIOp)
449 TYPEDUNAOP(arith::ExtUIOp)
450 TYPEDUNAOP(arith::IndexCastOp)
451 TYPEDUNAOP(arith::TruncIOp)
452 TYPEDUNAOP(arith::BitcastOp)
453 // TODO: complex?
455 } else if (def->getNumOperands() == 2) {
456 Value vx, vy;
457 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
458 vx) &&
459 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
460 vy)) {
461 // We only accept shift-by-invariant (where the same shift factor applies
462 // to all packed elements). In the vector dialect, this is still
463 // represented with an expanded vector at the right-hand-side, however,
464 // so that we do not have to special case the code generation.
465 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
466 isa<arith::ShRSIOp>(def)) {
467 Value shiftFactor = def->getOperand(1);
468 if (!isInvariantValue(shiftFactor, block))
469 return false;
471 // Generate code.
472 BINOP(arith::MulFOp)
473 BINOP(arith::MulIOp)
474 BINOP(arith::DivFOp)
475 BINOP(arith::DivSIOp)
476 BINOP(arith::DivUIOp)
477 BINOP(arith::AddFOp)
478 BINOP(arith::AddIOp)
479 BINOP(arith::SubFOp)
480 BINOP(arith::SubIOp)
481 BINOP(arith::AndIOp)
482 BINOP(arith::OrIOp)
483 BINOP(arith::XOrIOp)
484 BINOP(arith::ShLIOp)
485 BINOP(arith::ShRUIOp)
486 BINOP(arith::ShRSIOp)
487 // TODO: complex?
490 return false;
493 #undef UNAOP
494 #undef TYPEDUNAOP
495 #undef BINOP
497 /// This method is called twice to analyze and rewrite the given for-loop.
498 /// The first call (!codegen) does the analysis. Then, on success, the second
499 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
500 /// that analysis and rewriting code stay in sync.
501 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
502 bool codegen) {
503 Block &block = forOp.getRegion().front();
504 // For loops with single yield statement (as below) could be generated
505 // when custom reduce is used with unary operation.
506 // for (...)
507 // yield c_0
508 if (block.getOperations().size() <= 1)
509 return false;
511 Location loc = forOp.getLoc();
512 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
513 auto &last = *++block.rbegin();
514 scf::ForOp forOpNew;
516 // Perform initial set up during codegen (we know that the first analysis
517 // pass was successful). For reductions, we need to construct a completely
518 // new for-loop, since the incoming and outgoing reduction type
519 // changes into SIMD form. For stores, we can simply adjust the stride
520 // and insert in the existing for-loop. In both cases, we set up a vector
521 // mask for all operations which takes care of confining vectors to
522 // the original iteration space (later cleanup loops or other
523 // optimizations can take care of those).
524 Value vmask;
525 if (codegen) {
526 Value step = constantIndex(rewriter, loc, vl.vectorLength);
527 if (vl.enableVLAVectorization) {
528 Value vscale =
529 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
530 step = rewriter.create<arith::MulIOp>(loc, vscale, step);
532 if (!yield.getResults().empty()) {
533 Value init = forOp.getInitArgs()[0];
534 VectorType vtp = vectorType(vl, init.getType());
535 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
536 forOp.getRegionIterArg(0), init, vtp);
537 forOpNew = rewriter.create<scf::ForOp>(
538 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
539 forOpNew->setAttr(
540 LoopEmitter::getLoopEmitterLoopAttrName(),
541 forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
542 rewriter.setInsertionPointToStart(forOpNew.getBody());
543 } else {
544 rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
545 rewriter.setInsertionPoint(yield);
547 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
548 forOp.getLowerBound(), forOp.getUpperBound(), step);
551 // Sparse for-loops either are terminated by a non-empty yield operation
552 // (reduction loop) or otherwise by a store operation (pararallel loop).
553 if (!yield.getResults().empty()) {
554 // Analyze/vectorize reduction.
555 if (yield->getNumOperands() != 1)
556 return false;
557 Value red = yield->getOperand(0);
558 Value iter = forOp.getRegionIterArg(0);
559 vector::CombiningKind kind;
560 Value vrhs;
561 if (isVectorizableReduction(red, iter, kind) &&
562 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
563 if (codegen) {
564 Value partial = forOpNew.getResult(0);
565 Value vpass = genVectorInvariantValue(rewriter, vl, iter);
566 Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
567 rewriter.create<scf::YieldOp>(loc, vred);
568 rewriter.setInsertionPointAfter(forOpNew);
569 Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
570 // Now do some relinking (last one is not completely type safe
571 // but all bad ones are removed right away). This also folds away
572 // nop broadcast operations.
573 rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
574 rewriter.replaceAllUsesWith(forOp.getInductionVar(),
575 forOpNew.getInductionVar());
576 rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
577 forOpNew.getRegionIterArg(0));
578 rewriter.eraseOp(forOp);
580 return true;
582 } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
583 // Analyze/vectorize store operation.
584 auto subs = store.getIndices();
585 SmallVector<Value> idxs;
586 Value rhs = store.getValue();
587 Value vrhs;
588 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
589 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
590 if (codegen) {
591 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
592 rewriter.eraseOp(store);
594 return true;
598 assert(!codegen && "cannot call codegen when analysis failed");
599 return false;
602 /// Basic for-loop vectorizer.
603 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
604 public:
605 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
607 ForOpRewriter(MLIRContext *context, unsigned vectorLength,
608 bool enableVLAVectorization, bool enableSIMDIndex32)
609 : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
610 enableSIMDIndex32} {}
612 LogicalResult matchAndRewrite(scf::ForOp op,
613 PatternRewriter &rewriter) const override {
614 // Check for single block, unit-stride for-loop that is generated by
615 // sparsifier, which means no data dependence analysis is required,
616 // and its loop-body is very restricted in form.
617 if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
618 !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
619 return failure();
620 // Analyze (!codegen) and rewrite (codegen) loop-body.
621 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
622 vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
623 return success();
624 return failure();
627 private:
628 const VL vl;
631 /// Reduction chain cleanup.
632 /// v = for { }
633 /// s = vsum(v) v = for { }
634 /// u = expand(s) -> for (v) { }
635 /// for (u) { }
636 template <typename VectorOp>
637 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
638 public:
639 using OpRewritePattern<VectorOp>::OpRewritePattern;
641 LogicalResult matchAndRewrite(VectorOp op,
642 PatternRewriter &rewriter) const override {
643 Value inp = op.getSource();
644 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
645 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
646 if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
647 rewriter.replaceOp(op, redOp.getVector());
648 return success();
652 return failure();
656 } // namespace
658 //===----------------------------------------------------------------------===//
659 // Public method for populating vectorization rules.
660 //===----------------------------------------------------------------------===//
662 /// Populates the given patterns list with vectorization rules.
663 void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
664 unsigned vectorLength,
665 bool enableVLAVectorization,
666 bool enableSIMDIndex32) {
667 assert(vectorLength > 0);
668 vector::populateVectorStepLoweringPatterns(patterns);
669 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
670 enableVLAVectorization, enableSIMDIndex32);
671 patterns.add<ReducChainRewriter<vector::InsertElementOp>,
672 ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());