[libc++abi] Build cxxabi with sanitizers (#119612)
[llvm-project.git] / mlir / test / lib / Dialect / Tensor / TestTensorTransforms.cpp
blob34de600132f5dec637bee0fa8a9c5eda326b4b03
1 //===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Tensor transformations.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19 #include "mlir/Dialect/Transform/IR/TransformOps.h"
20 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 using namespace mlir;
26 namespace {
27 struct TestTensorTransforms
28 : public PassWrapper<TestTensorTransforms, OperationPass<>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms)
31 TestTensorTransforms() = default;
32 TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
34 void getDependentDialects(DialectRegistry &registry) const override {
35 registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
36 transform::TransformDialect>();
39 StringRef getArgument() const final {
40 return "test-tensor-transform-patterns";
42 StringRef getDescription() const final {
43 return "Test Tensor transformation patterns by applying them greedily.";
46 void runOnOperation() override;
48 Option<bool> testFoldConstantExtractSlice{
49 *this, "test-fold-constant-extract-slice",
50 llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
51 llvm::cl::init(false)};
53 Option<bool> testFoldConsecutiveInsertExtractSlice{
54 *this, "test-fold-consecutive-insert-extract-slice",
55 llvm::cl::desc(
56 "Test folding consecutive tensor.insert_slice/tensor.extract_slice"),
57 llvm::cl::init(false)};
59 Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
60 *this, "test-rewrite-extract-slice-from-collapse-shape",
61 llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
62 "with loop nest"),
63 llvm::cl::init(false)};
65 Option<bool> testDropRedundantInsertSliceRankExpansion{
66 *this, "test-drop-redundant-insert-slice-rank-expansion",
67 llvm::cl::desc("Test dropping redundant insert_slice rank expansions"),
68 llvm::cl::init(false)};
70 Option<bool> testReassociativeReshapeFolding{
71 *this, "test-reassociative-reshape-folding",
72 llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
73 llvm::cl::init(false)};
75 Option<bool> testBubbleUpExpandShapePatterns{
76 *this, "test-expand-shape-bubbling",
77 llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
78 llvm::cl::init(false)};
80 Option<bool> testFoldIntoPackAndUnpack{
81 *this, "test-fold-into-pack-and-unpack",
82 llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
83 llvm::cl::init(false)};
85 Option<bool> useForeach{
86 *this, "use-foreach",
87 llvm::cl::desc(
88 "Use the scf.forall operation when generating loop nests for "
89 "the extract_slice of collapse_shape pattern"),
90 llvm::cl::init(false)};
92 Option<bool> testSimplifyPackUnpackPatterns{
93 *this, "test-simplify-pack-unpack-patterns",
94 llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack"),
95 llvm::cl::init(false)};
97 Option<bool> testTrackingListener{
98 *this, "test-tracking-listener",
99 llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
100 llvm::cl::init(false)};
102 } // namespace
104 static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
105 RewritePatternSet patterns(rootOp->getContext());
106 tensor::populateReassociativeReshapeFoldingPatterns(patterns);
107 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
110 static void applyBubbleUpExpandShapePatterns(Operation *rootOp) {
111 RewritePatternSet patterns(rootOp->getContext());
112 tensor::populateBubbleUpExpandShapePatterns(patterns);
113 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
116 static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
117 RewritePatternSet patterns(rootOp->getContext());
118 tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
119 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
122 static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
123 RewritePatternSet patterns(rootOp->getContext());
124 tensor::ControlConstantExtractSliceFusionFn controlFn =
125 [](tensor::ExtractSliceOp op) {
126 if (!op.getSource().hasOneUse())
127 return false;
129 auto resultType = cast<ShapedType>(op.getResult().getType());
130 constexpr int64_t kConstantFoldingMaxNumElements = 1024;
131 return resultType.getNumElements() <= kConstantFoldingMaxNumElements;
134 tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
135 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
138 static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
139 RewritePatternSet patterns(rootOp->getContext());
140 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
141 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
144 static void
145 applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
146 RewritePatternSet patterns(rootOp->getContext());
147 tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
148 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
151 static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
152 RewritePatternSet patterns(rootOp->getContext());
153 tensor::populateSimplifyPackAndUnpackPatterns(patterns);
154 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
157 namespace {
158 /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
159 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
160 /// stitches together the desired tile from slices of the source of the collapse
161 /// shape op.
162 struct RewriteExtractSliceFromCollapseShapeBase
163 : public OpRewritePattern<tensor::ExtractSliceOp> {
164 RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context)
165 : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {}
167 /// Emit a loop or gather operation that uses `helper` to take each point in
168 /// the parallel iteration space bounds, extract a slice from the source
169 /// tensor and insert it into `dest`. For examples, see below for `scf.for`
170 /// and `scf.foreach`.
171 virtual LogicalResult
172 emitReplacement(tensor::ExtractSliceOp op, Value dest,
173 tensor::ExtractSliceFromCollapseHelper &helper,
174 PatternRewriter &rewriter) const = 0;
176 LogicalResult matchAndRewrite(tensor::ExtractSliceOp op,
177 PatternRewriter &rewriter) const override {
178 auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>();
179 if (!collapseOp)
180 return rewriter.notifyMatchFailure(
181 op, "producer is not a tensor.collapse_shape op");
183 // Try to simplify the collapse shape using a rank-reducing slice, if
184 // possible.
185 FailureOr<Operation *> simplifiedCollapseShapeResult =
186 tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp,
187 rewriter);
188 if (succeeded(simplifiedCollapseShapeResult)) {
189 auto newCollapseOp =
190 dyn_cast<tensor::CollapseShapeOp>(*simplifiedCollapseShapeResult);
191 // The collapse shape op might have been simplified away, so we can just
192 // return.
193 if (!newCollapseOp)
194 return success();
195 collapseOp = newCollapseOp;
198 // Materialize the output shape values of the slice operation.
199 ReifiedRankedShapedTypeDims reifiedShapes;
200 if (failed(reifyResultShapes(rewriter, op, reifiedShapes)))
201 return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
203 // Create the destination tensor using the above values.
204 Type elementType = op.getSourceType().getElementType();
205 SmallVector<OpFoldResult> outputShape = reifiedShapes[0];
206 Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape,
207 elementType);
209 // Calculate the parameters for the tile loop nest.
210 FailureOr<tensor::ExtractSliceFromCollapseHelper> params =
211 tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp,
212 op);
213 if (failed(params))
214 return rewriter.notifyMatchFailure(
215 op, "could not calculate tiling parameters");
216 return emitReplacement(op, dest, *params, rewriter);
220 struct RewriteExtractSliceFromCollapseShapeUsingScfFor
221 : public RewriteExtractSliceFromCollapseShapeBase {
222 RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context)
223 : RewriteExtractSliceFromCollapseShapeBase(context) {}
224 LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
225 tensor::ExtractSliceFromCollapseHelper &helper,
226 PatternRewriter &rewriter) const override {
227 Location loc = op.getLoc();
228 const unsigned numTiledDims = helper.getIterationSpaceSizes().size();
229 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
230 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
231 SmallVector<Value> lbs(numTiledDims, zero);
232 SmallVector<Value> steps(numTiledDims, one);
234 scf::LoopNest nest = scf::buildLoopNest(
235 rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest,
236 [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
237 ValueRange iterArgs) -> scf::ValueVector {
238 auto [tile, insertParams] =
239 helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
241 // Insert the slice into the destination.
242 return {nestedBuilder.create<tensor::InsertSliceOp>(
243 loc, tile, iterArgs[0], insertParams)};
245 rewriter.replaceOp(op, nest.results);
247 return success();
251 struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
252 : public RewriteExtractSliceFromCollapseShapeBase {
253 RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context)
254 : RewriteExtractSliceFromCollapseShapeBase(context) {}
255 LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
256 tensor::ExtractSliceFromCollapseHelper &helper,
257 PatternRewriter &rewriter) const override {
258 Location loc = op.getLoc();
259 auto forallOp = rewriter.create<scf::ForallOp>(
260 loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()),
261 /*outputs=*/dest,
262 /*mapping=*/std::nullopt,
263 [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
264 unsigned numThreadIdRegionArgs =
265 helper.getIterationSpaceSizes().size();
266 unsigned numOutputRegionArgs =
267 regionArgs.size() - numThreadIdRegionArgs;
268 ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs);
269 ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs);
270 assert(outputArgs.size() == 1 &&
271 "there should only be one output region argument");
272 auto [tile, insertParams] =
273 helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
274 // Insert the slice into the destination.
275 auto term = nestedBuilder.create<scf::InParallelOp>(loc);
276 nestedBuilder.setInsertionPointToStart(term.getBody());
277 nestedBuilder.create<tensor::ParallelInsertSliceOp>(
278 loc, tile, outputArgs[0], insertParams);
280 rewriter.replaceOp(op, forallOp->getResult(0));
281 return success();
284 } // namespace
286 static LogicalResult
287 applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
288 bool useForeach) {
289 RewritePatternSet patterns(rootOp->getContext());
290 if (useForeach)
291 patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>(
292 rootOp->getContext());
293 else
294 patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>(
295 rootOp->getContext());
296 return applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
299 namespace {
300 class DummyTrackingListener : public transform::TrackingListener {
301 public:
302 using transform::TrackingListener::TrackingListener;
304 // Expose `findReplacementOp` as a public function, so that it can be tested.
305 Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
306 Operation *replacementOp;
307 if (!findReplacementOp(replacementOp, op, newValues).succeeded())
308 return nullptr;
309 return replacementOp;
312 } // namespace
314 static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
315 // Find replaced op.
316 Operation *replaced = nullptr;
317 WalkResult status = rootOp->walk([&](Operation *op) {
318 if (op->hasAttr("replaced")) {
319 if (replaced) {
320 op->emitError("only one 'replaced' op is allowed per test case");
321 replaced->emitRemark("other 'replaced' op");
322 return WalkResult::interrupt();
324 replaced = op;
326 return WalkResult::advance();
328 if (status.wasInterrupted())
329 return failure();
330 if (!replaced) {
331 rootOp->emitError("could not find 'replaced' op");
332 return failure();
335 // Find replacements.
336 SmallVector<Value> replacements(replaced->getNumResults(), Value());
337 status = rootOp->walk([&](Operation *op) {
338 for (int64_t i = 0; i < replaced->getNumResults(); ++i) {
339 if (auto attr = op->getAttrOfType<IntegerAttr>("replacement_" +
340 std::to_string(i))) {
341 if (replacements[i]) {
342 op->emitError("only one 'replacement_" + std::to_string(i) +
343 "' is allowed per test case");
344 replacements[i].getDefiningOp()->emitRemark("other 'replacement_" +
345 std::to_string(i) + "'");
346 return WalkResult::interrupt();
348 replacements[i] = op->getResult(attr.getInt());
351 return WalkResult::advance();
353 if (status.wasInterrupted())
354 return failure();
356 if (!llvm::all_of(replacements,
357 [](Value v) { return static_cast<bool>(v); })) {
358 replaced->emitError("insufficient replacement values");
359 return failure();
362 // Find the replacement op (if any) and emit a remark/error.
363 transform::TransformState transformState =
364 transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
365 /*payloadRoot=*/nullptr);
366 MLIRContext *context = rootOp->getContext();
367 OpBuilder builder(context);
368 OwningOpRef<transform::NamedSequenceOp> transformOp =
369 builder.create<transform::NamedSequenceOp>(
370 rootOp->getLoc(),
371 /*sym_name=*/"test_sequence",
372 /*function_type=*/
373 TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
374 /*sym_visibility*/ StringAttr::get(context, "public"),
375 /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
376 /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
377 DummyTrackingListener listener(transformState, transformOp.get());
378 Operation *replacement = listener.getReplacementOp(replaced, replacements);
379 if (!replacement) {
380 replaced->emitError("listener could not find replacement op");
381 return failure();
384 replacement->emitRemark("replacement found");
385 return success();
388 void TestTensorTransforms::runOnOperation() {
389 Operation *rootOp = getOperation();
390 if (testSimplifyPackUnpackPatterns)
391 applySimplifyPackUnpackPatterns(rootOp);
392 if (testFoldConstantExtractSlice)
393 applyFoldConstantExtractSlicePatterns(rootOp);
394 if (testFoldConsecutiveInsertExtractSlice)
395 applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
396 if (testDropRedundantInsertSliceRankExpansion)
397 applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
398 if (testReassociativeReshapeFolding)
399 applyReassociativeReshapeFoldingPatterns(rootOp);
400 if (testBubbleUpExpandShapePatterns)
401 applyBubbleUpExpandShapePatterns(rootOp);
402 if (testFoldIntoPackAndUnpack)
403 applyFoldIntoPackAndUnpackPatterns(rootOp);
404 if (testRewriteExtractSliceWithTiledCollapseShape) {
405 if (failed(
406 applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
407 return signalPassFailure();
409 if (testTrackingListener)
410 if (failed(testTrackingListenerReplacements(rootOp)))
411 return signalPassFailure();
414 namespace mlir {
415 namespace test {
416 void registerTestTensorTransforms() {
417 PassRegistration<TestTensorTransforms>();
419 } // namespace test
420 } // namespace mlir