[RISCV] Fix the code alignment for GroupFloatVectors. NFC
[llvm-project.git] / mlir / test / CAPI / pass.c
blobc4974488c5f2dd4f36684f679dc2650bf4e4a123
1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Registration.h"
16 #include "mlir-c/Transforms.h"
18 #include <assert.h>
19 #include <math.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
24 void testRunPassOnModule() {
25 MlirContext ctx = mlirContextCreate();
26 mlirRegisterAllDialects(ctx);
28 MlirModule module = mlirModuleCreateParse(
29 ctx,
30 // clang-format off
31 mlirStringRefCreateFromCString(
32 "func @foo(%arg0 : i32) -> i32 { \n"
33 " %res = arith.addi %arg0, %arg0 : i32 \n"
34 " return %res : i32 \n"
35 "}"));
36 // clang-format on
37 if (mlirModuleIsNull(module)) {
38 fprintf(stderr, "Unexpected failure parsing module.\n");
39 exit(EXIT_FAILURE);
42 // Run the print-op-stats pass on the top-level module:
43 // CHECK-LABEL: Operations encountered:
44 // CHECK: arith.addi , 1
45 // CHECK: builtin.func , 1
46 // CHECK: std.return , 1
48 MlirPassManager pm = mlirPassManagerCreate(ctx);
49 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
50 mlirPassManagerAddOwnedPass(pm, printOpStatPass);
51 MlirLogicalResult success = mlirPassManagerRun(pm, module);
52 if (mlirLogicalResultIsFailure(success)) {
53 fprintf(stderr, "Unexpected failure running pass manager.\n");
54 exit(EXIT_FAILURE);
56 mlirPassManagerDestroy(pm);
58 mlirModuleDestroy(module);
59 mlirContextDestroy(ctx);
62 void testRunPassOnNestedModule() {
63 MlirContext ctx = mlirContextCreate();
64 mlirRegisterAllDialects(ctx);
66 MlirModule module =
67 mlirModuleCreateParse(ctx,
68 // clang-format off
69 mlirStringRefCreateFromCString(
70 "func @foo(%arg0 : i32) -> i32 { \n"
71 " %res = arith.addi %arg0, %arg0 : i32 \n"
72 " return %res : i32 \n"
73 "} \n"
74 "module { \n"
75 " func @bar(%arg0 : f32) -> f32 { \n"
76 " %res = arith.addf %arg0, %arg0 : f32 \n"
77 " return %res : f32 \n"
78 " } \n"
79 "}"));
80 // clang-format on
81 if (mlirModuleIsNull(module))
82 exit(1);
84 // Run the print-op-stats pass on functions under the top-level module:
85 // CHECK-LABEL: Operations encountered:
86 // CHECK: arith.addi , 1
87 // CHECK: builtin.func , 1
88 // CHECK: std.return , 1
90 MlirPassManager pm = mlirPassManagerCreate(ctx);
91 MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
92 pm, mlirStringRefCreateFromCString("builtin.func"));
93 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
94 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
95 MlirLogicalResult success = mlirPassManagerRun(pm, module);
96 if (mlirLogicalResultIsFailure(success))
97 exit(2);
98 mlirPassManagerDestroy(pm);
100 // Run the print-op-stats pass on functions under the nested module:
101 // CHECK-LABEL: Operations encountered:
102 // CHECK: arith.addf , 1
103 // CHECK: builtin.func , 1
104 // CHECK: std.return , 1
106 MlirPassManager pm = mlirPassManagerCreate(ctx);
107 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
108 pm, mlirStringRefCreateFromCString("builtin.module"));
109 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
110 nestedModulePm, mlirStringRefCreateFromCString("builtin.func"));
111 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
112 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
113 MlirLogicalResult success = mlirPassManagerRun(pm, module);
114 if (mlirLogicalResultIsFailure(success))
115 exit(2);
116 mlirPassManagerDestroy(pm);
119 mlirModuleDestroy(module);
120 mlirContextDestroy(ctx);
123 static void printToStderr(MlirStringRef str, void *userData) {
124 (void)userData;
125 fwrite(str.data, 1, str.length, stderr);
128 void testPrintPassPipeline() {
129 MlirContext ctx = mlirContextCreate();
130 MlirPassManager pm = mlirPassManagerCreate(ctx);
131 // Populate the pass-manager
132 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
133 pm, mlirStringRefCreateFromCString("builtin.module"));
134 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
135 nestedModulePm, mlirStringRefCreateFromCString("builtin.func"));
136 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
137 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
139 // Print the top level pass manager
140 // CHECK: Top-level: builtin.module(builtin.func(print-op-stats))
141 fprintf(stderr, "Top-level: ");
142 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
143 NULL);
144 fprintf(stderr, "\n");
146 // Print the pipeline nested one level down
147 // CHECK: Nested Module: builtin.func(print-op-stats)
148 fprintf(stderr, "Nested Module: ");
149 mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
150 fprintf(stderr, "\n");
152 // Print the pipeline nested two levels down
153 // CHECK: Nested Module>Func: print-op-stats
154 fprintf(stderr, "Nested Module>Func: ");
155 mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
156 fprintf(stderr, "\n");
158 mlirPassManagerDestroy(pm);
159 mlirContextDestroy(ctx);
162 void testParsePassPipeline() {
163 MlirContext ctx = mlirContextCreate();
164 MlirPassManager pm = mlirPassManagerCreate(ctx);
165 // Try parse a pipeline.
166 MlirLogicalResult status = mlirParsePassPipeline(
167 mlirPassManagerGetAsOpPassManager(pm),
168 mlirStringRefCreateFromCString(
169 "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))"));
170 // Expect a failure, we haven't registered the print-op-stats pass yet.
171 if (mlirLogicalResultIsSuccess(status)) {
172 fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
173 exit(EXIT_FAILURE);
175 // Try again after registrating the pass.
176 mlirRegisterTransformsPrintOpStats();
177 status = mlirParsePassPipeline(
178 mlirPassManagerGetAsOpPassManager(pm),
179 mlirStringRefCreateFromCString(
180 "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))"));
181 // Expect a failure, we haven't registered the print-op-stats pass yet.
182 if (mlirLogicalResultIsFailure(status)) {
183 fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
184 exit(EXIT_FAILURE);
187 // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))
188 fprintf(stderr, "Round-trip: ");
189 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
190 NULL);
191 fprintf(stderr, "\n");
192 mlirPassManagerDestroy(pm);
193 mlirContextDestroy(ctx);
196 int main() {
197 testRunPassOnModule();
198 testRunPassOnNestedModule();
199 testPrintPassPipeline();
200 testParsePassPipeline();
201 return 0;