1 //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
10 // run any optimization pass within it and only replaces the output module with
11 // the transformed version if it is smaller and interesting.
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Pass/PassRegistry.h"
17 #include "mlir/Reducer/Passes.h"
18 #include "mlir/Reducer/Tester.h"
19 #include "llvm/Support/Debug.h"
22 #define GEN_PASS_DEF_OPTREDUCTION
23 #include "mlir/Reducer/Passes.h.inc"
26 #define DEBUG_TYPE "mlir-reduce"
32 class OptReductionPass
: public impl::OptReductionBase
<OptReductionPass
> {
34 /// Runs the pass instance in the pass pipeline.
35 void runOnOperation() override
;
40 /// Runs the pass instance in the pass pipeline.
41 void OptReductionPass::runOnOperation() {
42 LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");
44 Tester
test(testerName
, testerArgs
);
46 ModuleOp module
= this->getOperation();
47 ModuleOp moduleVariant
= module
.clone();
49 OpPassManager
passManager("builtin.module");
50 if (failed(parsePassPipeline(optPass
, passManager
))) {
51 module
.emitError() << "\nfailed to parse pass pipeline";
52 return signalPassFailure();
55 std::pair
<Tester::Interestingness
, int> original
= test
.isInteresting(module
);
56 if (original
.first
!= Tester::Interestingness::True
) {
57 module
.emitError() << "\nthe original input is not interested";
58 return signalPassFailure();
61 // Temporarily push the variant under the main module and execute the pipeline
63 module
.getBody()->push_back(moduleVariant
);
64 LogicalResult pipelineResult
= runPipeline(passManager
, moduleVariant
);
65 moduleVariant
->remove();
67 if (failed(pipelineResult
)) {
68 module
.emitError() << "\nfailed to run pass pipeline";
69 return signalPassFailure();
72 std::pair
<Tester::Interestingness
, int> reduced
=
73 test
.isInteresting(moduleVariant
);
75 if (reduced
.first
== Tester::Interestingness::True
&&
76 reduced
.second
< original
.second
) {
77 module
.getBody()->clear();
78 module
.getBody()->getOperations().splice(
79 module
.getBody()->begin(), moduleVariant
.getBody()->getOperations());
80 LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n");
82 LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n");
85 moduleVariant
->destroy();
87 LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n");
90 std::unique_ptr
<Pass
> mlir::createOptReductionPass() {
91 return std::make_unique
<OptReductionPass
>();