[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / lib / Transforms / TestControlFlowSink.cpp
blobad34b6c2ffdf8b852c5d12359595a5a6cc74ab04
1 //===- TestControlFlowSink.cpp - Test control-flow sink pass --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass tests the control-flow sink utilities by implementing an example
10 // control-flow sink pass.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/IR/Dominance.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/ControlFlowSinkUtils.h"
19 using namespace mlir;
21 namespace {
22 /// An example control-flow sink pass to test the control-flow sink utilites.
23 /// This pass will sink ops named `test.sink_me` and tag them with an attribute
24 /// `was_sunk` into the first region of `test.sink_target` ops.
25 struct TestControlFlowSinkPass
26 : public PassWrapper<TestControlFlowSinkPass, OperationPass<func::FuncOp>> {
27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestControlFlowSinkPass)
29 /// Get the command-line argument of the test pass.
30 StringRef getArgument() const final { return "test-control-flow-sink"; }
31 /// Get the description of the test pass.
32 StringRef getDescription() const final {
33 return "Test control-flow sink pass";
36 /// Runs the pass on the function.
37 void runOnOperation() override {
38 auto &domInfo = getAnalysis<DominanceInfo>();
39 auto shouldMoveIntoRegion = [](Operation *op, Region *region) {
40 return region->getRegionNumber() == 0 &&
41 op->getName().getStringRef() == "test.sink_me";
43 auto moveIntoRegion = [](Operation *op, Region *region) {
44 Block &entry = region->front();
45 op->moveBefore(&entry, entry.begin());
46 op->setAttr("was_sunk",
47 Builder(op).getI32IntegerAttr(region->getRegionNumber()));
50 getOperation()->walk([&](Operation *op) {
51 if (op->getName().getStringRef() != "test.sink_target")
52 return;
53 SmallVector<Region *> regions =
54 llvm::to_vector(RegionRange(op->getRegions()));
55 controlFlowSink(regions, domInfo, shouldMoveIntoRegion, moveIntoRegion);
56 });
59 } // end anonymous namespace
61 namespace mlir {
62 namespace test {
63 void registerTestControlFlowSink() {
64 PassRegistration<TestControlFlowSinkPass>();
66 } // end namespace test
67 } // end namespace mlir