[AMDGPU] Mark AGPR tuple implicit in the first instr of AGPR spills. (#115285)
[llvm-project.git] / mlir / lib / Conversion / MathToLLVM / MathToLLVM.cpp
blob668f8385ac2dcf4a5ee175876269851cffda2dfd
1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
26 using namespace mlir;
28 namespace {
30 template <typename SourceOp, typename TargetOp>
31 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
33 template <typename SourceOp, typename TargetOp>
34 using ConvertFMFMathToLLVMPattern =
35 VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
37 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
38 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
39 using CopySignOpLowering =
40 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
41 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
42 using CtPopFOpLowering =
43 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
44 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
45 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
46 using FloorOpLowering =
47 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
48 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
49 using Log10OpLowering =
50 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
51 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
52 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
53 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
54 using FPowIOpLowering =
55 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
56 using RoundEvenOpLowering =
57 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
58 using RoundOpLowering =
59 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
60 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
61 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
62 using FTruncOpLowering =
63 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
65 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
66 template <typename MathOp, typename LLVMOp>
67 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
68 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
69 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
71 LogicalResult
72 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
73 ConversionPatternRewriter &rewriter) const override {
74 auto operandType = adaptor.getOperand().getType();
76 if (!operandType || !LLVM::isCompatibleType(operandType))
77 return failure();
79 auto loc = op.getLoc();
80 auto resultType = op.getResult().getType();
82 if (!isa<LLVM::LLVMArrayType>(operandType)) {
83 rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
84 false);
85 return success();
88 auto vectorType = dyn_cast<VectorType>(resultType);
89 if (!vectorType)
90 return failure();
92 return LLVM::detail::handleMultidimensionalVectors(
93 op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
94 [&](Type llvm1DVectorTy, ValueRange operands) {
95 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
96 false);
98 rewriter);
102 using CountLeadingZerosOpLowering =
103 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
104 using CountTrailingZerosOpLowering =
105 IntOpWithFlagLowering<math::CountTrailingZerosOp,
106 LLVM::CountTrailingZerosOp>;
107 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
109 // A `expm1` is converted into `exp - 1`.
110 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
111 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
113 LogicalResult
114 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter) const override {
116 auto operandType = adaptor.getOperand().getType();
118 if (!operandType || !LLVM::isCompatibleType(operandType))
119 return failure();
121 auto loc = op.getLoc();
122 auto resultType = op.getResult().getType();
123 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
124 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
125 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
126 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
128 if (!isa<LLVM::LLVMArrayType>(operandType)) {
129 LLVM::ConstantOp one;
130 if (LLVM::isCompatibleVectorType(operandType)) {
131 one = rewriter.create<LLVM::ConstantOp>(
132 loc, operandType,
133 SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
134 } else {
135 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
137 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
138 expAttrs.getAttrs());
139 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
140 op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
141 return success();
144 auto vectorType = dyn_cast<VectorType>(resultType);
145 if (!vectorType)
146 return rewriter.notifyMatchFailure(op, "expected vector result type");
148 return LLVM::detail::handleMultidimensionalVectors(
149 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
150 [&](Type llvm1DVectorTy, ValueRange operands) {
151 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
152 auto splatAttr = SplatElementsAttr::get(
153 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
154 {numElements.isScalable()}),
155 floatOne);
156 auto one =
157 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
158 auto exp = rewriter.create<LLVM::ExpOp>(
159 loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
160 return rewriter.create<LLVM::FSubOp>(
161 loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
163 rewriter);
167 // A `log1p` is converted into `log(1 + ...)`.
168 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
169 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
171 LogicalResult
172 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter) const override {
174 auto operandType = adaptor.getOperand().getType();
176 if (!operandType || !LLVM::isCompatibleType(operandType))
177 return rewriter.notifyMatchFailure(op, "unsupported operand type");
179 auto loc = op.getLoc();
180 auto resultType = op.getResult().getType();
181 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
182 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
183 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
184 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
186 if (!isa<LLVM::LLVMArrayType>(operandType)) {
187 LLVM::ConstantOp one =
188 LLVM::isCompatibleVectorType(operandType)
189 ? rewriter.create<LLVM::ConstantOp>(
190 loc, operandType,
191 SplatElementsAttr::get(cast<ShapedType>(resultType),
192 floatOne))
193 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
195 auto add = rewriter.create<LLVM::FAddOp>(
196 loc, operandType, ValueRange{one, adaptor.getOperand()},
197 addAttrs.getAttrs());
198 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
199 logAttrs.getAttrs());
200 return success();
203 auto vectorType = dyn_cast<VectorType>(resultType);
204 if (!vectorType)
205 return rewriter.notifyMatchFailure(op, "expected vector result type");
207 return LLVM::detail::handleMultidimensionalVectors(
208 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
209 [&](Type llvm1DVectorTy, ValueRange operands) {
210 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
211 auto splatAttr = SplatElementsAttr::get(
212 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
213 {numElements.isScalable()}),
214 floatOne);
215 auto one =
216 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
217 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
218 ValueRange{one, operands[0]},
219 addAttrs.getAttrs());
220 return rewriter.create<LLVM::LogOp>(
221 loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
223 rewriter);
227 // A `rsqrt` is converted into `1 / sqrt`.
228 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
229 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
231 LogicalResult
232 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter) const override {
234 auto operandType = adaptor.getOperand().getType();
236 if (!operandType || !LLVM::isCompatibleType(operandType))
237 return failure();
239 auto loc = op.getLoc();
240 auto resultType = op.getResult().getType();
241 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
242 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
243 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
244 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
246 if (!isa<LLVM::LLVMArrayType>(operandType)) {
247 LLVM::ConstantOp one;
248 if (LLVM::isCompatibleVectorType(operandType)) {
249 one = rewriter.create<LLVM::ConstantOp>(
250 loc, operandType,
251 SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
252 } else {
253 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
255 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
256 sqrtAttrs.getAttrs());
257 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
258 op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
259 return success();
262 auto vectorType = dyn_cast<VectorType>(resultType);
263 if (!vectorType)
264 return failure();
266 return LLVM::detail::handleMultidimensionalVectors(
267 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
268 [&](Type llvm1DVectorTy, ValueRange operands) {
269 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
270 auto splatAttr = SplatElementsAttr::get(
271 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
272 {numElements.isScalable()}),
273 floatOne);
274 auto one =
275 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
276 auto sqrt = rewriter.create<LLVM::SqrtOp>(
277 loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
278 return rewriter.create<LLVM::FDivOp>(
279 loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
281 rewriter);
285 struct ConvertMathToLLVMPass
286 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
287 using Base::Base;
289 void runOnOperation() override {
290 RewritePatternSet patterns(&getContext());
291 LLVMTypeConverter converter(&getContext());
292 populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
293 LLVMConversionTarget target(getContext());
294 if (failed(applyPartialConversion(getOperation(), target,
295 std::move(patterns))))
296 signalPassFailure();
299 } // namespace
301 void mlir::populateMathToLLVMConversionPatterns(
302 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
303 bool approximateLog1p) {
304 if (approximateLog1p)
305 patterns.add<Log1pOpLowering>(converter);
306 // clang-format off
307 patterns.add<
308 AbsFOpLowering,
309 AbsIOpLowering,
310 CeilOpLowering,
311 CopySignOpLowering,
312 CosOpLowering,
313 CountLeadingZerosOpLowering,
314 CountTrailingZerosOpLowering,
315 CtPopFOpLowering,
316 Exp2OpLowering,
317 ExpM1OpLowering,
318 ExpOpLowering,
319 FPowIOpLowering,
320 FloorOpLowering,
321 FmaOpLowering,
322 Log10OpLowering,
323 Log2OpLowering,
324 LogOpLowering,
325 PowFOpLowering,
326 RoundEvenOpLowering,
327 RoundOpLowering,
328 RsqrtOpLowering,
329 SinOpLowering,
330 SqrtOpLowering,
331 FTruncOpLowering
332 >(converter);
333 // clang-format on
336 //===----------------------------------------------------------------------===//
337 // ConvertToLLVMPatternInterface implementation
338 //===----------------------------------------------------------------------===//
340 namespace {
341 /// Implement the interface to convert Math to LLVM.
342 struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
343 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
344 void loadDependentDialects(MLIRContext *context) const final {
345 context->loadDialect<LLVM::LLVMDialect>();
348 /// Hook for derived dialect interface to provide conversion patterns
349 /// and mark dialect legal for the conversion target.
350 void populateConvertToLLVMConversionPatterns(
351 ConversionTarget &target, LLVMTypeConverter &typeConverter,
352 RewritePatternSet &patterns) const final {
353 populateMathToLLVMConversionPatterns(typeConverter, patterns);
356 } // namespace
358 void mlir::registerConvertMathToLLVMInterface(DialectRegistry &registry) {
359 registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
360 dialect->addInterfaces<MathToLLVMDialectInterface>();