[OpenACC] Treat 'delete' as a valid clause during parsing in C++ mode
[llvm-project.git] / mlir / lib / CAPI / IR / Pass.cpp
blob883b7e8bb832d2317189847f485900a7c44637cd
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/Pass.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 #include <optional>
18 using namespace mlir;
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
24 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25 return wrap(new PassManager(unwrap(ctx)));
28 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29 MlirStringRef anchorOp) {
30 return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
33 void mlirPassManagerDestroy(MlirPassManager passManager) {
34 delete unwrap(passManager);
37 MlirOpPassManager
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39 return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
42 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
43 MlirOperation op) {
44 return wrap(unwrap(passManager)->run(unwrap(op)));
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48 bool printBeforeAll, bool printAfterAll,
49 bool printModuleScope,
50 bool printAfterOnlyOnChange,
51 bool printAfterOnlyOnFailure,
52 MlirOpPrintingFlags flags,
53 MlirStringRef treePrintingPath) {
54 auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
55 return printBeforeAll;
57 auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
58 return printAfterAll;
60 if (unwrap(treePrintingPath).empty())
61 return unwrap(passManager)
62 ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
63 printModuleScope, printAfterOnlyOnChange,
64 printAfterOnlyOnFailure, /*out=*/llvm::errs(),
65 *unwrap(flags));
67 unwrap(passManager)
68 ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
69 printModuleScope, printAfterOnlyOnChange,
70 printAfterOnlyOnFailure,
71 unwrap(treePrintingPath), *unwrap(flags));
74 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
75 unwrap(passManager)->enableVerifier(enable);
78 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
79 MlirStringRef operationName) {
80 return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
83 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
84 MlirStringRef operationName) {
85 return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
88 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
89 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
92 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
93 MlirPass pass) {
94 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
97 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
98 MlirStringRef pipelineElements,
99 MlirStringCallback callback,
100 void *userData) {
101 detail::CallbackOstream stream(callback, userData);
102 return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
103 stream));
106 void mlirPrintPassPipeline(MlirOpPassManager passManager,
107 MlirStringCallback callback, void *userData) {
108 detail::CallbackOstream stream(callback, userData);
109 unwrap(passManager)->printAsTextualPipeline(stream);
112 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
113 MlirStringRef pipeline,
114 MlirStringCallback callback,
115 void *userData) {
116 detail::CallbackOstream stream(callback, userData);
117 FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
118 if (succeeded(pm))
119 *unwrap(passManager) = std::move(*pm);
120 return wrap(pm);
123 //===----------------------------------------------------------------------===//
124 // External Pass API.
125 //===----------------------------------------------------------------------===//
127 namespace mlir {
128 class ExternalPass;
129 } // namespace mlir
130 DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
132 namespace mlir {
133 /// This pass class wraps external passes defined in other languages using the
134 /// MLIR C-interface
135 class ExternalPass : public Pass {
136 public:
137 ExternalPass(TypeID passID, StringRef name, StringRef argument,
138 StringRef description, std::optional<StringRef> opName,
139 ArrayRef<MlirDialectHandle> dependentDialects,
140 MlirExternalPassCallbacks callbacks, void *userData)
141 : Pass(passID, opName), id(passID), name(name), argument(argument),
142 description(description), dependentDialects(dependentDialects),
143 callbacks(callbacks), userData(userData) {
144 callbacks.construct(userData);
147 ~ExternalPass() override { callbacks.destruct(userData); }
149 StringRef getName() const override { return name; }
150 StringRef getArgument() const override { return argument; }
151 StringRef getDescription() const override { return description; }
153 void getDependentDialects(DialectRegistry &registry) const override {
154 MlirDialectRegistry cRegistry = wrap(&registry);
155 for (MlirDialectHandle dialect : dependentDialects)
156 mlirDialectHandleInsertDialect(dialect, cRegistry);
159 void signalPassFailure() { Pass::signalPassFailure(); }
161 protected:
162 LogicalResult initialize(MLIRContext *ctx) override {
163 if (callbacks.initialize)
164 return unwrap(callbacks.initialize(wrap(ctx), userData));
165 return success();
168 bool canScheduleOn(RegisteredOperationName opName) const override {
169 if (std::optional<StringRef> specifiedOpName = getOpName())
170 return opName.getStringRef() == specifiedOpName;
171 return true;
174 void runOnOperation() override {
175 callbacks.run(wrap(getOperation()), wrap(this), userData);
178 std::unique_ptr<Pass> clonePass() const override {
179 void *clonedUserData = callbacks.clone(userData);
180 return std::make_unique<ExternalPass>(id, name, argument, description,
181 getOpName(), dependentDialects,
182 callbacks, clonedUserData);
185 private:
186 TypeID id;
187 std::string name;
188 std::string argument;
189 std::string description;
190 std::vector<MlirDialectHandle> dependentDialects;
191 MlirExternalPassCallbacks callbacks;
192 void *userData;
194 } // namespace mlir
196 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
197 MlirStringRef argument,
198 MlirStringRef description, MlirStringRef opName,
199 intptr_t nDependentDialects,
200 MlirDialectHandle *dependentDialects,
201 MlirExternalPassCallbacks callbacks,
202 void *userData) {
203 return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
204 unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
205 opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
206 : std::nullopt,
207 {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
208 userData)));
211 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
212 unwrap(pass)->signalPassFailure();