1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 #include <type_traits>
12 #include "mlir/Analysis/SliceAnalysis.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
27 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
29 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Support/LLVM.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
36 using namespace mlir::linalg
;
37 using namespace mlir::vector
;
41 struct TestVectorToVectorLowering
42 : public PassWrapper
<TestVectorToVectorLowering
,
43 OperationPass
<func::FuncOp
>> {
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering
)
46 TestVectorToVectorLowering() = default;
47 TestVectorToVectorLowering(const TestVectorToVectorLowering
&pass
)
48 : PassWrapper(pass
) {}
49 StringRef
getArgument() const final
{
50 return "test-vector-to-vector-lowering";
52 StringRef
getDescription() const final
{
53 return "Test lowering patterns between ops in the vector dialect";
56 void getDependentDialects(DialectRegistry
®istry
) const override
{
57 registry
.insert
<affine::AffineDialect
>();
58 registry
.insert
<vector::VectorDialect
>();
61 Option
<bool> unroll
{*this, "unroll", llvm::cl::desc("Include unrolling"),
62 llvm::cl::init(false)};
64 void runOnOperation() override
{
65 auto *ctx
= &getContext();
66 RewritePatternSet
patterns(ctx
);
68 populateVectorUnrollPatterns(
70 UnrollVectorOptions().setNativeShapeFn(getShape
).setFilterConstraint(
73 populateVectorToVectorCanonicalizationPatterns(patterns
);
74 populateBubbleVectorBitCastOpPatterns(patterns
);
75 populateCastAwayVectorLeadingOneDimPatterns(patterns
);
76 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
80 // Return the target shape based on op type.
81 static std::optional
<SmallVector
<int64_t>> getShape(Operation
*op
) {
82 if (isa
<arith::AddFOp
, arith::SelectOp
, arith::CmpFOp
>(op
))
83 return SmallVector
<int64_t>(2, 2);
84 if (isa
<vector::ContractionOp
>(op
))
85 return SmallVector
<int64_t>(3, 2);
86 // For transfer ops, just propagate the shape coming from
87 // InsertStridedSlices/ExtractStridedSlices.
88 if (auto readOp
= dyn_cast
<vector::TransferReadOp
>(op
)) {
90 for (Operation
*users
: readOp
->getUsers()) {
91 auto extract
= dyn_cast
<ExtractStridedSliceOp
>(users
);
94 auto vecType
= cast
<VectorType
>(extract
.getResult().getType());
95 if (dstVec
&& dstVec
!= vecType
)
99 return SmallVector
<int64_t>(dstVec
.getShape());
101 if (auto writeOp
= dyn_cast
<vector::TransferWriteOp
>(op
)) {
102 auto insert
= writeOp
.getVector().getDefiningOp
<InsertStridedSliceOp
>();
105 ArrayRef
<int64_t> shape
= insert
.getSourceVectorType().getShape();
106 return SmallVector
<int64_t>(shape
);
111 static LogicalResult
filter(Operation
*op
) {
112 return success(isa
<arith::AddFOp
, arith::SelectOp
, arith::CmpFOp
,
113 ContractionOp
, TransferReadOp
, TransferWriteOp
>(op
));
117 struct TestVectorContractionPrepareForMMTLowering
118 : public PassWrapper
<TestVectorContractionPrepareForMMTLowering
,
119 OperationPass
<func::FuncOp
>> {
120 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
121 TestVectorContractionPrepareForMMTLowering
)
123 StringRef
getArgument() const final
{
124 return "test-vector-contraction-prepare-for-mmt-lowering";
126 StringRef
getDescription() const final
{
127 return "Test vector.contraction matmul canonicalization for MMT lowering.";
129 TestVectorContractionPrepareForMMTLowering() = default;
131 void getDependentDialects(DialectRegistry
®istry
) const override
{
132 registry
.insert
<affine::AffineDialect
, arith::ArithDialect
,
133 vector::VectorDialect
>();
136 void runOnOperation() override
{
137 MLIRContext
*ctx
= &getContext();
138 RewritePatternSet
patterns(ctx
);
139 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns
);
140 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
144 struct TestVectorUnrollingPatterns
145 : public PassWrapper
<TestVectorUnrollingPatterns
,
146 OperationPass
<func::FuncOp
>> {
147 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns
)
149 StringRef
getArgument() const final
{
150 return "test-vector-unrolling-patterns";
152 StringRef
getDescription() const final
{
153 return "Test lowering patterns to unroll contract ops in the vector "
156 TestVectorUnrollingPatterns() = default;
157 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns
&pass
)
158 : PassWrapper(pass
) {}
159 void runOnOperation() override
{
160 MLIRContext
*ctx
= &getContext();
161 RewritePatternSet
patterns(ctx
);
162 populateVectorUnrollPatterns(
163 patterns
, UnrollVectorOptions()
164 .setNativeShape(ArrayRef
<int64_t>{2, 2})
165 .setFilterConstraint([](Operation
*op
) {
166 return success(isa
<arith::AddFOp
, vector::FMAOp
,
167 vector::MultiDimReductionOp
>(op
));
169 populateVectorUnrollPatterns(
170 patterns
, UnrollVectorOptions()
171 .setNativeShape(ArrayRef
<int64_t>{2})
172 .setFilterConstraint([](Operation
*op
) {
173 return success(isa
<vector::ReductionOp
>(op
));
175 populateVectorUnrollPatterns(
176 patterns
, UnrollVectorOptions()
177 .setNativeShape(ArrayRef
<int64_t>{1, 3, 4, 2})
178 .setFilterConstraint([](Operation
*op
) {
179 return success(isa
<vector::TransposeOp
>(op
));
182 if (unrollBasedOnType
) {
183 UnrollVectorOptions::NativeShapeFnType nativeShapeFn
=
184 [](Operation
*op
) -> std::optional
<SmallVector
<int64_t>> {
185 vector::ContractionOp contractOp
= cast
<vector::ContractionOp
>(op
);
186 SmallVector
<int64_t> nativeShape(contractOp
.getIteratorTypes().size(),
188 Type lhsType
= contractOp
.getLhsType().getElementType();
189 nativeShape
[nativeShape
.size() - 1] = lhsType
.isF16() ? 4 : 2;
193 UnrollVectorOptions opts
;
194 opts
.setNativeShapeFn(nativeShapeFn
)
195 .setFilterConstraint(
196 [](Operation
*op
) { return success(isa
<ContractionOp
>(op
)); });
198 if (!unrollOrder
.empty()) {
199 opts
.setUnrollTraversalOrderFn(
200 [this](Operation
*op
) -> std::optional
<SmallVector
<int64_t>> {
201 vector::ContractionOp contractOp
=
202 cast
<vector::ContractionOp
>(op
);
203 if (contractOp
.getIteratorTypes().size() == unrollOrder
.size())
204 return SmallVector
<int64_t>(unrollOrder
.begin(),
209 populateVectorUnrollPatterns(patterns
, opts
);
212 [](Operation
*op
) -> std::optional
<SmallVector
<int64_t>> {
213 auto contractOp
= dyn_cast
<ContractionOp
>(op
);
216 return SmallVector
<int64_t>(contractOp
.getIteratorTypes().size(), 2);
218 populateVectorUnrollPatterns(patterns
,
219 UnrollVectorOptions()
220 .setNativeShapeFn(nativeShapeFn
)
221 .setFilterConstraint([](Operation
*op
) {
222 return success(isa
<ContractionOp
>(op
));
225 populateVectorToVectorCanonicalizationPatterns(patterns
);
226 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
229 ListOption
<int64_t> unrollOrder
{*this, "unroll-order",
230 llvm::cl::desc("set the unroll order")};
232 Option
<bool> unrollBasedOnType
{
233 *this, "unroll-based-on-type",
234 llvm::cl::desc("Set the unroll factor based on type of the operation"),
235 llvm::cl::init(false)};
238 struct TestVectorTransferUnrollingPatterns
239 : public PassWrapper
<TestVectorTransferUnrollingPatterns
,
240 OperationPass
<func::FuncOp
>> {
241 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
242 TestVectorTransferUnrollingPatterns
)
244 TestVectorTransferUnrollingPatterns() = default;
245 TestVectorTransferUnrollingPatterns(
246 const TestVectorTransferUnrollingPatterns
&pass
)
247 : PassWrapper(pass
) {}
249 void getDependentDialects(DialectRegistry
®istry
) const override
{
250 registry
.insert
<affine::AffineDialect
>();
252 StringRef
getArgument() const final
{
253 return "test-vector-transfer-unrolling-patterns";
255 StringRef
getDescription() const final
{
256 return "Test lowering patterns to unroll transfer ops in the vector "
259 void runOnOperation() override
{
260 MLIRContext
*ctx
= &getContext();
261 RewritePatternSet
patterns(ctx
);
262 UnrollVectorOptions opts
;
263 opts
.setNativeShape(ArrayRef
<int64_t>{2, 2})
264 .setFilterConstraint([](Operation
*op
) {
265 return success(isa
<vector::TransferReadOp
, vector::TransferWriteOp
,
266 vector::GatherOp
>(op
));
268 if (reverseUnrollOrder
.getValue()) {
269 opts
.setUnrollTraversalOrderFn(
270 [](Operation
*op
) -> std::optional
<SmallVector
<int64_t>> {
271 int64_t numLoops
= 0;
272 if (auto readOp
= dyn_cast
<vector::TransferReadOp
>(op
))
273 numLoops
= readOp
.getVectorType().getRank();
274 else if (auto writeOp
= dyn_cast
<vector::TransferWriteOp
>(op
))
275 numLoops
= writeOp
.getVectorType().getRank();
276 else if (auto gatherOp
= dyn_cast
<vector::GatherOp
>(op
))
277 numLoops
= gatherOp
.getVectorType().getRank();
280 auto order
= llvm::reverse(llvm::seq
<int64_t>(0, numLoops
));
281 return llvm::to_vector(order
);
284 populateVectorUnrollPatterns(patterns
, opts
);
285 populateVectorToVectorCanonicalizationPatterns(patterns
);
286 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
289 Option
<bool> reverseUnrollOrder
{
290 *this, "reverse-unroll-order",
292 "reverse the order of unrolling of vector transfer operations"),
293 llvm::cl::init(false)};
296 struct TestScalarVectorTransferLoweringPatterns
297 : public PassWrapper
<TestScalarVectorTransferLoweringPatterns
,
298 OperationPass
<func::FuncOp
>> {
299 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
300 TestScalarVectorTransferLoweringPatterns
)
302 TestScalarVectorTransferLoweringPatterns() = default;
303 TestScalarVectorTransferLoweringPatterns(
304 const TestScalarVectorTransferLoweringPatterns
&pass
)
305 : PassWrapper(pass
) {}
307 StringRef
getArgument() const final
{
308 return "test-scalar-vector-transfer-lowering";
310 StringRef
getDescription() const final
{
311 return "Test lowering of scalar vector transfers to memref loads/stores.";
314 void getDependentDialects(DialectRegistry
®istry
) const override
{
315 registry
.insert
<affine::AffineDialect
, memref::MemRefDialect
,
316 tensor::TensorDialect
, vector::VectorDialect
>();
319 Option
<bool> allowMultipleUses
{
320 *this, "allow-multiple-uses",
321 llvm::cl::desc("Fold transfer operations with multiple uses"),
322 llvm::cl::init(false)};
324 void runOnOperation() override
{
325 MLIRContext
*ctx
= &getContext();
326 RewritePatternSet
patterns(ctx
);
327 vector::populateScalarVectorTransferLoweringPatterns(
328 patterns
, /*benefit=*/1, allowMultipleUses
.getValue());
329 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
333 struct TestVectorTransferOpt
334 : public PassWrapper
<TestVectorTransferOpt
, OperationPass
<func::FuncOp
>> {
335 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt
)
337 StringRef
getArgument() const final
{ return "test-vector-transferop-opt"; }
338 StringRef
getDescription() const final
{
339 return "Test optimization transformations for transfer ops";
341 void runOnOperation() override
{
342 IRRewriter
rewriter(&getContext());
343 transferOpflowOpt(rewriter
, getOperation());
347 struct TestVectorTransferCollapseInnerMostContiguousDims
348 : public PassWrapper
<TestVectorTransferCollapseInnerMostContiguousDims
,
349 OperationPass
<func::FuncOp
>> {
350 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
351 TestVectorTransferCollapseInnerMostContiguousDims
)
353 TestVectorTransferCollapseInnerMostContiguousDims() = default;
354 TestVectorTransferCollapseInnerMostContiguousDims(
355 const TestVectorTransferCollapseInnerMostContiguousDims
&pass
) = default;
357 void getDependentDialects(DialectRegistry
®istry
) const override
{
358 registry
.insert
<memref::MemRefDialect
, affine::AffineDialect
>();
361 StringRef
getArgument() const final
{
362 return "test-vector-transfer-collapse-inner-most-dims";
365 StringRef
getDescription() const final
{
366 return "Test lowering patterns that reducedes the rank of the vector "
367 "transfer memory and vector operands.";
370 void runOnOperation() override
{
371 RewritePatternSet
patterns(&getContext());
372 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns
);
373 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
377 struct TestVectorSinkPatterns
378 : public PassWrapper
<TestVectorSinkPatterns
, OperationPass
<func::FuncOp
>> {
379 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns
)
381 TestVectorSinkPatterns() = default;
382 TestVectorSinkPatterns(const TestVectorSinkPatterns
&pass
) = default;
384 void getDependentDialects(DialectRegistry
®istry
) const override
{
385 registry
.insert
<memref::MemRefDialect
, affine::AffineDialect
>();
388 StringRef
getArgument() const final
{ return "test-vector-sink-patterns"; }
390 StringRef
getDescription() const final
{
391 return "Test lowering patterns that eliminate redundant brodacast "
392 "and transpose operations.";
395 void runOnOperation() override
{
396 RewritePatternSet
patterns(&getContext());
397 populateSinkVectorOpsPatterns(patterns
);
398 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
402 struct TestVectorReduceToContractPatternsPatterns
403 : public PassWrapper
<TestVectorReduceToContractPatternsPatterns
,
404 OperationPass
<func::FuncOp
>> {
405 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
406 TestVectorReduceToContractPatternsPatterns
)
408 StringRef
getArgument() const final
{
409 return "test-vector-reduction-to-contract-patterns";
411 StringRef
getDescription() const final
{
412 return "Test patterns to convert multireduce op to contract and combine "
413 "broadcast/transpose to contract";
415 void runOnOperation() override
{
416 RewritePatternSet
patterns(&getContext());
417 populateVectorReductionToContractPatterns(patterns
);
418 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
422 struct TestVectorChainedReductionFoldingPatterns
423 : public PassWrapper
<TestVectorChainedReductionFoldingPatterns
,
424 OperationPass
<func::FuncOp
>> {
425 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
426 TestVectorChainedReductionFoldingPatterns
)
428 StringRef
getArgument() const final
{
429 return "test-vector-chained-reduction-folding-patterns";
431 StringRef
getDescription() const final
{
432 return "Test patterns to fold chained vector reductions";
434 void runOnOperation() override
{
435 RewritePatternSet
patterns(&getContext());
436 populateChainedVectorReductionFoldingPatterns(patterns
);
437 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
441 struct TestVectorBreakDownReductionPatterns
442 : public PassWrapper
<TestVectorBreakDownReductionPatterns
,
443 OperationPass
<func::FuncOp
>> {
444 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
445 TestVectorBreakDownReductionPatterns
)
447 StringRef
getArgument() const final
{
448 return "test-vector-break-down-reduction-patterns";
450 StringRef
getDescription() const final
{
451 return "Test patterns to break down vector reductions into arith "
454 void runOnOperation() override
{
455 RewritePatternSet
patterns(&getContext());
456 populateBreakDownVectorReductionPatterns(patterns
,
457 /*maxNumElementsToExtract=*/2);
458 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
462 struct TestFlattenVectorTransferPatterns
463 : public PassWrapper
<TestFlattenVectorTransferPatterns
,
464 OperationPass
<func::FuncOp
>> {
465 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
466 TestFlattenVectorTransferPatterns
)
468 TestFlattenVectorTransferPatterns() = default;
469 TestFlattenVectorTransferPatterns(
470 const TestFlattenVectorTransferPatterns
&pass
)
471 : PassWrapper(pass
) {}
473 StringRef
getArgument() const final
{
474 return "test-vector-transfer-flatten-patterns";
477 StringRef
getDescription() const final
{
478 return "Test patterns to rewrite contiguous row-major N-dimensional "
479 "vector.transfer_{read,write} ops into 1D transfers";
482 void getDependentDialects(DialectRegistry
®istry
) const override
{
483 registry
.insert
<memref::MemRefDialect
>();
484 registry
.insert
<affine::AffineDialect
>();
485 registry
.insert
<vector::VectorDialect
>();
488 Option
<unsigned> targetVectorBitwidth
{
489 *this, "target-vector-bitwidth",
491 "Minimum vector bitwidth to enable the flattening transformation. "
492 "For scalable vectors this is the base size, i.e. the size "
493 "corresponding to vscale=1."),
494 llvm::cl::init(std::numeric_limits
<unsigned>::max())};
496 void runOnOperation() override
{
497 RewritePatternSet
patterns(&getContext());
498 populateFlattenVectorTransferPatterns(patterns
, targetVectorBitwidth
);
499 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
503 struct TestVectorScanLowering
504 : public PassWrapper
<TestVectorScanLowering
, OperationPass
<func::FuncOp
>> {
505 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering
)
507 StringRef
getArgument() const final
{ return "test-vector-scan-lowering"; }
508 StringRef
getDescription() const final
{
509 return "Test lowering patterns that lower the scan op in the vector "
512 void runOnOperation() override
{
513 RewritePatternSet
patterns(&getContext());
514 populateVectorScanLoweringPatterns(patterns
);
515 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
519 /// Allocate shared memory for a single warp to test lowering of
520 /// WarpExecuteOnLane0Op.
521 static Value
allocateGlobalSharedMemory(Location loc
, OpBuilder
&builder
,
522 WarpExecuteOnLane0Op warpOp
,
524 static constexpr int64_t kSharedMemorySpace
= 3;
525 // Compute type of shared memory buffer.
526 MemRefType memrefType
;
527 if (auto vectorType
= dyn_cast
<VectorType
>(type
)) {
529 MemRefType::get(vectorType
.getShape(), vectorType
.getElementType(), {},
532 memrefType
= MemRefType::get({1}, type
, {}, kSharedMemorySpace
);
535 // Get symbol table holding all shared memory globals.
536 ModuleOp moduleOp
= warpOp
->getParentOfType
<ModuleOp
>();
537 SymbolTable
symbolTable(moduleOp
);
539 // Create a pretty name.
541 llvm::raw_svector_ostream
os(buf
);
542 interleave(memrefType
.getShape(), os
, "x");
543 os
<< "x" << memrefType
.getElementType();
544 std::string symbolName
= (Twine("__shared_") + os
.str()).str();
546 auto ip
= builder
.saveInsertionPoint();
547 builder
.setInsertionPoint(moduleOp
);
548 auto global
= builder
.create
<memref::GlobalOp
>(
550 /*sym_name=*/symbolName
,
551 /*sym_visibility=*/builder
.getStringAttr("private"),
553 /*initial_value=*/Attribute(),
555 /*alignment=*/IntegerAttr());
556 symbolTable
.insert(global
);
557 // The symbol table inserts at the end of the module, but globals are a bit
558 // nicer if they are at the beginning.
559 global
->moveBefore(&moduleOp
.front());
561 builder
.restoreInsertionPoint(ip
);
562 return builder
.create
<memref::GetGlobalOp
>(loc
, memrefType
, symbolName
);
565 static Value
warpReduction(Location loc
, OpBuilder
&builder
, Value input
,
566 CombiningKind kind
, uint32_t size
) {
567 // First reduce on a single thread to get per lane reduction value.
568 Value laneVal
= builder
.create
<vector::ReductionOp
>(loc
, kind
, input
);
569 // Parallel reduction using butterfly shuffles.
570 for (uint64_t i
= 1; i
< size
; i
<<= 1) {
571 Value shuffled
= builder
572 .create
<gpu::ShuffleOp
>(loc
, laneVal
, i
,
574 /*mode=*/gpu::ShuffleMode::XOR
)
576 laneVal
= makeArithReduction(builder
, loc
, kind
, laneVal
, shuffled
);
581 struct TestVectorDistribution
582 : public PassWrapper
<TestVectorDistribution
, OperationPass
<func::FuncOp
>> {
583 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution
)
585 void getDependentDialects(DialectRegistry
®istry
) const override
{
586 registry
.insert
<scf::SCFDialect
, memref::MemRefDialect
, gpu::GPUDialect
,
587 affine::AffineDialect
>();
590 StringRef
getArgument() const final
{ return "test-vector-warp-distribute"; }
591 StringRef
getDescription() const final
{
592 return "Test vector warp distribute transformation and lowering patterns";
594 TestVectorDistribution() = default;
595 TestVectorDistribution(const TestVectorDistribution
&pass
)
596 : PassWrapper(pass
) {}
598 Option
<bool> warpOpToSCF
{
599 *this, "rewrite-warp-ops-to-scf-if",
600 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
601 llvm::cl::init(false)};
603 Option
<bool> distributeTransferWriteOps
{
604 *this, "distribute-transfer-write",
605 llvm::cl::desc("Test distribution of transfer write"),
606 llvm::cl::init(false)};
608 Option
<unsigned> maxTransferWriteElements
{
609 *this, "max-transfer-write-elements",
610 llvm::cl::desc("Maximum number of transfer write elements to distribute"),
613 Option
<bool> hoistUniform
{*this, "hoist-uniform",
614 llvm::cl::desc("Test hoist uniform"),
615 llvm::cl::init(false)};
617 Option
<bool> propagateDistribution
{
618 *this, "propagate-distribution",
619 llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
621 void runOnOperation() override
{
622 RewritePatternSet
patterns(&getContext());
624 getOperation().walk([&](Operation
*op
) {
625 if (auto warpOp
= dyn_cast
<WarpExecuteOnLane0Op
>(op
)) {
627 moveScalarUniformCode(warpOp
);
629 WalkResult::interrupt();
632 MLIRContext
*ctx
= &getContext();
633 auto distributionFn
= [](Value val
) {
634 // Create an identity dim map of the same rank as the vector.
635 VectorType vecType
= dyn_cast
<VectorType
>(val
.getType());
636 int64_t vecRank
= vecType
? vecType
.getRank() : 0;
637 OpBuilder
builder(val
.getContext());
639 return AffineMap::get(val
.getContext());
640 return AffineMap::getMultiDimIdentityMap(vecRank
, val
.getContext());
642 auto shuffleFn
= [](Location loc
, OpBuilder
&builder
, Value val
,
643 Value srcIdx
, int64_t warpSz
) {
644 assert((val
.getType().isF32() || val
.getType().isInteger(32)) &&
645 "unsupported shuffle type");
646 Type i32Type
= builder
.getIntegerType(32);
648 builder
.create
<arith::IndexCastOp
>(loc
, i32Type
, srcIdx
);
649 Value warpSzI32
= builder
.create
<arith::ConstantOp
>(
650 loc
, builder
.getIntegerAttr(i32Type
, warpSz
));
651 Value result
= builder
652 .create
<gpu::ShuffleOp
>(loc
, val
, srcIdxI32
, warpSzI32
,
653 gpu::ShuffleMode::IDX
)
657 if (distributeTransferWriteOps
&& propagateDistribution
) {
658 RewritePatternSet
patterns(ctx
);
659 vector::populatePropagateWarpVectorDistributionPatterns(
660 patterns
, distributionFn
, shuffleFn
, /*benefit=*/1,
662 vector::populateDistributeReduction(patterns
, warpReduction
, 1);
663 populateDistributeTransferWriteOpPatterns(patterns
, distributionFn
, 2);
664 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
665 } else if (distributeTransferWriteOps
) {
666 RewritePatternSet
patterns(ctx
);
667 populateDistributeTransferWriteOpPatterns(patterns
, distributionFn
,
668 maxTransferWriteElements
);
669 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
670 } else if (propagateDistribution
) {
671 RewritePatternSet
patterns(ctx
);
672 vector::populatePropagateWarpVectorDistributionPatterns(
673 patterns
, distributionFn
, shuffleFn
);
674 vector::populateDistributeReduction(patterns
, warpReduction
);
675 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
677 WarpExecuteOnLane0LoweringOptions options
;
678 options
.warpAllocationFn
= allocateGlobalSharedMemory
;
679 options
.warpSyncronizationFn
= [](Location loc
, OpBuilder
&builder
,
680 WarpExecuteOnLane0Op warpOp
) {
681 builder
.create
<gpu::BarrierOp
>(loc
);
683 // Test on one pattern in isolation.
685 populateWarpExecuteOnLane0OpToScfForPattern(patterns
, options
);
686 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
692 struct TestVectorExtractStridedSliceLowering
693 : public PassWrapper
<TestVectorExtractStridedSliceLowering
,
694 OperationPass
<func::FuncOp
>> {
695 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
696 TestVectorExtractStridedSliceLowering
)
698 StringRef
getArgument() const final
{
699 return "test-vector-extract-strided-slice-lowering";
701 StringRef
getDescription() const final
{
702 return "Test lowering patterns that converts vector.extract_strided_slice "
703 "into a chain of vector.extract and vector.insert ops";
705 void runOnOperation() override
{
706 RewritePatternSet
patterns(&getContext());
707 populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns
);
708 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
712 struct TestVectorBreakDownBitCast
713 : public PassWrapper
<TestVectorBreakDownBitCast
,
714 OperationPass
<func::FuncOp
>> {
715 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast
)
717 StringRef
getArgument() const final
{
718 return "test-vector-break-down-bitcast";
720 StringRef
getDescription() const final
{
721 return "Test pattern that breaks down vector.bitcast ops ";
723 void runOnOperation() override
{
724 RewritePatternSet
patterns(&getContext());
725 populateBreakDownVectorBitCastOpPatterns(patterns
, [](BitCastOp op
) {
726 return op
.getSourceVectorType().getShape().back() > 4;
728 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
732 struct TestCreateVectorBroadcast
733 : public PassWrapper
<TestCreateVectorBroadcast
,
734 OperationPass
<func::FuncOp
>> {
735 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast
)
737 StringRef
getArgument() const final
{ return "test-create-vector-broadcast"; }
738 StringRef
getDescription() const final
{
739 return "Test optimization transformations for transfer ops";
741 void getDependentDialects(DialectRegistry
®istry
) const override
{
742 registry
.insert
<vector::VectorDialect
>();
745 void runOnOperation() override
{
746 getOperation()->walk([](Operation
*op
) {
747 if (op
->getName().getStringRef() != "test_create_broadcast")
750 cast
<VectorType
>(op
->getResult(0).getType()).getShape();
752 cast
<DenseI64ArrayAttr
>(op
->getDiscardableAttr("broadcast_dims"))
754 llvm::SetVector
<int64_t> broadcastedDims
;
755 broadcastedDims
.insert(arrayAttr
.begin(), arrayAttr
.end());
757 Value bcast
= vector::BroadcastOp::createOrFoldBroadcastOp(
758 b
, op
->getOperand(0), targetShape
, broadcastedDims
);
759 op
->getResult(0).replaceAllUsesWith(bcast
);
765 struct TestVectorGatherLowering
766 : public PassWrapper
<TestVectorGatherLowering
,
767 OperationPass
<func::FuncOp
>> {
768 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering
)
770 StringRef
getArgument() const final
{ return "test-vector-gather-lowering"; }
771 StringRef
getDescription() const final
{
772 return "Test patterns that lower the gather op in the vector conditional "
775 void getDependentDialects(DialectRegistry
®istry
) const override
{
776 registry
.insert
<arith::ArithDialect
, func::FuncDialect
,
777 memref::MemRefDialect
, scf::SCFDialect
,
778 tensor::TensorDialect
, vector::VectorDialect
>();
781 void runOnOperation() override
{
782 RewritePatternSet
patterns(&getContext());
783 populateVectorGatherLoweringPatterns(patterns
);
784 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
788 struct TestFoldArithExtensionIntoVectorContractPatterns
789 : public PassWrapper
<TestFoldArithExtensionIntoVectorContractPatterns
,
790 OperationPass
<func::FuncOp
>> {
791 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
792 TestFoldArithExtensionIntoVectorContractPatterns
)
794 StringRef
getArgument() const final
{
795 return "test-fold-arith-extf-into-vector-contract-patterns";
797 StringRef
getDescription() const final
{
798 return "Test patterns that fold arithmetic extension ops into vector "
802 void getDependentDialects(DialectRegistry
®istry
) const override
{
803 registry
.insert
<arith::ArithDialect
, func::FuncDialect
, nvgpu::NVGPUDialect
,
804 memref::MemRefDialect
, scf::SCFDialect
,
805 tensor::TensorDialect
, vector::VectorDialect
>();
808 void runOnOperation() override
{
809 RewritePatternSet
patterns(&getContext());
810 populateFoldArithExtensionPatterns(patterns
);
811 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
815 struct TestVectorEmulateMaskedLoadStore final
816 : public PassWrapper
<TestVectorEmulateMaskedLoadStore
,
817 OperationPass
<func::FuncOp
>> {
818 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore
)
820 StringRef
getArgument() const override
{
821 return "test-vector-emulate-masked-load-store";
823 StringRef
getDescription() const override
{
824 return "Test patterns that emulate the maskedload/maskedstore op by "
825 " memref.load/store and scf.if";
827 void getDependentDialects(DialectRegistry
®istry
) const override
{
829 .insert
<arith::ArithDialect
, func::FuncDialect
, memref::MemRefDialect
,
830 scf::SCFDialect
, vector::VectorDialect
>();
833 void runOnOperation() override
{
834 RewritePatternSet
patterns(&getContext());
835 populateVectorMaskedLoadStoreEmulationPatterns(patterns
);
836 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
840 struct TestVectorLinearize final
841 : public PassWrapper
<TestVectorLinearize
, OperationPass
<>> {
842 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize
)
844 TestVectorLinearize() = default;
845 TestVectorLinearize(const TestVectorLinearize
&pass
) : PassWrapper(pass
) {}
847 StringRef
getArgument() const override
{ return "test-vector-linearize"; }
848 StringRef
getDescription() const override
{
849 return "Linearizes ND vectors for N >= 2 into 1D vectors";
851 void getDependentDialects(DialectRegistry
®istry
) const override
{
852 registry
.insert
<vector::VectorDialect
>();
855 Option
<unsigned> targetVectorBitwidth
{
856 *this, "target-vector-bitwidth",
858 "Minimum vector bitwidth to enable the flattening transformation"),
859 llvm::cl::init(std::numeric_limits
<unsigned>::max())};
860 void runOnOperation() override
{
861 auto *context
= &getContext();
863 TypeConverter typeConverter
;
864 RewritePatternSet
patterns(context
);
865 ConversionTarget
target(*context
);
867 vector::populateVectorLinearizeTypeConversionsAndLegality(
868 typeConverter
, patterns
, target
, targetVectorBitwidth
);
869 vector::populateVectorLinearizeShuffleLikeOpsPatterns(
870 typeConverter
, patterns
, target
, targetVectorBitwidth
);
871 if (failed(applyPartialConversion(getOperation(), target
,
872 std::move(patterns
))))
873 return signalPassFailure();
877 struct TestEliminateVectorMasks
878 : public PassWrapper
<TestEliminateVectorMasks
,
879 OperationPass
<func::FuncOp
>> {
880 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks
)
882 TestEliminateVectorMasks() = default;
883 TestEliminateVectorMasks(const TestEliminateVectorMasks
&pass
)
884 : PassWrapper(pass
) {}
886 Option
<unsigned> vscaleMin
{
887 *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
889 Option
<unsigned> vscaleMax
{
890 *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
893 StringRef
getArgument() const final
{ return "test-eliminate-vector-masks"; }
894 StringRef
getDescription() const final
{
895 return "Test eliminating vector masks";
897 void runOnOperation() override
{
898 IRRewriter
rewriter(&getContext());
899 eliminateVectorMasks(rewriter
, getOperation(),
900 VscaleRange
{vscaleMin
, vscaleMax
});
907 void registerTestVectorLowerings() {
908 PassRegistration
<TestVectorToVectorLowering
>();
910 PassRegistration
<TestVectorContractionPrepareForMMTLowering
>();
912 PassRegistration
<TestVectorUnrollingPatterns
>();
914 PassRegistration
<TestVectorTransferUnrollingPatterns
>();
916 PassRegistration
<TestScalarVectorTransferLoweringPatterns
>();
918 PassRegistration
<TestVectorTransferOpt
>();
920 PassRegistration
<TestVectorTransferCollapseInnerMostContiguousDims
>();
922 PassRegistration
<TestVectorSinkPatterns
>();
924 PassRegistration
<TestVectorReduceToContractPatternsPatterns
>();
926 PassRegistration
<TestVectorChainedReductionFoldingPatterns
>();
928 PassRegistration
<TestVectorBreakDownReductionPatterns
>();
930 PassRegistration
<TestFlattenVectorTransferPatterns
>();
932 PassRegistration
<TestVectorScanLowering
>();
934 PassRegistration
<TestVectorDistribution
>();
936 PassRegistration
<TestVectorExtractStridedSliceLowering
>();
938 PassRegistration
<TestVectorBreakDownBitCast
>();
940 PassRegistration
<TestCreateVectorBroadcast
>();
942 PassRegistration
<TestVectorGatherLowering
>();
944 PassRegistration
<TestFoldArithExtensionIntoVectorContractPatterns
>();
946 PassRegistration
<TestVectorEmulateMaskedLoadStore
>();
948 PassRegistration
<TestVectorLinearize
>();
950 PassRegistration
<TestEliminateVectorMasks
>();