1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert the Arith dialect to the EmitC
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/EmitC/IR/EmitC.h"
18 #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/DialectConversion.h"
25 //===----------------------------------------------------------------------===//
26 // Conversion Patterns
27 //===----------------------------------------------------------------------===//
30 class ArithConstantOpConversionPattern
31 : public OpConversionPattern
<arith::ConstantOp
> {
33 using OpConversionPattern::OpConversionPattern
;
36 matchAndRewrite(arith::ConstantOp arithConst
,
37 arith::ConstantOp::Adaptor adaptor
,
38 ConversionPatternRewriter
&rewriter
) const override
{
39 Type newTy
= this->getTypeConverter()->convertType(arithConst
.getType());
41 return rewriter
.notifyMatchFailure(arithConst
, "type conversion failed");
42 rewriter
.replaceOpWithNewOp
<emitc::ConstantOp
>(arithConst
, newTy
,
48 /// Get the signed or unsigned type corresponding to \p ty.
49 Type
adaptIntegralTypeSignedness(Type ty
, bool needsUnsigned
) {
50 if (isa
<IntegerType
>(ty
)) {
51 if (ty
.isUnsignedInteger() != needsUnsigned
) {
52 auto signedness
= needsUnsigned
53 ? IntegerType::SignednessSemantics::Unsigned
54 : IntegerType::SignednessSemantics::Signed
;
55 return IntegerType::get(ty
.getContext(), ty
.getIntOrFloatBitWidth(),
58 } else if (emitc::isPointerWideType(ty
)) {
59 if (isa
<emitc::SizeTType
>(ty
) != needsUnsigned
) {
61 return emitc::SizeTType::get(ty
.getContext());
62 return emitc::PtrDiffTType::get(ty
.getContext());
68 /// Insert a cast operation to type \p ty if \p val does not have this type.
69 Value
adaptValueType(Value val
, ConversionPatternRewriter
&rewriter
, Type ty
) {
70 return rewriter
.createOrFold
<emitc::CastOp
>(val
.getLoc(), ty
, val
);
73 class CmpFOpConversion
: public OpConversionPattern
<arith::CmpFOp
> {
75 using OpConversionPattern::OpConversionPattern
;
78 matchAndRewrite(arith::CmpFOp op
, OpAdaptor adaptor
,
79 ConversionPatternRewriter
&rewriter
) const override
{
81 if (!isa
<FloatType
>(adaptor
.getRhs().getType())) {
82 return rewriter
.notifyMatchFailure(op
.getLoc(),
83 "cmpf currently only supported on "
84 "floats, not tensors/vectors thereof");
87 bool unordered
= false;
88 emitc::CmpPredicate predicate
;
89 switch (op
.getPredicate()) {
90 case arith::CmpFPredicate::AlwaysFalse
: {
91 auto constant
= rewriter
.create
<emitc::ConstantOp
>(
92 op
.getLoc(), rewriter
.getI1Type(),
93 rewriter
.getBoolAttr(/*value=*/false));
94 rewriter
.replaceOp(op
, constant
);
97 case arith::CmpFPredicate::OEQ
:
99 predicate
= emitc::CmpPredicate::eq
;
101 case arith::CmpFPredicate::OGT
:
103 predicate
= emitc::CmpPredicate::gt
;
105 case arith::CmpFPredicate::OGE
:
107 predicate
= emitc::CmpPredicate::ge
;
109 case arith::CmpFPredicate::OLT
:
111 predicate
= emitc::CmpPredicate::lt
;
113 case arith::CmpFPredicate::OLE
:
115 predicate
= emitc::CmpPredicate::le
;
117 case arith::CmpFPredicate::ONE
:
119 predicate
= emitc::CmpPredicate::ne
;
121 case arith::CmpFPredicate::ORD
: {
122 // ordered, i.e. none of the operands is NaN
123 auto cmp
= createCheckIsOrdered(rewriter
, op
.getLoc(), adaptor
.getLhs(),
125 rewriter
.replaceOp(op
, cmp
);
128 case arith::CmpFPredicate::UEQ
:
130 predicate
= emitc::CmpPredicate::eq
;
132 case arith::CmpFPredicate::UGT
:
134 predicate
= emitc::CmpPredicate::gt
;
136 case arith::CmpFPredicate::UGE
:
138 predicate
= emitc::CmpPredicate::ge
;
140 case arith::CmpFPredicate::ULT
:
142 predicate
= emitc::CmpPredicate::lt
;
144 case arith::CmpFPredicate::ULE
:
146 predicate
= emitc::CmpPredicate::le
;
148 case arith::CmpFPredicate::UNE
:
150 predicate
= emitc::CmpPredicate::ne
;
152 case arith::CmpFPredicate::UNO
: {
153 // unordered, i.e. either operand is nan
154 auto cmp
= createCheckIsUnordered(rewriter
, op
.getLoc(), adaptor
.getLhs(),
156 rewriter
.replaceOp(op
, cmp
);
159 case arith::CmpFPredicate::AlwaysTrue
: {
160 auto constant
= rewriter
.create
<emitc::ConstantOp
>(
161 op
.getLoc(), rewriter
.getI1Type(),
162 rewriter
.getBoolAttr(/*value=*/true));
163 rewriter
.replaceOp(op
, constant
);
168 // Compare the values naively
170 rewriter
.create
<emitc::CmpOp
>(op
.getLoc(), op
.getType(), predicate
,
171 adaptor
.getLhs(), adaptor
.getRhs());
173 // Adjust the results for unordered/ordered semantics
175 auto isUnordered
= createCheckIsUnordered(
176 rewriter
, op
.getLoc(), adaptor
.getLhs(), adaptor
.getRhs());
177 rewriter
.replaceOpWithNewOp
<emitc::LogicalOrOp
>(op
, op
.getType(),
178 isUnordered
, cmpResult
);
182 auto isOrdered
= createCheckIsOrdered(rewriter
, op
.getLoc(),
183 adaptor
.getLhs(), adaptor
.getRhs());
184 rewriter
.replaceOpWithNewOp
<emitc::LogicalAndOp
>(op
, op
.getType(),
185 isOrdered
, cmpResult
);
190 /// Return a value that is true if \p operand is NaN.
191 Value
isNaN(ConversionPatternRewriter
&rewriter
, Location loc
,
192 Value operand
) const {
193 // A value is NaN exactly when it compares unequal to itself.
194 return rewriter
.create
<emitc::CmpOp
>(
195 loc
, rewriter
.getI1Type(), emitc::CmpPredicate::ne
, operand
, operand
);
198 /// Return a value that is true if \p operand is not NaN.
199 Value
isNotNaN(ConversionPatternRewriter
&rewriter
, Location loc
,
200 Value operand
) const {
201 // A value is not NaN exactly when it compares equal to itself.
202 return rewriter
.create
<emitc::CmpOp
>(
203 loc
, rewriter
.getI1Type(), emitc::CmpPredicate::eq
, operand
, operand
);
206 /// Return a value that is true if the operands \p first and \p second are
207 /// unordered (i.e., at least one of them is NaN).
208 Value
createCheckIsUnordered(ConversionPatternRewriter
&rewriter
,
209 Location loc
, Value first
, Value second
) const {
210 auto firstIsNaN
= isNaN(rewriter
, loc
, first
);
211 auto secondIsNaN
= isNaN(rewriter
, loc
, second
);
212 return rewriter
.create
<emitc::LogicalOrOp
>(loc
, rewriter
.getI1Type(),
213 firstIsNaN
, secondIsNaN
);
216 /// Return a value that is true if the operands \p first and \p second are
217 /// both ordered (i.e., none one of them is NaN).
218 Value
createCheckIsOrdered(ConversionPatternRewriter
&rewriter
, Location loc
,
219 Value first
, Value second
) const {
220 auto firstIsNotNaN
= isNotNaN(rewriter
, loc
, first
);
221 auto secondIsNotNaN
= isNotNaN(rewriter
, loc
, second
);
222 return rewriter
.create
<emitc::LogicalAndOp
>(loc
, rewriter
.getI1Type(),
223 firstIsNotNaN
, secondIsNotNaN
);
227 class CmpIOpConversion
: public OpConversionPattern
<arith::CmpIOp
> {
229 using OpConversionPattern::OpConversionPattern
;
231 bool needsUnsignedCmp(arith::CmpIPredicate pred
) const {
233 case arith::CmpIPredicate::eq
:
234 case arith::CmpIPredicate::ne
:
235 case arith::CmpIPredicate::slt
:
236 case arith::CmpIPredicate::sle
:
237 case arith::CmpIPredicate::sgt
:
238 case arith::CmpIPredicate::sge
:
240 case arith::CmpIPredicate::ult
:
241 case arith::CmpIPredicate::ule
:
242 case arith::CmpIPredicate::ugt
:
243 case arith::CmpIPredicate::uge
:
246 llvm_unreachable("unknown cmpi predicate kind");
249 emitc::CmpPredicate
toEmitCPred(arith::CmpIPredicate pred
) const {
251 case arith::CmpIPredicate::eq
:
252 return emitc::CmpPredicate::eq
;
253 case arith::CmpIPredicate::ne
:
254 return emitc::CmpPredicate::ne
;
255 case arith::CmpIPredicate::slt
:
256 case arith::CmpIPredicate::ult
:
257 return emitc::CmpPredicate::lt
;
258 case arith::CmpIPredicate::sle
:
259 case arith::CmpIPredicate::ule
:
260 return emitc::CmpPredicate::le
;
261 case arith::CmpIPredicate::sgt
:
262 case arith::CmpIPredicate::ugt
:
263 return emitc::CmpPredicate::gt
;
264 case arith::CmpIPredicate::sge
:
265 case arith::CmpIPredicate::uge
:
266 return emitc::CmpPredicate::ge
;
268 llvm_unreachable("unknown cmpi predicate kind");
272 matchAndRewrite(arith::CmpIOp op
, OpAdaptor adaptor
,
273 ConversionPatternRewriter
&rewriter
) const override
{
275 Type type
= adaptor
.getLhs().getType();
276 if (!type
|| !(isa
<IntegerType
>(type
) || emitc::isPointerWideType(type
))) {
277 return rewriter
.notifyMatchFailure(
278 op
, "expected integer or size_t/ssize_t/ptrdiff_t type");
281 bool needsUnsigned
= needsUnsignedCmp(op
.getPredicate());
282 emitc::CmpPredicate pred
= toEmitCPred(op
.getPredicate());
284 Type arithmeticType
= adaptIntegralTypeSignedness(type
, needsUnsigned
);
285 Value lhs
= adaptValueType(adaptor
.getLhs(), rewriter
, arithmeticType
);
286 Value rhs
= adaptValueType(adaptor
.getRhs(), rewriter
, arithmeticType
);
288 rewriter
.replaceOpWithNewOp
<emitc::CmpOp
>(op
, op
.getType(), pred
, lhs
, rhs
);
293 class NegFOpConversion
: public OpConversionPattern
<arith::NegFOp
> {
295 using OpConversionPattern::OpConversionPattern
;
298 matchAndRewrite(arith::NegFOp op
, OpAdaptor adaptor
,
299 ConversionPatternRewriter
&rewriter
) const override
{
301 auto adaptedOp
= adaptor
.getOperand();
302 auto adaptedOpType
= adaptedOp
.getType();
304 if (isa
<TensorType
>(adaptedOpType
) || isa
<VectorType
>(adaptedOpType
)) {
305 return rewriter
.notifyMatchFailure(
307 "negf currently only supports scalar types, not vectors or tensors");
310 if (!emitc::isSupportedFloatType(adaptedOpType
)) {
311 return rewriter
.notifyMatchFailure(
312 op
.getLoc(), "floating-point type is not supported by EmitC");
315 rewriter
.replaceOpWithNewOp
<emitc::UnaryMinusOp
>(op
, adaptedOpType
,
321 template <typename ArithOp
, bool castToUnsigned
>
322 class CastConversion
: public OpConversionPattern
<ArithOp
> {
324 using OpConversionPattern
<ArithOp
>::OpConversionPattern
;
327 matchAndRewrite(ArithOp op
, typename
ArithOp::Adaptor adaptor
,
328 ConversionPatternRewriter
&rewriter
) const override
{
330 Type opReturnType
= this->getTypeConverter()->convertType(op
.getType());
331 if (!opReturnType
|| !(isa
<IntegerType
>(opReturnType
) ||
332 emitc::isPointerWideType(opReturnType
)))
333 return rewriter
.notifyMatchFailure(
334 op
, "expected integer or size_t/ssize_t/ptrdiff_t result type");
336 if (adaptor
.getOperands().size() != 1) {
337 return rewriter
.notifyMatchFailure(
338 op
, "CastConversion only supports unary ops");
341 Type operandType
= adaptor
.getIn().getType();
342 if (!operandType
|| !(isa
<IntegerType
>(operandType
) ||
343 emitc::isPointerWideType(operandType
)))
344 return rewriter
.notifyMatchFailure(
345 op
, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
347 // Signed (sign-extending) casts from i1 are not supported.
348 if (operandType
.isInteger(1) && !castToUnsigned
)
349 return rewriter
.notifyMatchFailure(op
,
350 "operation not supported on i1 type");
352 // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
353 // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
355 if (opReturnType
.isInteger(1)) {
356 Type attrType
= (emitc::isPointerWideType(operandType
))
357 ? rewriter
.getIndexType()
359 auto constOne
= rewriter
.create
<emitc::ConstantOp
>(
360 op
.getLoc(), operandType
, rewriter
.getOneAttr(attrType
));
361 auto oneAndOperand
= rewriter
.create
<emitc::BitwiseAndOp
>(
362 op
.getLoc(), operandType
, adaptor
.getIn(), constOne
);
363 rewriter
.replaceOpWithNewOp
<emitc::CastOp
>(op
, opReturnType
,
369 (isa
<IntegerType
>(operandType
) && isa
<IntegerType
>(opReturnType
) &&
370 operandType
.getIntOrFloatBitWidth() >
371 opReturnType
.getIntOrFloatBitWidth());
372 bool doUnsigned
= castToUnsigned
|| isTruncation
;
374 // Adapt the signedness of the result (bitwidth-preserving cast)
375 // This is needed e.g., if the return type is signless.
376 Type castDestType
= adaptIntegralTypeSignedness(opReturnType
, doUnsigned
);
378 // Adapt the signedness of the operand (bitwidth-preserving cast)
379 Type castSrcType
= adaptIntegralTypeSignedness(operandType
, doUnsigned
);
380 Value actualOp
= adaptValueType(adaptor
.getIn(), rewriter
, castSrcType
);
382 // Actual cast (may change bitwidth)
383 auto cast
= rewriter
.template create
<emitc::CastOp
>(op
.getLoc(),
384 castDestType
, actualOp
);
386 // Cast to the expected output type
387 auto result
= adaptValueType(cast
, rewriter
, opReturnType
);
389 rewriter
.replaceOp(op
, result
);
394 template <typename ArithOp
>
395 class UnsignedCastConversion
: public CastConversion
<ArithOp
, true> {
396 using CastConversion
<ArithOp
, true>::CastConversion
;
399 template <typename ArithOp
>
400 class SignedCastConversion
: public CastConversion
<ArithOp
, false> {
401 using CastConversion
<ArithOp
, false>::CastConversion
;
404 template <typename ArithOp
, typename EmitCOp
>
405 class ArithOpConversion final
: public OpConversionPattern
<ArithOp
> {
407 using OpConversionPattern
<ArithOp
>::OpConversionPattern
;
410 matchAndRewrite(ArithOp arithOp
, typename
ArithOp::Adaptor adaptor
,
411 ConversionPatternRewriter
&rewriter
) const override
{
413 Type newTy
= this->getTypeConverter()->convertType(arithOp
.getType());
415 return rewriter
.notifyMatchFailure(arithOp
,
416 "converting result type failed");
417 rewriter
.template replaceOpWithNewOp
<EmitCOp
>(arithOp
, newTy
,
418 adaptor
.getOperands());
424 template <class ArithOp
, class EmitCOp
>
425 class BinaryUIOpConversion final
: public OpConversionPattern
<ArithOp
> {
427 using OpConversionPattern
<ArithOp
>::OpConversionPattern
;
430 matchAndRewrite(ArithOp uiBinOp
, typename
ArithOp::Adaptor adaptor
,
431 ConversionPatternRewriter
&rewriter
) const override
{
432 Type newRetTy
= this->getTypeConverter()->convertType(uiBinOp
.getType());
434 return rewriter
.notifyMatchFailure(uiBinOp
,
435 "converting result type failed");
436 if (!isa
<IntegerType
>(newRetTy
)) {
437 return rewriter
.notifyMatchFailure(uiBinOp
, "expected integer type");
440 adaptIntegralTypeSignedness(newRetTy
, /*needsUnsigned=*/true);
442 return rewriter
.notifyMatchFailure(uiBinOp
,
443 "converting result type failed");
444 Value lhsAdapted
= adaptValueType(uiBinOp
.getLhs(), rewriter
, unsignedType
);
445 Value rhsAdapted
= adaptValueType(uiBinOp
.getRhs(), rewriter
, unsignedType
);
448 rewriter
.create
<EmitCOp
>(uiBinOp
.getLoc(), unsignedType
,
449 ArrayRef
<Value
>{lhsAdapted
, rhsAdapted
});
450 Value resultAdapted
= adaptValueType(newDivOp
, rewriter
, newRetTy
);
451 rewriter
.replaceOp(uiBinOp
, resultAdapted
);
456 template <typename ArithOp
, typename EmitCOp
>
457 class IntegerOpConversion final
: public OpConversionPattern
<ArithOp
> {
459 using OpConversionPattern
<ArithOp
>::OpConversionPattern
;
462 matchAndRewrite(ArithOp op
, typename
ArithOp::Adaptor adaptor
,
463 ConversionPatternRewriter
&rewriter
) const override
{
465 Type type
= this->getTypeConverter()->convertType(op
.getType());
466 if (!type
|| !(isa
<IntegerType
>(type
) || emitc::isPointerWideType(type
))) {
467 return rewriter
.notifyMatchFailure(
468 op
, "expected integer or size_t/ssize_t/ptrdiff_t type");
471 if (type
.isInteger(1)) {
472 // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
473 return rewriter
.notifyMatchFailure(op
, "i1 type is not implemented");
476 Type arithmeticType
= type
;
477 if ((type
.isSignlessInteger() || type
.isSignedInteger()) &&
478 !bitEnumContainsAll(op
.getOverflowFlags(),
479 arith::IntegerOverflowFlags::nsw
)) {
480 // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
481 // we compute in unsigned integers to avoid UB.
482 arithmeticType
= rewriter
.getIntegerType(type
.getIntOrFloatBitWidth(),
486 Value lhs
= adaptValueType(adaptor
.getLhs(), rewriter
, arithmeticType
);
487 Value rhs
= adaptValueType(adaptor
.getRhs(), rewriter
, arithmeticType
);
489 Value arithmeticResult
= rewriter
.template create
<EmitCOp
>(
490 op
.getLoc(), arithmeticType
, lhs
, rhs
);
492 Value result
= adaptValueType(arithmeticResult
, rewriter
, type
);
494 rewriter
.replaceOp(op
, result
);
499 template <typename ArithOp
, typename EmitCOp
>
500 class BitwiseOpConversion
: public OpConversionPattern
<ArithOp
> {
502 using OpConversionPattern
<ArithOp
>::OpConversionPattern
;
505 matchAndRewrite(ArithOp op
, typename
ArithOp::Adaptor adaptor
,
506 ConversionPatternRewriter
&rewriter
) const override
{
508 Type type
= this->getTypeConverter()->convertType(op
.getType());
509 if (!isa_and_nonnull
<IntegerType
>(type
)) {
510 return rewriter
.notifyMatchFailure(
512 "expected integer type, vector/tensor support not yet implemented");
515 // Bitwise ops can be performed directly on booleans
516 if (type
.isInteger(1)) {
517 rewriter
.replaceOpWithNewOp
<EmitCOp
>(op
, type
, adaptor
.getLhs(),
522 // Bitwise ops are defined by the C standard on unsigned operands.
523 Type arithmeticType
=
524 adaptIntegralTypeSignedness(type
, /*needsUnsigned=*/true);
526 Value lhs
= adaptValueType(adaptor
.getLhs(), rewriter
, arithmeticType
);
527 Value rhs
= adaptValueType(adaptor
.getRhs(), rewriter
, arithmeticType
);
529 Value arithmeticResult
= rewriter
.template create
<EmitCOp
>(
530 op
.getLoc(), arithmeticType
, lhs
, rhs
);
532 Value result
= adaptValueType(arithmeticResult
, rewriter
, type
);
534 rewriter
.replaceOp(op
, result
);
539 template <typename ArithOp
, typename EmitCOp
, bool isUnsignedOp
>
540 class ShiftOpConversion
: public OpConversionPattern
<ArithOp
> {
542 using OpConversionPattern
<ArithOp
>::OpConversionPattern
;
545 matchAndRewrite(ArithOp op
, typename
ArithOp::Adaptor adaptor
,
546 ConversionPatternRewriter
&rewriter
) const override
{
548 Type type
= this->getTypeConverter()->convertType(op
.getType());
549 if (!type
|| !(isa
<IntegerType
>(type
) || emitc::isPointerWideType(type
))) {
550 return rewriter
.notifyMatchFailure(
551 op
, "expected integer or size_t/ssize_t/ptrdiff_t type");
554 if (type
.isInteger(1)) {
555 return rewriter
.notifyMatchFailure(op
, "i1 type is not implemented");
558 Type arithmeticType
= adaptIntegralTypeSignedness(type
, isUnsignedOp
);
560 Value lhs
= adaptValueType(adaptor
.getLhs(), rewriter
, arithmeticType
);
561 // Shift amount interpreted as unsigned per Arith dialect spec.
562 Type rhsType
= adaptIntegralTypeSignedness(adaptor
.getRhs().getType(),
563 /*needsUnsigned=*/true);
564 Value rhs
= adaptValueType(adaptor
.getRhs(), rewriter
, rhsType
);
566 // Add a runtime check for overflow
568 if (emitc::isPointerWideType(type
)) {
569 Value eight
= rewriter
.create
<emitc::ConstantOp
>(
570 op
.getLoc(), rhsType
, rewriter
.getIndexAttr(8));
571 emitc::CallOpaqueOp sizeOfCall
= rewriter
.create
<emitc::CallOpaqueOp
>(
572 op
.getLoc(), rhsType
, "sizeof", ArrayRef
<Value
>{eight
});
573 width
= rewriter
.create
<emitc::MulOp
>(op
.getLoc(), rhsType
, eight
,
574 sizeOfCall
.getResult(0));
576 width
= rewriter
.create
<emitc::ConstantOp
>(
577 op
.getLoc(), rhsType
,
578 rewriter
.getIntegerAttr(rhsType
, type
.getIntOrFloatBitWidth()));
581 Value excessCheck
= rewriter
.create
<emitc::CmpOp
>(
582 op
.getLoc(), rewriter
.getI1Type(), emitc::CmpPredicate::lt
, rhs
, width
);
584 // Any concrete value is a valid refinement of poison.
585 Value poison
= rewriter
.create
<emitc::ConstantOp
>(
586 op
.getLoc(), arithmeticType
,
587 (isa
<IntegerType
>(arithmeticType
)
588 ? rewriter
.getIntegerAttr(arithmeticType
, 0)
589 : rewriter
.getIndexAttr(0)));
591 emitc::ExpressionOp ternary
= rewriter
.create
<emitc::ExpressionOp
>(
592 op
.getLoc(), arithmeticType
, /*do_not_inline=*/false);
593 Block
&bodyBlock
= ternary
.getBodyRegion().emplaceBlock();
594 auto currentPoint
= rewriter
.getInsertionPoint();
595 rewriter
.setInsertionPointToStart(&bodyBlock
);
596 Value arithmeticResult
=
597 rewriter
.create
<EmitCOp
>(op
.getLoc(), arithmeticType
, lhs
, rhs
);
598 Value resultOrPoison
= rewriter
.create
<emitc::ConditionalOp
>(
599 op
.getLoc(), arithmeticType
, excessCheck
, arithmeticResult
, poison
);
600 rewriter
.create
<emitc::YieldOp
>(op
.getLoc(), resultOrPoison
);
601 rewriter
.setInsertionPoint(op
->getBlock(), currentPoint
);
603 Value result
= adaptValueType(ternary
, rewriter
, type
);
605 rewriter
.replaceOp(op
, result
);
610 template <typename ArithOp
, typename EmitCOp
>
611 class SignedShiftOpConversion final
612 : public ShiftOpConversion
<ArithOp
, EmitCOp
, false> {
613 using ShiftOpConversion
<ArithOp
, EmitCOp
, false>::ShiftOpConversion
;
616 template <typename ArithOp
, typename EmitCOp
>
617 class UnsignedShiftOpConversion final
618 : public ShiftOpConversion
<ArithOp
, EmitCOp
, true> {
619 using ShiftOpConversion
<ArithOp
, EmitCOp
, true>::ShiftOpConversion
;
622 class SelectOpConversion
: public OpConversionPattern
<arith::SelectOp
> {
624 using OpConversionPattern
<arith::SelectOp
>::OpConversionPattern
;
627 matchAndRewrite(arith::SelectOp selectOp
, OpAdaptor adaptor
,
628 ConversionPatternRewriter
&rewriter
) const override
{
630 Type dstType
= getTypeConverter()->convertType(selectOp
.getType());
632 return rewriter
.notifyMatchFailure(selectOp
, "type conversion failed");
634 if (!adaptor
.getCondition().getType().isInteger(1))
635 return rewriter
.notifyMatchFailure(
637 "can only be converted if condition is a scalar of type i1");
639 rewriter
.replaceOpWithNewOp
<emitc::ConditionalOp
>(selectOp
, dstType
,
640 adaptor
.getOperands());
646 // Floating-point to integer conversions.
647 template <typename CastOp
>
648 class FtoICastOpConversion
: public OpConversionPattern
<CastOp
> {
650 FtoICastOpConversion(const TypeConverter
&typeConverter
, MLIRContext
*context
)
651 : OpConversionPattern
<CastOp
>(typeConverter
, context
) {}
654 matchAndRewrite(CastOp castOp
, typename
CastOp::Adaptor adaptor
,
655 ConversionPatternRewriter
&rewriter
) const override
{
657 Type operandType
= adaptor
.getIn().getType();
658 if (!emitc::isSupportedFloatType(operandType
))
659 return rewriter
.notifyMatchFailure(castOp
,
660 "unsupported cast source type");
662 Type dstType
= this->getTypeConverter()->convertType(castOp
.getType());
664 return rewriter
.notifyMatchFailure(castOp
, "type conversion failed");
666 // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
667 // truncated to 0, whereas a boolean conversion would return true.
668 if (!emitc::isSupportedIntegerType(dstType
) || dstType
.isInteger(1))
669 return rewriter
.notifyMatchFailure(castOp
,
670 "unsupported cast destination type");
672 // Convert to unsigned if it's the "ui" variant
673 // Signless is interpreted as signed, so no need to cast for "si"
674 Type actualResultType
= dstType
;
675 if (isa
<arith::FPToUIOp
>(castOp
)) {
677 rewriter
.getIntegerType(operandType
.getIntOrFloatBitWidth(),
681 Value result
= rewriter
.create
<emitc::CastOp
>(
682 castOp
.getLoc(), actualResultType
, adaptor
.getOperands());
684 if (isa
<arith::FPToUIOp
>(castOp
)) {
685 result
= rewriter
.create
<emitc::CastOp
>(castOp
.getLoc(), dstType
, result
);
687 rewriter
.replaceOp(castOp
, result
);
693 // Integer to floating-point conversions.
694 template <typename CastOp
>
695 class ItoFCastOpConversion
: public OpConversionPattern
<CastOp
> {
697 ItoFCastOpConversion(const TypeConverter
&typeConverter
, MLIRContext
*context
)
698 : OpConversionPattern
<CastOp
>(typeConverter
, context
) {}
701 matchAndRewrite(CastOp castOp
, typename
CastOp::Adaptor adaptor
,
702 ConversionPatternRewriter
&rewriter
) const override
{
703 // Vectors in particular are not supported
704 Type operandType
= adaptor
.getIn().getType();
705 if (!emitc::isSupportedIntegerType(operandType
))
706 return rewriter
.notifyMatchFailure(castOp
,
707 "unsupported cast source type");
709 Type dstType
= this->getTypeConverter()->convertType(castOp
.getType());
711 return rewriter
.notifyMatchFailure(castOp
, "type conversion failed");
713 if (!emitc::isSupportedFloatType(dstType
))
714 return rewriter
.notifyMatchFailure(castOp
,
715 "unsupported cast destination type");
717 // Convert to unsigned if it's the "ui" variant
718 // Signless is interpreted as signed, so no need to cast for "si"
719 Type actualOperandType
= operandType
;
720 if (isa
<arith::UIToFPOp
>(castOp
)) {
722 rewriter
.getIntegerType(operandType
.getIntOrFloatBitWidth(),
725 Value fpCastOperand
= adaptor
.getIn();
726 if (actualOperandType
!= operandType
) {
727 fpCastOperand
= rewriter
.template create
<emitc::CastOp
>(
728 castOp
.getLoc(), actualOperandType
, fpCastOperand
);
730 rewriter
.replaceOpWithNewOp
<emitc::CastOp
>(castOp
, dstType
, fpCastOperand
);
738 //===----------------------------------------------------------------------===//
739 // Pattern population
740 //===----------------------------------------------------------------------===//
742 void mlir::populateArithToEmitCPatterns(TypeConverter
&typeConverter
,
743 RewritePatternSet
&patterns
) {
744 MLIRContext
*ctx
= patterns
.getContext();
746 mlir::populateEmitCSizeTTypeConversions(typeConverter
);
750 ArithConstantOpConversionPattern
,
751 ArithOpConversion
<arith::AddFOp
, emitc::AddOp
>,
752 ArithOpConversion
<arith::DivFOp
, emitc::DivOp
>,
753 ArithOpConversion
<arith::DivSIOp
, emitc::DivOp
>,
754 ArithOpConversion
<arith::MulFOp
, emitc::MulOp
>,
755 ArithOpConversion
<arith::RemSIOp
, emitc::RemOp
>,
756 ArithOpConversion
<arith::SubFOp
, emitc::SubOp
>,
757 BinaryUIOpConversion
<arith::DivUIOp
, emitc::DivOp
>,
758 BinaryUIOpConversion
<arith::RemUIOp
, emitc::RemOp
>,
759 IntegerOpConversion
<arith::AddIOp
, emitc::AddOp
>,
760 IntegerOpConversion
<arith::MulIOp
, emitc::MulOp
>,
761 IntegerOpConversion
<arith::SubIOp
, emitc::SubOp
>,
762 BitwiseOpConversion
<arith::AndIOp
, emitc::BitwiseAndOp
>,
763 BitwiseOpConversion
<arith::OrIOp
, emitc::BitwiseOrOp
>,
764 BitwiseOpConversion
<arith::XOrIOp
, emitc::BitwiseXorOp
>,
765 UnsignedShiftOpConversion
<arith::ShLIOp
, emitc::BitwiseLeftShiftOp
>,
766 SignedShiftOpConversion
<arith::ShRSIOp
, emitc::BitwiseRightShiftOp
>,
767 UnsignedShiftOpConversion
<arith::ShRUIOp
, emitc::BitwiseRightShiftOp
>,
772 // Truncation is guaranteed for unsigned types.
773 UnsignedCastConversion
<arith::TruncIOp
>,
774 SignedCastConversion
<arith::ExtSIOp
>,
775 UnsignedCastConversion
<arith::ExtUIOp
>,
776 SignedCastConversion
<arith::IndexCastOp
>,
777 UnsignedCastConversion
<arith::IndexCastUIOp
>,
778 ItoFCastOpConversion
<arith::SIToFPOp
>,
779 ItoFCastOpConversion
<arith::UIToFPOp
>,
780 FtoICastOpConversion
<arith::FPToSIOp
>,
781 FtoICastOpConversion
<arith::FPToUIOp
>
782 >(typeConverter
, ctx
);