Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Affine / TestReifyValueBounds.cpp
blob34513cd418e4c208b4d776137dcebc64b856fca7
1 //===- TestReifyValueBounds.cpp - Test value bounds reification -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
13 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
14 #include "mlir/Dialect/Arith/Transforms/Transforms.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
21 #include "mlir/Pass/Pass.h"
23 #define PASS_NAME "test-affine-reify-value-bounds"
25 using namespace mlir;
26 using namespace mlir::affine;
27 using mlir::presburger::BoundType;
29 namespace {
31 /// This pass applies the permutation on the first maximal perfect nest.
32 struct TestReifyValueBounds
33 : public PassWrapper<TestReifyValueBounds, OperationPass<func::FuncOp>> {
34 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
36 StringRef getArgument() const final { return PASS_NAME; }
37 StringRef getDescription() const final {
38 return "Tests ValueBoundsOpInterface with affine dialect reification";
40 TestReifyValueBounds() = default;
41 TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
43 void getDependentDialects(DialectRegistry &registry) const override {
44 registry.insert<affine::AffineDialect, tensor::TensorDialect,
45 memref::MemRefDialect>();
48 void runOnOperation() override;
50 private:
51 Option<bool> reifyToFuncArgs{
52 *this, "reify-to-func-args",
53 llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)};
55 Option<bool> useArithOps{*this, "use-arith-ops",
56 llvm::cl::desc("Reify with arith dialect ops"),
57 llvm::cl::init(false)};
60 } // namespace
62 static ValueBoundsConstraintSet::ComparisonOperator
63 invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
64 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
65 return ValueBoundsConstraintSet::ComparisonOperator::GE;
66 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE)
67 return ValueBoundsConstraintSet::ComparisonOperator::GT;
68 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT)
69 return ValueBoundsConstraintSet::ComparisonOperator::LE;
70 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE)
71 return ValueBoundsConstraintSet::ComparisonOperator::LT;
72 llvm_unreachable("unsupported comparison operator");
75 /// Look for "test.reify_bound" ops in the input and replace their results with
76 /// the reified values.
77 static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
78 bool reifyToFuncArgs,
79 bool useArithOps) {
80 IRRewriter rewriter(funcOp.getContext());
81 WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
82 auto boundType = op.getBoundType();
83 Value value = op.getVar();
84 std::optional<int64_t> dim = op.getDim();
85 bool constant = op.getConstant();
86 bool scalable = op.getScalable();
88 // Prepare stop condition. By default, reify in terms of the op's
89 // operands. No stop condition is used when a constant was requested.
90 std::function<bool(Value, std::optional<int64_t>,
91 ValueBoundsConstraintSet & cstr)>
92 stopCondition = [&](Value v, std::optional<int64_t> d,
93 ValueBoundsConstraintSet &cstr) {
94 // Reify in terms of SSA values that are different from `value`.
95 return v != value;
97 if (reifyToFuncArgs) {
98 // Reify in terms of function block arguments.
99 stopCondition = [](Value v, std::optional<int64_t> d,
100 ValueBoundsConstraintSet &cstr) {
101 auto bbArg = dyn_cast<BlockArgument>(v);
102 if (!bbArg)
103 return false;
104 return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp());
108 // Reify value bound
109 rewriter.setInsertionPointAfter(op);
110 FailureOr<OpFoldResult> reified = failure();
111 if (constant) {
112 auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
113 boundType, {value, dim}, /*stopCondition=*/nullptr);
114 if (succeeded(reifiedConst))
115 reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
116 } else if (scalable) {
117 auto loc = op->getLoc();
118 auto reifiedScalable =
119 vector::ScalableValueBoundsConstraintSet::computeScalableBound(
120 value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType);
121 if (succeeded(reifiedScalable)) {
122 SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
123 if (reifiedScalable->map.getNumInputs() == 1) {
124 // The only possible input to the bound is vscale.
125 vscaleOperand.push_back(std::make_pair(
126 rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
128 reified = affine::materializeComputedBound(
129 rewriter, loc, reifiedScalable->map, vscaleOperand);
131 } else {
132 if (useArithOps) {
133 reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
134 op.getVariable(), stopCondition);
135 } else {
136 reified = reifyValueBound(rewriter, op->getLoc(), boundType,
137 op.getVariable(), stopCondition);
140 if (failed(reified)) {
141 op->emitOpError("could not reify bound");
142 return WalkResult::interrupt();
145 // Replace the op with the reified bound.
146 if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
147 rewriter.replaceOp(op, val);
148 return WalkResult::skip();
150 Value constOp = rewriter.create<arith::ConstantIndexOp>(
151 op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
152 rewriter.replaceOp(op, constOp);
153 return WalkResult::skip();
155 return failure(result.wasInterrupted());
158 /// Look for "test.compare" ops and emit errors/remarks.
159 static LogicalResult testEquality(func::FuncOp funcOp) {
160 IRRewriter rewriter(funcOp.getContext());
161 WalkResult result = funcOp.walk([&](test::CompareOp op) {
162 auto cmpType = op.getComparisonOperator();
163 if (op.getCompose()) {
164 if (cmpType != ValueBoundsConstraintSet::EQ) {
165 op->emitOpError(
166 "comparison operator must be EQ when 'composed' is specified");
167 return WalkResult::interrupt();
169 FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
170 op->getOperand(0), op->getOperand(1));
171 if (failed(delta)) {
172 op->emitError("could not determine equality");
173 } else if (*delta == 0) {
174 op->emitRemark("equal");
175 } else {
176 op->emitRemark("different");
178 return WalkResult::advance();
181 auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
182 return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
184 if (compare(cmpType)) {
185 op->emitRemark("true");
186 } else if (cmpType != ValueBoundsConstraintSet::EQ &&
187 compare(invertComparisonOperator(cmpType))) {
188 op->emitRemark("false");
189 } else if (cmpType == ValueBoundsConstraintSet::EQ &&
190 (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
191 compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
192 op->emitRemark("false");
193 } else {
194 op->emitError("unknown");
196 return WalkResult::advance();
198 return failure(result.wasInterrupted());
201 void TestReifyValueBounds::runOnOperation() {
202 if (failed(
203 testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps)))
204 signalPassFailure();
205 if (failed(testEquality(getOperation())))
206 signalPassFailure();
209 namespace mlir {
210 void registerTestAffineReifyValueBoundsPass() {
211 PassRegistration<TestReifyValueBounds>();
213 } // namespace mlir