Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / IndexToLLVM / IndexToLLVM.cpp
blob0473bb59fa6aa3600b4d3655b21337bd0af00a81
1 //===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
14 #include "mlir/Dialect/Index/IR/IndexDialect.h"
15 #include "mlir/Dialect/Index/IR/IndexOps.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Pass/Pass.h"
19 using namespace mlir;
20 using namespace index;
22 namespace {
24 //===----------------------------------------------------------------------===//
25 // ConvertIndexCeilDivS
26 //===----------------------------------------------------------------------===//
28 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
29 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
30 struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
31 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
33 LogicalResult
34 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36 Location loc = op.getLoc();
37 Value n = adaptor.getLhs();
38 Value m = adaptor.getRhs();
39 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
40 Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
41 Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
43 // Compute `x`.
44 Value mPos =
45 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
46 Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
48 // Compute the positive result.
49 Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
50 Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
51 Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
53 // Compute the negative result.
54 Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
55 Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
56 Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);
58 // Pick the positive result if `n` and `m` have the same sign and `n` is
59 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
60 Value nPos =
61 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
62 Value sameSign =
63 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
64 Value nNonZero =
65 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
66 Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
67 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
68 return success();
72 //===----------------------------------------------------------------------===//
73 // ConvertIndexCeilDivU
74 //===----------------------------------------------------------------------===//
76 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
77 struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
78 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
80 LogicalResult
81 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
82 ConversionPatternRewriter &rewriter) const override {
83 Location loc = op.getLoc();
84 Value n = adaptor.getLhs();
85 Value m = adaptor.getRhs();
86 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
87 Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
89 // Compute the non-zero result.
90 Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
91 Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
92 Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);
94 // Pick the result.
95 Value cmp =
96 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
97 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
98 return success();
102 //===----------------------------------------------------------------------===//
103 // ConvertIndexFloorDivS
104 //===----------------------------------------------------------------------===//
106 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
107 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
108 struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
109 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
111 LogicalResult
112 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
113 ConversionPatternRewriter &rewriter) const override {
114 Location loc = op.getLoc();
115 Value n = adaptor.getLhs();
116 Value m = adaptor.getRhs();
117 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
118 Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
119 Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
121 // Compute `x`.
122 Value mNeg =
123 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
124 Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
126 // Compute the negative result.
127 Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
128 Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
129 Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
131 // Compute the positive result.
132 Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);
134 // Pick the negative result if `n` and `m` have different signs and `n` is
135 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
136 Value nNeg =
137 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
138 Value diffSign =
139 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
140 Value nNonZero =
141 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
142 Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero);
143 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
144 return success();
148 //===----------------------------------------------------------------------===//
149 // CovnertIndexCast
150 //===----------------------------------------------------------------------===//
152 /// Convert a cast op. If the materialized index type is the same as the other
153 /// type, fold away the op. Otherwise, truncate or extend the op as appropriate.
154 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
155 /// zero extend when the result bitwidth is larger.
156 template <typename CastOp, typename ExtOp>
157 struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern<CastOp> {
158 using mlir::ConvertOpToLLVMPattern<CastOp>::ConvertOpToLLVMPattern;
160 LogicalResult
161 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
162 ConversionPatternRewriter &rewriter) const override {
163 Type in = adaptor.getInput().getType();
164 Type out = this->getTypeConverter()->convertType(op.getType());
165 if (in == out)
166 rewriter.replaceOp(op, adaptor.getInput());
167 else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth())
168 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, out, adaptor.getInput());
169 else
170 rewriter.replaceOpWithNewOp<ExtOp>(op, out, adaptor.getInput());
171 return success();
175 using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>;
176 using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>;
178 //===----------------------------------------------------------------------===//
179 // ConvertIndexCmp
180 //===----------------------------------------------------------------------===//
182 /// Assert that the LLVM comparison enum lines up with index's enum.
183 static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs,
184 IndexCmpPredicate rhs) {
185 return static_cast<int>(lhs) == static_cast<int>(rhs);
188 static_assert(
189 LLVM::getMaxEnumValForICmpPredicate() ==
190 getMaxEnumValForIndexCmpPredicate() &&
191 checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) &&
192 checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) &&
193 checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) &&
194 checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) &&
195 checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) &&
196 checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) &&
197 checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) &&
198 checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) &&
199 checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) &&
200 checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT),
201 "LLVM ICmpPredicate mismatches IndexCmpPredicate");
203 struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern<CmpOp> {
204 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
206 LogicalResult
207 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override {
209 // The LLVM enum has the same values as the index predicate enums.
210 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
211 op, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op.getPred())),
212 adaptor.getLhs(), adaptor.getRhs());
213 return success();
217 //===----------------------------------------------------------------------===//
218 // ConvertIndexSizeOf
219 //===----------------------------------------------------------------------===//
221 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
222 struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern<SizeOfOp> {
223 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
225 LogicalResult
226 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override {
228 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
229 op, getTypeConverter()->getIndexType(),
230 getTypeConverter()->getIndexTypeBitwidth());
231 return success();
235 //===----------------------------------------------------------------------===//
236 // ConvertIndexConstant
237 //===----------------------------------------------------------------------===//
239 /// Convert an index constant. Truncate the value as appropriate.
240 struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern<ConstantOp> {
241 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
243 LogicalResult
244 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 Type type = getTypeConverter()->getIndexType();
247 APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth());
248 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
249 op, type, IntegerAttr::get(type, value));
250 return success();
254 //===----------------------------------------------------------------------===//
255 // Trivial Conversions
256 //===----------------------------------------------------------------------===//
258 using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern<AddOp, LLVM::AddOp>;
259 using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern<SubOp, LLVM::SubOp>;
260 using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern<MulOp, LLVM::MulOp>;
261 using ConvertIndexDivS =
262 mlir::OneToOneConvertToLLVMPattern<DivSOp, LLVM::SDivOp>;
263 using ConvertIndexDivU =
264 mlir::OneToOneConvertToLLVMPattern<DivUOp, LLVM::UDivOp>;
265 using ConvertIndexRemS =
266 mlir::OneToOneConvertToLLVMPattern<RemSOp, LLVM::SRemOp>;
267 using ConvertIndexRemU =
268 mlir::OneToOneConvertToLLVMPattern<RemUOp, LLVM::URemOp>;
269 using ConvertIndexMaxS =
270 mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>;
271 using ConvertIndexMaxU =
272 mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>;
273 using ConvertIndexMinS =
274 mlir::OneToOneConvertToLLVMPattern<MinSOp, LLVM::SMinOp>;
275 using ConvertIndexMinU =
276 mlir::OneToOneConvertToLLVMPattern<MinUOp, LLVM::UMinOp>;
277 using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern<ShlOp, LLVM::ShlOp>;
278 using ConvertIndexShrS =
279 mlir::OneToOneConvertToLLVMPattern<ShrSOp, LLVM::AShrOp>;
280 using ConvertIndexShrU =
281 mlir::OneToOneConvertToLLVMPattern<ShrUOp, LLVM::LShrOp>;
282 using ConvertIndexAnd = mlir::OneToOneConvertToLLVMPattern<AndOp, LLVM::AndOp>;
283 using ConvertIndexOr = mlir::OneToOneConvertToLLVMPattern<OrOp, LLVM::OrOp>;
284 using ConvertIndexXor = mlir::OneToOneConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
285 using ConvertIndexBoolConstant =
286 mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>;
288 } // namespace
290 //===----------------------------------------------------------------------===//
291 // Pattern Population
292 //===----------------------------------------------------------------------===//
294 void index::populateIndexToLLVMConversionPatterns(
295 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
296 patterns.insert<
297 // clang-format off
298 ConvertIndexAdd,
299 ConvertIndexSub,
300 ConvertIndexMul,
301 ConvertIndexDivS,
302 ConvertIndexDivU,
303 ConvertIndexRemS,
304 ConvertIndexRemU,
305 ConvertIndexMaxS,
306 ConvertIndexMaxU,
307 ConvertIndexMinS,
308 ConvertIndexMinU,
309 ConvertIndexShl,
310 ConvertIndexShrS,
311 ConvertIndexShrU,
312 ConvertIndexAnd,
313 ConvertIndexOr,
314 ConvertIndexXor,
315 ConvertIndexCeilDivS,
316 ConvertIndexCeilDivU,
317 ConvertIndexFloorDivS,
318 ConvertIndexCastS,
319 ConvertIndexCastU,
320 ConvertIndexCmp,
321 ConvertIndexSizeOf,
322 ConvertIndexConstant,
323 ConvertIndexBoolConstant
324 // clang-format on
325 >(typeConverter);
328 //===----------------------------------------------------------------------===//
329 // ODS-Generated Definitions
330 //===----------------------------------------------------------------------===//
332 namespace mlir {
333 #define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS
334 #include "mlir/Conversion/Passes.h.inc"
335 } // namespace mlir
337 //===----------------------------------------------------------------------===//
338 // Pass Definition
339 //===----------------------------------------------------------------------===//
341 namespace {
342 struct ConvertIndexToLLVMPass
343 : public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> {
344 using Base::Base;
346 void runOnOperation() override;
348 } // namespace
350 void ConvertIndexToLLVMPass::runOnOperation() {
351 // Configure dialect conversion.
352 ConversionTarget target(getContext());
353 target.addIllegalDialect<IndexDialect>();
354 target.addLegalDialect<LLVM::LLVMDialect>();
356 // Set LLVM lowering options.
357 LowerToLLVMOptions options(&getContext());
358 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
359 options.overrideIndexBitwidth(indexBitwidth);
360 LLVMTypeConverter typeConverter(&getContext(), options);
362 // Populate patterns and run the conversion.
363 RewritePatternSet patterns(&getContext());
364 populateIndexToLLVMConversionPatterns(typeConverter, patterns);
366 if (failed(
367 applyPartialConversion(getOperation(), target, std::move(patterns))))
368 return signalPassFailure();
371 //===----------------------------------------------------------------------===//
372 // ConvertToLLVMPatternInterface implementation
373 //===----------------------------------------------------------------------===//
375 namespace {
376 /// Implement the interface to convert Index to LLVM.
377 struct IndexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
378 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
379 void loadDependentDialects(MLIRContext *context) const final {
380 context->loadDialect<LLVM::LLVMDialect>();
383 /// Hook for derived dialect interface to provide conversion patterns
384 /// and mark dialect legal for the conversion target.
385 void populateConvertToLLVMConversionPatterns(
386 ConversionTarget &target, LLVMTypeConverter &typeConverter,
387 RewritePatternSet &patterns) const final {
388 populateIndexToLLVMConversionPatterns(typeConverter, patterns);
391 } // namespace
393 void mlir::index::registerConvertIndexToLLVMInterface(
394 DialectRegistry &registry) {
395 registry.addExtension(+[](MLIRContext *ctx, index::IndexDialect *dialect) {
396 dialect->addInterfaces<IndexToLLVMDialectInterface>();