1 //===- pass.c - Simple test of C APIs -------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Registration.h"
16 #include "mlir-c/Transforms.h"
24 void testRunPassOnModule() {
25 MlirContext ctx
= mlirContextCreate();
26 mlirRegisterAllDialects(ctx
);
28 MlirModule module
= mlirModuleCreateParse(
31 mlirStringRefCreateFromCString(
32 "func @foo(%arg0 : i32) -> i32 { \n"
33 " %res = arith.addi %arg0, %arg0 : i32 \n"
34 " return %res : i32 \n"
37 if (mlirModuleIsNull(module
)) {
38 fprintf(stderr
, "Unexpected failure parsing module.\n");
42 // Run the print-op-stats pass on the top-level module:
43 // CHECK-LABEL: Operations encountered:
44 // CHECK: arith.addi , 1
45 // CHECK: builtin.func , 1
46 // CHECK: std.return , 1
48 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
49 MlirPass printOpStatPass
= mlirCreateTransformsPrintOpStats();
50 mlirPassManagerAddOwnedPass(pm
, printOpStatPass
);
51 MlirLogicalResult success
= mlirPassManagerRun(pm
, module
);
52 if (mlirLogicalResultIsFailure(success
)) {
53 fprintf(stderr
, "Unexpected failure running pass manager.\n");
56 mlirPassManagerDestroy(pm
);
58 mlirModuleDestroy(module
);
59 mlirContextDestroy(ctx
);
62 void testRunPassOnNestedModule() {
63 MlirContext ctx
= mlirContextCreate();
64 mlirRegisterAllDialects(ctx
);
67 mlirModuleCreateParse(ctx
,
69 mlirStringRefCreateFromCString(
70 "func @foo(%arg0 : i32) -> i32 { \n"
71 " %res = arith.addi %arg0, %arg0 : i32 \n"
72 " return %res : i32 \n"
75 " func @bar(%arg0 : f32) -> f32 { \n"
76 " %res = arith.addf %arg0, %arg0 : f32 \n"
77 " return %res : f32 \n"
81 if (mlirModuleIsNull(module
))
84 // Run the print-op-stats pass on functions under the top-level module:
85 // CHECK-LABEL: Operations encountered:
86 // CHECK: arith.addi , 1
87 // CHECK: builtin.func , 1
88 // CHECK: std.return , 1
90 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
91 MlirOpPassManager nestedFuncPm
= mlirPassManagerGetNestedUnder(
92 pm
, mlirStringRefCreateFromCString("builtin.func"));
93 MlirPass printOpStatPass
= mlirCreateTransformsPrintOpStats();
94 mlirOpPassManagerAddOwnedPass(nestedFuncPm
, printOpStatPass
);
95 MlirLogicalResult success
= mlirPassManagerRun(pm
, module
);
96 if (mlirLogicalResultIsFailure(success
))
98 mlirPassManagerDestroy(pm
);
100 // Run the print-op-stats pass on functions under the nested module:
101 // CHECK-LABEL: Operations encountered:
102 // CHECK: arith.addf , 1
103 // CHECK: builtin.func , 1
104 // CHECK: std.return , 1
106 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
107 MlirOpPassManager nestedModulePm
= mlirPassManagerGetNestedUnder(
108 pm
, mlirStringRefCreateFromCString("builtin.module"));
109 MlirOpPassManager nestedFuncPm
= mlirOpPassManagerGetNestedUnder(
110 nestedModulePm
, mlirStringRefCreateFromCString("builtin.func"));
111 MlirPass printOpStatPass
= mlirCreateTransformsPrintOpStats();
112 mlirOpPassManagerAddOwnedPass(nestedFuncPm
, printOpStatPass
);
113 MlirLogicalResult success
= mlirPassManagerRun(pm
, module
);
114 if (mlirLogicalResultIsFailure(success
))
116 mlirPassManagerDestroy(pm
);
119 mlirModuleDestroy(module
);
120 mlirContextDestroy(ctx
);
123 static void printToStderr(MlirStringRef str
, void *userData
) {
125 fwrite(str
.data
, 1, str
.length
, stderr
);
128 void testPrintPassPipeline() {
129 MlirContext ctx
= mlirContextCreate();
130 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
131 // Populate the pass-manager
132 MlirOpPassManager nestedModulePm
= mlirPassManagerGetNestedUnder(
133 pm
, mlirStringRefCreateFromCString("builtin.module"));
134 MlirOpPassManager nestedFuncPm
= mlirOpPassManagerGetNestedUnder(
135 nestedModulePm
, mlirStringRefCreateFromCString("builtin.func"));
136 MlirPass printOpStatPass
= mlirCreateTransformsPrintOpStats();
137 mlirOpPassManagerAddOwnedPass(nestedFuncPm
, printOpStatPass
);
139 // Print the top level pass manager
140 // CHECK: Top-level: builtin.module(builtin.func(print-op-stats))
141 fprintf(stderr
, "Top-level: ");
142 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm
), printToStderr
,
144 fprintf(stderr
, "\n");
146 // Print the pipeline nested one level down
147 // CHECK: Nested Module: builtin.func(print-op-stats)
148 fprintf(stderr
, "Nested Module: ");
149 mlirPrintPassPipeline(nestedModulePm
, printToStderr
, NULL
);
150 fprintf(stderr
, "\n");
152 // Print the pipeline nested two levels down
153 // CHECK: Nested Module>Func: print-op-stats
154 fprintf(stderr
, "Nested Module>Func: ");
155 mlirPrintPassPipeline(nestedFuncPm
, printToStderr
, NULL
);
156 fprintf(stderr
, "\n");
158 mlirPassManagerDestroy(pm
);
159 mlirContextDestroy(ctx
);
162 void testParsePassPipeline() {
163 MlirContext ctx
= mlirContextCreate();
164 MlirPassManager pm
= mlirPassManagerCreate(ctx
);
165 // Try parse a pipeline.
166 MlirLogicalResult status
= mlirParsePassPipeline(
167 mlirPassManagerGetAsOpPassManager(pm
),
168 mlirStringRefCreateFromCString(
169 "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))"));
170 // Expect a failure, we haven't registered the print-op-stats pass yet.
171 if (mlirLogicalResultIsSuccess(status
)) {
172 fprintf(stderr
, "Unexpected success parsing pipeline without registering the pass\n");
175 // Try again after registrating the pass.
176 mlirRegisterTransformsPrintOpStats();
177 status
= mlirParsePassPipeline(
178 mlirPassManagerGetAsOpPassManager(pm
),
179 mlirStringRefCreateFromCString(
180 "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))"));
181 // Expect a failure, we haven't registered the print-op-stats pass yet.
182 if (mlirLogicalResultIsFailure(status
)) {
183 fprintf(stderr
, "Unexpected failure parsing pipeline after registering the pass\n");
187 // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))
188 fprintf(stderr
, "Round-trip: ");
189 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm
), printToStderr
,
191 fprintf(stderr
, "\n");
192 mlirPassManagerDestroy(pm
);
193 mlirContextDestroy(ctx
);
197 testRunPassOnModule();
198 testRunPassOnNestedModule();
199 testPrintPassPipeline();
200 testParsePassPipeline();