1 //===- TestWhileOpBuilder.cpp - Pass to test WhileOp::build ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to test some builder functions of WhileOp. It
10 // tests the regression explained in https://reviews.llvm.org/D142952, where
11 // a WhileOp::build overload crashed when fed with operands of different types
12 // than the result types.
14 // To test the build function, the pass copies each WhileOp found in the body
15 // of a FuncOp and adds an additional WhileOp with the same operands and result
16 // types (but dummy computations) using the builder in question.
18 //===----------------------------------------------------------------------===//
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/Pass/Pass.h"
28 using namespace mlir::arith
;
29 using namespace mlir::scf
;
32 struct TestSCFWhileOpBuilderPass
33 : public PassWrapper
<TestSCFWhileOpBuilderPass
,
34 OperationPass
<func::FuncOp
>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileOpBuilderPass
)
37 StringRef
getArgument() const final
{ return "test-scf-while-op-builder"; }
38 StringRef
getDescription() const final
{
39 return "test build functions of scf.while";
41 explicit TestSCFWhileOpBuilderPass() = default;
42 TestSCFWhileOpBuilderPass(const TestSCFWhileOpBuilderPass
&pass
) = default;
44 void runOnOperation() override
{
45 func::FuncOp func
= getOperation();
46 func
.walk([&](WhileOp whileOp
) {
47 Location loc
= whileOp
->getLoc();
48 ImplicitLocOpBuilder
builder(loc
, whileOp
);
50 // Create a WhileOp with the same operands and result types.
51 TypeRange resultTypes
= whileOp
->getResultTypes();
52 ValueRange operands
= whileOp
->getOperands();
53 builder
.create
<WhileOp
>(
54 loc
, resultTypes
, operands
, /*beforeBuilder=*/
55 [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
56 // Just cast the before args into the right types for condition.
57 ImplicitLocOpBuilder
builder(loc
, b
);
59 builder
.create
<UnrealizedConversionCastOp
>(resultTypes
, args
);
60 auto cmp
= builder
.create
<ConstantIntOp
>(/*value=*/1, /*width=*/1);
61 builder
.create
<ConditionOp
>(cmp
, castOp
->getResults());
64 [&](OpBuilder
&b
, Location loc
, ValueRange args
) {
65 // Just cast the after args into the right types for yield.
66 ImplicitLocOpBuilder
builder(loc
, b
);
67 auto castOp
= builder
.create
<UnrealizedConversionCastOp
>(
68 operands
.getTypes(), args
);
69 builder
.create
<YieldOp
>(castOp
->getResults());
78 void registerTestSCFWhileOpBuilderPass() {
79 PassRegistration
<TestSCFWhileOpBuilderPass
>();