1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/Pass.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
24 MlirPassManager
mlirPassManagerCreate(MlirContext ctx
) {
25 return wrap(new PassManager(unwrap(ctx
)));
28 MlirPassManager
mlirPassManagerCreateOnOperation(MlirContext ctx
,
29 MlirStringRef anchorOp
) {
30 return wrap(new PassManager(unwrap(ctx
), unwrap(anchorOp
)));
33 void mlirPassManagerDestroy(MlirPassManager passManager
) {
34 delete unwrap(passManager
);
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager
) {
39 return wrap(static_cast<OpPassManager
*>(unwrap(passManager
)));
42 MlirLogicalResult
mlirPassManagerRunOnOp(MlirPassManager passManager
,
44 return wrap(unwrap(passManager
)->run(unwrap(op
)));
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager
,
48 bool printBeforeAll
, bool printAfterAll
,
49 bool printModuleScope
,
50 bool printAfterOnlyOnChange
,
51 bool printAfterOnlyOnFailure
) {
52 auto shouldPrintBeforePass
= [printBeforeAll
](Pass
*, Operation
*) {
53 return printBeforeAll
;
55 auto shouldPrintAfterPass
= [printAfterAll
](Pass
*, Operation
*) {
58 return unwrap(passManager
)
59 ->enableIRPrinting(shouldPrintBeforePass
, shouldPrintAfterPass
,
60 printModuleScope
, printAfterOnlyOnChange
,
61 printAfterOnlyOnFailure
);
64 void mlirPassManagerEnableVerifier(MlirPassManager passManager
, bool enable
) {
65 unwrap(passManager
)->enableVerifier(enable
);
68 MlirOpPassManager
mlirPassManagerGetNestedUnder(MlirPassManager passManager
,
69 MlirStringRef operationName
) {
70 return wrap(&unwrap(passManager
)->nest(unwrap(operationName
)));
73 MlirOpPassManager
mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager
,
74 MlirStringRef operationName
) {
75 return wrap(&unwrap(passManager
)->nest(unwrap(operationName
)));
78 void mlirPassManagerAddOwnedPass(MlirPassManager passManager
, MlirPass pass
) {
79 unwrap(passManager
)->addPass(std::unique_ptr
<Pass
>(unwrap(pass
)));
82 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager
,
84 unwrap(passManager
)->addPass(std::unique_ptr
<Pass
>(unwrap(pass
)));
87 MlirLogicalResult
mlirOpPassManagerAddPipeline(MlirOpPassManager passManager
,
88 MlirStringRef pipelineElements
,
89 MlirStringCallback callback
,
91 detail::CallbackOstream
stream(callback
, userData
);
92 return wrap(parsePassPipeline(unwrap(pipelineElements
), *unwrap(passManager
),
96 void mlirPrintPassPipeline(MlirOpPassManager passManager
,
97 MlirStringCallback callback
, void *userData
) {
98 detail::CallbackOstream
stream(callback
, userData
);
99 unwrap(passManager
)->printAsTextualPipeline(stream
);
102 MlirLogicalResult
mlirParsePassPipeline(MlirOpPassManager passManager
,
103 MlirStringRef pipeline
,
104 MlirStringCallback callback
,
106 detail::CallbackOstream
stream(callback
, userData
);
107 FailureOr
<OpPassManager
> pm
= parsePassPipeline(unwrap(pipeline
), stream
);
109 *unwrap(passManager
) = std::move(*pm
);
113 //===----------------------------------------------------------------------===//
114 // External Pass API.
115 //===----------------------------------------------------------------------===//
120 DEFINE_C_API_PTR_METHODS(MlirExternalPass
, mlir::ExternalPass
)
123 /// This pass class wraps external passes defined in other languages using the
125 class ExternalPass
: public Pass
{
127 ExternalPass(TypeID passID
, StringRef name
, StringRef argument
,
128 StringRef description
, std::optional
<StringRef
> opName
,
129 ArrayRef
<MlirDialectHandle
> dependentDialects
,
130 MlirExternalPassCallbacks callbacks
, void *userData
)
131 : Pass(passID
, opName
), id(passID
), name(name
), argument(argument
),
132 description(description
), dependentDialects(dependentDialects
),
133 callbacks(callbacks
), userData(userData
) {
134 callbacks
.construct(userData
);
137 ~ExternalPass() override
{ callbacks
.destruct(userData
); }
139 StringRef
getName() const override
{ return name
; }
140 StringRef
getArgument() const override
{ return argument
; }
141 StringRef
getDescription() const override
{ return description
; }
143 void getDependentDialects(DialectRegistry
®istry
) const override
{
144 MlirDialectRegistry cRegistry
= wrap(®istry
);
145 for (MlirDialectHandle dialect
: dependentDialects
)
146 mlirDialectHandleInsertDialect(dialect
, cRegistry
);
149 void signalPassFailure() { Pass::signalPassFailure(); }
152 LogicalResult
initialize(MLIRContext
*ctx
) override
{
153 if (callbacks
.initialize
)
154 return unwrap(callbacks
.initialize(wrap(ctx
), userData
));
158 bool canScheduleOn(RegisteredOperationName opName
) const override
{
159 if (std::optional
<StringRef
> specifiedOpName
= getOpName())
160 return opName
.getStringRef() == specifiedOpName
;
164 void runOnOperation() override
{
165 callbacks
.run(wrap(getOperation()), wrap(this), userData
);
168 std::unique_ptr
<Pass
> clonePass() const override
{
169 void *clonedUserData
= callbacks
.clone(userData
);
170 return std::make_unique
<ExternalPass
>(id
, name
, argument
, description
,
171 getOpName(), dependentDialects
,
172 callbacks
, clonedUserData
);
178 std::string argument
;
179 std::string description
;
180 std::vector
<MlirDialectHandle
> dependentDialects
;
181 MlirExternalPassCallbacks callbacks
;
186 MlirPass
mlirCreateExternalPass(MlirTypeID passID
, MlirStringRef name
,
187 MlirStringRef argument
,
188 MlirStringRef description
, MlirStringRef opName
,
189 intptr_t nDependentDialects
,
190 MlirDialectHandle
*dependentDialects
,
191 MlirExternalPassCallbacks callbacks
,
193 return wrap(static_cast<mlir::Pass
*>(new mlir::ExternalPass(
194 unwrap(passID
), unwrap(name
), unwrap(argument
), unwrap(description
),
195 opName
.length
> 0 ? std::optional
<StringRef
>(unwrap(opName
))
197 {dependentDialects
, static_cast<size_t>(nDependentDialects
)}, callbacks
,
201 void mlirExternalPassSignalFailure(MlirExternalPass pass
) {
202 unwrap(pass
)->signalPassFailure();