[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / CAPI / pass.c
blob3aad0016b393c403a41469bf1e64907d1be3a20b
1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/Dialect/Func.h"
15 #include "mlir-c/IR.h"
16 #include "mlir-c/RegisterEverything.h"
17 #include "mlir-c/Transforms.h"
19 #include <assert.h>
20 #include <math.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
25 static void registerAllUpstreamDialects(MlirContext ctx) {
26 MlirDialectRegistry registry = mlirDialectRegistryCreate();
27 mlirRegisterAllDialects(registry);
28 mlirContextAppendDialectRegistry(ctx, registry);
29 mlirDialectRegistryDestroy(registry);
32 void testRunPassOnModule(void) {
33 MlirContext ctx = mlirContextCreate();
34 registerAllUpstreamDialects(ctx);
36 const char *funcAsm = //
37 "func.func @foo(%arg0 : i32) -> i32 { \n"
38 " %res = arith.addi %arg0, %arg0 : i32 \n"
39 " return %res : i32 \n"
40 "} \n";
41 MlirOperation func =
42 mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(funcAsm),
43 mlirStringRefCreateFromCString("funcAsm"));
44 if (mlirOperationIsNull(func)) {
45 fprintf(stderr, "Unexpected failure parsing asm.\n");
46 exit(EXIT_FAILURE);
49 // Run the print-op-stats pass on the top-level module:
50 // CHECK-LABEL: Operations encountered:
51 // CHECK: arith.addi , 1
52 // CHECK: func.func , 1
53 // CHECK: func.return , 1
55 MlirPassManager pm = mlirPassManagerCreate(ctx);
56 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
57 mlirPassManagerAddOwnedPass(pm, printOpStatPass);
58 MlirLogicalResult success = mlirPassManagerRunOnOp(pm, func);
59 if (mlirLogicalResultIsFailure(success)) {
60 fprintf(stderr, "Unexpected failure running pass manager.\n");
61 exit(EXIT_FAILURE);
63 mlirPassManagerDestroy(pm);
65 mlirOperationDestroy(func);
66 mlirContextDestroy(ctx);
69 void testRunPassOnNestedModule(void) {
70 MlirContext ctx = mlirContextCreate();
71 registerAllUpstreamDialects(ctx);
73 const char *moduleAsm = //
74 "module { \n"
75 " func.func @foo(%arg0 : i32) -> i32 { \n"
76 " %res = arith.addi %arg0, %arg0 : i32 \n"
77 " return %res : i32 \n"
78 " } \n"
79 " module { \n"
80 " func.func @bar(%arg0 : f32) -> f32 { \n"
81 " %res = arith.addf %arg0, %arg0 : f32 \n"
82 " return %res : f32 \n"
83 " } \n"
84 " } \n"
85 "} \n";
86 MlirOperation module =
87 mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
88 mlirStringRefCreateFromCString("moduleAsm"));
89 if (mlirOperationIsNull(module))
90 exit(1);
92 // Run the print-op-stats pass on functions under the top-level module:
93 // CHECK-LABEL: Operations encountered:
94 // CHECK: arith.addi , 1
95 // CHECK: func.func , 1
96 // CHECK: func.return , 1
98 MlirPassManager pm = mlirPassManagerCreate(ctx);
99 MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
100 pm, mlirStringRefCreateFromCString("func.func"));
101 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
102 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
103 MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
104 if (mlirLogicalResultIsFailure(success))
105 exit(2);
106 mlirPassManagerDestroy(pm);
108 // Run the print-op-stats pass on functions under the nested module:
109 // CHECK-LABEL: Operations encountered:
110 // CHECK: arith.addf , 1
111 // CHECK: func.func , 1
112 // CHECK: func.return , 1
114 MlirPassManager pm = mlirPassManagerCreate(ctx);
115 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
116 pm, mlirStringRefCreateFromCString("builtin.module"));
117 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
118 nestedModulePm, mlirStringRefCreateFromCString("func.func"));
119 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
120 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
121 MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
122 if (mlirLogicalResultIsFailure(success))
123 exit(2);
124 mlirPassManagerDestroy(pm);
127 mlirOperationDestroy(module);
128 mlirContextDestroy(ctx);
131 static void printToStderr(MlirStringRef str, void *userData) {
132 (void)userData;
133 fwrite(str.data, 1, str.length, stderr);
136 static void dontPrint(MlirStringRef str, void *userData) {
137 (void)str;
138 (void)userData;
141 void testPrintPassPipeline(void) {
142 MlirContext ctx = mlirContextCreate();
143 MlirPassManager pm = mlirPassManagerCreateOnOperation(
144 ctx, mlirStringRefCreateFromCString("any"));
145 // Populate the pass-manager
146 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
147 pm, mlirStringRefCreateFromCString("builtin.module"));
148 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
149 nestedModulePm, mlirStringRefCreateFromCString("func.func"));
150 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
151 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
153 // Print the top level pass manager
154 // CHECK: Top-level: any(
155 // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false}))
156 // CHECK-SAME: )
157 fprintf(stderr, "Top-level: ");
158 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
159 NULL);
160 fprintf(stderr, "\n");
162 // Print the pipeline nested one level down
163 // CHECK: Nested Module: builtin.module(func.func(print-op-stats{json=false}))
164 fprintf(stderr, "Nested Module: ");
165 mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
166 fprintf(stderr, "\n");
168 // Print the pipeline nested two levels down
169 // CHECK: Nested Module>Func: func.func(print-op-stats{json=false})
170 fprintf(stderr, "Nested Module>Func: ");
171 mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
172 fprintf(stderr, "\n");
174 mlirPassManagerDestroy(pm);
175 mlirContextDestroy(ctx);
178 void testParsePassPipeline(void) {
179 MlirContext ctx = mlirContextCreate();
180 MlirPassManager pm = mlirPassManagerCreate(ctx);
181 // Try parse a pipeline.
182 MlirLogicalResult status = mlirParsePassPipeline(
183 mlirPassManagerGetAsOpPassManager(pm),
184 mlirStringRefCreateFromCString(
185 "builtin.module(func.func(print-op-stats{json=false}))"),
186 printToStderr, NULL);
187 // Expect a failure, we haven't registered the print-op-stats pass yet.
188 if (mlirLogicalResultIsSuccess(status)) {
189 fprintf(
190 stderr,
191 "Unexpected success parsing pipeline without registering the pass\n");
192 exit(EXIT_FAILURE);
194 // Try again after registrating the pass.
195 mlirRegisterTransformsPrintOpStats();
196 status = mlirParsePassPipeline(
197 mlirPassManagerGetAsOpPassManager(pm),
198 mlirStringRefCreateFromCString(
199 "builtin.module(func.func(print-op-stats{json=false}))"),
200 printToStderr, NULL);
201 // Expect a failure, we haven't registered the print-op-stats pass yet.
202 if (mlirLogicalResultIsFailure(status)) {
203 fprintf(stderr,
204 "Unexpected failure parsing pipeline after registering the pass\n");
205 exit(EXIT_FAILURE);
208 // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}))
209 fprintf(stderr, "Round-trip: ");
210 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
211 NULL);
212 fprintf(stderr, "\n");
214 // Try appending a pass:
215 status = mlirOpPassManagerAddPipeline(
216 mlirPassManagerGetAsOpPassManager(pm),
217 mlirStringRefCreateFromCString("func.func(print-op-stats{json=false})"),
218 printToStderr, NULL);
219 if (mlirLogicalResultIsFailure(status)) {
220 fprintf(stderr, "Unexpected failure appending pipeline\n");
221 exit(EXIT_FAILURE);
223 // CHECK: Appended: builtin.module(
224 // CHECK-SAME: func.func(print-op-stats{json=false}),
225 // CHECK-SAME: func.func(print-op-stats{json=false})
226 // CHECK-SAME: )
227 fprintf(stderr, "Appended: ");
228 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
229 NULL);
230 fprintf(stderr, "\n");
232 mlirPassManagerDestroy(pm);
233 mlirContextDestroy(ctx);
236 void testParseErrorCapture(void) {
237 // CHECK-LABEL: testParseErrorCapture:
238 fprintf(stderr, "\nTEST: testParseErrorCapture:\n");
240 MlirContext ctx = mlirContextCreate();
241 MlirPassManager pm = mlirPassManagerCreate(ctx);
242 MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(pm);
243 MlirStringRef invalidPipeline = mlirStringRefCreateFromCString("invalid");
245 // CHECK: mlirParsePassPipeline:
246 // CHECK: expected pass pipeline to be wrapped with the anchor operation type
247 fprintf(stderr, "mlirParsePassPipeline:\n");
248 if (mlirLogicalResultIsSuccess(
249 mlirParsePassPipeline(opm, invalidPipeline, printToStderr, NULL)))
250 exit(EXIT_FAILURE);
251 fprintf(stderr, "\n");
253 // CHECK: mlirOpPassManagerAddPipeline:
254 // CHECK: 'invalid' does not refer to a registered pass or pass pipeline
255 fprintf(stderr, "mlirOpPassManagerAddPipeline:\n");
256 if (mlirLogicalResultIsSuccess(mlirOpPassManagerAddPipeline(
257 opm, invalidPipeline, printToStderr, NULL)))
258 exit(EXIT_FAILURE);
259 fprintf(stderr, "\n");
261 // Make sure all output is going through the callback.
262 // CHECK: dontPrint: <>
263 fprintf(stderr, "dontPrint: <");
264 if (mlirLogicalResultIsSuccess(
265 mlirParsePassPipeline(opm, invalidPipeline, dontPrint, NULL)))
266 exit(EXIT_FAILURE);
267 if (mlirLogicalResultIsSuccess(
268 mlirOpPassManagerAddPipeline(opm, invalidPipeline, dontPrint, NULL)))
269 exit(EXIT_FAILURE);
270 fprintf(stderr, ">\n");
272 mlirPassManagerDestroy(pm);
273 mlirContextDestroy(ctx);
276 struct TestExternalPassUserData {
277 int constructCallCount;
278 int destructCallCount;
279 int initializeCallCount;
280 int cloneCallCount;
281 int runCallCount;
283 typedef struct TestExternalPassUserData TestExternalPassUserData;
285 void testConstructExternalPass(void *userData) {
286 ++((TestExternalPassUserData *)userData)->constructCallCount;
289 void testDestructExternalPass(void *userData) {
290 ++((TestExternalPassUserData *)userData)->destructCallCount;
293 MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
294 ++((TestExternalPassUserData *)userData)->initializeCallCount;
295 return mlirLogicalResultSuccess();
298 MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
299 void *userData) {
300 ++((TestExternalPassUserData *)userData)->initializeCallCount;
301 return mlirLogicalResultFailure();
304 void *testCloneExternalPass(void *userData) {
305 ++((TestExternalPassUserData *)userData)->cloneCallCount;
306 return userData;
309 void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
310 void *userData) {
311 ++((TestExternalPassUserData *)userData)->runCallCount;
314 void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
315 void *userData) {
316 ++((TestExternalPassUserData *)userData)->runCallCount;
317 MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
318 if (!mlirStringRefEqual(opName,
319 mlirStringRefCreateFromCString("func.func"))) {
320 mlirExternalPassSignalFailure(pass);
324 void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
325 void *userData) {
326 ++((TestExternalPassUserData *)userData)->runCallCount;
327 mlirExternalPassSignalFailure(pass);
330 MlirExternalPassCallbacks makeTestExternalPassCallbacks(
331 MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
332 void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
333 return (MlirExternalPassCallbacks){testConstructExternalPass,
334 testDestructExternalPass, initializePass,
335 testCloneExternalPass, runPass};
338 void testExternalPass(void) {
339 MlirContext ctx = mlirContextCreate();
340 registerAllUpstreamDialects(ctx);
342 const char *moduleAsm = //
343 "module { \n"
344 " func.func @foo(%arg0 : i32) -> i32 { \n"
345 " %res = arith.addi %arg0, %arg0 : i32 \n"
346 " return %res : i32 \n"
347 " } \n"
348 "}";
349 MlirOperation module =
350 mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
351 mlirStringRefCreateFromCString("moduleAsm"));
352 if (mlirOperationIsNull(module)) {
353 fprintf(stderr, "Unexpected failure parsing module.\n");
354 exit(EXIT_FAILURE);
357 MlirStringRef description = mlirStringRefCreateFromCString("");
358 MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
360 MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
362 // Run a generic pass
364 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
365 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
366 MlirStringRef argument =
367 mlirStringRefCreateFromCString("test-external-pass");
368 TestExternalPassUserData userData = {0};
370 MlirPass externalPass = mlirCreateExternalPass(
371 passID, name, argument, description, emptyOpName, 0, NULL,
372 makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
374 if (userData.constructCallCount != 1) {
375 fprintf(stderr, "Expected constructCallCount to be 1\n");
376 exit(EXIT_FAILURE);
379 MlirPassManager pm = mlirPassManagerCreate(ctx);
380 mlirPassManagerAddOwnedPass(pm, externalPass);
381 MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
382 if (mlirLogicalResultIsFailure(success)) {
383 fprintf(stderr, "Unexpected failure running external pass.\n");
384 exit(EXIT_FAILURE);
387 if (userData.runCallCount != 1) {
388 fprintf(stderr, "Expected runCallCount to be 1\n");
389 exit(EXIT_FAILURE);
392 mlirPassManagerDestroy(pm);
394 if (userData.destructCallCount != userData.constructCallCount) {
395 fprintf(stderr, "Expected destructCallCount to be equal to "
396 "constructCallCount\n");
397 exit(EXIT_FAILURE);
401 // Run a func operation pass
403 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
404 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
405 MlirStringRef argument =
406 mlirStringRefCreateFromCString("test-external-func-pass");
407 TestExternalPassUserData userData = {0};
408 MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
409 MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
411 MlirPass externalPass = mlirCreateExternalPass(
412 passID, name, argument, description, funcOpName, 1, &funcHandle,
413 makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
414 &userData);
416 if (userData.constructCallCount != 1) {
417 fprintf(stderr, "Expected constructCallCount to be 1\n");
418 exit(EXIT_FAILURE);
421 MlirPassManager pm = mlirPassManagerCreate(ctx);
422 MlirOpPassManager nestedFuncPm =
423 mlirPassManagerGetNestedUnder(pm, funcOpName);
424 mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
425 MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
426 if (mlirLogicalResultIsFailure(success)) {
427 fprintf(stderr, "Unexpected failure running external operation pass.\n");
428 exit(EXIT_FAILURE);
431 // Since this is a nested pass, it can be cloned and run in parallel
432 if (userData.cloneCallCount != userData.constructCallCount - 1) {
433 fprintf(stderr, "Expected constructCallCount to be 1\n");
434 exit(EXIT_FAILURE);
437 // The pass should only be run once this there is only one func op
438 if (userData.runCallCount != 1) {
439 fprintf(stderr, "Expected runCallCount to be 1\n");
440 exit(EXIT_FAILURE);
443 mlirPassManagerDestroy(pm);
445 if (userData.destructCallCount != userData.constructCallCount) {
446 fprintf(stderr, "Expected destructCallCount to be equal to "
447 "constructCallCount\n");
448 exit(EXIT_FAILURE);
452 // Run a pass with `initialize` set
454 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
455 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
456 MlirStringRef argument =
457 mlirStringRefCreateFromCString("test-external-pass");
458 TestExternalPassUserData userData = {0};
460 MlirPass externalPass = mlirCreateExternalPass(
461 passID, name, argument, description, emptyOpName, 0, NULL,
462 makeTestExternalPassCallbacks(testInitializeExternalPass,
463 testRunExternalPass),
464 &userData);
466 if (userData.constructCallCount != 1) {
467 fprintf(stderr, "Expected constructCallCount to be 1\n");
468 exit(EXIT_FAILURE);
471 MlirPassManager pm = mlirPassManagerCreate(ctx);
472 mlirPassManagerAddOwnedPass(pm, externalPass);
473 MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
474 if (mlirLogicalResultIsFailure(success)) {
475 fprintf(stderr, "Unexpected failure running external pass.\n");
476 exit(EXIT_FAILURE);
479 if (userData.initializeCallCount != 1) {
480 fprintf(stderr, "Expected initializeCallCount to be 1\n");
481 exit(EXIT_FAILURE);
484 if (userData.runCallCount != 1) {
485 fprintf(stderr, "Expected runCallCount to be 1\n");
486 exit(EXIT_FAILURE);
489 mlirPassManagerDestroy(pm);
491 if (userData.destructCallCount != userData.constructCallCount) {
492 fprintf(stderr, "Expected destructCallCount to be equal to "
493 "constructCallCount\n");
494 exit(EXIT_FAILURE);
498 // Run a pass that fails during `initialize`
500 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
501 MlirStringRef name =
502 mlirStringRefCreateFromCString("TestExternalFailingPass");
503 MlirStringRef argument =
504 mlirStringRefCreateFromCString("test-external-failing-pass");
505 TestExternalPassUserData userData = {0};
507 MlirPass externalPass = mlirCreateExternalPass(
508 passID, name, argument, description, emptyOpName, 0, NULL,
509 makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
510 testRunExternalPass),
511 &userData);
513 if (userData.constructCallCount != 1) {
514 fprintf(stderr, "Expected constructCallCount to be 1\n");
515 exit(EXIT_FAILURE);
518 MlirPassManager pm = mlirPassManagerCreate(ctx);
519 mlirPassManagerAddOwnedPass(pm, externalPass);
520 MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
521 if (mlirLogicalResultIsSuccess(success)) {
522 fprintf(
523 stderr,
524 "Expected failure running pass manager on failing external pass.\n");
525 exit(EXIT_FAILURE);
528 if (userData.initializeCallCount != 1) {
529 fprintf(stderr, "Expected initializeCallCount to be 1\n");
530 exit(EXIT_FAILURE);
533 if (userData.runCallCount != 0) {
534 fprintf(stderr, "Expected runCallCount to be 0\n");
535 exit(EXIT_FAILURE);
538 mlirPassManagerDestroy(pm);
540 if (userData.destructCallCount != userData.constructCallCount) {
541 fprintf(stderr, "Expected destructCallCount to be equal to "
542 "constructCallCount\n");
543 exit(EXIT_FAILURE);
547 // Run a pass that fails during `run`
549 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
550 MlirStringRef name =
551 mlirStringRefCreateFromCString("TestExternalFailingPass");
552 MlirStringRef argument =
553 mlirStringRefCreateFromCString("test-external-failing-pass");
554 TestExternalPassUserData userData = {0};
556 MlirPass externalPass = mlirCreateExternalPass(
557 passID, name, argument, description, emptyOpName, 0, NULL,
558 makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
559 &userData);
561 if (userData.constructCallCount != 1) {
562 fprintf(stderr, "Expected constructCallCount to be 1\n");
563 exit(EXIT_FAILURE);
566 MlirPassManager pm = mlirPassManagerCreate(ctx);
567 mlirPassManagerAddOwnedPass(pm, externalPass);
568 MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
569 if (mlirLogicalResultIsSuccess(success)) {
570 fprintf(
571 stderr,
572 "Expected failure running pass manager on failing external pass.\n");
573 exit(EXIT_FAILURE);
576 if (userData.runCallCount != 1) {
577 fprintf(stderr, "Expected runCallCount to be 1\n");
578 exit(EXIT_FAILURE);
581 mlirPassManagerDestroy(pm);
583 if (userData.destructCallCount != userData.constructCallCount) {
584 fprintf(stderr, "Expected destructCallCount to be equal to "
585 "constructCallCount\n");
586 exit(EXIT_FAILURE);
590 mlirTypeIDAllocatorDestroy(typeIDAllocator);
591 mlirOperationDestroy(module);
592 mlirContextDestroy(ctx);
595 int main(void) {
596 testRunPassOnModule();
597 testRunPassOnNestedModule();
598 testPrintPassPipeline();
599 testParsePassPipeline();
600 testParseErrorCapture();
601 testExternalPass();
602 return 0;