1 //===- IndexOps.cpp - Index operation definitions --------------------------==//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Index/IR/IndexOps.h"
10 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
11 #include "mlir/Dialect/Index/IR/IndexDialect.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/OpImplementation.h"
16 using namespace mlir::index
;
18 //===----------------------------------------------------------------------===//
20 //===----------------------------------------------------------------------===//
22 void IndexDialect::registerOperations() {
25 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
29 Operation
*IndexDialect::materializeConstant(OpBuilder
&b
, Attribute value
,
30 Type type
, Location loc
) {
31 // Materialize bool constants as `i1`.
32 if (auto boolValue
= dyn_cast
<BoolAttr
>(value
)) {
33 if (!type
.isSignlessInteger(1))
35 return b
.create
<BoolConstantOp
>(loc
, type
, boolValue
);
38 // Materialize integer attributes as `index`.
39 if (auto indexValue
= dyn_cast
<IntegerAttr
>(value
)) {
40 if (!indexValue
.getType().isa
<IndexType
>() || !type
.isa
<IndexType
>())
42 assert(indexValue
.getValue().getBitWidth() ==
43 IndexType::kInternalStorageBitWidth
);
44 return b
.create
<ConstantOp
>(loc
, indexValue
);
50 //===----------------------------------------------------------------------===//
52 //===----------------------------------------------------------------------===//
54 /// Fold an index operation irrespective of the target bitwidth. The
55 /// operation must satisfy the property:
58 /// trunc(f(a, b)) = f(trunc(a), trunc(b))
61 /// For all values of `a` and `b`. The function accepts a lambda that computes
62 /// the integer result, which in turn must satisfy the above property.
63 static OpFoldResult
foldBinaryOpUnchecked(
64 ArrayRef
<Attribute
> operands
,
65 function_ref
<Optional
<APInt
>(const APInt
&, const APInt
&)> calculate
) {
66 assert(operands
.size() == 2 && "binary operation expected 2 operands");
67 auto lhs
= dyn_cast_if_present
<IntegerAttr
>(operands
[0]);
68 auto rhs
= dyn_cast_if_present
<IntegerAttr
>(operands
[1]);
72 Optional
<APInt
> result
= calculate(lhs
.getValue(), rhs
.getValue());
75 assert(result
->trunc(32) ==
76 calculate(lhs
.getValue().trunc(32), rhs
.getValue().trunc(32)));
77 return IntegerAttr::get(IndexType::get(lhs
.getContext()), std::move(*result
));
80 /// Fold an index operation only if the truncated 64-bit result matches the
81 /// 32-bit result for operations that don't satisfy the above property. These
82 /// are operations where the upper bits of the operands can affect the lower
83 /// bits of the results.
85 /// The function accepts a lambda that computes the integer result in both
86 /// 64-bit and 32-bit. If either call returns `None`, the operation is not
88 static OpFoldResult
foldBinaryOpChecked(
89 ArrayRef
<Attribute
> operands
,
90 function_ref
<Optional
<APInt
>(const APInt
&, const APInt
&lhs
)> calculate
) {
91 assert(operands
.size() == 2 && "binary operation expected 2 operands");
92 auto lhs
= dyn_cast_if_present
<IntegerAttr
>(operands
[0]);
93 auto rhs
= dyn_cast_if_present
<IntegerAttr
>(operands
[1]);
94 // Only fold index operands.
98 // Compute the 64-bit result and the 32-bit result.
99 Optional
<APInt
> result64
= calculate(lhs
.getValue(), rhs
.getValue());
102 Optional
<APInt
> result32
=
103 calculate(lhs
.getValue().trunc(32), rhs
.getValue().trunc(32));
106 // Compare the truncated 64-bit result to the 32-bit result.
107 if (result64
->trunc(32) != *result32
)
109 // The operation can be folded for these particular operands.
110 return IntegerAttr::get(IndexType::get(lhs
.getContext()),
111 std::move(*result64
));
114 //===----------------------------------------------------------------------===//
116 //===----------------------------------------------------------------------===//
118 OpFoldResult
AddOp::fold(ArrayRef
<Attribute
> operands
) {
119 return foldBinaryOpUnchecked(
120 operands
, [](const APInt
&lhs
, const APInt
&rhs
) { return lhs
+ rhs
; });
123 //===----------------------------------------------------------------------===//
125 //===----------------------------------------------------------------------===//
127 OpFoldResult
SubOp::fold(ArrayRef
<Attribute
> operands
) {
128 return foldBinaryOpUnchecked(
129 operands
, [](const APInt
&lhs
, const APInt
&rhs
) { return lhs
- rhs
; });
132 //===----------------------------------------------------------------------===//
134 //===----------------------------------------------------------------------===//
136 OpFoldResult
MulOp::fold(ArrayRef
<Attribute
> operands
) {
137 return foldBinaryOpUnchecked(
138 operands
, [](const APInt
&lhs
, const APInt
&rhs
) { return lhs
* rhs
; });
141 //===----------------------------------------------------------------------===//
143 //===----------------------------------------------------------------------===//
145 OpFoldResult
DivSOp::fold(ArrayRef
<Attribute
> operands
) {
146 return foldBinaryOpChecked(
147 operands
, [](const APInt
&lhs
, const APInt
&rhs
) -> Optional
<APInt
> {
148 // Don't fold division by zero.
151 return lhs
.sdiv(rhs
);
155 //===----------------------------------------------------------------------===//
157 //===----------------------------------------------------------------------===//
159 OpFoldResult
DivUOp::fold(ArrayRef
<Attribute
> operands
) {
160 return foldBinaryOpChecked(
161 operands
, [](const APInt
&lhs
, const APInt
&rhs
) -> Optional
<APInt
> {
162 // Don't fold division by zero.
165 return lhs
.udiv(rhs
);
169 //===----------------------------------------------------------------------===//
171 //===----------------------------------------------------------------------===//
173 /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
174 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
175 static Optional
<APInt
> calculateCeilDivS(const APInt
&n
, const APInt
&m
) {
176 // Don't fold division by zero.
179 // Short-circuit the zero case.
183 bool mGtZ
= m
.sgt(0);
184 if (n
.sgt(0) != mGtZ
) {
185 // If the operands have different signs, compute the negative result. Signed
186 // division overflow is not possible, since if `m == -1`, `n` can be at most
187 // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
188 return -(-n
).sdiv(m
);
190 // Otherwise, compute the positive result. Signed division overflow is not
191 // possible since if `m == -1`, `x` will be `1`.
192 int64_t x
= mGtZ
? -1 : 1;
193 return (n
+ x
).sdiv(m
) + 1;
196 OpFoldResult
CeilDivSOp::fold(ArrayRef
<Attribute
> operands
) {
197 return foldBinaryOpChecked(operands
, calculateCeilDivS
);
200 //===----------------------------------------------------------------------===//
202 //===----------------------------------------------------------------------===//
204 OpFoldResult
CeilDivUOp::fold(ArrayRef
<Attribute
> operands
) {
205 // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
206 return foldBinaryOpChecked(
207 operands
, [](const APInt
&n
, const APInt
&m
) -> Optional
<APInt
> {
208 // Don't fold division by zero.
211 // Short-circuit the zero case.
215 return (n
- 1).udiv(m
) + 1;
219 //===----------------------------------------------------------------------===//
221 //===----------------------------------------------------------------------===//
223 /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
224 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
225 static Optional
<APInt
> calculateFloorDivS(const APInt
&n
, const APInt
&m
) {
226 // Don't fold division by zero.
229 // Short-circuit the zero case.
233 bool mLtZ
= m
.slt(0);
234 if (n
.slt(0) == mLtZ
) {
235 // If the operands have the same sign, compute the positive result.
238 // If the operands have different signs, compute the negative result. Signed
239 // division overflow is not possible since if `m == -1`, `x` will be 1 and
240 // `n` can be at most `INT_MAX`.
241 int64_t x
= mLtZ
? 1 : -1;
242 return -1 - (x
- n
).sdiv(m
);
245 OpFoldResult
FloorDivSOp::fold(ArrayRef
<Attribute
> operands
) {
246 return foldBinaryOpChecked(operands
, calculateFloorDivS
);
249 //===----------------------------------------------------------------------===//
251 //===----------------------------------------------------------------------===//
253 OpFoldResult
RemSOp::fold(ArrayRef
<Attribute
> operands
) {
254 return foldBinaryOpChecked(operands
, [](const APInt
&lhs
, const APInt
&rhs
) {
255 return lhs
.srem(rhs
);
259 //===----------------------------------------------------------------------===//
261 //===----------------------------------------------------------------------===//
263 OpFoldResult
RemUOp::fold(ArrayRef
<Attribute
> operands
) {
264 return foldBinaryOpChecked(operands
, [](const APInt
&lhs
, const APInt
&rhs
) {
265 return lhs
.urem(rhs
);
269 //===----------------------------------------------------------------------===//
271 //===----------------------------------------------------------------------===//
273 OpFoldResult
MaxSOp::fold(ArrayRef
<Attribute
> operands
) {
274 return foldBinaryOpChecked(operands
, [](const APInt
&lhs
, const APInt
&rhs
) {
275 return lhs
.sgt(rhs
) ? lhs
: rhs
;
279 //===----------------------------------------------------------------------===//
281 //===----------------------------------------------------------------------===//
283 OpFoldResult
MaxUOp::fold(ArrayRef
<Attribute
> operands
) {
284 return foldBinaryOpChecked(operands
, [](const APInt
&lhs
, const APInt
&rhs
) {
285 return lhs
.ugt(rhs
) ? lhs
: rhs
;
289 //===----------------------------------------------------------------------===//
291 //===----------------------------------------------------------------------===//
293 OpFoldResult
ShlOp::fold(ArrayRef
<Attribute
> operands
) {
294 return foldBinaryOpUnchecked(
295 operands
, [](const APInt
&lhs
, const APInt
&rhs
) -> Optional
<APInt
> {
296 // We cannot fold if the RHS is greater than or equal to 32 because
297 // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
298 // already treated as unsigned.
305 //===----------------------------------------------------------------------===//
307 //===----------------------------------------------------------------------===//
309 OpFoldResult
ShrSOp::fold(ArrayRef
<Attribute
> operands
) {
310 return foldBinaryOpChecked(
311 operands
, [](const APInt
&lhs
, const APInt
&rhs
) -> Optional
<APInt
> {
312 // Don't fold if RHS is greater than or equal to 32.
315 return lhs
.ashr(rhs
);
319 //===----------------------------------------------------------------------===//
321 //===----------------------------------------------------------------------===//
323 OpFoldResult
ShrUOp::fold(ArrayRef
<Attribute
> operands
) {
324 return foldBinaryOpChecked(
325 operands
, [](const APInt
&lhs
, const APInt
&rhs
) -> Optional
<APInt
> {
326 // Don't fold if RHS is greater than or equal to 32.
329 return lhs
.lshr(rhs
);
333 //===----------------------------------------------------------------------===//
335 //===----------------------------------------------------------------------===//
337 bool CastSOp::areCastCompatible(TypeRange lhsTypes
, TypeRange rhsTypes
) {
338 return lhsTypes
.front().isa
<IndexType
>() != rhsTypes
.front().isa
<IndexType
>();
341 //===----------------------------------------------------------------------===//
343 //===----------------------------------------------------------------------===//
345 bool CastUOp::areCastCompatible(TypeRange lhsTypes
, TypeRange rhsTypes
) {
346 return lhsTypes
.front().isa
<IndexType
>() != rhsTypes
.front().isa
<IndexType
>();
349 //===----------------------------------------------------------------------===//
351 //===----------------------------------------------------------------------===//
353 /// Compare two integers according to the comparison predicate.
354 bool compareIndices(const APInt
&lhs
, const APInt
&rhs
,
355 IndexCmpPredicate pred
) {
357 case IndexCmpPredicate::EQ
:
359 case IndexCmpPredicate::NE
:
361 case IndexCmpPredicate::SGE
:
363 case IndexCmpPredicate::SGT
:
365 case IndexCmpPredicate::SLE
:
367 case IndexCmpPredicate::SLT
:
369 case IndexCmpPredicate::UGE
:
371 case IndexCmpPredicate::UGT
:
373 case IndexCmpPredicate::ULE
:
375 case IndexCmpPredicate::ULT
:
378 llvm_unreachable("unhandled IndexCmpPredicate predicate");
381 OpFoldResult
CmpOp::fold(ArrayRef
<Attribute
> operands
) {
382 assert(operands
.size() == 2 && "compare expected 2 operands");
383 auto lhs
= dyn_cast_if_present
<IntegerAttr
>(operands
[0]);
384 auto rhs
= dyn_cast_if_present
<IntegerAttr
>(operands
[1]);
388 // Perform the comparison in 64-bit and 32-bit.
389 bool result64
= compareIndices(lhs
.getValue(), rhs
.getValue(), getPred());
390 bool result32
= compareIndices(lhs
.getValue().trunc(32),
391 rhs
.getValue().trunc(32), getPred());
392 if (result64
!= result32
)
394 return BoolAttr::get(getContext(), result64
);
397 //===----------------------------------------------------------------------===//
399 //===----------------------------------------------------------------------===//
401 OpFoldResult
ConstantOp::fold(ArrayRef
<Attribute
> operands
) {
402 return getValueAttr();
405 void ConstantOp::build(OpBuilder
&b
, OperationState
&state
, int64_t value
) {
406 build(b
, state
, b
.getIndexType(), b
.getIndexAttr(value
));
409 //===----------------------------------------------------------------------===//
411 //===----------------------------------------------------------------------===//
413 OpFoldResult
BoolConstantOp::fold(ArrayRef
<Attribute
> operands
) {
414 return getValueAttr();
417 //===----------------------------------------------------------------------===//
418 // ODS-Generated Definitions
419 //===----------------------------------------------------------------------===//
421 #define GET_OP_CLASSES
422 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"