1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/Pass.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
24 MlirPassManager
mlirPassManagerCreate(MlirContext ctx
) {
25 return wrap(new PassManager(unwrap(ctx
)));
28 MlirPassManager
mlirPassManagerCreateOnOperation(MlirContext ctx
,
29 MlirStringRef anchorOp
) {
30 return wrap(new PassManager(unwrap(ctx
), unwrap(anchorOp
)));
33 void mlirPassManagerDestroy(MlirPassManager passManager
) {
34 delete unwrap(passManager
);
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager
) {
39 return wrap(static_cast<OpPassManager
*>(unwrap(passManager
)));
42 MlirLogicalResult
mlirPassManagerRunOnOp(MlirPassManager passManager
,
44 return wrap(unwrap(passManager
)->run(unwrap(op
)));
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager
,
48 bool printBeforeAll
, bool printAfterAll
,
49 bool printModuleScope
,
50 bool printAfterOnlyOnChange
,
51 bool printAfterOnlyOnFailure
,
52 MlirOpPrintingFlags flags
,
53 MlirStringRef treePrintingPath
) {
54 auto shouldPrintBeforePass
= [printBeforeAll
](Pass
*, Operation
*) {
55 return printBeforeAll
;
57 auto shouldPrintAfterPass
= [printAfterAll
](Pass
*, Operation
*) {
60 if (unwrap(treePrintingPath
).empty())
61 return unwrap(passManager
)
62 ->enableIRPrinting(shouldPrintBeforePass
, shouldPrintAfterPass
,
63 printModuleScope
, printAfterOnlyOnChange
,
64 printAfterOnlyOnFailure
, /*out=*/llvm::errs(),
68 ->enableIRPrintingToFileTree(shouldPrintBeforePass
, shouldPrintAfterPass
,
69 printModuleScope
, printAfterOnlyOnChange
,
70 printAfterOnlyOnFailure
,
71 unwrap(treePrintingPath
), *unwrap(flags
));
74 void mlirPassManagerEnableVerifier(MlirPassManager passManager
, bool enable
) {
75 unwrap(passManager
)->enableVerifier(enable
);
78 MlirOpPassManager
mlirPassManagerGetNestedUnder(MlirPassManager passManager
,
79 MlirStringRef operationName
) {
80 return wrap(&unwrap(passManager
)->nest(unwrap(operationName
)));
83 MlirOpPassManager
mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager
,
84 MlirStringRef operationName
) {
85 return wrap(&unwrap(passManager
)->nest(unwrap(operationName
)));
88 void mlirPassManagerAddOwnedPass(MlirPassManager passManager
, MlirPass pass
) {
89 unwrap(passManager
)->addPass(std::unique_ptr
<Pass
>(unwrap(pass
)));
92 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager
,
94 unwrap(passManager
)->addPass(std::unique_ptr
<Pass
>(unwrap(pass
)));
97 MlirLogicalResult
mlirOpPassManagerAddPipeline(MlirOpPassManager passManager
,
98 MlirStringRef pipelineElements
,
99 MlirStringCallback callback
,
101 detail::CallbackOstream
stream(callback
, userData
);
102 return wrap(parsePassPipeline(unwrap(pipelineElements
), *unwrap(passManager
),
106 void mlirPrintPassPipeline(MlirOpPassManager passManager
,
107 MlirStringCallback callback
, void *userData
) {
108 detail::CallbackOstream
stream(callback
, userData
);
109 unwrap(passManager
)->printAsTextualPipeline(stream
);
112 MlirLogicalResult
mlirParsePassPipeline(MlirOpPassManager passManager
,
113 MlirStringRef pipeline
,
114 MlirStringCallback callback
,
116 detail::CallbackOstream
stream(callback
, userData
);
117 FailureOr
<OpPassManager
> pm
= parsePassPipeline(unwrap(pipeline
), stream
);
119 *unwrap(passManager
) = std::move(*pm
);
123 //===----------------------------------------------------------------------===//
124 // External Pass API.
125 //===----------------------------------------------------------------------===//
130 DEFINE_C_API_PTR_METHODS(MlirExternalPass
, mlir::ExternalPass
)
133 /// This pass class wraps external passes defined in other languages using the
135 class ExternalPass
: public Pass
{
137 ExternalPass(TypeID passID
, StringRef name
, StringRef argument
,
138 StringRef description
, std::optional
<StringRef
> opName
,
139 ArrayRef
<MlirDialectHandle
> dependentDialects
,
140 MlirExternalPassCallbacks callbacks
, void *userData
)
141 : Pass(passID
, opName
), id(passID
), name(name
), argument(argument
),
142 description(description
), dependentDialects(dependentDialects
),
143 callbacks(callbacks
), userData(userData
) {
144 callbacks
.construct(userData
);
147 ~ExternalPass() override
{ callbacks
.destruct(userData
); }
149 StringRef
getName() const override
{ return name
; }
150 StringRef
getArgument() const override
{ return argument
; }
151 StringRef
getDescription() const override
{ return description
; }
153 void getDependentDialects(DialectRegistry
®istry
) const override
{
154 MlirDialectRegistry cRegistry
= wrap(®istry
);
155 for (MlirDialectHandle dialect
: dependentDialects
)
156 mlirDialectHandleInsertDialect(dialect
, cRegistry
);
159 void signalPassFailure() { Pass::signalPassFailure(); }
162 LogicalResult
initialize(MLIRContext
*ctx
) override
{
163 if (callbacks
.initialize
)
164 return unwrap(callbacks
.initialize(wrap(ctx
), userData
));
168 bool canScheduleOn(RegisteredOperationName opName
) const override
{
169 if (std::optional
<StringRef
> specifiedOpName
= getOpName())
170 return opName
.getStringRef() == specifiedOpName
;
174 void runOnOperation() override
{
175 callbacks
.run(wrap(getOperation()), wrap(this), userData
);
178 std::unique_ptr
<Pass
> clonePass() const override
{
179 void *clonedUserData
= callbacks
.clone(userData
);
180 return std::make_unique
<ExternalPass
>(id
, name
, argument
, description
,
181 getOpName(), dependentDialects
,
182 callbacks
, clonedUserData
);
188 std::string argument
;
189 std::string description
;
190 std::vector
<MlirDialectHandle
> dependentDialects
;
191 MlirExternalPassCallbacks callbacks
;
196 MlirPass
mlirCreateExternalPass(MlirTypeID passID
, MlirStringRef name
,
197 MlirStringRef argument
,
198 MlirStringRef description
, MlirStringRef opName
,
199 intptr_t nDependentDialects
,
200 MlirDialectHandle
*dependentDialects
,
201 MlirExternalPassCallbacks callbacks
,
203 return wrap(static_cast<mlir::Pass
*>(new mlir::ExternalPass(
204 unwrap(passID
), unwrap(name
), unwrap(argument
), unwrap(description
),
205 opName
.length
> 0 ? std::optional
<StringRef
>(unwrap(opName
))
207 {dependentDialects
, static_cast<size_t>(nDependentDialects
)}, callbacks
,
211 void mlirExternalPassSignalFailure(MlirExternalPass pass
) {
212 unwrap(pass
)->signalPassFailure();