1 //===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the passes to test wrap-in-zero-trip-check transforms on
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 struct TestWrapWhileLoopInZeroTripCheckPass
27 : public PassWrapper
<TestWrapWhileLoopInZeroTripCheckPass
,
28 OperationPass
<func::FuncOp
>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
30 TestWrapWhileLoopInZeroTripCheckPass
)
32 StringRef
getArgument() const final
{
33 return "test-wrap-scf-while-loop-in-zero-trip-check";
36 StringRef
getDescription() const final
{
37 return "test scf::wrapWhileLoopInZeroTripCheck";
40 TestWrapWhileLoopInZeroTripCheckPass() = default;
41 TestWrapWhileLoopInZeroTripCheckPass(
42 const TestWrapWhileLoopInZeroTripCheckPass
&) {}
43 explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam
) {
44 forceCreateCheck
= forceCreateCheckParam
;
47 void runOnOperation() override
{
48 func::FuncOp func
= getOperation();
49 MLIRContext
*context
= &getContext();
50 IRRewriter
rewriter(context
);
51 if (forceCreateCheck
) {
52 func
.walk([&](scf::WhileOp op
) {
53 FailureOr
<scf::WhileOp
> result
=
54 scf::wrapWhileLoopInZeroTripCheck(op
, rewriter
, forceCreateCheck
);
55 // Ignore not implemented failure in tests. The expected output should
56 // catch problems (e.g. transformation doesn't happen).
60 RewritePatternSet
patterns(context
);
61 scf::populateSCFRotateWhileLoopPatterns(patterns
);
62 (void)applyPatternsAndFoldGreedily(func
, std::move(patterns
));
66 Option
<bool> forceCreateCheck
{
67 *this, "force-create-check",
68 llvm::cl::desc("Force to create zero-trip-check."),
69 llvm::cl::init(false)};
76 void registerTestSCFWrapInZeroTripCheckPasses() {
77 PassRegistration
<TestWrapWhileLoopInZeroTripCheckPass
>();