[mlir][index] Add shl, shrs, and shru ops
[llvm-project.git] / mlir / lib / Dialect / Index / IR / IndexOps.cpp
blob241fa416eddabf5ab647a42dadd527c2c73a59ad
1 //===- IndexOps.cpp - Index operation definitions --------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Index/IR/IndexOps.h"
10 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
11 #include "mlir/Dialect/Index/IR/IndexDialect.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/OpImplementation.h"
15 using namespace mlir;
16 using namespace mlir::index;
18 //===----------------------------------------------------------------------===//
19 // IndexDialect
20 //===----------------------------------------------------------------------===//
22 void IndexDialect::registerOperations() {
23 addOperations<
24 #define GET_OP_LIST
25 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
26 >();
29 Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
30 Type type, Location loc) {
31 // Materialize bool constants as `i1`.
32 if (auto boolValue = dyn_cast<BoolAttr>(value)) {
33 if (!type.isSignlessInteger(1))
34 return nullptr;
35 return b.create<BoolConstantOp>(loc, type, boolValue);
38 // Materialize integer attributes as `index`.
39 if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
40 if (!indexValue.getType().isa<IndexType>() || !type.isa<IndexType>())
41 return nullptr;
42 assert(indexValue.getValue().getBitWidth() ==
43 IndexType::kInternalStorageBitWidth);
44 return b.create<ConstantOp>(loc, indexValue);
47 return nullptr;
50 //===----------------------------------------------------------------------===//
51 // Fold Utilities
52 //===----------------------------------------------------------------------===//
54 /// Fold an index operation irrespective of the target bitwidth. The
55 /// operation must satisfy the property:
56 ///
57 /// ```
58 /// trunc(f(a, b)) = f(trunc(a), trunc(b))
59 /// ```
60 ///
61 /// For all values of `a` and `b`. The function accepts a lambda that computes
62 /// the integer result, which in turn must satisfy the above property.
63 static OpFoldResult foldBinaryOpUnchecked(
64 ArrayRef<Attribute> operands,
65 function_ref<Optional<APInt>(const APInt &, const APInt &)> calculate) {
66 assert(operands.size() == 2 && "binary operation expected 2 operands");
67 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
68 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
69 if (!lhs || !rhs)
70 return {};
72 Optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
73 if (!result)
74 return {};
75 assert(result->trunc(32) ==
76 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
77 return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(*result));
80 /// Fold an index operation only if the truncated 64-bit result matches the
81 /// 32-bit result for operations that don't satisfy the above property. These
82 /// are operations where the upper bits of the operands can affect the lower
83 /// bits of the results.
84 ///
85 /// The function accepts a lambda that computes the integer result in both
86 /// 64-bit and 32-bit. If either call returns `None`, the operation is not
87 /// folded.
88 static OpFoldResult foldBinaryOpChecked(
89 ArrayRef<Attribute> operands,
90 function_ref<Optional<APInt>(const APInt &, const APInt &lhs)> calculate) {
91 assert(operands.size() == 2 && "binary operation expected 2 operands");
92 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
93 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
94 // Only fold index operands.
95 if (!lhs || !rhs)
96 return {};
98 // Compute the 64-bit result and the 32-bit result.
99 Optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
100 if (!result64)
101 return {};
102 Optional<APInt> result32 =
103 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
104 if (!result32)
105 return {};
106 // Compare the truncated 64-bit result to the 32-bit result.
107 if (result64->trunc(32) != *result32)
108 return {};
109 // The operation can be folded for these particular operands.
110 return IntegerAttr::get(IndexType::get(lhs.getContext()),
111 std::move(*result64));
114 //===----------------------------------------------------------------------===//
115 // AddOp
116 //===----------------------------------------------------------------------===//
118 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
119 return foldBinaryOpUnchecked(
120 operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
123 //===----------------------------------------------------------------------===//
124 // SubOp
125 //===----------------------------------------------------------------------===//
127 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
128 return foldBinaryOpUnchecked(
129 operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
132 //===----------------------------------------------------------------------===//
133 // MulOp
134 //===----------------------------------------------------------------------===//
136 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
137 return foldBinaryOpUnchecked(
138 operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
141 //===----------------------------------------------------------------------===//
142 // DivSOp
143 //===----------------------------------------------------------------------===//
145 OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
146 return foldBinaryOpChecked(
147 operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
148 // Don't fold division by zero.
149 if (rhs.isZero())
150 return None;
151 return lhs.sdiv(rhs);
155 //===----------------------------------------------------------------------===//
156 // DivUOp
157 //===----------------------------------------------------------------------===//
159 OpFoldResult DivUOp::fold(ArrayRef<Attribute> operands) {
160 return foldBinaryOpChecked(
161 operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
162 // Don't fold division by zero.
163 if (rhs.isZero())
164 return None;
165 return lhs.udiv(rhs);
169 //===----------------------------------------------------------------------===//
170 // CeilDivSOp
171 //===----------------------------------------------------------------------===//
173 /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
174 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
175 static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
176 // Don't fold division by zero.
177 if (m.isZero())
178 return None;
179 // Short-circuit the zero case.
180 if (n.isZero())
181 return n;
183 bool mGtZ = m.sgt(0);
184 if (n.sgt(0) != mGtZ) {
185 // If the operands have different signs, compute the negative result. Signed
186 // division overflow is not possible, since if `m == -1`, `n` can be at most
187 // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
188 return -(-n).sdiv(m);
190 // Otherwise, compute the positive result. Signed division overflow is not
191 // possible since if `m == -1`, `x` will be `1`.
192 int64_t x = mGtZ ? -1 : 1;
193 return (n + x).sdiv(m) + 1;
196 OpFoldResult CeilDivSOp::fold(ArrayRef<Attribute> operands) {
197 return foldBinaryOpChecked(operands, calculateCeilDivS);
200 //===----------------------------------------------------------------------===//
201 // CeilDivUOp
202 //===----------------------------------------------------------------------===//
204 OpFoldResult CeilDivUOp::fold(ArrayRef<Attribute> operands) {
205 // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
206 return foldBinaryOpChecked(
207 operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
208 // Don't fold division by zero.
209 if (m.isZero())
210 return None;
211 // Short-circuit the zero case.
212 if (n.isZero())
213 return n;
215 return (n - 1).udiv(m) + 1;
219 //===----------------------------------------------------------------------===//
220 // FloorDivSOp
221 //===----------------------------------------------------------------------===//
223 /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
224 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
225 static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
226 // Don't fold division by zero.
227 if (m.isZero())
228 return None;
229 // Short-circuit the zero case.
230 if (n.isZero())
231 return n;
233 bool mLtZ = m.slt(0);
234 if (n.slt(0) == mLtZ) {
235 // If the operands have the same sign, compute the positive result.
236 return n.sdiv(m);
238 // If the operands have different signs, compute the negative result. Signed
239 // division overflow is not possible since if `m == -1`, `x` will be 1 and
240 // `n` can be at most `INT_MAX`.
241 int64_t x = mLtZ ? 1 : -1;
242 return -1 - (x - n).sdiv(m);
245 OpFoldResult FloorDivSOp::fold(ArrayRef<Attribute> operands) {
246 return foldBinaryOpChecked(operands, calculateFloorDivS);
249 //===----------------------------------------------------------------------===//
250 // RemSOp
251 //===----------------------------------------------------------------------===//
253 OpFoldResult RemSOp::fold(ArrayRef<Attribute> operands) {
254 return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
255 return lhs.srem(rhs);
259 //===----------------------------------------------------------------------===//
260 // RemUOp
261 //===----------------------------------------------------------------------===//
263 OpFoldResult RemUOp::fold(ArrayRef<Attribute> operands) {
264 return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
265 return lhs.urem(rhs);
269 //===----------------------------------------------------------------------===//
270 // MaxSOp
271 //===----------------------------------------------------------------------===//
273 OpFoldResult MaxSOp::fold(ArrayRef<Attribute> operands) {
274 return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
275 return lhs.sgt(rhs) ? lhs : rhs;
279 //===----------------------------------------------------------------------===//
280 // MaxUOp
281 //===----------------------------------------------------------------------===//
283 OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
284 return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
285 return lhs.ugt(rhs) ? lhs : rhs;
289 //===----------------------------------------------------------------------===//
290 // ShlOp
291 //===----------------------------------------------------------------------===//
293 OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
294 return foldBinaryOpUnchecked(
295 operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
296 // We cannot fold if the RHS is greater than or equal to 32 because
297 // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
298 // already treated as unsigned.
299 if (rhs.uge(32))
300 return {};
301 return lhs << rhs;
305 //===----------------------------------------------------------------------===//
306 // ShrSOp
307 //===----------------------------------------------------------------------===//
309 OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
310 return foldBinaryOpChecked(
311 operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
312 // Don't fold if RHS is greater than or equal to 32.
313 if (rhs.uge(32))
314 return {};
315 return lhs.ashr(rhs);
319 //===----------------------------------------------------------------------===//
320 // ShrUOp
321 //===----------------------------------------------------------------------===//
323 OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
324 return foldBinaryOpChecked(
325 operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
326 // Don't fold if RHS is greater than or equal to 32.
327 if (rhs.uge(32))
328 return {};
329 return lhs.lshr(rhs);
333 //===----------------------------------------------------------------------===//
334 // CastSOp
335 //===----------------------------------------------------------------------===//
337 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
338 return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
341 //===----------------------------------------------------------------------===//
342 // CastUOp
343 //===----------------------------------------------------------------------===//
345 bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
346 return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
349 //===----------------------------------------------------------------------===//
350 // CmpOp
351 //===----------------------------------------------------------------------===//
353 /// Compare two integers according to the comparison predicate.
354 bool compareIndices(const APInt &lhs, const APInt &rhs,
355 IndexCmpPredicate pred) {
356 switch (pred) {
357 case IndexCmpPredicate::EQ:
358 return lhs.eq(rhs);
359 case IndexCmpPredicate::NE:
360 return lhs.ne(rhs);
361 case IndexCmpPredicate::SGE:
362 return lhs.sge(rhs);
363 case IndexCmpPredicate::SGT:
364 return lhs.sgt(rhs);
365 case IndexCmpPredicate::SLE:
366 return lhs.sle(rhs);
367 case IndexCmpPredicate::SLT:
368 return lhs.slt(rhs);
369 case IndexCmpPredicate::UGE:
370 return lhs.uge(rhs);
371 case IndexCmpPredicate::UGT:
372 return lhs.ugt(rhs);
373 case IndexCmpPredicate::ULE:
374 return lhs.ule(rhs);
375 case IndexCmpPredicate::ULT:
376 return lhs.ult(rhs);
378 llvm_unreachable("unhandled IndexCmpPredicate predicate");
381 OpFoldResult CmpOp::fold(ArrayRef<Attribute> operands) {
382 assert(operands.size() == 2 && "compare expected 2 operands");
383 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
384 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
385 if (!lhs || !rhs)
386 return {};
388 // Perform the comparison in 64-bit and 32-bit.
389 bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
390 bool result32 = compareIndices(lhs.getValue().trunc(32),
391 rhs.getValue().trunc(32), getPred());
392 if (result64 != result32)
393 return {};
394 return BoolAttr::get(getContext(), result64);
397 //===----------------------------------------------------------------------===//
398 // ConstantOp
399 //===----------------------------------------------------------------------===//
401 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
402 return getValueAttr();
405 void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
406 build(b, state, b.getIndexType(), b.getIndexAttr(value));
409 //===----------------------------------------------------------------------===//
410 // BoolConstantOp
411 //===----------------------------------------------------------------------===//
413 OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
414 return getValueAttr();
417 //===----------------------------------------------------------------------===//
418 // ODS-Generated Definitions
419 //===----------------------------------------------------------------------===//
421 #define GET_OP_CLASSES
422 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"