1 //===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
14 #include "mlir/Dialect/Index/IR/IndexDialect.h"
15 #include "mlir/Dialect/Index/IR/IndexOps.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Pass/Pass.h"
20 using namespace index
;
24 //===----------------------------------------------------------------------===//
25 // ConvertIndexCeilDivS
26 //===----------------------------------------------------------------------===//
28 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
29 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
30 struct ConvertIndexCeilDivS
: mlir::ConvertOpToLLVMPattern
<CeilDivSOp
> {
31 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
34 matchAndRewrite(CeilDivSOp op
, CeilDivSOpAdaptor adaptor
,
35 ConversionPatternRewriter
&rewriter
) const override
{
36 Location loc
= op
.getLoc();
37 Value n
= adaptor
.getLhs();
38 Value m
= adaptor
.getRhs();
39 Value zero
= rewriter
.create
<LLVM::ConstantOp
>(loc
, n
.getType(), 0);
40 Value posOne
= rewriter
.create
<LLVM::ConstantOp
>(loc
, n
.getType(), 1);
41 Value negOne
= rewriter
.create
<LLVM::ConstantOp
>(loc
, n
.getType(), -1);
45 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::sgt
, m
, zero
);
46 Value x
= rewriter
.create
<LLVM::SelectOp
>(loc
, mPos
, negOne
, posOne
);
48 // Compute the positive result.
49 Value nPlusX
= rewriter
.create
<LLVM::AddOp
>(loc
, n
, x
);
50 Value nPlusXDivM
= rewriter
.create
<LLVM::SDivOp
>(loc
, nPlusX
, m
);
51 Value posRes
= rewriter
.create
<LLVM::AddOp
>(loc
, nPlusXDivM
, posOne
);
53 // Compute the negative result.
54 Value negN
= rewriter
.create
<LLVM::SubOp
>(loc
, zero
, n
);
55 Value negNDivM
= rewriter
.create
<LLVM::SDivOp
>(loc
, negN
, m
);
56 Value negRes
= rewriter
.create
<LLVM::SubOp
>(loc
, zero
, negNDivM
);
58 // Pick the positive result if `n` and `m` have the same sign and `n` is
59 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
61 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::sgt
, n
, zero
);
63 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::eq
, nPos
, mPos
);
65 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::ne
, n
, zero
);
66 Value cmp
= rewriter
.create
<LLVM::AndOp
>(loc
, sameSign
, nNonZero
);
67 rewriter
.replaceOpWithNewOp
<LLVM::SelectOp
>(op
, cmp
, posRes
, negRes
);
72 //===----------------------------------------------------------------------===//
73 // ConvertIndexCeilDivU
74 //===----------------------------------------------------------------------===//
76 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
77 struct ConvertIndexCeilDivU
: mlir::ConvertOpToLLVMPattern
<CeilDivUOp
> {
78 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
81 matchAndRewrite(CeilDivUOp op
, CeilDivUOpAdaptor adaptor
,
82 ConversionPatternRewriter
&rewriter
) const override
{
83 Location loc
= op
.getLoc();
84 Value n
= adaptor
.getLhs();
85 Value m
= adaptor
.getRhs();
86 Value zero
= rewriter
.create
<LLVM::ConstantOp
>(loc
, n
.getType(), 0);
87 Value one
= rewriter
.create
<LLVM::ConstantOp
>(loc
, n
.getType(), 1);
89 // Compute the non-zero result.
90 Value minusOne
= rewriter
.create
<LLVM::SubOp
>(loc
, n
, one
);
91 Value quotient
= rewriter
.create
<LLVM::UDivOp
>(loc
, minusOne
, m
);
92 Value plusOne
= rewriter
.create
<LLVM::AddOp
>(loc
, quotient
, one
);
96 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::eq
, n
, zero
);
97 rewriter
.replaceOpWithNewOp
<LLVM::SelectOp
>(op
, cmp
, zero
, plusOne
);
102 //===----------------------------------------------------------------------===//
103 // ConvertIndexFloorDivS
104 //===----------------------------------------------------------------------===//
106 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
107 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
108 struct ConvertIndexFloorDivS
: mlir::ConvertOpToLLVMPattern
<FloorDivSOp
> {
109 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
112 matchAndRewrite(FloorDivSOp op
, FloorDivSOpAdaptor adaptor
,
113 ConversionPatternRewriter
&rewriter
) const override
{
114 Location loc
= op
.getLoc();
115 Value n
= adaptor
.getLhs();
116 Value m
= adaptor
.getRhs();
117 Value zero
= rewriter
.create
<LLVM::ConstantOp
>(loc
, n
.getType(), 0);
118 Value posOne
= rewriter
.create
<LLVM::ConstantOp
>(loc
, n
.getType(), 1);
119 Value negOne
= rewriter
.create
<LLVM::ConstantOp
>(loc
, n
.getType(), -1);
123 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::slt
, m
, zero
);
124 Value x
= rewriter
.create
<LLVM::SelectOp
>(loc
, mNeg
, posOne
, negOne
);
126 // Compute the negative result.
127 Value xMinusN
= rewriter
.create
<LLVM::SubOp
>(loc
, x
, n
);
128 Value xMinusNDivM
= rewriter
.create
<LLVM::SDivOp
>(loc
, xMinusN
, m
);
129 Value negRes
= rewriter
.create
<LLVM::SubOp
>(loc
, negOne
, xMinusNDivM
);
131 // Compute the positive result.
132 Value posRes
= rewriter
.create
<LLVM::SDivOp
>(loc
, n
, m
);
134 // Pick the negative result if `n` and `m` have different signs and `n` is
135 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
137 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::slt
, n
, zero
);
139 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::ne
, nNeg
, mNeg
);
141 rewriter
.create
<LLVM::ICmpOp
>(loc
, LLVM::ICmpPredicate::ne
, n
, zero
);
142 Value cmp
= rewriter
.create
<LLVM::AndOp
>(loc
, diffSign
, nNonZero
);
143 rewriter
.replaceOpWithNewOp
<LLVM::SelectOp
>(op
, cmp
, negRes
, posRes
);
148 //===----------------------------------------------------------------------===//
150 //===----------------------------------------------------------------------===//
152 /// Convert a cast op. If the materialized index type is the same as the other
153 /// type, fold away the op. Otherwise, truncate or extend the op as appropriate.
154 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
155 /// zero extend when the result bitwidth is larger.
156 template <typename CastOp
, typename ExtOp
>
157 struct ConvertIndexCast
: public mlir::ConvertOpToLLVMPattern
<CastOp
> {
158 using mlir::ConvertOpToLLVMPattern
<CastOp
>::ConvertOpToLLVMPattern
;
161 matchAndRewrite(CastOp op
, typename
CastOp::Adaptor adaptor
,
162 ConversionPatternRewriter
&rewriter
) const override
{
163 Type in
= adaptor
.getInput().getType();
164 Type out
= this->getTypeConverter()->convertType(op
.getType());
166 rewriter
.replaceOp(op
, adaptor
.getInput());
167 else if (in
.getIntOrFloatBitWidth() > out
.getIntOrFloatBitWidth())
168 rewriter
.replaceOpWithNewOp
<LLVM::TruncOp
>(op
, out
, adaptor
.getInput());
170 rewriter
.replaceOpWithNewOp
<ExtOp
>(op
, out
, adaptor
.getInput());
175 using ConvertIndexCastS
= ConvertIndexCast
<CastSOp
, LLVM::SExtOp
>;
176 using ConvertIndexCastU
= ConvertIndexCast
<CastUOp
, LLVM::ZExtOp
>;
178 //===----------------------------------------------------------------------===//
180 //===----------------------------------------------------------------------===//
182 /// Assert that the LLVM comparison enum lines up with index's enum.
183 static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs
,
184 IndexCmpPredicate rhs
) {
185 return static_cast<int>(lhs
) == static_cast<int>(rhs
);
189 LLVM::getMaxEnumValForICmpPredicate() ==
190 getMaxEnumValForIndexCmpPredicate() &&
191 checkPredicates(LLVM::ICmpPredicate::eq
, IndexCmpPredicate::EQ
) &&
192 checkPredicates(LLVM::ICmpPredicate::ne
, IndexCmpPredicate::NE
) &&
193 checkPredicates(LLVM::ICmpPredicate::sge
, IndexCmpPredicate::SGE
) &&
194 checkPredicates(LLVM::ICmpPredicate::sgt
, IndexCmpPredicate::SGT
) &&
195 checkPredicates(LLVM::ICmpPredicate::sle
, IndexCmpPredicate::SLE
) &&
196 checkPredicates(LLVM::ICmpPredicate::slt
, IndexCmpPredicate::SLT
) &&
197 checkPredicates(LLVM::ICmpPredicate::uge
, IndexCmpPredicate::UGE
) &&
198 checkPredicates(LLVM::ICmpPredicate::ugt
, IndexCmpPredicate::UGT
) &&
199 checkPredicates(LLVM::ICmpPredicate::ule
, IndexCmpPredicate::ULE
) &&
200 checkPredicates(LLVM::ICmpPredicate::ult
, IndexCmpPredicate::ULT
),
201 "LLVM ICmpPredicate mismatches IndexCmpPredicate");
203 struct ConvertIndexCmp
: public mlir::ConvertOpToLLVMPattern
<CmpOp
> {
204 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
207 matchAndRewrite(CmpOp op
, CmpOpAdaptor adaptor
,
208 ConversionPatternRewriter
&rewriter
) const override
{
209 // The LLVM enum has the same values as the index predicate enums.
210 rewriter
.replaceOpWithNewOp
<LLVM::ICmpOp
>(
211 op
, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op
.getPred())),
212 adaptor
.getLhs(), adaptor
.getRhs());
217 //===----------------------------------------------------------------------===//
218 // ConvertIndexSizeOf
219 //===----------------------------------------------------------------------===//
221 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
222 struct ConvertIndexSizeOf
: public mlir::ConvertOpToLLVMPattern
<SizeOfOp
> {
223 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
226 matchAndRewrite(SizeOfOp op
, SizeOfOpAdaptor adaptor
,
227 ConversionPatternRewriter
&rewriter
) const override
{
228 rewriter
.replaceOpWithNewOp
<LLVM::ConstantOp
>(
229 op
, getTypeConverter()->getIndexType(),
230 getTypeConverter()->getIndexTypeBitwidth());
235 //===----------------------------------------------------------------------===//
236 // ConvertIndexConstant
237 //===----------------------------------------------------------------------===//
239 /// Convert an index constant. Truncate the value as appropriate.
240 struct ConvertIndexConstant
: public mlir::ConvertOpToLLVMPattern
<ConstantOp
> {
241 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
244 matchAndRewrite(ConstantOp op
, ConstantOpAdaptor adaptor
,
245 ConversionPatternRewriter
&rewriter
) const override
{
246 Type type
= getTypeConverter()->getIndexType();
247 APInt value
= op
.getValue().trunc(type
.getIntOrFloatBitWidth());
248 rewriter
.replaceOpWithNewOp
<LLVM::ConstantOp
>(
249 op
, type
, IntegerAttr::get(type
, value
));
254 //===----------------------------------------------------------------------===//
255 // Trivial Conversions
256 //===----------------------------------------------------------------------===//
258 using ConvertIndexAdd
= mlir::OneToOneConvertToLLVMPattern
<AddOp
, LLVM::AddOp
>;
259 using ConvertIndexSub
= mlir::OneToOneConvertToLLVMPattern
<SubOp
, LLVM::SubOp
>;
260 using ConvertIndexMul
= mlir::OneToOneConvertToLLVMPattern
<MulOp
, LLVM::MulOp
>;
261 using ConvertIndexDivS
=
262 mlir::OneToOneConvertToLLVMPattern
<DivSOp
, LLVM::SDivOp
>;
263 using ConvertIndexDivU
=
264 mlir::OneToOneConvertToLLVMPattern
<DivUOp
, LLVM::UDivOp
>;
265 using ConvertIndexRemS
=
266 mlir::OneToOneConvertToLLVMPattern
<RemSOp
, LLVM::SRemOp
>;
267 using ConvertIndexRemU
=
268 mlir::OneToOneConvertToLLVMPattern
<RemUOp
, LLVM::URemOp
>;
269 using ConvertIndexMaxS
=
270 mlir::OneToOneConvertToLLVMPattern
<MaxSOp
, LLVM::SMaxOp
>;
271 using ConvertIndexMaxU
=
272 mlir::OneToOneConvertToLLVMPattern
<MaxUOp
, LLVM::UMaxOp
>;
273 using ConvertIndexMinS
=
274 mlir::OneToOneConvertToLLVMPattern
<MinSOp
, LLVM::SMinOp
>;
275 using ConvertIndexMinU
=
276 mlir::OneToOneConvertToLLVMPattern
<MinUOp
, LLVM::UMinOp
>;
277 using ConvertIndexShl
= mlir::OneToOneConvertToLLVMPattern
<ShlOp
, LLVM::ShlOp
>;
278 using ConvertIndexShrS
=
279 mlir::OneToOneConvertToLLVMPattern
<ShrSOp
, LLVM::AShrOp
>;
280 using ConvertIndexShrU
=
281 mlir::OneToOneConvertToLLVMPattern
<ShrUOp
, LLVM::LShrOp
>;
282 using ConvertIndexAnd
= mlir::OneToOneConvertToLLVMPattern
<AndOp
, LLVM::AndOp
>;
283 using ConvertIndexOr
= mlir::OneToOneConvertToLLVMPattern
<OrOp
, LLVM::OrOp
>;
284 using ConvertIndexXor
= mlir::OneToOneConvertToLLVMPattern
<XOrOp
, LLVM::XOrOp
>;
285 using ConvertIndexBoolConstant
=
286 mlir::OneToOneConvertToLLVMPattern
<BoolConstantOp
, LLVM::ConstantOp
>;
290 //===----------------------------------------------------------------------===//
291 // Pattern Population
292 //===----------------------------------------------------------------------===//
294 void index::populateIndexToLLVMConversionPatterns(
295 const LLVMTypeConverter
&typeConverter
, RewritePatternSet
&patterns
) {
315 ConvertIndexCeilDivS
,
316 ConvertIndexCeilDivU
,
317 ConvertIndexFloorDivS
,
322 ConvertIndexConstant
,
323 ConvertIndexBoolConstant
328 //===----------------------------------------------------------------------===//
329 // ODS-Generated Definitions
330 //===----------------------------------------------------------------------===//
333 #define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS
334 #include "mlir/Conversion/Passes.h.inc"
337 //===----------------------------------------------------------------------===//
339 //===----------------------------------------------------------------------===//
342 struct ConvertIndexToLLVMPass
343 : public impl::ConvertIndexToLLVMPassBase
<ConvertIndexToLLVMPass
> {
346 void runOnOperation() override
;
350 void ConvertIndexToLLVMPass::runOnOperation() {
351 // Configure dialect conversion.
352 ConversionTarget
target(getContext());
353 target
.addIllegalDialect
<IndexDialect
>();
354 target
.addLegalDialect
<LLVM::LLVMDialect
>();
356 // Set LLVM lowering options.
357 LowerToLLVMOptions
options(&getContext());
358 if (indexBitwidth
!= kDeriveIndexBitwidthFromDataLayout
)
359 options
.overrideIndexBitwidth(indexBitwidth
);
360 LLVMTypeConverter
typeConverter(&getContext(), options
);
362 // Populate patterns and run the conversion.
363 RewritePatternSet
patterns(&getContext());
364 populateIndexToLLVMConversionPatterns(typeConverter
, patterns
);
367 applyPartialConversion(getOperation(), target
, std::move(patterns
))))
368 return signalPassFailure();
371 //===----------------------------------------------------------------------===//
372 // ConvertToLLVMPatternInterface implementation
373 //===----------------------------------------------------------------------===//
376 /// Implement the interface to convert Index to LLVM.
377 struct IndexToLLVMDialectInterface
: public ConvertToLLVMPatternInterface
{
378 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface
;
379 void loadDependentDialects(MLIRContext
*context
) const final
{
380 context
->loadDialect
<LLVM::LLVMDialect
>();
383 /// Hook for derived dialect interface to provide conversion patterns
384 /// and mark dialect legal for the conversion target.
385 void populateConvertToLLVMConversionPatterns(
386 ConversionTarget
&target
, LLVMTypeConverter
&typeConverter
,
387 RewritePatternSet
&patterns
) const final
{
388 populateIndexToLLVMConversionPatterns(typeConverter
, patterns
);
393 void mlir::index::registerConvertIndexToLLVMInterface(
394 DialectRegistry
®istry
) {
395 registry
.addExtension(+[](MLIRContext
*ctx
, index::IndexDialect
*dialect
) {
396 dialect
->addInterfaces
<IndexToLLVMDialectInterface
>();