1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements logic for testing Linalg transformations.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/SmallVector.h"
30 using namespace mlir::linalg
;
33 struct TestLinalgTransforms
34 : public PassWrapper
<TestLinalgTransforms
, OperationPass
<func::FuncOp
>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms
)
37 TestLinalgTransforms() = default;
38 TestLinalgTransforms(const TestLinalgTransforms
&pass
) : PassWrapper(pass
) {}
40 void getDependentDialects(DialectRegistry
®istry
) const override
{
42 registry
.insert
<affine::AffineDialect
,
43 bufferization::BufferizationDialect
,
44 memref::MemRefDialect
,
46 linalg::LinalgDialect
,
47 vector::VectorDialect
,
51 StringRef
getArgument() const final
{
52 return "test-linalg-transform-patterns";
54 StringRef
getDescription() const final
{
55 return "Test Linalg transformation patterns by applying them greedily.";
58 void runOnOperation() override
;
60 Option
<bool> testPatterns
{*this, "test-patterns",
61 llvm::cl::desc("Test a mixed set of patterns"),
62 llvm::cl::init(false)};
63 Option
<bool> testVectorTransferForwardingPatterns
{
64 *this, "test-vector-transfer-forwarding-patterns",
66 "Test a fused pass that forwards memref.copy to vector.transfer"),
67 llvm::cl::init(false)};
68 Option
<bool> testGenericToVectorPattern
{
69 *this, "test-linalg-to-vector-patterns",
70 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
71 "in vector.contract form"),
72 llvm::cl::init(false)};
73 Option
<bool> testDecomposePadTensor
{
74 *this, "test-decompose-pad-tensor",
75 llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
76 llvm::cl::init(false)};
77 Option
<bool> testDecomposeTensorPackOp
{
78 *this, "test-decompose-tensor-pack",
79 llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
80 "of tensor and Linalg ops"),
81 llvm::cl::init(false)};
82 Option
<bool> testDecomposeTensorUnPackOp
{
83 *this, "test-decompose-tensor-unpack",
85 "Test transform that generalizes unpack ops into a sequence "
86 "of tensor and Linalg ops"),
87 llvm::cl::init(false)};
88 Option
<bool> testSwapSubTensorPadTensor
{
89 *this, "test-swap-subtensor-padtensor",
90 llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
91 "tensor.pad(subtensor)"),
92 llvm::cl::init(false)};
93 ListOption
<int64_t> peeledLoops
{
94 *this, "peeled-loops",
95 llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
96 ListOption
<int64_t> tileSizes
{
98 llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
99 Option
<bool> skipPartial
{
100 *this, "skip-partial",
101 llvm::cl::desc("Skip loops inside partial iterations during peeling"),
102 llvm::cl::init(false)};
103 Option
<std::string
> loopType
{
105 llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
107 llvm::cl::init("for")};
108 Option
<bool> testBubbleUpExtractSliceOpPattern
{
109 *this, "test-bubble-up-extract-slice-op-pattern",
110 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
111 "extract_slice + linalgOp"),
112 llvm::cl::init(false)};
113 Option
<bool> testSwapExtractSliceWithFill
{
114 *this, "test-swap-extract-slice-with-fill-pattern",
116 "Test patterns to swap tensor.extract_slice(linalg.fill())"),
117 llvm::cl::init(false)};
118 Option
<bool> testEraseUnusedOperandsAndResults
{
119 *this, "test-erase-unused-operands-and-results",
120 llvm::cl::desc("Test patterns to erase unused operands and results"),
121 llvm::cl::init(false)};
122 Option
<bool> testEraseUnnecessaryInputs
{
123 *this, "test-erase-unnecessary-inputs",
124 llvm::cl::desc("Test patterns to erase unnecessary inputs"),
125 llvm::cl::init(false)};
126 Option
<bool> testWinogradConv2D
{
127 *this, "test-winograd-conv2d",
128 llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
129 llvm::cl::init(false)};
130 Option
<bool> testDecomposeWinogradOps
{
131 *this, "test-decompose-winograd-ops",
132 llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
136 static void applyPatterns(func::FuncOp funcOp
) {
137 MLIRContext
*ctx
= funcOp
.getContext();
138 RewritePatternSet
patterns(ctx
);
140 //===--------------------------------------------------------------------===//
141 // Linalg distribution patterns.
142 //===--------------------------------------------------------------------===//
143 LinalgLoopDistributionOptions distributionOptions
;
145 //===--------------------------------------------------------------------===//
146 // Linalg to vector contraction patterns.
147 //===--------------------------------------------------------------------===//
148 patterns
.add
<CopyVectorizationPattern
>(ctx
);
150 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
153 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp
) {
154 RewritePatternSet
forwardPattern(funcOp
.getContext());
155 forwardPattern
.add
<LinalgCopyVTRForwardingPattern
>(funcOp
.getContext());
156 forwardPattern
.add
<LinalgCopyVTWForwardingPattern
>(funcOp
.getContext());
157 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(forwardPattern
));
160 static void applyLinalgToVectorPatterns(func::FuncOp funcOp
) {
161 RewritePatternSet
patterns(funcOp
.getContext());
162 auto *ctx
= funcOp
.getContext();
163 patterns
.add
<CopyVectorizationPattern
>(ctx
);
164 populatePadOpVectorizationPatterns(patterns
);
165 populateConvolutionVectorizationPatterns(patterns
);
166 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
169 static void applyDecomposePadPatterns(func::FuncOp funcOp
) {
170 RewritePatternSet
patterns(funcOp
.getContext());
171 patterns
.add
<DecomposePadOpPattern
>(funcOp
.getContext());
172 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
175 static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp
) {
176 RewritePatternSet
patterns(funcOp
.getContext());
177 patterns
.add
<DecomposeOuterUnitDimsPackOpPattern
>(funcOp
.getContext());
178 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
181 static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp
) {
182 RewritePatternSet
patterns(funcOp
.getContext());
183 patterns
.add
<DecomposeOuterUnitDimsUnPackOpPattern
>(funcOp
.getContext());
184 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
187 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp
) {
188 RewritePatternSet
patterns(funcOp
.getContext());
189 patterns
.add
<ExtractSliceOfPadTensorSwapPattern
>(funcOp
.getContext());
190 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
193 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp
) {
194 RewritePatternSet
patterns(funcOp
.getContext());
195 populateBubbleUpExtractSliceOpPatterns(patterns
);
196 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
199 static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp
) {
200 RewritePatternSet
patterns(funcOp
.getContext());
201 populateSwapExtractSliceWithFillPatterns(patterns
);
202 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
205 static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp
) {
206 RewritePatternSet
patterns(funcOp
.getContext());
207 populateEraseUnusedOperandsAndResultsPatterns(patterns
);
208 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
211 static void applyEraseUnnecessaryInputs(func::FuncOp funcOp
) {
212 RewritePatternSet
patterns(funcOp
.getContext());
213 populateEraseUnnecessaryInputsPatterns(patterns
);
214 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
217 static void applyWinogradConv2D(func::FuncOp funcOp
) {
218 RewritePatternSet
patterns(funcOp
.getContext());
219 populateWinogradConv2DPatterns(patterns
, /*m=*/4, /*r=*/3);
220 populateWinogradConv2DPatterns(patterns
, /*m=*/2, /*r=*/5);
221 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
224 static void applyDecomposeWinogradOps(func::FuncOp funcOp
) {
225 RewritePatternSet
patterns(funcOp
.getContext());
226 populateDecomposeWinogradOpsPatterns(patterns
);
227 (void)applyPatternsAndFoldGreedily(funcOp
, std::move(patterns
));
230 /// Apply transformations specified as patterns.
231 void TestLinalgTransforms::runOnOperation() {
233 return applyPatterns(getOperation());
234 if (testVectorTransferForwardingPatterns
)
235 return applyVectorTransferForwardingPatterns(getOperation());
236 if (testGenericToVectorPattern
)
237 return applyLinalgToVectorPatterns(getOperation());
238 if (testDecomposePadTensor
)
239 return applyDecomposePadPatterns(getOperation());
240 if (testDecomposeTensorPackOp
)
241 return applyDecomposeTensorPackPatterns(getOperation());
242 if (testDecomposeTensorUnPackOp
)
243 return applyDecomposeTensorUnPackPatterns(getOperation());
244 if (testSwapSubTensorPadTensor
)
245 return applyExtractSliceOfPadTensorSwapPattern(getOperation());
246 if (testBubbleUpExtractSliceOpPattern
)
247 return applyBubbleUpExtractSliceOpPattern(getOperation());
248 if (testSwapExtractSliceWithFill
)
249 return applySwapExtractSliceWithFillPattern(getOperation());
250 if (testEraseUnusedOperandsAndResults
)
251 return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
252 if (testEraseUnnecessaryInputs
)
253 return applyEraseUnnecessaryInputs(getOperation());
254 if (testWinogradConv2D
)
255 return applyWinogradConv2D(getOperation());
256 if (testDecomposeWinogradOps
)
257 return applyDecomposeWinogradOps(getOperation());
262 void registerTestLinalgTransforms() {
263 PassRegistration
<TestLinalgTransforms
>();