1 //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Utils/IndexingUtils.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30 #include "mlir/Conversion/Passes.h.inc"
35 #define DEBUG_TYPE "math-to-funcs"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
39 // Pattern to convert vector operations to scalar operations.
40 template <typename Op
>
41 struct VecOpToScalarOp
: public OpRewritePattern
<Op
> {
43 using OpRewritePattern
<Op
>::OpRewritePattern
;
45 LogicalResult
matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const final
;
48 // Callback type for getting pre-generated FuncOp implementing
49 // an operation of the given type.
50 using GetFuncCallbackTy
= function_ref
<func::FuncOp(Operation
*, Type
)>;
52 // Pattern to convert scalar IPowIOp into a call of outlined
53 // software implementation.
54 class IPowIOpLowering
: public OpRewritePattern
<math::IPowIOp
> {
56 IPowIOpLowering(MLIRContext
*context
, GetFuncCallbackTy cb
)
57 : OpRewritePattern
<math::IPowIOp
>(context
), getFuncOpCallback(cb
) {}
59 /// Convert IPowI into a call to a local function implementing
60 /// the power operation. The local function computes a scalar result,
61 /// so vector forms of IPowI are linearized.
62 LogicalResult
matchAndRewrite(math::IPowIOp op
,
63 PatternRewriter
&rewriter
) const final
;
66 GetFuncCallbackTy getFuncOpCallback
;
69 // Pattern to convert scalar FPowIOp into a call of outlined
70 // software implementation.
71 class FPowIOpLowering
: public OpRewritePattern
<math::FPowIOp
> {
73 FPowIOpLowering(MLIRContext
*context
, GetFuncCallbackTy cb
)
74 : OpRewritePattern
<math::FPowIOp
>(context
), getFuncOpCallback(cb
) {}
76 /// Convert FPowI into a call to a local function implementing
77 /// the power operation. The local function computes a scalar result,
78 /// so vector forms of FPowI are linearized.
79 LogicalResult
matchAndRewrite(math::FPowIOp op
,
80 PatternRewriter
&rewriter
) const final
;
83 GetFuncCallbackTy getFuncOpCallback
;
86 // Pattern to convert scalar ctlz into a call of outlined software
88 class CtlzOpLowering
: public OpRewritePattern
<math::CountLeadingZerosOp
> {
90 CtlzOpLowering(MLIRContext
*context
, GetFuncCallbackTy cb
)
91 : OpRewritePattern
<math::CountLeadingZerosOp
>(context
),
92 getFuncOpCallback(cb
) {}
94 /// Convert ctlz into a call to a local function implementing
95 /// the count leading zeros operation.
96 LogicalResult
matchAndRewrite(math::CountLeadingZerosOp op
,
97 PatternRewriter
&rewriter
) const final
;
100 GetFuncCallbackTy getFuncOpCallback
;
104 template <typename Op
>
106 VecOpToScalarOp
<Op
>::matchAndRewrite(Op op
, PatternRewriter
&rewriter
) const {
107 Type opType
= op
.getType();
108 Location loc
= op
.getLoc();
109 auto vecType
= dyn_cast
<VectorType
>(opType
);
112 return rewriter
.notifyMatchFailure(op
, "not a vector operation");
113 if (!vecType
.hasRank())
114 return rewriter
.notifyMatchFailure(op
, "unknown vector rank");
115 ArrayRef
<int64_t> shape
= vecType
.getShape();
116 int64_t numElements
= vecType
.getNumElements();
118 Type resultElementType
= vecType
.getElementType();
119 Attribute initValueAttr
;
120 if (isa
<FloatType
>(resultElementType
))
121 initValueAttr
= FloatAttr::get(resultElementType
, 0.0);
123 initValueAttr
= IntegerAttr::get(resultElementType
, 0);
124 Value result
= rewriter
.create
<arith::ConstantOp
>(
125 loc
, DenseElementsAttr::get(vecType
, initValueAttr
));
126 SmallVector
<int64_t> strides
= computeStrides(shape
);
127 for (int64_t linearIndex
= 0; linearIndex
< numElements
; ++linearIndex
) {
128 SmallVector
<int64_t> positions
= delinearize(linearIndex
, strides
);
129 SmallVector
<Value
> operands
;
130 for (Value input
: op
->getOperands())
132 rewriter
.create
<vector::ExtractOp
>(loc
, input
, positions
));
134 rewriter
.create
<Op
>(loc
, vecType
.getElementType(), operands
);
136 rewriter
.create
<vector::InsertOp
>(loc
, scalarOp
, result
, positions
);
138 rewriter
.replaceOp(op
, result
);
142 static FunctionType
getElementalFuncTypeForOp(Operation
*op
) {
143 SmallVector
<Type
, 1> resultTys(op
->getNumResults());
144 SmallVector
<Type
, 2> inputTys(op
->getNumOperands());
145 std::transform(op
->result_type_begin(), op
->result_type_end(),
147 [](Type ty
) { return getElementTypeOrSelf(ty
); });
148 std::transform(op
->operand_type_begin(), op
->operand_type_end(),
150 [](Type ty
) { return getElementTypeOrSelf(ty
); });
151 return FunctionType::get(op
->getContext(), inputTys
, resultTys
);
154 /// Create linkonce_odr function to implement the power function with
155 /// the given \p elementType type inside \p module. The \p elementType
156 /// must be IntegerType, an the created function has
157 /// 'IntegerType (*)(IntegerType, IntegerType)' function type.
159 /// template <typename T>
160 /// T __mlir_math_ipowi_*(T b, T p) {
165 /// return T(1) / T(0); // trigger div-by-zero
168 /// if (b == T(-1)) {
185 static func::FuncOp
createElementIPowIFunc(ModuleOp
*module
, Type elementType
) {
186 assert(isa
<IntegerType
>(elementType
) &&
187 "non-integer element type for IPowIOp");
189 ImplicitLocOpBuilder builder
=
190 ImplicitLocOpBuilder::atBlockEnd(module
->getLoc(), module
->getBody());
192 std::string
funcName("__mlir_math_ipowi");
193 llvm::raw_string_ostream
nameOS(funcName
);
194 nameOS
<< '_' << elementType
;
196 FunctionType funcType
= FunctionType::get(
197 builder
.getContext(), {elementType
, elementType
}, elementType
);
198 auto funcOp
= builder
.create
<func::FuncOp
>(funcName
, funcType
);
199 LLVM::linkage::Linkage inlineLinkage
= LLVM::linkage::Linkage::LinkonceODR
;
201 LLVM::LinkageAttr::get(builder
.getContext(), inlineLinkage
);
202 funcOp
->setAttr("llvm.linkage", linkage
);
205 Block
*entryBlock
= funcOp
.addEntryBlock();
206 Region
*funcBody
= entryBlock
->getParent();
208 Value bArg
= funcOp
.getArgument(0);
209 Value pArg
= funcOp
.getArgument(1);
210 builder
.setInsertionPointToEnd(entryBlock
);
211 Value zeroValue
= builder
.create
<arith::ConstantOp
>(
212 elementType
, builder
.getIntegerAttr(elementType
, 0));
213 Value oneValue
= builder
.create
<arith::ConstantOp
>(
214 elementType
, builder
.getIntegerAttr(elementType
, 1));
215 Value minusOneValue
= builder
.create
<arith::ConstantOp
>(
217 builder
.getIntegerAttr(elementType
,
218 APInt(elementType
.getIntOrFloatBitWidth(), -1ULL,
219 /*isSigned=*/true)));
224 builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
, pArg
, zeroValue
);
225 Block
*thenBlock
= builder
.createBlock(funcBody
);
226 builder
.create
<func::ReturnOp
>(oneValue
);
227 Block
*fallthroughBlock
= builder
.createBlock(funcBody
);
228 // Set up conditional branch for (p == T(0)).
229 builder
.setInsertionPointToEnd(pIsZero
->getBlock());
230 builder
.create
<cf::CondBranchOp
>(pIsZero
, thenBlock
, fallthroughBlock
);
233 builder
.setInsertionPointToEnd(fallthroughBlock
);
235 builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::sle
, pArg
, zeroValue
);
237 builder
.createBlock(funcBody
);
239 builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
, bArg
, zeroValue
);
240 // return T(1) / T(0);
241 thenBlock
= builder
.createBlock(funcBody
);
242 builder
.create
<func::ReturnOp
>(
243 builder
.create
<arith::DivSIOp
>(oneValue
, zeroValue
).getResult());
244 fallthroughBlock
= builder
.createBlock(funcBody
);
245 // Set up conditional branch for (b == T(0)).
246 builder
.setInsertionPointToEnd(bIsZero
->getBlock());
247 builder
.create
<cf::CondBranchOp
>(bIsZero
, thenBlock
, fallthroughBlock
);
250 builder
.setInsertionPointToEnd(fallthroughBlock
);
252 builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
, bArg
, oneValue
);
254 thenBlock
= builder
.createBlock(funcBody
);
255 builder
.create
<func::ReturnOp
>(oneValue
);
256 fallthroughBlock
= builder
.createBlock(funcBody
);
257 // Set up conditional branch for (b == T(1)).
258 builder
.setInsertionPointToEnd(bIsOne
->getBlock());
259 builder
.create
<cf::CondBranchOp
>(bIsOne
, thenBlock
, fallthroughBlock
);
262 builder
.setInsertionPointToEnd(fallthroughBlock
);
263 auto bIsMinusOne
= builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
,
264 bArg
, minusOneValue
);
266 builder
.createBlock(funcBody
);
267 auto pIsOdd
= builder
.create
<arith::CmpIOp
>(
268 arith::CmpIPredicate::ne
, builder
.create
<arith::AndIOp
>(pArg
, oneValue
),
271 thenBlock
= builder
.createBlock(funcBody
);
272 builder
.create
<func::ReturnOp
>(minusOneValue
);
273 fallthroughBlock
= builder
.createBlock(funcBody
);
274 // Set up conditional branch for (p & T(1)).
275 builder
.setInsertionPointToEnd(pIsOdd
->getBlock());
276 builder
.create
<cf::CondBranchOp
>(pIsOdd
, thenBlock
, fallthroughBlock
);
280 builder
.setInsertionPointToEnd(fallthroughBlock
);
281 builder
.create
<func::ReturnOp
>(oneValue
);
282 fallthroughBlock
= builder
.createBlock(funcBody
);
283 // Set up conditional branch for (b == T(-1)).
284 builder
.setInsertionPointToEnd(bIsMinusOne
->getBlock());
285 builder
.create
<cf::CondBranchOp
>(bIsMinusOne
, pIsOdd
->getBlock(),
290 builder
.setInsertionPointToEnd(fallthroughBlock
);
291 builder
.create
<func::ReturnOp
>(zeroValue
);
292 Block
*loopHeader
= builder
.createBlock(
293 funcBody
, funcBody
->end(), {elementType
, elementType
, elementType
},
294 {builder
.getLoc(), builder
.getLoc(), builder
.getLoc()});
295 // Set up conditional branch for (p < T(0)).
296 builder
.setInsertionPointToEnd(pIsNeg
->getBlock());
297 // Set initial values of 'result', 'b' and 'p' for the loop.
298 builder
.create
<cf::CondBranchOp
>(pIsNeg
, bIsZero
->getBlock(), loopHeader
,
299 ValueRange
{oneValue
, bArg
, pArg
});
310 Value resultTmp
= loopHeader
->getArgument(0);
311 Value baseTmp
= loopHeader
->getArgument(1);
312 Value powerTmp
= loopHeader
->getArgument(2);
313 builder
.setInsertionPointToEnd(loopHeader
);
316 auto powerTmpIsOdd
= builder
.create
<arith::CmpIOp
>(
317 arith::CmpIPredicate::ne
,
318 builder
.create
<arith::AndIOp
>(powerTmp
, oneValue
), zeroValue
);
319 thenBlock
= builder
.createBlock(funcBody
);
321 Value newResultTmp
= builder
.create
<arith::MulIOp
>(resultTmp
, baseTmp
);
322 fallthroughBlock
= builder
.createBlock(funcBody
, funcBody
->end(), elementType
,
324 builder
.setInsertionPointToEnd(thenBlock
);
325 builder
.create
<cf::BranchOp
>(newResultTmp
, fallthroughBlock
);
326 // Set up conditional branch for (p & T(1)).
327 builder
.setInsertionPointToEnd(powerTmpIsOdd
->getBlock());
328 builder
.create
<cf::CondBranchOp
>(powerTmpIsOdd
, thenBlock
, fallthroughBlock
,
331 newResultTmp
= fallthroughBlock
->getArgument(0);
334 builder
.setInsertionPointToEnd(fallthroughBlock
);
335 Value newPowerTmp
= builder
.create
<arith::ShRUIOp
>(powerTmp
, oneValue
);
338 auto newPowerIsZero
= builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
,
339 newPowerTmp
, zeroValue
);
341 thenBlock
= builder
.createBlock(funcBody
);
342 builder
.create
<func::ReturnOp
>(newResultTmp
);
343 fallthroughBlock
= builder
.createBlock(funcBody
);
344 // Set up conditional branch for (p == T(0)).
345 builder
.setInsertionPointToEnd(newPowerIsZero
->getBlock());
346 builder
.create
<cf::CondBranchOp
>(newPowerIsZero
, thenBlock
, fallthroughBlock
);
350 builder
.setInsertionPointToEnd(fallthroughBlock
);
351 Value newBaseTmp
= builder
.create
<arith::MulIOp
>(baseTmp
, baseTmp
);
352 // Pass new values for 'result', 'b' and 'p' to the loop header.
353 builder
.create
<cf::BranchOp
>(
354 ValueRange
{newResultTmp
, newBaseTmp
, newPowerTmp
}, loopHeader
);
358 /// Convert IPowI into a call to a local function implementing
359 /// the power operation. The local function computes a scalar result,
360 /// so vector forms of IPowI are linearized.
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op
,
363 PatternRewriter
&rewriter
) const {
364 auto baseType
= dyn_cast
<IntegerType
>(op
.getOperands()[0].getType());
367 return rewriter
.notifyMatchFailure(op
, "non-integer base operand");
369 // The outlined software implementation must have been already
371 func::FuncOp elementFunc
= getFuncOpCallback(op
, baseType
);
373 return rewriter
.notifyMatchFailure(op
, "missing software implementation");
375 rewriter
.replaceOpWithNewOp
<func::CallOp
>(op
, elementFunc
, op
.getOperands());
379 /// Create linkonce_odr function to implement the power function with
380 /// the given \p funcType type inside \p module. The \p funcType must be
381 /// 'FloatType (*)(FloatType, IntegerType)' function type.
383 /// template <typename T>
384 /// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
387 /// bool isNegativePower{p < Tp{0}}
388 /// bool isMin{p == std::numeric_limits<Tp>::min()};
390 /// p = std::numeric_limits<Tp>::max();
391 /// } else if (isNegativePower) {
394 /// Tb result = Tb{1};
395 /// Tb origBase = Tb{b};
405 /// result *= origBase;
407 /// if (isNegativePower) {
408 /// result = Tb{1} / result;
412 static func::FuncOp
createElementFPowIFunc(ModuleOp
*module
,
413 FunctionType funcType
) {
414 auto baseType
= cast
<FloatType
>(funcType
.getInput(0));
415 auto powType
= cast
<IntegerType
>(funcType
.getInput(1));
416 ImplicitLocOpBuilder builder
=
417 ImplicitLocOpBuilder::atBlockEnd(module
->getLoc(), module
->getBody());
419 std::string
funcName("__mlir_math_fpowi");
420 llvm::raw_string_ostream
nameOS(funcName
);
421 nameOS
<< '_' << baseType
;
422 nameOS
<< '_' << powType
;
423 auto funcOp
= builder
.create
<func::FuncOp
>(funcName
, funcType
);
424 LLVM::linkage::Linkage inlineLinkage
= LLVM::linkage::Linkage::LinkonceODR
;
426 LLVM::LinkageAttr::get(builder
.getContext(), inlineLinkage
);
427 funcOp
->setAttr("llvm.linkage", linkage
);
430 Block
*entryBlock
= funcOp
.addEntryBlock();
431 Region
*funcBody
= entryBlock
->getParent();
433 Value bArg
= funcOp
.getArgument(0);
434 Value pArg
= funcOp
.getArgument(1);
435 builder
.setInsertionPointToEnd(entryBlock
);
436 Value oneBValue
= builder
.create
<arith::ConstantOp
>(
437 baseType
, builder
.getFloatAttr(baseType
, 1.0));
438 Value zeroPValue
= builder
.create
<arith::ConstantOp
>(
439 powType
, builder
.getIntegerAttr(powType
, 0));
440 Value onePValue
= builder
.create
<arith::ConstantOp
>(
441 powType
, builder
.getIntegerAttr(powType
, 1));
442 Value minPValue
= builder
.create
<arith::ConstantOp
>(
443 powType
, builder
.getIntegerAttr(powType
, llvm::APInt::getSignedMinValue(
444 powType
.getWidth())));
445 Value maxPValue
= builder
.create
<arith::ConstantOp
>(
446 powType
, builder
.getIntegerAttr(powType
, llvm::APInt::getSignedMaxValue(
447 powType
.getWidth())));
452 builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
, pArg
, zeroPValue
);
453 Block
*thenBlock
= builder
.createBlock(funcBody
);
454 builder
.create
<func::ReturnOp
>(oneBValue
);
455 Block
*fallthroughBlock
= builder
.createBlock(funcBody
);
456 // Set up conditional branch for (p == Tp{0}).
457 builder
.setInsertionPointToEnd(pIsZero
->getBlock());
458 builder
.create
<cf::CondBranchOp
>(pIsZero
, thenBlock
, fallthroughBlock
);
460 builder
.setInsertionPointToEnd(fallthroughBlock
);
461 // bool isNegativePower{p < Tp{0}}
462 auto pIsNeg
= builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::sle
, pArg
,
464 // bool isMin{p == std::numeric_limits<Tp>::min()};
466 builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
, pArg
, minPValue
);
469 // p = std::numeric_limits<Tp>::max();
470 // } else if (isNegativePower) {
473 Value negP
= builder
.create
<arith::SubIOp
>(zeroPValue
, pArg
);
474 auto pInit
= builder
.create
<arith::SelectOp
>(pIsNeg
, negP
, pArg
);
475 pInit
= builder
.create
<arith::SelectOp
>(pIsMin
, maxPValue
, pInit
);
477 // Tb result = Tb{1};
478 // Tb origBase = Tb{b};
487 Block
*loopHeader
= builder
.createBlock(
488 funcBody
, funcBody
->end(), {baseType
, baseType
, powType
},
489 {builder
.getLoc(), builder
.getLoc(), builder
.getLoc()});
490 // Set initial values of 'result', 'b' and 'p' for the loop.
491 builder
.setInsertionPointToEnd(pInit
->getBlock());
492 builder
.create
<cf::BranchOp
>(loopHeader
, ValueRange
{oneBValue
, bArg
, pInit
});
495 Value resultTmp
= loopHeader
->getArgument(0);
496 Value baseTmp
= loopHeader
->getArgument(1);
497 Value powerTmp
= loopHeader
->getArgument(2);
498 builder
.setInsertionPointToEnd(loopHeader
);
501 auto powerTmpIsOdd
= builder
.create
<arith::CmpIOp
>(
502 arith::CmpIPredicate::ne
,
503 builder
.create
<arith::AndIOp
>(powerTmp
, onePValue
), zeroPValue
);
504 thenBlock
= builder
.createBlock(funcBody
);
506 Value newResultTmp
= builder
.create
<arith::MulFOp
>(resultTmp
, baseTmp
);
507 fallthroughBlock
= builder
.createBlock(funcBody
, funcBody
->end(), baseType
,
509 builder
.setInsertionPointToEnd(thenBlock
);
510 builder
.create
<cf::BranchOp
>(newResultTmp
, fallthroughBlock
);
511 // Set up conditional branch for (p & Tp{1}).
512 builder
.setInsertionPointToEnd(powerTmpIsOdd
->getBlock());
513 builder
.create
<cf::CondBranchOp
>(powerTmpIsOdd
, thenBlock
, fallthroughBlock
,
516 newResultTmp
= fallthroughBlock
->getArgument(0);
519 builder
.setInsertionPointToEnd(fallthroughBlock
);
520 Value newPowerTmp
= builder
.create
<arith::ShRUIOp
>(powerTmp
, onePValue
);
523 auto newPowerIsZero
= builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
,
524 newPowerTmp
, zeroPValue
);
527 // The conditional branch is finalized below with a jump to
528 // the loop exit block.
529 fallthroughBlock
= builder
.createBlock(funcBody
);
533 builder
.setInsertionPointToEnd(fallthroughBlock
);
534 Value newBaseTmp
= builder
.create
<arith::MulFOp
>(baseTmp
, baseTmp
);
535 // Pass new values for 'result', 'b' and 'p' to the loop header.
536 builder
.create
<cf::BranchOp
>(
537 ValueRange
{newResultTmp
, newBaseTmp
, newPowerTmp
}, loopHeader
);
539 // Set up conditional branch for early loop exit:
542 Block
*loopExit
= builder
.createBlock(funcBody
, funcBody
->end(), baseType
,
544 builder
.setInsertionPointToEnd(newPowerIsZero
->getBlock());
545 builder
.create
<cf::CondBranchOp
>(newPowerIsZero
, loopExit
, newResultTmp
,
546 fallthroughBlock
, ValueRange
{});
549 // result *= origBase;
551 newResultTmp
= loopExit
->getArgument(0);
552 thenBlock
= builder
.createBlock(funcBody
);
553 fallthroughBlock
= builder
.createBlock(funcBody
, funcBody
->end(), baseType
,
555 builder
.setInsertionPointToEnd(loopExit
);
556 builder
.create
<cf::CondBranchOp
>(pIsMin
, thenBlock
, fallthroughBlock
,
558 builder
.setInsertionPointToEnd(thenBlock
);
559 newResultTmp
= builder
.create
<arith::MulFOp
>(newResultTmp
, bArg
);
560 builder
.create
<cf::BranchOp
>(newResultTmp
, fallthroughBlock
);
562 /// if (isNegativePower) {
563 /// result = Tb{1} / result;
565 newResultTmp
= fallthroughBlock
->getArgument(0);
566 thenBlock
= builder
.createBlock(funcBody
);
567 Block
*returnBlock
= builder
.createBlock(funcBody
, funcBody
->end(), baseType
,
569 builder
.setInsertionPointToEnd(fallthroughBlock
);
570 builder
.create
<cf::CondBranchOp
>(pIsNeg
, thenBlock
, returnBlock
,
572 builder
.setInsertionPointToEnd(thenBlock
);
573 newResultTmp
= builder
.create
<arith::DivFOp
>(oneBValue
, newResultTmp
);
574 builder
.create
<cf::BranchOp
>(newResultTmp
, returnBlock
);
577 builder
.setInsertionPointToEnd(returnBlock
);
578 builder
.create
<func::ReturnOp
>(returnBlock
->getArgument(0));
583 /// Convert FPowI into a call to a local function implementing
584 /// the power operation. The local function computes a scalar result,
585 /// so vector forms of FPowI are linearized.
587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op
,
588 PatternRewriter
&rewriter
) const {
589 if (dyn_cast
<VectorType
>(op
.getType()))
590 return rewriter
.notifyMatchFailure(op
, "non-scalar operation");
592 FunctionType funcType
= getElementalFuncTypeForOp(op
);
594 // The outlined software implementation must have been already
596 func::FuncOp elementFunc
= getFuncOpCallback(op
, funcType
);
598 return rewriter
.notifyMatchFailure(op
, "missing software implementation");
600 rewriter
.replaceOpWithNewOp
<func::CallOp
>(op
, elementFunc
, op
.getOperands());
604 /// Create function to implement the ctlz function the given \p elementType type
605 /// inside \p module. The \p elementType must be IntegerType, an the created
606 /// function has 'IntegerType (*)(IntegerType)' function type.
608 /// template <typename T>
609 /// T __mlir_math_ctlz_*(T x) {
610 /// bits = sizeof(x) * 8;
615 /// for (int i = 1; i < bits; ++i) {
616 /// if (x < 0) continue;
623 /// Converts to (for i32):
625 /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
626 /// %c_32 = arith.constant 32 : index
627 /// %c_0 = arith.constant 0 : i32
628 /// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
629 /// %out = scf.if %arg_eq_zero {
630 /// scf.yield %c_32 : i32
632 /// %c_1index = arith.constant 1 : index
633 /// %c_1i32 = arith.constant 1 : i32
634 /// %n = arith.constant 0 : i32
635 /// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
636 /// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
637 /// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
638 /// %yield_val = scf.if %cond {
639 /// scf.yield %arg_iter, %n_iter : i32, i32
641 /// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
642 /// %n_next = arith.addi %n_iter, %c_1i32 : i32
643 /// scf.yield %arg_next, %n_next : i32, i32
645 /// scf.yield %yield_val: i32, i32
647 /// scf.yield %n_out : i32
651 static func::FuncOp
createCtlzFunc(ModuleOp
*module
, Type elementType
) {
652 if (!isa
<IntegerType
>(elementType
)) {
654 DBGS() << "non-integer element type for CtlzFunc; type was: ";
655 elementType
.print(llvm::dbgs());
657 llvm_unreachable("non-integer element type");
659 int64_t bitWidth
= elementType
.getIntOrFloatBitWidth();
661 Location loc
= module
->getLoc();
662 ImplicitLocOpBuilder builder
=
663 ImplicitLocOpBuilder::atBlockEnd(loc
, module
->getBody());
665 std::string
funcName("__mlir_math_ctlz");
666 llvm::raw_string_ostream
nameOS(funcName
);
667 nameOS
<< '_' << elementType
;
668 FunctionType funcType
=
669 FunctionType::get(builder
.getContext(), {elementType
}, elementType
);
670 auto funcOp
= builder
.create
<func::FuncOp
>(funcName
, funcType
);
672 // LinkonceODR ensures that there is only one implementation of this function
673 // across all math.ctlz functions that are lowered in this way.
674 LLVM::linkage::Linkage inlineLinkage
= LLVM::linkage::Linkage::LinkonceODR
;
676 LLVM::LinkageAttr::get(builder
.getContext(), inlineLinkage
);
677 funcOp
->setAttr("llvm.linkage", linkage
);
680 // set the insertion point to the start of the function
681 Block
*funcBody
= funcOp
.addEntryBlock();
682 builder
.setInsertionPointToStart(funcBody
);
684 Value arg
= funcOp
.getArgument(0);
685 Type indexType
= builder
.getIndexType();
686 Value bitWidthValue
= builder
.create
<arith::ConstantOp
>(
687 elementType
, builder
.getIntegerAttr(elementType
, bitWidth
));
688 Value zeroValue
= builder
.create
<arith::ConstantOp
>(
689 elementType
, builder
.getIntegerAttr(elementType
, 0));
692 builder
.create
<arith::CmpIOp
>(arith::CmpIPredicate::eq
, arg
, zeroValue
);
694 // if input == 0, return bit width, else enter loop.
695 scf::IfOp ifOp
= builder
.create
<scf::IfOp
>(
696 elementType
, inputEqZero
, /*addThenBlock=*/true, /*addElseBlock=*/true);
697 ifOp
.getThenBodyBuilder().create
<scf::YieldOp
>(loc
, bitWidthValue
);
700 ImplicitLocOpBuilder::atBlockEnd(loc
, &ifOp
.getElseRegion().front());
702 Value oneIndex
= elseBuilder
.create
<arith::ConstantOp
>(
703 indexType
, elseBuilder
.getIndexAttr(1));
704 Value oneValue
= elseBuilder
.create
<arith::ConstantOp
>(
705 elementType
, elseBuilder
.getIntegerAttr(elementType
, 1));
706 Value bitWidthIndex
= elseBuilder
.create
<arith::ConstantOp
>(
707 indexType
, elseBuilder
.getIndexAttr(bitWidth
));
708 Value nValue
= elseBuilder
.create
<arith::ConstantOp
>(
709 elementType
, elseBuilder
.getIntegerAttr(elementType
, 0));
711 auto loop
= elseBuilder
.create
<scf::ForOp
>(
712 oneIndex
, bitWidthIndex
, oneIndex
,
713 // Initial values for two loop induction variables, the arg which is being
714 // shifted left in each iteration, and the n value which tracks the count
716 ValueRange
{arg
, nValue
},
717 // Callback to build the body of the for loop
724 [&](OpBuilder
&b
, Location loc
, Value iv
, ValueRange args
) {
725 Value argIter
= args
[0];
726 Value nIter
= args
[1];
728 Value argIsNonNegative
= b
.create
<arith::CmpIOp
>(
729 loc
, arith::CmpIPredicate::slt
, argIter
, zeroValue
);
730 scf::IfOp ifOp
= b
.create
<scf::IfOp
>(
731 loc
, argIsNonNegative
,
732 [&](OpBuilder
&b
, Location loc
) {
733 // If arg is negative, continue (effectively, break)
734 b
.create
<scf::YieldOp
>(loc
, ValueRange
{argIter
, nIter
});
736 [&](OpBuilder
&b
, Location loc
) {
737 // Otherwise, increment n and shift arg left.
738 Value nNext
= b
.create
<arith::AddIOp
>(loc
, nIter
, oneValue
);
739 Value argNext
= b
.create
<arith::ShLIOp
>(loc
, argIter
, oneValue
);
740 b
.create
<scf::YieldOp
>(loc
, ValueRange
{argNext
, nNext
});
742 b
.create
<scf::YieldOp
>(loc
, ifOp
.getResults());
744 elseBuilder
.create
<scf::YieldOp
>(loop
.getResult(1));
746 builder
.create
<func::ReturnOp
>(ifOp
.getResult(0));
750 /// Convert ctlz into a call to a local function implementing the ctlz
752 LogicalResult
CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op
,
753 PatternRewriter
&rewriter
) const {
754 if (dyn_cast
<VectorType
>(op
.getType()))
755 return rewriter
.notifyMatchFailure(op
, "non-scalar operation");
757 Type type
= getElementTypeOrSelf(op
.getResult().getType());
758 func::FuncOp elementFunc
= getFuncOpCallback(op
, type
);
760 return rewriter
.notifyMatchFailure(op
, [&](::mlir::Diagnostic
&diag
) {
761 diag
<< "Missing software implementation for op " << op
->getName()
762 << " and type " << type
;
765 rewriter
.replaceOpWithNewOp
<func::CallOp
>(op
, elementFunc
, op
.getOperand());
770 struct ConvertMathToFuncsPass
771 : public impl::ConvertMathToFuncsBase
<ConvertMathToFuncsPass
> {
772 ConvertMathToFuncsPass() = default;
773 ConvertMathToFuncsPass(const ConvertMathToFuncsOptions
&options
)
774 : impl::ConvertMathToFuncsBase
<ConvertMathToFuncsPass
>(options
) {}
776 void runOnOperation() override
;
779 // Return true, if this FPowI operation must be converted
780 // because the width of its exponent's type is greater than
781 // or equal to minWidthOfFPowIExponent option value.
782 bool isFPowIConvertible(math::FPowIOp op
);
784 // Reture true, if operation is integer type.
785 bool isConvertible(Operation
*op
);
787 // Generate outlined implementations for power operations
788 // and store them in funcImpls map.
789 void generateOpImplementations();
791 // A map between pairs of (operation, type) deduced from operations that this
792 // pass will convert, and the corresponding outlined software implementations
793 // of these operations for the given type.
794 DenseMap
<std::pair
<OperationName
, Type
>, func::FuncOp
> funcImpls
;
798 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op
) {
800 dyn_cast
<IntegerType
>(getElementTypeOrSelf(op
.getRhs().getType()));
801 return (expTy
&& expTy
.getWidth() >= minWidthOfFPowIExponent
);
804 bool ConvertMathToFuncsPass::isConvertible(Operation
*op
) {
805 return isa
<IntegerType
>(getElementTypeOrSelf(op
->getResult(0).getType()));
808 void ConvertMathToFuncsPass::generateOpImplementations() {
809 ModuleOp module
= getOperation();
811 module
.walk([&](Operation
*op
) {
812 TypeSwitch
<Operation
*>(op
)
813 .Case
<math::CountLeadingZerosOp
>([&](math::CountLeadingZerosOp op
) {
814 if (!convertCtlz
|| !isConvertible(op
))
816 Type resultType
= getElementTypeOrSelf(op
.getResult().getType());
818 // Generate the software implementation of this operation,
819 // if it has not been generated yet.
820 auto key
= std::pair(op
->getName(), resultType
);
821 auto entry
= funcImpls
.try_emplace(key
, func::FuncOp
{});
823 entry
.first
->second
= createCtlzFunc(&module
, resultType
);
825 .Case
<math::IPowIOp
>([&](math::IPowIOp op
) {
826 if (!isConvertible(op
))
829 Type resultType
= getElementTypeOrSelf(op
.getResult().getType());
831 // Generate the software implementation of this operation,
832 // if it has not been generated yet.
833 auto key
= std::pair(op
->getName(), resultType
);
834 auto entry
= funcImpls
.try_emplace(key
, func::FuncOp
{});
836 entry
.first
->second
= createElementIPowIFunc(&module
, resultType
);
838 .Case
<math::FPowIOp
>([&](math::FPowIOp op
) {
839 if (!isFPowIConvertible(op
))
842 FunctionType funcType
= getElementalFuncTypeForOp(op
);
844 // Generate the software implementation of this operation,
845 // if it has not been generated yet.
846 // FPowI implementations are mapped via the FunctionType
847 // created from the operation's result and operands.
848 auto key
= std::pair(op
->getName(), funcType
);
849 auto entry
= funcImpls
.try_emplace(key
, func::FuncOp
{});
851 entry
.first
->second
= createElementFPowIFunc(&module
, funcType
);
856 void ConvertMathToFuncsPass::runOnOperation() {
857 ModuleOp module
= getOperation();
859 // Create outlined implementations for power operations.
860 generateOpImplementations();
862 RewritePatternSet
patterns(&getContext());
863 patterns
.add
<VecOpToScalarOp
<math::IPowIOp
>, VecOpToScalarOp
<math::FPowIOp
>,
864 VecOpToScalarOp
<math::CountLeadingZerosOp
>>(
865 patterns
.getContext());
867 // For the given Type Returns FuncOp stored in funcImpls map.
868 auto getFuncOpByType
= [&](Operation
*op
, Type type
) -> func::FuncOp
{
869 auto it
= funcImpls
.find(std::pair(op
->getName(), type
));
870 if (it
== funcImpls
.end())
875 patterns
.add
<IPowIOpLowering
, FPowIOpLowering
>(patterns
.getContext(),
879 patterns
.add
<CtlzOpLowering
>(patterns
.getContext(), getFuncOpByType
);
881 ConversionTarget
target(getContext());
882 target
.addLegalDialect
<arith::ArithDialect
, cf::ControlFlowDialect
,
883 func::FuncDialect
, scf::SCFDialect
,
884 vector::VectorDialect
>();
886 target
.addDynamicallyLegalOp
<math::IPowIOp
>(
887 [this](math::IPowIOp op
) { return !isConvertible(op
); });
889 target
.addDynamicallyLegalOp
<math::CountLeadingZerosOp
>(
890 [this](math::CountLeadingZerosOp op
) { return !isConvertible(op
); });
892 target
.addDynamicallyLegalOp
<math::FPowIOp
>(
893 [this](math::FPowIOp op
) { return !isFPowIConvertible(op
); });
894 if (failed(applyPartialConversion(module
, target
, std::move(patterns
))))