[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / CAPI / IR / Pass.cpp
bloba6c9fbd08d45a6b9f097ca178b7ba0f55608f885
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/Pass.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 #include <optional>
18 using namespace mlir;
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
24 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25 return wrap(new PassManager(unwrap(ctx)));
28 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29 MlirStringRef anchorOp) {
30 return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
33 void mlirPassManagerDestroy(MlirPassManager passManager) {
34 delete unwrap(passManager);
37 MlirOpPassManager
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39 return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
42 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
43 MlirOperation op) {
44 return wrap(unwrap(passManager)->run(unwrap(op)));
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48 bool printBeforeAll, bool printAfterAll,
49 bool printModuleScope,
50 bool printAfterOnlyOnChange,
51 bool printAfterOnlyOnFailure) {
52 auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
53 return printBeforeAll;
55 auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
56 return printAfterAll;
58 return unwrap(passManager)
59 ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
60 printModuleScope, printAfterOnlyOnChange,
61 printAfterOnlyOnFailure);
64 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
65 unwrap(passManager)->enableVerifier(enable);
68 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
69 MlirStringRef operationName) {
70 return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
73 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
74 MlirStringRef operationName) {
75 return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
78 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
79 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
82 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
83 MlirPass pass) {
84 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
87 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
88 MlirStringRef pipelineElements,
89 MlirStringCallback callback,
90 void *userData) {
91 detail::CallbackOstream stream(callback, userData);
92 return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
93 stream));
96 void mlirPrintPassPipeline(MlirOpPassManager passManager,
97 MlirStringCallback callback, void *userData) {
98 detail::CallbackOstream stream(callback, userData);
99 unwrap(passManager)->printAsTextualPipeline(stream);
102 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
103 MlirStringRef pipeline,
104 MlirStringCallback callback,
105 void *userData) {
106 detail::CallbackOstream stream(callback, userData);
107 FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
108 if (succeeded(pm))
109 *unwrap(passManager) = std::move(*pm);
110 return wrap(pm);
113 //===----------------------------------------------------------------------===//
114 // External Pass API.
115 //===----------------------------------------------------------------------===//
117 namespace mlir {
118 class ExternalPass;
119 } // namespace mlir
120 DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
122 namespace mlir {
123 /// This pass class wraps external passes defined in other languages using the
124 /// MLIR C-interface
125 class ExternalPass : public Pass {
126 public:
127 ExternalPass(TypeID passID, StringRef name, StringRef argument,
128 StringRef description, std::optional<StringRef> opName,
129 ArrayRef<MlirDialectHandle> dependentDialects,
130 MlirExternalPassCallbacks callbacks, void *userData)
131 : Pass(passID, opName), id(passID), name(name), argument(argument),
132 description(description), dependentDialects(dependentDialects),
133 callbacks(callbacks), userData(userData) {
134 callbacks.construct(userData);
137 ~ExternalPass() override { callbacks.destruct(userData); }
139 StringRef getName() const override { return name; }
140 StringRef getArgument() const override { return argument; }
141 StringRef getDescription() const override { return description; }
143 void getDependentDialects(DialectRegistry &registry) const override {
144 MlirDialectRegistry cRegistry = wrap(&registry);
145 for (MlirDialectHandle dialect : dependentDialects)
146 mlirDialectHandleInsertDialect(dialect, cRegistry);
149 void signalPassFailure() { Pass::signalPassFailure(); }
151 protected:
152 LogicalResult initialize(MLIRContext *ctx) override {
153 if (callbacks.initialize)
154 return unwrap(callbacks.initialize(wrap(ctx), userData));
155 return success();
158 bool canScheduleOn(RegisteredOperationName opName) const override {
159 if (std::optional<StringRef> specifiedOpName = getOpName())
160 return opName.getStringRef() == specifiedOpName;
161 return true;
164 void runOnOperation() override {
165 callbacks.run(wrap(getOperation()), wrap(this), userData);
168 std::unique_ptr<Pass> clonePass() const override {
169 void *clonedUserData = callbacks.clone(userData);
170 return std::make_unique<ExternalPass>(id, name, argument, description,
171 getOpName(), dependentDialects,
172 callbacks, clonedUserData);
175 private:
176 TypeID id;
177 std::string name;
178 std::string argument;
179 std::string description;
180 std::vector<MlirDialectHandle> dependentDialects;
181 MlirExternalPassCallbacks callbacks;
182 void *userData;
184 } // namespace mlir
186 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
187 MlirStringRef argument,
188 MlirStringRef description, MlirStringRef opName,
189 intptr_t nDependentDialects,
190 MlirDialectHandle *dependentDialects,
191 MlirExternalPassCallbacks callbacks,
192 void *userData) {
193 return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
194 unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
195 opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
196 : std::nullopt,
197 {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
198 userData)));
201 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
202 unwrap(pass)->signalPassFailure();