Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Vector / TestVectorTransforms.cpp
blob72aaa7dc4f8973432ccd6c6ca4909a93e437c53e
1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <optional>
10 #include <type_traits>
12 #include "mlir/Analysis/SliceAnalysis.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
27 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
29 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Support/LLVM.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35 using namespace mlir;
36 using namespace mlir::linalg;
37 using namespace mlir::vector;
39 namespace {
41 struct TestVectorToVectorLowering
42 : public PassWrapper<TestVectorToVectorLowering,
43 OperationPass<func::FuncOp>> {
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
46 TestVectorToVectorLowering() = default;
47 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
48 : PassWrapper(pass) {}
49 StringRef getArgument() const final {
50 return "test-vector-to-vector-lowering";
52 StringRef getDescription() const final {
53 return "Test lowering patterns between ops in the vector dialect";
56 void getDependentDialects(DialectRegistry &registry) const override {
57 registry.insert<affine::AffineDialect>();
58 registry.insert<vector::VectorDialect>();
61 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
62 llvm::cl::init(false)};
64 void runOnOperation() override {
65 auto *ctx = &getContext();
66 RewritePatternSet patterns(ctx);
67 if (unroll) {
68 populateVectorUnrollPatterns(
69 patterns,
70 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
71 filter));
73 populateVectorToVectorCanonicalizationPatterns(patterns);
74 populateBubbleVectorBitCastOpPatterns(patterns);
75 populateCastAwayVectorLeadingOneDimPatterns(patterns);
76 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
79 private:
80 // Return the target shape based on op type.
81 static std::optional<SmallVector<int64_t>> getShape(Operation *op) {
82 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
83 return SmallVector<int64_t>(2, 2);
84 if (isa<vector::ContractionOp>(op))
85 return SmallVector<int64_t>(3, 2);
86 // For transfer ops, just propagate the shape coming from
87 // InsertStridedSlices/ExtractStridedSlices.
88 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
89 VectorType dstVec;
90 for (Operation *users : readOp->getUsers()) {
91 auto extract = dyn_cast<ExtractStridedSliceOp>(users);
92 if (!extract)
93 return std::nullopt;
94 auto vecType = cast<VectorType>(extract.getResult().getType());
95 if (dstVec && dstVec != vecType)
96 return std::nullopt;
97 dstVec = vecType;
99 return SmallVector<int64_t>(dstVec.getShape());
101 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
102 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
103 if (!insert)
104 return std::nullopt;
105 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
106 return SmallVector<int64_t>(shape);
108 return std::nullopt;
111 static LogicalResult filter(Operation *op) {
112 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
113 ContractionOp, TransferReadOp, TransferWriteOp>(op));
117 struct TestVectorContractionPrepareForMMTLowering
118 : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
119 OperationPass<func::FuncOp>> {
120 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
121 TestVectorContractionPrepareForMMTLowering)
123 StringRef getArgument() const final {
124 return "test-vector-contraction-prepare-for-mmt-lowering";
126 StringRef getDescription() const final {
127 return "Test vector.contraction matmul canonicalization for MMT lowering.";
129 TestVectorContractionPrepareForMMTLowering() = default;
131 void getDependentDialects(DialectRegistry &registry) const override {
132 registry.insert<affine::AffineDialect, arith::ArithDialect,
133 vector::VectorDialect>();
136 void runOnOperation() override {
137 MLIRContext *ctx = &getContext();
138 RewritePatternSet patterns(ctx);
139 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
140 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
144 struct TestVectorUnrollingPatterns
145 : public PassWrapper<TestVectorUnrollingPatterns,
146 OperationPass<func::FuncOp>> {
147 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
149 StringRef getArgument() const final {
150 return "test-vector-unrolling-patterns";
152 StringRef getDescription() const final {
153 return "Test lowering patterns to unroll contract ops in the vector "
154 "dialect";
156 TestVectorUnrollingPatterns() = default;
157 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
158 : PassWrapper(pass) {}
159 void runOnOperation() override {
160 MLIRContext *ctx = &getContext();
161 RewritePatternSet patterns(ctx);
162 populateVectorUnrollPatterns(
163 patterns, UnrollVectorOptions()
164 .setNativeShape(ArrayRef<int64_t>{2, 2})
165 .setFilterConstraint([](Operation *op) {
166 return success(isa<arith::AddFOp, vector::FMAOp,
167 vector::MultiDimReductionOp>(op));
168 }));
169 populateVectorUnrollPatterns(
170 patterns, UnrollVectorOptions()
171 .setNativeShape(ArrayRef<int64_t>{2})
172 .setFilterConstraint([](Operation *op) {
173 return success(isa<vector::ReductionOp>(op));
174 }));
175 populateVectorUnrollPatterns(
176 patterns, UnrollVectorOptions()
177 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
178 .setFilterConstraint([](Operation *op) {
179 return success(isa<vector::TransposeOp>(op));
180 }));
182 if (unrollBasedOnType) {
183 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
184 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
185 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
186 SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
188 Type lhsType = contractOp.getLhsType().getElementType();
189 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
190 return nativeShape;
193 UnrollVectorOptions opts;
194 opts.setNativeShapeFn(nativeShapeFn)
195 .setFilterConstraint(
196 [](Operation *op) { return success(isa<ContractionOp>(op)); });
198 if (!unrollOrder.empty()) {
199 opts.setUnrollTraversalOrderFn(
200 [this](Operation *op) -> std::optional<SmallVector<int64_t>> {
201 vector::ContractionOp contractOp =
202 cast<vector::ContractionOp>(op);
203 if (contractOp.getIteratorTypes().size() == unrollOrder.size())
204 return SmallVector<int64_t>(unrollOrder.begin(),
205 unrollOrder.end());
206 return std::nullopt;
209 populateVectorUnrollPatterns(patterns, opts);
210 } else {
211 auto nativeShapeFn =
212 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
213 auto contractOp = dyn_cast<ContractionOp>(op);
214 if (!contractOp)
215 return std::nullopt;
216 return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
218 populateVectorUnrollPatterns(patterns,
219 UnrollVectorOptions()
220 .setNativeShapeFn(nativeShapeFn)
221 .setFilterConstraint([](Operation *op) {
222 return success(isa<ContractionOp>(op));
223 }));
225 populateVectorToVectorCanonicalizationPatterns(patterns);
226 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
229 ListOption<int64_t> unrollOrder{*this, "unroll-order",
230 llvm::cl::desc("set the unroll order")};
232 Option<bool> unrollBasedOnType{
233 *this, "unroll-based-on-type",
234 llvm::cl::desc("Set the unroll factor based on type of the operation"),
235 llvm::cl::init(false)};
238 struct TestVectorTransferUnrollingPatterns
239 : public PassWrapper<TestVectorTransferUnrollingPatterns,
240 OperationPass<func::FuncOp>> {
241 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
242 TestVectorTransferUnrollingPatterns)
244 TestVectorTransferUnrollingPatterns() = default;
245 TestVectorTransferUnrollingPatterns(
246 const TestVectorTransferUnrollingPatterns &pass)
247 : PassWrapper(pass) {}
249 void getDependentDialects(DialectRegistry &registry) const override {
250 registry.insert<affine::AffineDialect>();
252 StringRef getArgument() const final {
253 return "test-vector-transfer-unrolling-patterns";
255 StringRef getDescription() const final {
256 return "Test lowering patterns to unroll transfer ops in the vector "
257 "dialect";
259 void runOnOperation() override {
260 MLIRContext *ctx = &getContext();
261 RewritePatternSet patterns(ctx);
262 UnrollVectorOptions opts;
263 opts.setNativeShape(ArrayRef<int64_t>{2, 2})
264 .setFilterConstraint([](Operation *op) {
265 return success(isa<vector::TransferReadOp, vector::TransferWriteOp,
266 vector::GatherOp>(op));
268 if (reverseUnrollOrder.getValue()) {
269 opts.setUnrollTraversalOrderFn(
270 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
271 int64_t numLoops = 0;
272 if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
273 numLoops = readOp.getVectorType().getRank();
274 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
275 numLoops = writeOp.getVectorType().getRank();
276 else if (auto gatherOp = dyn_cast<vector::GatherOp>(op))
277 numLoops = gatherOp.getVectorType().getRank();
278 else
279 return std::nullopt;
280 auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
281 return llvm::to_vector(order);
284 populateVectorUnrollPatterns(patterns, opts);
285 populateVectorToVectorCanonicalizationPatterns(patterns);
286 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
289 Option<bool> reverseUnrollOrder{
290 *this, "reverse-unroll-order",
291 llvm::cl::desc(
292 "reverse the order of unrolling of vector transfer operations"),
293 llvm::cl::init(false)};
296 struct TestScalarVectorTransferLoweringPatterns
297 : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
298 OperationPass<func::FuncOp>> {
299 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
300 TestScalarVectorTransferLoweringPatterns)
302 TestScalarVectorTransferLoweringPatterns() = default;
303 TestScalarVectorTransferLoweringPatterns(
304 const TestScalarVectorTransferLoweringPatterns &pass)
305 : PassWrapper(pass) {}
307 StringRef getArgument() const final {
308 return "test-scalar-vector-transfer-lowering";
310 StringRef getDescription() const final {
311 return "Test lowering of scalar vector transfers to memref loads/stores.";
314 void getDependentDialects(DialectRegistry &registry) const override {
315 registry.insert<affine::AffineDialect, memref::MemRefDialect,
316 tensor::TensorDialect, vector::VectorDialect>();
319 Option<bool> allowMultipleUses{
320 *this, "allow-multiple-uses",
321 llvm::cl::desc("Fold transfer operations with multiple uses"),
322 llvm::cl::init(false)};
324 void runOnOperation() override {
325 MLIRContext *ctx = &getContext();
326 RewritePatternSet patterns(ctx);
327 vector::populateScalarVectorTransferLoweringPatterns(
328 patterns, /*benefit=*/1, allowMultipleUses.getValue());
329 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
333 struct TestVectorTransferOpt
334 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
335 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
337 StringRef getArgument() const final { return "test-vector-transferop-opt"; }
338 StringRef getDescription() const final {
339 return "Test optimization transformations for transfer ops";
341 void runOnOperation() override {
342 IRRewriter rewriter(&getContext());
343 transferOpflowOpt(rewriter, getOperation());
347 struct TestVectorTransferCollapseInnerMostContiguousDims
348 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
349 OperationPass<func::FuncOp>> {
350 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
351 TestVectorTransferCollapseInnerMostContiguousDims)
353 TestVectorTransferCollapseInnerMostContiguousDims() = default;
354 TestVectorTransferCollapseInnerMostContiguousDims(
355 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
357 void getDependentDialects(DialectRegistry &registry) const override {
358 registry.insert<memref::MemRefDialect, affine::AffineDialect>();
361 StringRef getArgument() const final {
362 return "test-vector-transfer-collapse-inner-most-dims";
365 StringRef getDescription() const final {
366 return "Test lowering patterns that reducedes the rank of the vector "
367 "transfer memory and vector operands.";
370 void runOnOperation() override {
371 RewritePatternSet patterns(&getContext());
372 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
373 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
377 struct TestVectorSinkPatterns
378 : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
379 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
381 TestVectorSinkPatterns() = default;
382 TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
384 void getDependentDialects(DialectRegistry &registry) const override {
385 registry.insert<memref::MemRefDialect, affine::AffineDialect>();
388 StringRef getArgument() const final { return "test-vector-sink-patterns"; }
390 StringRef getDescription() const final {
391 return "Test lowering patterns that eliminate redundant brodacast "
392 "and transpose operations.";
395 void runOnOperation() override {
396 RewritePatternSet patterns(&getContext());
397 populateSinkVectorOpsPatterns(patterns);
398 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
402 struct TestVectorReduceToContractPatternsPatterns
403 : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
404 OperationPass<func::FuncOp>> {
405 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
406 TestVectorReduceToContractPatternsPatterns)
408 StringRef getArgument() const final {
409 return "test-vector-reduction-to-contract-patterns";
411 StringRef getDescription() const final {
412 return "Test patterns to convert multireduce op to contract and combine "
413 "broadcast/transpose to contract";
415 void runOnOperation() override {
416 RewritePatternSet patterns(&getContext());
417 populateVectorReductionToContractPatterns(patterns);
418 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
422 struct TestVectorChainedReductionFoldingPatterns
423 : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
424 OperationPass<func::FuncOp>> {
425 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
426 TestVectorChainedReductionFoldingPatterns)
428 StringRef getArgument() const final {
429 return "test-vector-chained-reduction-folding-patterns";
431 StringRef getDescription() const final {
432 return "Test patterns to fold chained vector reductions";
434 void runOnOperation() override {
435 RewritePatternSet patterns(&getContext());
436 populateChainedVectorReductionFoldingPatterns(patterns);
437 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
441 struct TestVectorBreakDownReductionPatterns
442 : public PassWrapper<TestVectorBreakDownReductionPatterns,
443 OperationPass<func::FuncOp>> {
444 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
445 TestVectorBreakDownReductionPatterns)
447 StringRef getArgument() const final {
448 return "test-vector-break-down-reduction-patterns";
450 StringRef getDescription() const final {
451 return "Test patterns to break down vector reductions into arith "
452 "reductions";
454 void runOnOperation() override {
455 RewritePatternSet patterns(&getContext());
456 populateBreakDownVectorReductionPatterns(patterns,
457 /*maxNumElementsToExtract=*/2);
458 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
462 struct TestFlattenVectorTransferPatterns
463 : public PassWrapper<TestFlattenVectorTransferPatterns,
464 OperationPass<func::FuncOp>> {
465 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
466 TestFlattenVectorTransferPatterns)
468 TestFlattenVectorTransferPatterns() = default;
469 TestFlattenVectorTransferPatterns(
470 const TestFlattenVectorTransferPatterns &pass)
471 : PassWrapper(pass) {}
473 StringRef getArgument() const final {
474 return "test-vector-transfer-flatten-patterns";
477 StringRef getDescription() const final {
478 return "Test patterns to rewrite contiguous row-major N-dimensional "
479 "vector.transfer_{read,write} ops into 1D transfers";
482 void getDependentDialects(DialectRegistry &registry) const override {
483 registry.insert<memref::MemRefDialect>();
484 registry.insert<affine::AffineDialect>();
485 registry.insert<vector::VectorDialect>();
488 Option<unsigned> targetVectorBitwidth{
489 *this, "target-vector-bitwidth",
490 llvm::cl::desc(
491 "Minimum vector bitwidth to enable the flattening transformation. "
492 "For scalable vectors this is the base size, i.e. the size "
493 "corresponding to vscale=1."),
494 llvm::cl::init(std::numeric_limits<unsigned>::max())};
496 void runOnOperation() override {
497 RewritePatternSet patterns(&getContext());
498 populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
499 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
503 struct TestVectorScanLowering
504 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
505 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
507 StringRef getArgument() const final { return "test-vector-scan-lowering"; }
508 StringRef getDescription() const final {
509 return "Test lowering patterns that lower the scan op in the vector "
510 "dialect";
512 void runOnOperation() override {
513 RewritePatternSet patterns(&getContext());
514 populateVectorScanLoweringPatterns(patterns);
515 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
519 /// Allocate shared memory for a single warp to test lowering of
520 /// WarpExecuteOnLane0Op.
521 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
522 WarpExecuteOnLane0Op warpOp,
523 Type type) {
524 static constexpr int64_t kSharedMemorySpace = 3;
525 // Compute type of shared memory buffer.
526 MemRefType memrefType;
527 if (auto vectorType = dyn_cast<VectorType>(type)) {
528 memrefType =
529 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
530 kSharedMemorySpace);
531 } else {
532 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
535 // Get symbol table holding all shared memory globals.
536 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
537 SymbolTable symbolTable(moduleOp);
539 // Create a pretty name.
540 SmallString<64> buf;
541 llvm::raw_svector_ostream os(buf);
542 interleave(memrefType.getShape(), os, "x");
543 os << "x" << memrefType.getElementType();
544 std::string symbolName = (Twine("__shared_") + os.str()).str();
546 auto ip = builder.saveInsertionPoint();
547 builder.setInsertionPoint(moduleOp);
548 auto global = builder.create<memref::GlobalOp>(
549 loc,
550 /*sym_name=*/symbolName,
551 /*sym_visibility=*/builder.getStringAttr("private"),
552 /*type=*/memrefType,
553 /*initial_value=*/Attribute(),
554 /*constant=*/false,
555 /*alignment=*/IntegerAttr());
556 symbolTable.insert(global);
557 // The symbol table inserts at the end of the module, but globals are a bit
558 // nicer if they are at the beginning.
559 global->moveBefore(&moduleOp.front());
561 builder.restoreInsertionPoint(ip);
562 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
565 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
566 CombiningKind kind, uint32_t size) {
567 // First reduce on a single thread to get per lane reduction value.
568 Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
569 // Parallel reduction using butterfly shuffles.
570 for (uint64_t i = 1; i < size; i <<= 1) {
571 Value shuffled = builder
572 .create<gpu::ShuffleOp>(loc, laneVal, i,
573 /*width=*/size,
574 /*mode=*/gpu::ShuffleMode::XOR)
575 .getShuffleResult();
576 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
578 return laneVal;
581 struct TestVectorDistribution
582 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
583 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
585 void getDependentDialects(DialectRegistry &registry) const override {
586 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
587 affine::AffineDialect>();
590 StringRef getArgument() const final { return "test-vector-warp-distribute"; }
591 StringRef getDescription() const final {
592 return "Test vector warp distribute transformation and lowering patterns";
594 TestVectorDistribution() = default;
595 TestVectorDistribution(const TestVectorDistribution &pass)
596 : PassWrapper(pass) {}
598 Option<bool> warpOpToSCF{
599 *this, "rewrite-warp-ops-to-scf-if",
600 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
601 llvm::cl::init(false)};
603 Option<bool> distributeTransferWriteOps{
604 *this, "distribute-transfer-write",
605 llvm::cl::desc("Test distribution of transfer write"),
606 llvm::cl::init(false)};
608 Option<unsigned> maxTransferWriteElements{
609 *this, "max-transfer-write-elements",
610 llvm::cl::desc("Maximum number of transfer write elements to distribute"),
611 llvm::cl::init(1)};
613 Option<bool> hoistUniform{*this, "hoist-uniform",
614 llvm::cl::desc("Test hoist uniform"),
615 llvm::cl::init(false)};
617 Option<bool> propagateDistribution{
618 *this, "propagate-distribution",
619 llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
621 void runOnOperation() override {
622 RewritePatternSet patterns(&getContext());
624 getOperation().walk([&](Operation *op) {
625 if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
626 if (hoistUniform) {
627 moveScalarUniformCode(warpOp);
629 WalkResult::interrupt();
632 MLIRContext *ctx = &getContext();
633 auto distributionFn = [](Value val) {
634 // Create an identity dim map of the same rank as the vector.
635 VectorType vecType = dyn_cast<VectorType>(val.getType());
636 int64_t vecRank = vecType ? vecType.getRank() : 0;
637 OpBuilder builder(val.getContext());
638 if (vecRank == 0)
639 return AffineMap::get(val.getContext());
640 return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
642 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
643 Value srcIdx, int64_t warpSz) {
644 assert((val.getType().isF32() || val.getType().isInteger(32)) &&
645 "unsupported shuffle type");
646 Type i32Type = builder.getIntegerType(32);
647 Value srcIdxI32 =
648 builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
649 Value warpSzI32 = builder.create<arith::ConstantOp>(
650 loc, builder.getIntegerAttr(i32Type, warpSz));
651 Value result = builder
652 .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
653 gpu::ShuffleMode::IDX)
654 .getResult(0);
655 return result;
657 if (distributeTransferWriteOps && propagateDistribution) {
658 RewritePatternSet patterns(ctx);
659 vector::populatePropagateWarpVectorDistributionPatterns(
660 patterns, distributionFn, shuffleFn, /*benefit=*/1,
661 /*readBenefit=*/0);
662 vector::populateDistributeReduction(patterns, warpReduction, 1);
663 populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
664 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
665 } else if (distributeTransferWriteOps) {
666 RewritePatternSet patterns(ctx);
667 populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
668 maxTransferWriteElements);
669 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
670 } else if (propagateDistribution) {
671 RewritePatternSet patterns(ctx);
672 vector::populatePropagateWarpVectorDistributionPatterns(
673 patterns, distributionFn, shuffleFn);
674 vector::populateDistributeReduction(patterns, warpReduction);
675 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
677 WarpExecuteOnLane0LoweringOptions options;
678 options.warpAllocationFn = allocateGlobalSharedMemory;
679 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
680 WarpExecuteOnLane0Op warpOp) {
681 builder.create<gpu::BarrierOp>(loc);
683 // Test on one pattern in isolation.
684 if (warpOpToSCF) {
685 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
686 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
687 return;
692 struct TestVectorExtractStridedSliceLowering
693 : public PassWrapper<TestVectorExtractStridedSliceLowering,
694 OperationPass<func::FuncOp>> {
695 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
696 TestVectorExtractStridedSliceLowering)
698 StringRef getArgument() const final {
699 return "test-vector-extract-strided-slice-lowering";
701 StringRef getDescription() const final {
702 return "Test lowering patterns that converts vector.extract_strided_slice "
703 "into a chain of vector.extract and vector.insert ops";
705 void runOnOperation() override {
706 RewritePatternSet patterns(&getContext());
707 populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
708 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
712 struct TestVectorBreakDownBitCast
713 : public PassWrapper<TestVectorBreakDownBitCast,
714 OperationPass<func::FuncOp>> {
715 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
717 StringRef getArgument() const final {
718 return "test-vector-break-down-bitcast";
720 StringRef getDescription() const final {
721 return "Test pattern that breaks down vector.bitcast ops ";
723 void runOnOperation() override {
724 RewritePatternSet patterns(&getContext());
725 populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) {
726 return op.getSourceVectorType().getShape().back() > 4;
728 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
732 struct TestCreateVectorBroadcast
733 : public PassWrapper<TestCreateVectorBroadcast,
734 OperationPass<func::FuncOp>> {
735 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
737 StringRef getArgument() const final { return "test-create-vector-broadcast"; }
738 StringRef getDescription() const final {
739 return "Test optimization transformations for transfer ops";
741 void getDependentDialects(DialectRegistry &registry) const override {
742 registry.insert<vector::VectorDialect>();
745 void runOnOperation() override {
746 getOperation()->walk([](Operation *op) {
747 if (op->getName().getStringRef() != "test_create_broadcast")
748 return;
749 auto targetShape =
750 cast<VectorType>(op->getResult(0).getType()).getShape();
751 auto arrayAttr =
752 cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims"))
753 .asArrayRef();
754 llvm::SetVector<int64_t> broadcastedDims;
755 broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
756 OpBuilder b(op);
757 Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
758 b, op->getOperand(0), targetShape, broadcastedDims);
759 op->getResult(0).replaceAllUsesWith(bcast);
760 op->erase();
765 struct TestVectorGatherLowering
766 : public PassWrapper<TestVectorGatherLowering,
767 OperationPass<func::FuncOp>> {
768 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
770 StringRef getArgument() const final { return "test-vector-gather-lowering"; }
771 StringRef getDescription() const final {
772 return "Test patterns that lower the gather op in the vector conditional "
773 "loads";
775 void getDependentDialects(DialectRegistry &registry) const override {
776 registry.insert<arith::ArithDialect, func::FuncDialect,
777 memref::MemRefDialect, scf::SCFDialect,
778 tensor::TensorDialect, vector::VectorDialect>();
781 void runOnOperation() override {
782 RewritePatternSet patterns(&getContext());
783 populateVectorGatherLoweringPatterns(patterns);
784 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
788 struct TestFoldArithExtensionIntoVectorContractPatterns
789 : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
790 OperationPass<func::FuncOp>> {
791 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
792 TestFoldArithExtensionIntoVectorContractPatterns)
794 StringRef getArgument() const final {
795 return "test-fold-arith-extf-into-vector-contract-patterns";
797 StringRef getDescription() const final {
798 return "Test patterns that fold arithmetic extension ops into vector "
799 "contract ops";
802 void getDependentDialects(DialectRegistry &registry) const override {
803 registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
804 memref::MemRefDialect, scf::SCFDialect,
805 tensor::TensorDialect, vector::VectorDialect>();
808 void runOnOperation() override {
809 RewritePatternSet patterns(&getContext());
810 populateFoldArithExtensionPatterns(patterns);
811 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
815 struct TestVectorEmulateMaskedLoadStore final
816 : public PassWrapper<TestVectorEmulateMaskedLoadStore,
817 OperationPass<func::FuncOp>> {
818 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
820 StringRef getArgument() const override {
821 return "test-vector-emulate-masked-load-store";
823 StringRef getDescription() const override {
824 return "Test patterns that emulate the maskedload/maskedstore op by "
825 " memref.load/store and scf.if";
827 void getDependentDialects(DialectRegistry &registry) const override {
828 registry
829 .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
830 scf::SCFDialect, vector::VectorDialect>();
833 void runOnOperation() override {
834 RewritePatternSet patterns(&getContext());
835 populateVectorMaskedLoadStoreEmulationPatterns(patterns);
836 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
840 struct TestVectorLinearize final
841 : public PassWrapper<TestVectorLinearize, OperationPass<>> {
842 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
844 TestVectorLinearize() = default;
845 TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
847 StringRef getArgument() const override { return "test-vector-linearize"; }
848 StringRef getDescription() const override {
849 return "Linearizes ND vectors for N >= 2 into 1D vectors";
851 void getDependentDialects(DialectRegistry &registry) const override {
852 registry.insert<vector::VectorDialect>();
855 Option<unsigned> targetVectorBitwidth{
856 *this, "target-vector-bitwidth",
857 llvm::cl::desc(
858 "Minimum vector bitwidth to enable the flattening transformation"),
859 llvm::cl::init(std::numeric_limits<unsigned>::max())};
860 void runOnOperation() override {
861 auto *context = &getContext();
863 TypeConverter typeConverter;
864 RewritePatternSet patterns(context);
865 ConversionTarget target(*context);
867 vector::populateVectorLinearizeTypeConversionsAndLegality(
868 typeConverter, patterns, target, targetVectorBitwidth);
869 vector::populateVectorLinearizeShuffleLikeOpsPatterns(
870 typeConverter, patterns, target, targetVectorBitwidth);
871 if (failed(applyPartialConversion(getOperation(), target,
872 std::move(patterns))))
873 return signalPassFailure();
877 struct TestEliminateVectorMasks
878 : public PassWrapper<TestEliminateVectorMasks,
879 OperationPass<func::FuncOp>> {
880 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks)
882 TestEliminateVectorMasks() = default;
883 TestEliminateVectorMasks(const TestEliminateVectorMasks &pass)
884 : PassWrapper(pass) {}
886 Option<unsigned> vscaleMin{
887 *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
888 llvm::cl::init(1)};
889 Option<unsigned> vscaleMax{
890 *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
891 llvm::cl::init(16)};
893 StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
894 StringRef getDescription() const final {
895 return "Test eliminating vector masks";
897 void runOnOperation() override {
898 IRRewriter rewriter(&getContext());
899 eliminateVectorMasks(rewriter, getOperation(),
900 VscaleRange{vscaleMin, vscaleMax});
903 } // namespace
905 namespace mlir {
906 namespace test {
907 void registerTestVectorLowerings() {
908 PassRegistration<TestVectorToVectorLowering>();
910 PassRegistration<TestVectorContractionPrepareForMMTLowering>();
912 PassRegistration<TestVectorUnrollingPatterns>();
914 PassRegistration<TestVectorTransferUnrollingPatterns>();
916 PassRegistration<TestScalarVectorTransferLoweringPatterns>();
918 PassRegistration<TestVectorTransferOpt>();
920 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
922 PassRegistration<TestVectorSinkPatterns>();
924 PassRegistration<TestVectorReduceToContractPatternsPatterns>();
926 PassRegistration<TestVectorChainedReductionFoldingPatterns>();
928 PassRegistration<TestVectorBreakDownReductionPatterns>();
930 PassRegistration<TestFlattenVectorTransferPatterns>();
932 PassRegistration<TestVectorScanLowering>();
934 PassRegistration<TestVectorDistribution>();
936 PassRegistration<TestVectorExtractStridedSliceLowering>();
938 PassRegistration<TestVectorBreakDownBitCast>();
940 PassRegistration<TestCreateVectorBroadcast>();
942 PassRegistration<TestVectorGatherLowering>();
944 PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
946 PassRegistration<TestVectorEmulateMaskedLoadStore>();
948 PassRegistration<TestVectorLinearize>();
950 PassRegistration<TestEliminateVectorMasks>();
952 } // namespace test
953 } // namespace mlir