1 //===- TestReifyValueBounds.cpp - Test value bounds reification -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
13 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
14 #include "mlir/Dialect/Arith/Transforms/Transforms.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
21 #include "mlir/Pass/Pass.h"
23 #define PASS_NAME "test-affine-reify-value-bounds"
26 using namespace mlir::affine
;
27 using mlir::presburger::BoundType
;
31 /// This pass applies the permutation on the first maximal perfect nest.
32 struct TestReifyValueBounds
33 : public PassWrapper
<TestReifyValueBounds
, OperationPass
<func::FuncOp
>> {
34 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds
)
36 StringRef
getArgument() const final
{ return PASS_NAME
; }
37 StringRef
getDescription() const final
{
38 return "Tests ValueBoundsOpInterface with affine dialect reification";
40 TestReifyValueBounds() = default;
41 TestReifyValueBounds(const TestReifyValueBounds
&pass
) : PassWrapper(pass
){};
43 void getDependentDialects(DialectRegistry
®istry
) const override
{
44 registry
.insert
<affine::AffineDialect
, tensor::TensorDialect
,
45 memref::MemRefDialect
>();
48 void runOnOperation() override
;
51 Option
<bool> reifyToFuncArgs
{
52 *this, "reify-to-func-args",
53 llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)};
55 Option
<bool> useArithOps
{*this, "use-arith-ops",
56 llvm::cl::desc("Reify with arith dialect ops"),
57 llvm::cl::init(false)};
62 static ValueBoundsConstraintSet::ComparisonOperator
63 invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp
) {
64 if (cmp
== ValueBoundsConstraintSet::ComparisonOperator::LT
)
65 return ValueBoundsConstraintSet::ComparisonOperator::GE
;
66 if (cmp
== ValueBoundsConstraintSet::ComparisonOperator::LE
)
67 return ValueBoundsConstraintSet::ComparisonOperator::GT
;
68 if (cmp
== ValueBoundsConstraintSet::ComparisonOperator::GT
)
69 return ValueBoundsConstraintSet::ComparisonOperator::LE
;
70 if (cmp
== ValueBoundsConstraintSet::ComparisonOperator::GE
)
71 return ValueBoundsConstraintSet::ComparisonOperator::LT
;
72 llvm_unreachable("unsupported comparison operator");
75 /// Look for "test.reify_bound" ops in the input and replace their results with
76 /// the reified values.
77 static LogicalResult
testReifyValueBounds(func::FuncOp funcOp
,
80 IRRewriter
rewriter(funcOp
.getContext());
81 WalkResult result
= funcOp
.walk([&](test::ReifyBoundOp op
) {
82 auto boundType
= op
.getBoundType();
83 Value value
= op
.getVar();
84 std::optional
<int64_t> dim
= op
.getDim();
85 bool constant
= op
.getConstant();
86 bool scalable
= op
.getScalable();
88 // Prepare stop condition. By default, reify in terms of the op's
89 // operands. No stop condition is used when a constant was requested.
90 std::function
<bool(Value
, std::optional
<int64_t>,
91 ValueBoundsConstraintSet
& cstr
)>
92 stopCondition
= [&](Value v
, std::optional
<int64_t> d
,
93 ValueBoundsConstraintSet
&cstr
) {
94 // Reify in terms of SSA values that are different from `value`.
97 if (reifyToFuncArgs
) {
98 // Reify in terms of function block arguments.
99 stopCondition
= [](Value v
, std::optional
<int64_t> d
,
100 ValueBoundsConstraintSet
&cstr
) {
101 auto bbArg
= dyn_cast
<BlockArgument
>(v
);
104 return isa
<FunctionOpInterface
>(bbArg
.getParentBlock()->getParentOp());
109 rewriter
.setInsertionPointAfter(op
);
110 FailureOr
<OpFoldResult
> reified
= failure();
112 auto reifiedConst
= ValueBoundsConstraintSet::computeConstantBound(
113 boundType
, {value
, dim
}, /*stopCondition=*/nullptr);
114 if (succeeded(reifiedConst
))
115 reified
= FailureOr
<OpFoldResult
>(rewriter
.getIndexAttr(*reifiedConst
));
116 } else if (scalable
) {
117 auto loc
= op
->getLoc();
118 auto reifiedScalable
=
119 vector::ScalableValueBoundsConstraintSet::computeScalableBound(
120 value
, dim
, *op
.getVscaleMin(), *op
.getVscaleMax(), boundType
);
121 if (succeeded(reifiedScalable
)) {
122 SmallVector
<std::pair
<Value
, std::optional
<int64_t>>, 1> vscaleOperand
;
123 if (reifiedScalable
->map
.getNumInputs() == 1) {
124 // The only possible input to the bound is vscale.
125 vscaleOperand
.push_back(std::make_pair(
126 rewriter
.create
<vector::VectorScaleOp
>(loc
), std::nullopt
));
128 reified
= affine::materializeComputedBound(
129 rewriter
, loc
, reifiedScalable
->map
, vscaleOperand
);
133 reified
= arith::reifyValueBound(rewriter
, op
->getLoc(), boundType
,
134 op
.getVariable(), stopCondition
);
136 reified
= reifyValueBound(rewriter
, op
->getLoc(), boundType
,
137 op
.getVariable(), stopCondition
);
140 if (failed(reified
)) {
141 op
->emitOpError("could not reify bound");
142 return WalkResult::interrupt();
145 // Replace the op with the reified bound.
146 if (auto val
= llvm::dyn_cast_if_present
<Value
>(*reified
)) {
147 rewriter
.replaceOp(op
, val
);
148 return WalkResult::skip();
150 Value constOp
= rewriter
.create
<arith::ConstantIndexOp
>(
151 op
->getLoc(), cast
<IntegerAttr
>(reified
->get
<Attribute
>()).getInt());
152 rewriter
.replaceOp(op
, constOp
);
153 return WalkResult::skip();
155 return failure(result
.wasInterrupted());
158 /// Look for "test.compare" ops and emit errors/remarks.
159 static LogicalResult
testEquality(func::FuncOp funcOp
) {
160 IRRewriter
rewriter(funcOp
.getContext());
161 WalkResult result
= funcOp
.walk([&](test::CompareOp op
) {
162 auto cmpType
= op
.getComparisonOperator();
163 if (op
.getCompose()) {
164 if (cmpType
!= ValueBoundsConstraintSet::EQ
) {
166 "comparison operator must be EQ when 'composed' is specified");
167 return WalkResult::interrupt();
169 FailureOr
<int64_t> delta
= affine::fullyComposeAndComputeConstantDelta(
170 op
->getOperand(0), op
->getOperand(1));
172 op
->emitError("could not determine equality");
173 } else if (*delta
== 0) {
174 op
->emitRemark("equal");
176 op
->emitRemark("different");
178 return WalkResult::advance();
181 auto compare
= [&](ValueBoundsConstraintSet::ComparisonOperator cmp
) {
182 return ValueBoundsConstraintSet::compare(op
.getLhs(), cmp
, op
.getRhs());
184 if (compare(cmpType
)) {
185 op
->emitRemark("true");
186 } else if (cmpType
!= ValueBoundsConstraintSet::EQ
&&
187 compare(invertComparisonOperator(cmpType
))) {
188 op
->emitRemark("false");
189 } else if (cmpType
== ValueBoundsConstraintSet::EQ
&&
190 (compare(ValueBoundsConstraintSet::ComparisonOperator::LT
) ||
191 compare(ValueBoundsConstraintSet::ComparisonOperator::GT
))) {
192 op
->emitRemark("false");
194 op
->emitError("unknown");
196 return WalkResult::advance();
198 return failure(result
.wasInterrupted());
201 void TestReifyValueBounds::runOnOperation() {
203 testReifyValueBounds(getOperation(), reifyToFuncArgs
, useArithOps
)))
205 if (failed(testEquality(getOperation())))
210 void registerTestAffineReifyValueBoundsPass() {
211 PassRegistration
<TestReifyValueBounds
>();