[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / ArithToEmitC / ArithToEmitC.cpp
blob50384d9a08e5d971b5da51c9c38edf0c2dd8853c
1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert the Arith dialect to the EmitC
10 // dialect.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/EmitC/IR/EmitC.h"
18 #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/DialectConversion.h"
23 using namespace mlir;
25 //===----------------------------------------------------------------------===//
26 // Conversion Patterns
27 //===----------------------------------------------------------------------===//
29 namespace {
30 class ArithConstantOpConversionPattern
31 : public OpConversionPattern<arith::ConstantOp> {
32 public:
33 using OpConversionPattern::OpConversionPattern;
35 LogicalResult
36 matchAndRewrite(arith::ConstantOp arithConst,
37 arith::ConstantOp::Adaptor adaptor,
38 ConversionPatternRewriter &rewriter) const override {
39 Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
40 if (!newTy)
41 return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
42 rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
43 adaptor.getValue());
44 return success();
48 /// Get the signed or unsigned type corresponding to \p ty.
49 Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
50 if (isa<IntegerType>(ty)) {
51 if (ty.isUnsignedInteger() != needsUnsigned) {
52 auto signedness = needsUnsigned
53 ? IntegerType::SignednessSemantics::Unsigned
54 : IntegerType::SignednessSemantics::Signed;
55 return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
56 signedness);
58 } else if (emitc::isPointerWideType(ty)) {
59 if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
60 if (needsUnsigned)
61 return emitc::SizeTType::get(ty.getContext());
62 return emitc::PtrDiffTType::get(ty.getContext());
65 return ty;
68 /// Insert a cast operation to type \p ty if \p val does not have this type.
69 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
70 return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
73 class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
74 public:
75 using OpConversionPattern::OpConversionPattern;
77 LogicalResult
78 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter) const override {
81 if (!isa<FloatType>(adaptor.getRhs().getType())) {
82 return rewriter.notifyMatchFailure(op.getLoc(),
83 "cmpf currently only supported on "
84 "floats, not tensors/vectors thereof");
87 bool unordered = false;
88 emitc::CmpPredicate predicate;
89 switch (op.getPredicate()) {
90 case arith::CmpFPredicate::AlwaysFalse: {
91 auto constant = rewriter.create<emitc::ConstantOp>(
92 op.getLoc(), rewriter.getI1Type(),
93 rewriter.getBoolAttr(/*value=*/false));
94 rewriter.replaceOp(op, constant);
95 return success();
97 case arith::CmpFPredicate::OEQ:
98 unordered = false;
99 predicate = emitc::CmpPredicate::eq;
100 break;
101 case arith::CmpFPredicate::OGT:
102 unordered = false;
103 predicate = emitc::CmpPredicate::gt;
104 break;
105 case arith::CmpFPredicate::OGE:
106 unordered = false;
107 predicate = emitc::CmpPredicate::ge;
108 break;
109 case arith::CmpFPredicate::OLT:
110 unordered = false;
111 predicate = emitc::CmpPredicate::lt;
112 break;
113 case arith::CmpFPredicate::OLE:
114 unordered = false;
115 predicate = emitc::CmpPredicate::le;
116 break;
117 case arith::CmpFPredicate::ONE:
118 unordered = false;
119 predicate = emitc::CmpPredicate::ne;
120 break;
121 case arith::CmpFPredicate::ORD: {
122 // ordered, i.e. none of the operands is NaN
123 auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
124 adaptor.getRhs());
125 rewriter.replaceOp(op, cmp);
126 return success();
128 case arith::CmpFPredicate::UEQ:
129 unordered = true;
130 predicate = emitc::CmpPredicate::eq;
131 break;
132 case arith::CmpFPredicate::UGT:
133 unordered = true;
134 predicate = emitc::CmpPredicate::gt;
135 break;
136 case arith::CmpFPredicate::UGE:
137 unordered = true;
138 predicate = emitc::CmpPredicate::ge;
139 break;
140 case arith::CmpFPredicate::ULT:
141 unordered = true;
142 predicate = emitc::CmpPredicate::lt;
143 break;
144 case arith::CmpFPredicate::ULE:
145 unordered = true;
146 predicate = emitc::CmpPredicate::le;
147 break;
148 case arith::CmpFPredicate::UNE:
149 unordered = true;
150 predicate = emitc::CmpPredicate::ne;
151 break;
152 case arith::CmpFPredicate::UNO: {
153 // unordered, i.e. either operand is nan
154 auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
155 adaptor.getRhs());
156 rewriter.replaceOp(op, cmp);
157 return success();
159 case arith::CmpFPredicate::AlwaysTrue: {
160 auto constant = rewriter.create<emitc::ConstantOp>(
161 op.getLoc(), rewriter.getI1Type(),
162 rewriter.getBoolAttr(/*value=*/true));
163 rewriter.replaceOp(op, constant);
164 return success();
168 // Compare the values naively
169 auto cmpResult =
170 rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
171 adaptor.getLhs(), adaptor.getRhs());
173 // Adjust the results for unordered/ordered semantics
174 if (unordered) {
175 auto isUnordered = createCheckIsUnordered(
176 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
177 rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
178 isUnordered, cmpResult);
179 return success();
182 auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
183 adaptor.getLhs(), adaptor.getRhs());
184 rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
185 isOrdered, cmpResult);
186 return success();
189 private:
190 /// Return a value that is true if \p operand is NaN.
191 Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
192 Value operand) const {
193 // A value is NaN exactly when it compares unequal to itself.
194 return rewriter.create<emitc::CmpOp>(
195 loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
198 /// Return a value that is true if \p operand is not NaN.
199 Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
200 Value operand) const {
201 // A value is not NaN exactly when it compares equal to itself.
202 return rewriter.create<emitc::CmpOp>(
203 loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
206 /// Return a value that is true if the operands \p first and \p second are
207 /// unordered (i.e., at least one of them is NaN).
208 Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
209 Location loc, Value first, Value second) const {
210 auto firstIsNaN = isNaN(rewriter, loc, first);
211 auto secondIsNaN = isNaN(rewriter, loc, second);
212 return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
213 firstIsNaN, secondIsNaN);
216 /// Return a value that is true if the operands \p first and \p second are
217 /// both ordered (i.e., none one of them is NaN).
218 Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
219 Value first, Value second) const {
220 auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
221 auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
222 return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
223 firstIsNotNaN, secondIsNotNaN);
227 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
228 public:
229 using OpConversionPattern::OpConversionPattern;
231 bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
232 switch (pred) {
233 case arith::CmpIPredicate::eq:
234 case arith::CmpIPredicate::ne:
235 case arith::CmpIPredicate::slt:
236 case arith::CmpIPredicate::sle:
237 case arith::CmpIPredicate::sgt:
238 case arith::CmpIPredicate::sge:
239 return false;
240 case arith::CmpIPredicate::ult:
241 case arith::CmpIPredicate::ule:
242 case arith::CmpIPredicate::ugt:
243 case arith::CmpIPredicate::uge:
244 return true;
246 llvm_unreachable("unknown cmpi predicate kind");
249 emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
250 switch (pred) {
251 case arith::CmpIPredicate::eq:
252 return emitc::CmpPredicate::eq;
253 case arith::CmpIPredicate::ne:
254 return emitc::CmpPredicate::ne;
255 case arith::CmpIPredicate::slt:
256 case arith::CmpIPredicate::ult:
257 return emitc::CmpPredicate::lt;
258 case arith::CmpIPredicate::sle:
259 case arith::CmpIPredicate::ule:
260 return emitc::CmpPredicate::le;
261 case arith::CmpIPredicate::sgt:
262 case arith::CmpIPredicate::ugt:
263 return emitc::CmpPredicate::gt;
264 case arith::CmpIPredicate::sge:
265 case arith::CmpIPredicate::uge:
266 return emitc::CmpPredicate::ge;
268 llvm_unreachable("unknown cmpi predicate kind");
271 LogicalResult
272 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
275 Type type = adaptor.getLhs().getType();
276 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
277 return rewriter.notifyMatchFailure(
278 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
281 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
282 emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
284 Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
285 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
286 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
288 rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
289 return success();
293 class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
294 public:
295 using OpConversionPattern::OpConversionPattern;
297 LogicalResult
298 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override {
301 auto adaptedOp = adaptor.getOperand();
302 auto adaptedOpType = adaptedOp.getType();
304 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
305 return rewriter.notifyMatchFailure(
306 op.getLoc(),
307 "negf currently only supports scalar types, not vectors or tensors");
310 if (!emitc::isSupportedFloatType(adaptedOpType)) {
311 return rewriter.notifyMatchFailure(
312 op.getLoc(), "floating-point type is not supported by EmitC");
315 rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
316 adaptedOp);
317 return success();
321 template <typename ArithOp, bool castToUnsigned>
322 class CastConversion : public OpConversionPattern<ArithOp> {
323 public:
324 using OpConversionPattern<ArithOp>::OpConversionPattern;
326 LogicalResult
327 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
328 ConversionPatternRewriter &rewriter) const override {
330 Type opReturnType = this->getTypeConverter()->convertType(op.getType());
331 if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
332 emitc::isPointerWideType(opReturnType)))
333 return rewriter.notifyMatchFailure(
334 op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
336 if (adaptor.getOperands().size() != 1) {
337 return rewriter.notifyMatchFailure(
338 op, "CastConversion only supports unary ops");
341 Type operandType = adaptor.getIn().getType();
342 if (!operandType || !(isa<IntegerType>(operandType) ||
343 emitc::isPointerWideType(operandType)))
344 return rewriter.notifyMatchFailure(
345 op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
347 // Signed (sign-extending) casts from i1 are not supported.
348 if (operandType.isInteger(1) && !castToUnsigned)
349 return rewriter.notifyMatchFailure(op,
350 "operation not supported on i1 type");
352 // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
353 // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
354 // truncation.
355 if (opReturnType.isInteger(1)) {
356 Type attrType = (emitc::isPointerWideType(operandType))
357 ? rewriter.getIndexType()
358 : operandType;
359 auto constOne = rewriter.create<emitc::ConstantOp>(
360 op.getLoc(), operandType, rewriter.getOneAttr(attrType));
361 auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
362 op.getLoc(), operandType, adaptor.getIn(), constOne);
363 rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
364 oneAndOperand);
365 return success();
368 bool isTruncation =
369 (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
370 operandType.getIntOrFloatBitWidth() >
371 opReturnType.getIntOrFloatBitWidth());
372 bool doUnsigned = castToUnsigned || isTruncation;
374 // Adapt the signedness of the result (bitwidth-preserving cast)
375 // This is needed e.g., if the return type is signless.
376 Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
378 // Adapt the signedness of the operand (bitwidth-preserving cast)
379 Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
380 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
382 // Actual cast (may change bitwidth)
383 auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
384 castDestType, actualOp);
386 // Cast to the expected output type
387 auto result = adaptValueType(cast, rewriter, opReturnType);
389 rewriter.replaceOp(op, result);
390 return success();
394 template <typename ArithOp>
395 class UnsignedCastConversion : public CastConversion<ArithOp, true> {
396 using CastConversion<ArithOp, true>::CastConversion;
399 template <typename ArithOp>
400 class SignedCastConversion : public CastConversion<ArithOp, false> {
401 using CastConversion<ArithOp, false>::CastConversion;
404 template <typename ArithOp, typename EmitCOp>
405 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
406 public:
407 using OpConversionPattern<ArithOp>::OpConversionPattern;
409 LogicalResult
410 matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
411 ConversionPatternRewriter &rewriter) const override {
413 Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
414 if (!newTy)
415 return rewriter.notifyMatchFailure(arithOp,
416 "converting result type failed");
417 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
418 adaptor.getOperands());
420 return success();
424 template <class ArithOp, class EmitCOp>
425 class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
426 public:
427 using OpConversionPattern<ArithOp>::OpConversionPattern;
429 LogicalResult
430 matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
431 ConversionPatternRewriter &rewriter) const override {
432 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
433 if (!newRetTy)
434 return rewriter.notifyMatchFailure(uiBinOp,
435 "converting result type failed");
436 if (!isa<IntegerType>(newRetTy)) {
437 return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
439 Type unsignedType =
440 adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
441 if (!unsignedType)
442 return rewriter.notifyMatchFailure(uiBinOp,
443 "converting result type failed");
444 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
445 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
447 auto newDivOp =
448 rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
449 ArrayRef<Value>{lhsAdapted, rhsAdapted});
450 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
451 rewriter.replaceOp(uiBinOp, resultAdapted);
452 return success();
456 template <typename ArithOp, typename EmitCOp>
457 class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
458 public:
459 using OpConversionPattern<ArithOp>::OpConversionPattern;
461 LogicalResult
462 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
463 ConversionPatternRewriter &rewriter) const override {
465 Type type = this->getTypeConverter()->convertType(op.getType());
466 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
467 return rewriter.notifyMatchFailure(
468 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
471 if (type.isInteger(1)) {
472 // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
473 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
476 Type arithmeticType = type;
477 if ((type.isSignlessInteger() || type.isSignedInteger()) &&
478 !bitEnumContainsAll(op.getOverflowFlags(),
479 arith::IntegerOverflowFlags::nsw)) {
480 // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
481 // we compute in unsigned integers to avoid UB.
482 arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
483 /*isSigned=*/false);
486 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
487 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
489 Value arithmeticResult = rewriter.template create<EmitCOp>(
490 op.getLoc(), arithmeticType, lhs, rhs);
492 Value result = adaptValueType(arithmeticResult, rewriter, type);
494 rewriter.replaceOp(op, result);
495 return success();
499 template <typename ArithOp, typename EmitCOp>
500 class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
501 public:
502 using OpConversionPattern<ArithOp>::OpConversionPattern;
504 LogicalResult
505 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
506 ConversionPatternRewriter &rewriter) const override {
508 Type type = this->getTypeConverter()->convertType(op.getType());
509 if (!isa_and_nonnull<IntegerType>(type)) {
510 return rewriter.notifyMatchFailure(
512 "expected integer type, vector/tensor support not yet implemented");
515 // Bitwise ops can be performed directly on booleans
516 if (type.isInteger(1)) {
517 rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
518 adaptor.getRhs());
519 return success();
522 // Bitwise ops are defined by the C standard on unsigned operands.
523 Type arithmeticType =
524 adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
526 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
527 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
529 Value arithmeticResult = rewriter.template create<EmitCOp>(
530 op.getLoc(), arithmeticType, lhs, rhs);
532 Value result = adaptValueType(arithmeticResult, rewriter, type);
534 rewriter.replaceOp(op, result);
535 return success();
539 template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
540 class ShiftOpConversion : public OpConversionPattern<ArithOp> {
541 public:
542 using OpConversionPattern<ArithOp>::OpConversionPattern;
544 LogicalResult
545 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
546 ConversionPatternRewriter &rewriter) const override {
548 Type type = this->getTypeConverter()->convertType(op.getType());
549 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
550 return rewriter.notifyMatchFailure(
551 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
554 if (type.isInteger(1)) {
555 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
558 Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
560 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
561 // Shift amount interpreted as unsigned per Arith dialect spec.
562 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
563 /*needsUnsigned=*/true);
564 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
566 // Add a runtime check for overflow
567 Value width;
568 if (emitc::isPointerWideType(type)) {
569 Value eight = rewriter.create<emitc::ConstantOp>(
570 op.getLoc(), rhsType, rewriter.getIndexAttr(8));
571 emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
572 op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
573 width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
574 sizeOfCall.getResult(0));
575 } else {
576 width = rewriter.create<emitc::ConstantOp>(
577 op.getLoc(), rhsType,
578 rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
581 Value excessCheck = rewriter.create<emitc::CmpOp>(
582 op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
584 // Any concrete value is a valid refinement of poison.
585 Value poison = rewriter.create<emitc::ConstantOp>(
586 op.getLoc(), arithmeticType,
587 (isa<IntegerType>(arithmeticType)
588 ? rewriter.getIntegerAttr(arithmeticType, 0)
589 : rewriter.getIndexAttr(0)));
591 emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
592 op.getLoc(), arithmeticType, /*do_not_inline=*/false);
593 Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
594 auto currentPoint = rewriter.getInsertionPoint();
595 rewriter.setInsertionPointToStart(&bodyBlock);
596 Value arithmeticResult =
597 rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
598 Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
599 op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
600 rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
601 rewriter.setInsertionPoint(op->getBlock(), currentPoint);
603 Value result = adaptValueType(ternary, rewriter, type);
605 rewriter.replaceOp(op, result);
606 return success();
610 template <typename ArithOp, typename EmitCOp>
611 class SignedShiftOpConversion final
612 : public ShiftOpConversion<ArithOp, EmitCOp, false> {
613 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
616 template <typename ArithOp, typename EmitCOp>
617 class UnsignedShiftOpConversion final
618 : public ShiftOpConversion<ArithOp, EmitCOp, true> {
619 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
622 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
623 public:
624 using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
626 LogicalResult
627 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
628 ConversionPatternRewriter &rewriter) const override {
630 Type dstType = getTypeConverter()->convertType(selectOp.getType());
631 if (!dstType)
632 return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
634 if (!adaptor.getCondition().getType().isInteger(1))
635 return rewriter.notifyMatchFailure(
636 selectOp,
637 "can only be converted if condition is a scalar of type i1");
639 rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
640 adaptor.getOperands());
642 return success();
646 // Floating-point to integer conversions.
647 template <typename CastOp>
648 class FtoICastOpConversion : public OpConversionPattern<CastOp> {
649 public:
650 FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
651 : OpConversionPattern<CastOp>(typeConverter, context) {}
653 LogicalResult
654 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
655 ConversionPatternRewriter &rewriter) const override {
657 Type operandType = adaptor.getIn().getType();
658 if (!emitc::isSupportedFloatType(operandType))
659 return rewriter.notifyMatchFailure(castOp,
660 "unsupported cast source type");
662 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
663 if (!dstType)
664 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
666 // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
667 // truncated to 0, whereas a boolean conversion would return true.
668 if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
669 return rewriter.notifyMatchFailure(castOp,
670 "unsupported cast destination type");
672 // Convert to unsigned if it's the "ui" variant
673 // Signless is interpreted as signed, so no need to cast for "si"
674 Type actualResultType = dstType;
675 if (isa<arith::FPToUIOp>(castOp)) {
676 actualResultType =
677 rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
678 /*isSigned=*/false);
681 Value result = rewriter.create<emitc::CastOp>(
682 castOp.getLoc(), actualResultType, adaptor.getOperands());
684 if (isa<arith::FPToUIOp>(castOp)) {
685 result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
687 rewriter.replaceOp(castOp, result);
689 return success();
693 // Integer to floating-point conversions.
694 template <typename CastOp>
695 class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
696 public:
697 ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
698 : OpConversionPattern<CastOp>(typeConverter, context) {}
700 LogicalResult
701 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
702 ConversionPatternRewriter &rewriter) const override {
703 // Vectors in particular are not supported
704 Type operandType = adaptor.getIn().getType();
705 if (!emitc::isSupportedIntegerType(operandType))
706 return rewriter.notifyMatchFailure(castOp,
707 "unsupported cast source type");
709 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
710 if (!dstType)
711 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
713 if (!emitc::isSupportedFloatType(dstType))
714 return rewriter.notifyMatchFailure(castOp,
715 "unsupported cast destination type");
717 // Convert to unsigned if it's the "ui" variant
718 // Signless is interpreted as signed, so no need to cast for "si"
719 Type actualOperandType = operandType;
720 if (isa<arith::UIToFPOp>(castOp)) {
721 actualOperandType =
722 rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
723 /*isSigned=*/false);
725 Value fpCastOperand = adaptor.getIn();
726 if (actualOperandType != operandType) {
727 fpCastOperand = rewriter.template create<emitc::CastOp>(
728 castOp.getLoc(), actualOperandType, fpCastOperand);
730 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
732 return success();
736 } // namespace
738 //===----------------------------------------------------------------------===//
739 // Pattern population
740 //===----------------------------------------------------------------------===//
742 void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
743 RewritePatternSet &patterns) {
744 MLIRContext *ctx = patterns.getContext();
746 mlir::populateEmitCSizeTTypeConversions(typeConverter);
748 // clang-format off
749 patterns.add<
750 ArithConstantOpConversionPattern,
751 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
752 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
753 ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
754 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
755 ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
756 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
757 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
758 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
759 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
760 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
761 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
762 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
763 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
764 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
765 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
766 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
767 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
768 CmpFOpConversion,
769 CmpIOpConversion,
770 NegFOpConversion,
771 SelectOpConversion,
772 // Truncation is guaranteed for unsigned types.
773 UnsignedCastConversion<arith::TruncIOp>,
774 SignedCastConversion<arith::ExtSIOp>,
775 UnsignedCastConversion<arith::ExtUIOp>,
776 SignedCastConversion<arith::IndexCastOp>,
777 UnsignedCastConversion<arith::IndexCastUIOp>,
778 ItoFCastOpConversion<arith::SIToFPOp>,
779 ItoFCastOpConversion<arith::UIToFPOp>,
780 FtoICastOpConversion<arith::FPToSIOp>,
781 FtoICastOpConversion<arith::FPToUIOp>
782 >(typeConverter, ctx);
783 // clang-format on