1 //===- TestControlFlowSink.cpp - Test control-flow sink pass --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass tests the control-flow sink utilities by implementing an example
10 // control-flow sink pass.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/IR/Dominance.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/ControlFlowSinkUtils.h"
22 /// An example control-flow sink pass to test the control-flow sink utilites.
23 /// This pass will sink ops named `test.sink_me` and tag them with an attribute
24 /// `was_sunk` into the first region of `test.sink_target` ops.
25 struct TestControlFlowSinkPass
26 : public PassWrapper
<TestControlFlowSinkPass
, OperationPass
<func::FuncOp
>> {
27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestControlFlowSinkPass
)
29 /// Get the command-line argument of the test pass.
30 StringRef
getArgument() const final
{ return "test-control-flow-sink"; }
31 /// Get the description of the test pass.
32 StringRef
getDescription() const final
{
33 return "Test control-flow sink pass";
36 /// Runs the pass on the function.
37 void runOnOperation() override
{
38 auto &domInfo
= getAnalysis
<DominanceInfo
>();
39 auto shouldMoveIntoRegion
= [](Operation
*op
, Region
*region
) {
40 return region
->getRegionNumber() == 0 &&
41 op
->getName().getStringRef() == "test.sink_me";
43 auto moveIntoRegion
= [](Operation
*op
, Region
*region
) {
44 Block
&entry
= region
->front();
45 op
->moveBefore(&entry
, entry
.begin());
46 op
->setAttr("was_sunk",
47 Builder(op
).getI32IntegerAttr(region
->getRegionNumber()));
50 getOperation()->walk([&](Operation
*op
) {
51 if (op
->getName().getStringRef() != "test.sink_target")
53 SmallVector
<Region
*> regions
=
54 llvm::to_vector(RegionRange(op
->getRegions()));
55 controlFlowSink(regions
, domInfo
, shouldMoveIntoRegion
, moveIntoRegion
);
59 } // end anonymous namespace
63 void registerTestControlFlowSink() {
64 PassRegistration
<TestControlFlowSinkPass
>();
66 } // end namespace test
67 } // end namespace mlir