1 //===- pass.c - Simple test of C APIs -------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/Dialect/Func.h"
15 #include "mlir-c/IR.h"
16 #include "mlir-c/RegisterEverything.h"
17 #include "mlir-c/Transforms.h"
25 static void registerAllUpstreamDialects(MlirContext ctx
) {
26 MlirDialectRegistry registry
= mlirDialectRegistryCreate();
27 mlirRegisterAllDialects(registry
);
28 mlirContextAppendDialectRegistry(ctx
, registry
);
29 mlirDialectRegistryDestroy(registry
);
32 void testRunPassOnModule(void) {
33 MlirContext ctx
= mlirContextCreate();
34 registerAllUpstreamDialects(ctx
);
36 const char *funcAsm
= //
37 "func.func @foo(%arg0 : i32) -> i32 { \n"
38 " %res = arith.addi %arg0, %arg0 : i32 \n"
39 " return %res : i32 \n"
42 mlirOperationCreateParse(ctx
, mlirStringRefCreateFromCString(funcAsm
),
43 mlirStringRefCreateFromCString("funcAsm"));
44 if (mlirOperationIsNull(func
)) {
45 fprintf(stderr
, "Unexpected failure parsing asm.\n");
49 // Run the print-op-stats pass on the top-level module:
50 // CHECK-LABEL: Operations encountered:
51 // CHECK: arith.addi , 1
52 // CHECK: func.func , 1
53 // CHECK: func.return , 1
55 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
56 MlirPass printOpStatPass
= mlirCreateTransformsPrintOpStats();
57 mlirPassManagerAddOwnedPass(pm
, printOpStatPass
);
58 MlirLogicalResult success
= mlirPassManagerRunOnOp(pm
, func
);
59 if (mlirLogicalResultIsFailure(success
)) {
60 fprintf(stderr
, "Unexpected failure running pass manager.\n");
63 mlirPassManagerDestroy(pm
);
65 mlirOperationDestroy(func
);
66 mlirContextDestroy(ctx
);
69 void testRunPassOnNestedModule(void) {
70 MlirContext ctx
= mlirContextCreate();
71 registerAllUpstreamDialects(ctx
);
73 const char *moduleAsm
= //
75 " func.func @foo(%arg0 : i32) -> i32 { \n"
76 " %res = arith.addi %arg0, %arg0 : i32 \n"
77 " return %res : i32 \n"
80 " func.func @bar(%arg0 : f32) -> f32 { \n"
81 " %res = arith.addf %arg0, %arg0 : f32 \n"
82 " return %res : f32 \n"
86 MlirOperation module
=
87 mlirOperationCreateParse(ctx
, mlirStringRefCreateFromCString(moduleAsm
),
88 mlirStringRefCreateFromCString("moduleAsm"));
89 if (mlirOperationIsNull(module
))
92 // Run the print-op-stats pass on functions under the top-level module:
93 // CHECK-LABEL: Operations encountered:
94 // CHECK: arith.addi , 1
95 // CHECK: func.func , 1
96 // CHECK: func.return , 1
98 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
99 MlirOpPassManager nestedFuncPm
= mlirPassManagerGetNestedUnder(
100 pm
, mlirStringRefCreateFromCString("func.func"));
101 MlirPass printOpStatPass
= mlirCreateTransformsPrintOpStats();
102 mlirOpPassManagerAddOwnedPass(nestedFuncPm
, printOpStatPass
);
103 MlirLogicalResult success
= mlirPassManagerRunOnOp(pm
, module
);
104 if (mlirLogicalResultIsFailure(success
))
106 mlirPassManagerDestroy(pm
);
108 // Run the print-op-stats pass on functions under the nested module:
109 // CHECK-LABEL: Operations encountered:
110 // CHECK: arith.addf , 1
111 // CHECK: func.func , 1
112 // CHECK: func.return , 1
114 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
115 MlirOpPassManager nestedModulePm
= mlirPassManagerGetNestedUnder(
116 pm
, mlirStringRefCreateFromCString("builtin.module"));
117 MlirOpPassManager nestedFuncPm
= mlirOpPassManagerGetNestedUnder(
118 nestedModulePm
, mlirStringRefCreateFromCString("func.func"));
119 MlirPass printOpStatPass
= mlirCreateTransformsPrintOpStats();
120 mlirOpPassManagerAddOwnedPass(nestedFuncPm
, printOpStatPass
);
121 MlirLogicalResult success
= mlirPassManagerRunOnOp(pm
, module
);
122 if (mlirLogicalResultIsFailure(success
))
124 mlirPassManagerDestroy(pm
);
127 mlirOperationDestroy(module
);
128 mlirContextDestroy(ctx
);
131 static void printToStderr(MlirStringRef str
, void *userData
) {
133 fwrite(str
.data
, 1, str
.length
, stderr
);
136 static void dontPrint(MlirStringRef str
, void *userData
) {
141 void testPrintPassPipeline(void) {
142 MlirContext ctx
= mlirContextCreate();
143 MlirPassManager pm
= mlirPassManagerCreateOnOperation(
144 ctx
, mlirStringRefCreateFromCString("any"));
145 // Populate the pass-manager
146 MlirOpPassManager nestedModulePm
= mlirPassManagerGetNestedUnder(
147 pm
, mlirStringRefCreateFromCString("builtin.module"));
148 MlirOpPassManager nestedFuncPm
= mlirOpPassManagerGetNestedUnder(
149 nestedModulePm
, mlirStringRefCreateFromCString("func.func"));
150 MlirPass printOpStatPass
= mlirCreateTransformsPrintOpStats();
151 mlirOpPassManagerAddOwnedPass(nestedFuncPm
, printOpStatPass
);
153 // Print the top level pass manager
154 // CHECK: Top-level: any(
155 // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false}))
157 fprintf(stderr
, "Top-level: ");
158 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm
), printToStderr
,
160 fprintf(stderr
, "\n");
162 // Print the pipeline nested one level down
163 // CHECK: Nested Module: builtin.module(func.func(print-op-stats{json=false}))
164 fprintf(stderr
, "Nested Module: ");
165 mlirPrintPassPipeline(nestedModulePm
, printToStderr
, NULL
);
166 fprintf(stderr
, "\n");
168 // Print the pipeline nested two levels down
169 // CHECK: Nested Module>Func: func.func(print-op-stats{json=false})
170 fprintf(stderr
, "Nested Module>Func: ");
171 mlirPrintPassPipeline(nestedFuncPm
, printToStderr
, NULL
);
172 fprintf(stderr
, "\n");
174 mlirPassManagerDestroy(pm
);
175 mlirContextDestroy(ctx
);
178 void testParsePassPipeline(void) {
179 MlirContext ctx
= mlirContextCreate();
180 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
181 // Try parse a pipeline.
182 MlirLogicalResult status
= mlirParsePassPipeline(
183 mlirPassManagerGetAsOpPassManager(pm
),
184 mlirStringRefCreateFromCString(
185 "builtin.module(func.func(print-op-stats{json=false}))"),
186 printToStderr
, NULL
);
187 // Expect a failure, we haven't registered the print-op-stats pass yet.
188 if (mlirLogicalResultIsSuccess(status
)) {
191 "Unexpected success parsing pipeline without registering the pass\n");
194 // Try again after registrating the pass.
195 mlirRegisterTransformsPrintOpStats();
196 status
= mlirParsePassPipeline(
197 mlirPassManagerGetAsOpPassManager(pm
),
198 mlirStringRefCreateFromCString(
199 "builtin.module(func.func(print-op-stats{json=false}))"),
200 printToStderr
, NULL
);
201 // Expect a failure, we haven't registered the print-op-stats pass yet.
202 if (mlirLogicalResultIsFailure(status
)) {
204 "Unexpected failure parsing pipeline after registering the pass\n");
208 // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}))
209 fprintf(stderr
, "Round-trip: ");
210 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm
), printToStderr
,
212 fprintf(stderr
, "\n");
214 // Try appending a pass:
215 status
= mlirOpPassManagerAddPipeline(
216 mlirPassManagerGetAsOpPassManager(pm
),
217 mlirStringRefCreateFromCString("func.func(print-op-stats{json=false})"),
218 printToStderr
, NULL
);
219 if (mlirLogicalResultIsFailure(status
)) {
220 fprintf(stderr
, "Unexpected failure appending pipeline\n");
223 // CHECK: Appended: builtin.module(
224 // CHECK-SAME: func.func(print-op-stats{json=false}),
225 // CHECK-SAME: func.func(print-op-stats{json=false})
227 fprintf(stderr
, "Appended: ");
228 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm
), printToStderr
,
230 fprintf(stderr
, "\n");
232 mlirPassManagerDestroy(pm
);
233 mlirContextDestroy(ctx
);
236 void testParseErrorCapture(void) {
237 // CHECK-LABEL: testParseErrorCapture:
238 fprintf(stderr
, "\nTEST: testParseErrorCapture:\n");
240 MlirContext ctx
= mlirContextCreate();
241 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
242 MlirOpPassManager opm
= mlirPassManagerGetAsOpPassManager(pm
);
243 MlirStringRef invalidPipeline
= mlirStringRefCreateFromCString("invalid");
245 // CHECK: mlirParsePassPipeline:
246 // CHECK: expected pass pipeline to be wrapped with the anchor operation type
247 fprintf(stderr
, "mlirParsePassPipeline:\n");
248 if (mlirLogicalResultIsSuccess(
249 mlirParsePassPipeline(opm
, invalidPipeline
, printToStderr
, NULL
)))
251 fprintf(stderr
, "\n");
253 // CHECK: mlirOpPassManagerAddPipeline:
254 // CHECK: 'invalid' does not refer to a registered pass or pass pipeline
255 fprintf(stderr
, "mlirOpPassManagerAddPipeline:\n");
256 if (mlirLogicalResultIsSuccess(mlirOpPassManagerAddPipeline(
257 opm
, invalidPipeline
, printToStderr
, NULL
)))
259 fprintf(stderr
, "\n");
261 // Make sure all output is going through the callback.
262 // CHECK: dontPrint: <>
263 fprintf(stderr
, "dontPrint: <");
264 if (mlirLogicalResultIsSuccess(
265 mlirParsePassPipeline(opm
, invalidPipeline
, dontPrint
, NULL
)))
267 if (mlirLogicalResultIsSuccess(
268 mlirOpPassManagerAddPipeline(opm
, invalidPipeline
, dontPrint
, NULL
)))
270 fprintf(stderr
, ">\n");
272 mlirPassManagerDestroy(pm
);
273 mlirContextDestroy(ctx
);
276 struct TestExternalPassUserData
{
277 int constructCallCount
;
278 int destructCallCount
;
279 int initializeCallCount
;
283 typedef struct TestExternalPassUserData TestExternalPassUserData
;
285 void testConstructExternalPass(void *userData
) {
286 ++((TestExternalPassUserData
*)userData
)->constructCallCount
;
289 void testDestructExternalPass(void *userData
) {
290 ++((TestExternalPassUserData
*)userData
)->destructCallCount
;
293 MlirLogicalResult
testInitializeExternalPass(MlirContext ctx
, void *userData
) {
294 ++((TestExternalPassUserData
*)userData
)->initializeCallCount
;
295 return mlirLogicalResultSuccess();
298 MlirLogicalResult
testInitializeFailingExternalPass(MlirContext ctx
,
300 ++((TestExternalPassUserData
*)userData
)->initializeCallCount
;
301 return mlirLogicalResultFailure();
304 void *testCloneExternalPass(void *userData
) {
305 ++((TestExternalPassUserData
*)userData
)->cloneCallCount
;
309 void testRunExternalPass(MlirOperation op
, MlirExternalPass pass
,
311 ++((TestExternalPassUserData
*)userData
)->runCallCount
;
314 void testRunExternalFuncPass(MlirOperation op
, MlirExternalPass pass
,
316 ++((TestExternalPassUserData
*)userData
)->runCallCount
;
317 MlirStringRef opName
= mlirIdentifierStr(mlirOperationGetName(op
));
318 if (!mlirStringRefEqual(opName
,
319 mlirStringRefCreateFromCString("func.func"))) {
320 mlirExternalPassSignalFailure(pass
);
324 void testRunFailingExternalPass(MlirOperation op
, MlirExternalPass pass
,
326 ++((TestExternalPassUserData
*)userData
)->runCallCount
;
327 mlirExternalPassSignalFailure(pass
);
330 MlirExternalPassCallbacks
makeTestExternalPassCallbacks(
331 MlirLogicalResult (*initializePass
)(MlirContext ctx
, void *userData
),
332 void (*runPass
)(MlirOperation op
, MlirExternalPass
, void *userData
)) {
333 return (MlirExternalPassCallbacks
){testConstructExternalPass
,
334 testDestructExternalPass
, initializePass
,
335 testCloneExternalPass
, runPass
};
338 void testExternalPass(void) {
339 MlirContext ctx
= mlirContextCreate();
340 registerAllUpstreamDialects(ctx
);
342 const char *moduleAsm
= //
344 " func.func @foo(%arg0 : i32) -> i32 { \n"
345 " %res = arith.addi %arg0, %arg0 : i32 \n"
346 " return %res : i32 \n"
349 MlirOperation module
=
350 mlirOperationCreateParse(ctx
, mlirStringRefCreateFromCString(moduleAsm
),
351 mlirStringRefCreateFromCString("moduleAsm"));
352 if (mlirOperationIsNull(module
)) {
353 fprintf(stderr
, "Unexpected failure parsing module.\n");
357 MlirStringRef description
= mlirStringRefCreateFromCString("");
358 MlirStringRef emptyOpName
= mlirStringRefCreateFromCString("");
360 MlirTypeIDAllocator typeIDAllocator
= mlirTypeIDAllocatorCreate();
362 // Run a generic pass
364 MlirTypeID passID
= mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator
);
365 MlirStringRef name
= mlirStringRefCreateFromCString("TestExternalPass");
366 MlirStringRef argument
=
367 mlirStringRefCreateFromCString("test-external-pass");
368 TestExternalPassUserData userData
= {0};
370 MlirPass externalPass
= mlirCreateExternalPass(
371 passID
, name
, argument
, description
, emptyOpName
, 0, NULL
,
372 makeTestExternalPassCallbacks(NULL
, testRunExternalPass
), &userData
);
374 if (userData
.constructCallCount
!= 1) {
375 fprintf(stderr
, "Expected constructCallCount to be 1\n");
379 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
380 mlirPassManagerAddOwnedPass(pm
, externalPass
);
381 MlirLogicalResult success
= mlirPassManagerRunOnOp(pm
, module
);
382 if (mlirLogicalResultIsFailure(success
)) {
383 fprintf(stderr
, "Unexpected failure running external pass.\n");
387 if (userData
.runCallCount
!= 1) {
388 fprintf(stderr
, "Expected runCallCount to be 1\n");
392 mlirPassManagerDestroy(pm
);
394 if (userData
.destructCallCount
!= userData
.constructCallCount
) {
395 fprintf(stderr
, "Expected destructCallCount to be equal to "
396 "constructCallCount\n");
401 // Run a func operation pass
403 MlirTypeID passID
= mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator
);
404 MlirStringRef name
= mlirStringRefCreateFromCString("TestExternalFuncPass");
405 MlirStringRef argument
=
406 mlirStringRefCreateFromCString("test-external-func-pass");
407 TestExternalPassUserData userData
= {0};
408 MlirDialectHandle funcHandle
= mlirGetDialectHandle__func__();
409 MlirStringRef funcOpName
= mlirStringRefCreateFromCString("func.func");
411 MlirPass externalPass
= mlirCreateExternalPass(
412 passID
, name
, argument
, description
, funcOpName
, 1, &funcHandle
,
413 makeTestExternalPassCallbacks(NULL
, testRunExternalFuncPass
),
416 if (userData
.constructCallCount
!= 1) {
417 fprintf(stderr
, "Expected constructCallCount to be 1\n");
421 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
422 MlirOpPassManager nestedFuncPm
=
423 mlirPassManagerGetNestedUnder(pm
, funcOpName
);
424 mlirOpPassManagerAddOwnedPass(nestedFuncPm
, externalPass
);
425 MlirLogicalResult success
= mlirPassManagerRunOnOp(pm
, module
);
426 if (mlirLogicalResultIsFailure(success
)) {
427 fprintf(stderr
, "Unexpected failure running external operation pass.\n");
431 // Since this is a nested pass, it can be cloned and run in parallel
432 if (userData
.cloneCallCount
!= userData
.constructCallCount
- 1) {
433 fprintf(stderr
, "Expected constructCallCount to be 1\n");
437 // The pass should only be run once this there is only one func op
438 if (userData
.runCallCount
!= 1) {
439 fprintf(stderr
, "Expected runCallCount to be 1\n");
443 mlirPassManagerDestroy(pm
);
445 if (userData
.destructCallCount
!= userData
.constructCallCount
) {
446 fprintf(stderr
, "Expected destructCallCount to be equal to "
447 "constructCallCount\n");
452 // Run a pass with `initialize` set
454 MlirTypeID passID
= mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator
);
455 MlirStringRef name
= mlirStringRefCreateFromCString("TestExternalPass");
456 MlirStringRef argument
=
457 mlirStringRefCreateFromCString("test-external-pass");
458 TestExternalPassUserData userData
= {0};
460 MlirPass externalPass
= mlirCreateExternalPass(
461 passID
, name
, argument
, description
, emptyOpName
, 0, NULL
,
462 makeTestExternalPassCallbacks(testInitializeExternalPass
,
463 testRunExternalPass
),
466 if (userData
.constructCallCount
!= 1) {
467 fprintf(stderr
, "Expected constructCallCount to be 1\n");
471 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
472 mlirPassManagerAddOwnedPass(pm
, externalPass
);
473 MlirLogicalResult success
= mlirPassManagerRunOnOp(pm
, module
);
474 if (mlirLogicalResultIsFailure(success
)) {
475 fprintf(stderr
, "Unexpected failure running external pass.\n");
479 if (userData
.initializeCallCount
!= 1) {
480 fprintf(stderr
, "Expected initializeCallCount to be 1\n");
484 if (userData
.runCallCount
!= 1) {
485 fprintf(stderr
, "Expected runCallCount to be 1\n");
489 mlirPassManagerDestroy(pm
);
491 if (userData
.destructCallCount
!= userData
.constructCallCount
) {
492 fprintf(stderr
, "Expected destructCallCount to be equal to "
493 "constructCallCount\n");
498 // Run a pass that fails during `initialize`
500 MlirTypeID passID
= mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator
);
502 mlirStringRefCreateFromCString("TestExternalFailingPass");
503 MlirStringRef argument
=
504 mlirStringRefCreateFromCString("test-external-failing-pass");
505 TestExternalPassUserData userData
= {0};
507 MlirPass externalPass
= mlirCreateExternalPass(
508 passID
, name
, argument
, description
, emptyOpName
, 0, NULL
,
509 makeTestExternalPassCallbacks(testInitializeFailingExternalPass
,
510 testRunExternalPass
),
513 if (userData
.constructCallCount
!= 1) {
514 fprintf(stderr
, "Expected constructCallCount to be 1\n");
518 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
519 mlirPassManagerAddOwnedPass(pm
, externalPass
);
520 MlirLogicalResult success
= mlirPassManagerRunOnOp(pm
, module
);
521 if (mlirLogicalResultIsSuccess(success
)) {
524 "Expected failure running pass manager on failing external pass.\n");
528 if (userData
.initializeCallCount
!= 1) {
529 fprintf(stderr
, "Expected initializeCallCount to be 1\n");
533 if (userData
.runCallCount
!= 0) {
534 fprintf(stderr
, "Expected runCallCount to be 0\n");
538 mlirPassManagerDestroy(pm
);
540 if (userData
.destructCallCount
!= userData
.constructCallCount
) {
541 fprintf(stderr
, "Expected destructCallCount to be equal to "
542 "constructCallCount\n");
547 // Run a pass that fails during `run`
549 MlirTypeID passID
= mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator
);
551 mlirStringRefCreateFromCString("TestExternalFailingPass");
552 MlirStringRef argument
=
553 mlirStringRefCreateFromCString("test-external-failing-pass");
554 TestExternalPassUserData userData
= {0};
556 MlirPass externalPass
= mlirCreateExternalPass(
557 passID
, name
, argument
, description
, emptyOpName
, 0, NULL
,
558 makeTestExternalPassCallbacks(NULL
, testRunFailingExternalPass
),
561 if (userData
.constructCallCount
!= 1) {
562 fprintf(stderr
, "Expected constructCallCount to be 1\n");
566 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
567 mlirPassManagerAddOwnedPass(pm
, externalPass
);
568 MlirLogicalResult success
= mlirPassManagerRunOnOp(pm
, module
);
569 if (mlirLogicalResultIsSuccess(success
)) {
572 "Expected failure running pass manager on failing external pass.\n");
576 if (userData
.runCallCount
!= 1) {
577 fprintf(stderr
, "Expected runCallCount to be 1\n");
581 mlirPassManagerDestroy(pm
);
583 if (userData
.destructCallCount
!= userData
.constructCallCount
) {
584 fprintf(stderr
, "Expected destructCallCount to be equal to "
585 "constructCallCount\n");
590 mlirTypeIDAllocatorDestroy(typeIDAllocator
);
591 mlirOperationDestroy(module
);
592 mlirContextDestroy(ctx
);
596 testRunPassOnModule();
597 testRunPassOnNestedModule();
598 testPrintPassPipeline();
599 testParsePassPipeline();
600 testParseErrorCapture();