Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / SCF / TestSCFWrapInZeroTripCheck.cpp
blob7e51d67702b050dc539f67abea53cf7166de19de
1 //===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the passes to test wrap-in-zero-trip-check transforms on
10 // SCF loop ops.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 using namespace mlir;
24 namespace {
26 struct TestWrapWhileLoopInZeroTripCheckPass
27 : public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass,
28 OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
30 TestWrapWhileLoopInZeroTripCheckPass)
32 StringRef getArgument() const final {
33 return "test-wrap-scf-while-loop-in-zero-trip-check";
36 StringRef getDescription() const final {
37 return "test scf::wrapWhileLoopInZeroTripCheck";
40 TestWrapWhileLoopInZeroTripCheckPass() = default;
41 TestWrapWhileLoopInZeroTripCheckPass(
42 const TestWrapWhileLoopInZeroTripCheckPass &) {}
43 explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) {
44 forceCreateCheck = forceCreateCheckParam;
47 void runOnOperation() override {
48 func::FuncOp func = getOperation();
49 MLIRContext *context = &getContext();
50 IRRewriter rewriter(context);
51 if (forceCreateCheck) {
52 func.walk([&](scf::WhileOp op) {
53 FailureOr<scf::WhileOp> result =
54 scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
55 // Ignore not implemented failure in tests. The expected output should
56 // catch problems (e.g. transformation doesn't happen).
57 (void)result;
58 });
59 } else {
60 RewritePatternSet patterns(context);
61 scf::populateSCFRotateWhileLoopPatterns(patterns);
62 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
66 Option<bool> forceCreateCheck{
67 *this, "force-create-check",
68 llvm::cl::desc("Force to create zero-trip-check."),
69 llvm::cl::init(false)};
72 } // namespace
74 namespace mlir {
75 namespace test {
76 void registerTestSCFWrapInZeroTripCheckPasses() {
77 PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>();
79 } // namespace test
80 } // namespace mlir