Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / MathToFuncs / MathToFuncs.cpp
blobdf5396ac628cf678950759db03a6e00852906a01
1 //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Utils/IndexingUtils.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
33 using namespace mlir;
35 #define DEBUG_TYPE "math-to-funcs"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38 namespace {
39 // Pattern to convert vector operations to scalar operations.
40 template <typename Op>
41 struct VecOpToScalarOp : public OpRewritePattern<Op> {
42 public:
43 using OpRewritePattern<Op>::OpRewritePattern;
45 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
48 // Callback type for getting pre-generated FuncOp implementing
49 // an operation of the given type.
50 using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
52 // Pattern to convert scalar IPowIOp into a call of outlined
53 // software implementation.
54 class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
55 public:
56 IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
57 : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
59 /// Convert IPowI into a call to a local function implementing
60 /// the power operation. The local function computes a scalar result,
61 /// so vector forms of IPowI are linearized.
62 LogicalResult matchAndRewrite(math::IPowIOp op,
63 PatternRewriter &rewriter) const final;
65 private:
66 GetFuncCallbackTy getFuncOpCallback;
69 // Pattern to convert scalar FPowIOp into a call of outlined
70 // software implementation.
71 class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
72 public:
73 FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
74 : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
76 /// Convert FPowI into a call to a local function implementing
77 /// the power operation. The local function computes a scalar result,
78 /// so vector forms of FPowI are linearized.
79 LogicalResult matchAndRewrite(math::FPowIOp op,
80 PatternRewriter &rewriter) const final;
82 private:
83 GetFuncCallbackTy getFuncOpCallback;
86 // Pattern to convert scalar ctlz into a call of outlined software
87 // implementation.
88 class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
89 public:
90 CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
91 : OpRewritePattern<math::CountLeadingZerosOp>(context),
92 getFuncOpCallback(cb) {}
94 /// Convert ctlz into a call to a local function implementing
95 /// the count leading zeros operation.
96 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
97 PatternRewriter &rewriter) const final;
99 private:
100 GetFuncCallbackTy getFuncOpCallback;
102 } // namespace
104 template <typename Op>
105 LogicalResult
106 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
107 Type opType = op.getType();
108 Location loc = op.getLoc();
109 auto vecType = dyn_cast<VectorType>(opType);
111 if (!vecType)
112 return rewriter.notifyMatchFailure(op, "not a vector operation");
113 if (!vecType.hasRank())
114 return rewriter.notifyMatchFailure(op, "unknown vector rank");
115 ArrayRef<int64_t> shape = vecType.getShape();
116 int64_t numElements = vecType.getNumElements();
118 Type resultElementType = vecType.getElementType();
119 Attribute initValueAttr;
120 if (isa<FloatType>(resultElementType))
121 initValueAttr = FloatAttr::get(resultElementType, 0.0);
122 else
123 initValueAttr = IntegerAttr::get(resultElementType, 0);
124 Value result = rewriter.create<arith::ConstantOp>(
125 loc, DenseElementsAttr::get(vecType, initValueAttr));
126 SmallVector<int64_t> strides = computeStrides(shape);
127 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
128 SmallVector<int64_t> positions = delinearize(linearIndex, strides);
129 SmallVector<Value> operands;
130 for (Value input : op->getOperands())
131 operands.push_back(
132 rewriter.create<vector::ExtractOp>(loc, input, positions));
133 Value scalarOp =
134 rewriter.create<Op>(loc, vecType.getElementType(), operands);
135 result =
136 rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
138 rewriter.replaceOp(op, result);
139 return success();
142 static FunctionType getElementalFuncTypeForOp(Operation *op) {
143 SmallVector<Type, 1> resultTys(op->getNumResults());
144 SmallVector<Type, 2> inputTys(op->getNumOperands());
145 std::transform(op->result_type_begin(), op->result_type_end(),
146 resultTys.begin(),
147 [](Type ty) { return getElementTypeOrSelf(ty); });
148 std::transform(op->operand_type_begin(), op->operand_type_end(),
149 inputTys.begin(),
150 [](Type ty) { return getElementTypeOrSelf(ty); });
151 return FunctionType::get(op->getContext(), inputTys, resultTys);
154 /// Create linkonce_odr function to implement the power function with
155 /// the given \p elementType type inside \p module. The \p elementType
156 /// must be IntegerType, an the created function has
157 /// 'IntegerType (*)(IntegerType, IntegerType)' function type.
159 /// template <typename T>
160 /// T __mlir_math_ipowi_*(T b, T p) {
161 /// if (p == T(0))
162 /// return T(1);
163 /// if (p < T(0)) {
164 /// if (b == T(0))
165 /// return T(1) / T(0); // trigger div-by-zero
166 /// if (b == T(1))
167 /// return T(1);
168 /// if (b == T(-1)) {
169 /// if (p & T(1))
170 /// return T(-1);
171 /// return T(1);
172 /// }
173 /// return T(0);
174 /// }
175 /// T result = T(1);
176 /// while (true) {
177 /// if (p & T(1))
178 /// result *= b;
179 /// p >>= T(1);
180 /// if (p == T(0))
181 /// return result;
182 /// b *= b;
183 /// }
184 /// }
185 static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
186 assert(isa<IntegerType>(elementType) &&
187 "non-integer element type for IPowIOp");
189 ImplicitLocOpBuilder builder =
190 ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
192 std::string funcName("__mlir_math_ipowi");
193 llvm::raw_string_ostream nameOS(funcName);
194 nameOS << '_' << elementType;
196 FunctionType funcType = FunctionType::get(
197 builder.getContext(), {elementType, elementType}, elementType);
198 auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
199 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
200 Attribute linkage =
201 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
202 funcOp->setAttr("llvm.linkage", linkage);
203 funcOp.setPrivate();
205 Block *entryBlock = funcOp.addEntryBlock();
206 Region *funcBody = entryBlock->getParent();
208 Value bArg = funcOp.getArgument(0);
209 Value pArg = funcOp.getArgument(1);
210 builder.setInsertionPointToEnd(entryBlock);
211 Value zeroValue = builder.create<arith::ConstantOp>(
212 elementType, builder.getIntegerAttr(elementType, 0));
213 Value oneValue = builder.create<arith::ConstantOp>(
214 elementType, builder.getIntegerAttr(elementType, 1));
215 Value minusOneValue = builder.create<arith::ConstantOp>(
216 elementType,
217 builder.getIntegerAttr(elementType,
218 APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
219 /*isSigned=*/true)));
221 // if (p == T(0))
222 // return T(1);
223 auto pIsZero =
224 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
225 Block *thenBlock = builder.createBlock(funcBody);
226 builder.create<func::ReturnOp>(oneValue);
227 Block *fallthroughBlock = builder.createBlock(funcBody);
228 // Set up conditional branch for (p == T(0)).
229 builder.setInsertionPointToEnd(pIsZero->getBlock());
230 builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
232 // if (p < T(0)) {
233 builder.setInsertionPointToEnd(fallthroughBlock);
234 auto pIsNeg =
235 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
236 // if (b == T(0))
237 builder.createBlock(funcBody);
238 auto bIsZero =
239 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
240 // return T(1) / T(0);
241 thenBlock = builder.createBlock(funcBody);
242 builder.create<func::ReturnOp>(
243 builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
244 fallthroughBlock = builder.createBlock(funcBody);
245 // Set up conditional branch for (b == T(0)).
246 builder.setInsertionPointToEnd(bIsZero->getBlock());
247 builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
249 // if (b == T(1))
250 builder.setInsertionPointToEnd(fallthroughBlock);
251 auto bIsOne =
252 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
253 // return T(1);
254 thenBlock = builder.createBlock(funcBody);
255 builder.create<func::ReturnOp>(oneValue);
256 fallthroughBlock = builder.createBlock(funcBody);
257 // Set up conditional branch for (b == T(1)).
258 builder.setInsertionPointToEnd(bIsOne->getBlock());
259 builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
261 // if (b == T(-1)) {
262 builder.setInsertionPointToEnd(fallthroughBlock);
263 auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
264 bArg, minusOneValue);
265 // if (p & T(1))
266 builder.createBlock(funcBody);
267 auto pIsOdd = builder.create<arith::CmpIOp>(
268 arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
269 zeroValue);
270 // return T(-1);
271 thenBlock = builder.createBlock(funcBody);
272 builder.create<func::ReturnOp>(minusOneValue);
273 fallthroughBlock = builder.createBlock(funcBody);
274 // Set up conditional branch for (p & T(1)).
275 builder.setInsertionPointToEnd(pIsOdd->getBlock());
276 builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
278 // return T(1);
279 // } // b == T(-1)
280 builder.setInsertionPointToEnd(fallthroughBlock);
281 builder.create<func::ReturnOp>(oneValue);
282 fallthroughBlock = builder.createBlock(funcBody);
283 // Set up conditional branch for (b == T(-1)).
284 builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
285 builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
286 fallthroughBlock);
288 // return T(0);
289 // } // (p < T(0))
290 builder.setInsertionPointToEnd(fallthroughBlock);
291 builder.create<func::ReturnOp>(zeroValue);
292 Block *loopHeader = builder.createBlock(
293 funcBody, funcBody->end(), {elementType, elementType, elementType},
294 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
295 // Set up conditional branch for (p < T(0)).
296 builder.setInsertionPointToEnd(pIsNeg->getBlock());
297 // Set initial values of 'result', 'b' and 'p' for the loop.
298 builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
299 ValueRange{oneValue, bArg, pArg});
301 // T result = T(1);
302 // while (true) {
303 // if (p & T(1))
304 // result *= b;
305 // p >>= T(1);
306 // if (p == T(0))
307 // return result;
308 // b *= b;
309 // }
310 Value resultTmp = loopHeader->getArgument(0);
311 Value baseTmp = loopHeader->getArgument(1);
312 Value powerTmp = loopHeader->getArgument(2);
313 builder.setInsertionPointToEnd(loopHeader);
315 // if (p & T(1))
316 auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
317 arith::CmpIPredicate::ne,
318 builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
319 thenBlock = builder.createBlock(funcBody);
320 // result *= b;
321 Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
322 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
323 builder.getLoc());
324 builder.setInsertionPointToEnd(thenBlock);
325 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
326 // Set up conditional branch for (p & T(1)).
327 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
328 builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
329 resultTmp);
330 // Merged 'result'.
331 newResultTmp = fallthroughBlock->getArgument(0);
333 // p >>= T(1);
334 builder.setInsertionPointToEnd(fallthroughBlock);
335 Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
337 // if (p == T(0))
338 auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
339 newPowerTmp, zeroValue);
340 // return result;
341 thenBlock = builder.createBlock(funcBody);
342 builder.create<func::ReturnOp>(newResultTmp);
343 fallthroughBlock = builder.createBlock(funcBody);
344 // Set up conditional branch for (p == T(0)).
345 builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
346 builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
348 // b *= b;
349 // }
350 builder.setInsertionPointToEnd(fallthroughBlock);
351 Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
352 // Pass new values for 'result', 'b' and 'p' to the loop header.
353 builder.create<cf::BranchOp>(
354 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
355 return funcOp;
358 /// Convert IPowI into a call to a local function implementing
359 /// the power operation. The local function computes a scalar result,
360 /// so vector forms of IPowI are linearized.
361 LogicalResult
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
363 PatternRewriter &rewriter) const {
364 auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
366 if (!baseType)
367 return rewriter.notifyMatchFailure(op, "non-integer base operand");
369 // The outlined software implementation must have been already
370 // generated.
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372 if (!elementFunc)
373 return rewriter.notifyMatchFailure(op, "missing software implementation");
375 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
376 return success();
379 /// Create linkonce_odr function to implement the power function with
380 /// the given \p funcType type inside \p module. The \p funcType must be
381 /// 'FloatType (*)(FloatType, IntegerType)' function type.
383 /// template <typename T>
384 /// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
385 /// if (p == Tp{0})
386 /// return Tb{1};
387 /// bool isNegativePower{p < Tp{0}}
388 /// bool isMin{p == std::numeric_limits<Tp>::min()};
389 /// if (isMin) {
390 /// p = std::numeric_limits<Tp>::max();
391 /// } else if (isNegativePower) {
392 /// p = -p;
393 /// }
394 /// Tb result = Tb{1};
395 /// Tb origBase = Tb{b};
396 /// while (true) {
397 /// if (p & Tp{1})
398 /// result *= b;
399 /// p >>= Tp{1};
400 /// if (p == Tp{0})
401 /// break;
402 /// b *= b;
403 /// }
404 /// if (isMin) {
405 /// result *= origBase;
406 /// }
407 /// if (isNegativePower) {
408 /// result = Tb{1} / result;
409 /// }
410 /// return result;
411 /// }
412 static func::FuncOp createElementFPowIFunc(ModuleOp *module,
413 FunctionType funcType) {
414 auto baseType = cast<FloatType>(funcType.getInput(0));
415 auto powType = cast<IntegerType>(funcType.getInput(1));
416 ImplicitLocOpBuilder builder =
417 ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
419 std::string funcName("__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS << '_' << baseType;
422 nameOS << '_' << powType;
423 auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
425 Attribute linkage =
426 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
427 funcOp->setAttr("llvm.linkage", linkage);
428 funcOp.setPrivate();
430 Block *entryBlock = funcOp.addEntryBlock();
431 Region *funcBody = entryBlock->getParent();
433 Value bArg = funcOp.getArgument(0);
434 Value pArg = funcOp.getArgument(1);
435 builder.setInsertionPointToEnd(entryBlock);
436 Value oneBValue = builder.create<arith::ConstantOp>(
437 baseType, builder.getFloatAttr(baseType, 1.0));
438 Value zeroPValue = builder.create<arith::ConstantOp>(
439 powType, builder.getIntegerAttr(powType, 0));
440 Value onePValue = builder.create<arith::ConstantOp>(
441 powType, builder.getIntegerAttr(powType, 1));
442 Value minPValue = builder.create<arith::ConstantOp>(
443 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
444 powType.getWidth())));
445 Value maxPValue = builder.create<arith::ConstantOp>(
446 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
447 powType.getWidth())));
449 // if (p == Tp{0})
450 // return Tb{1};
451 auto pIsZero =
452 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
453 Block *thenBlock = builder.createBlock(funcBody);
454 builder.create<func::ReturnOp>(oneBValue);
455 Block *fallthroughBlock = builder.createBlock(funcBody);
456 // Set up conditional branch for (p == Tp{0}).
457 builder.setInsertionPointToEnd(pIsZero->getBlock());
458 builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
460 builder.setInsertionPointToEnd(fallthroughBlock);
461 // bool isNegativePower{p < Tp{0}}
462 auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
463 zeroPValue);
464 // bool isMin{p == std::numeric_limits<Tp>::min()};
465 auto pIsMin =
466 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
468 // if (isMin) {
469 // p = std::numeric_limits<Tp>::max();
470 // } else if (isNegativePower) {
471 // p = -p;
472 // }
473 Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
474 auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
475 pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
477 // Tb result = Tb{1};
478 // Tb origBase = Tb{b};
479 // while (true) {
480 // if (p & Tp{1})
481 // result *= b;
482 // p >>= Tp{1};
483 // if (p == Tp{0})
484 // break;
485 // b *= b;
486 // }
487 Block *loopHeader = builder.createBlock(
488 funcBody, funcBody->end(), {baseType, baseType, powType},
489 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
490 // Set initial values of 'result', 'b' and 'p' for the loop.
491 builder.setInsertionPointToEnd(pInit->getBlock());
492 builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
494 // Create loop body.
495 Value resultTmp = loopHeader->getArgument(0);
496 Value baseTmp = loopHeader->getArgument(1);
497 Value powerTmp = loopHeader->getArgument(2);
498 builder.setInsertionPointToEnd(loopHeader);
500 // if (p & Tp{1})
501 auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
502 arith::CmpIPredicate::ne,
503 builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
504 thenBlock = builder.createBlock(funcBody);
505 // result *= b;
506 Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
507 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
508 builder.getLoc());
509 builder.setInsertionPointToEnd(thenBlock);
510 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
511 // Set up conditional branch for (p & Tp{1}).
512 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
513 builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
514 resultTmp);
515 // Merged 'result'.
516 newResultTmp = fallthroughBlock->getArgument(0);
518 // p >>= Tp{1};
519 builder.setInsertionPointToEnd(fallthroughBlock);
520 Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
522 // if (p == Tp{0})
523 auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
524 newPowerTmp, zeroPValue);
525 // break;
527 // The conditional branch is finalized below with a jump to
528 // the loop exit block.
529 fallthroughBlock = builder.createBlock(funcBody);
531 // b *= b;
532 // }
533 builder.setInsertionPointToEnd(fallthroughBlock);
534 Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
535 // Pass new values for 'result', 'b' and 'p' to the loop header.
536 builder.create<cf::BranchOp>(
537 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
539 // Set up conditional branch for early loop exit:
540 // if (p == Tp{0})
541 // break;
542 Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
543 builder.getLoc());
544 builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
545 builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
546 fallthroughBlock, ValueRange{});
548 // if (isMin) {
549 // result *= origBase;
550 // }
551 newResultTmp = loopExit->getArgument(0);
552 thenBlock = builder.createBlock(funcBody);
553 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
554 builder.getLoc());
555 builder.setInsertionPointToEnd(loopExit);
556 builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
557 newResultTmp);
558 builder.setInsertionPointToEnd(thenBlock);
559 newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
560 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
562 /// if (isNegativePower) {
563 /// result = Tb{1} / result;
564 /// }
565 newResultTmp = fallthroughBlock->getArgument(0);
566 thenBlock = builder.createBlock(funcBody);
567 Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
568 builder.getLoc());
569 builder.setInsertionPointToEnd(fallthroughBlock);
570 builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
571 newResultTmp);
572 builder.setInsertionPointToEnd(thenBlock);
573 newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
574 builder.create<cf::BranchOp>(newResultTmp, returnBlock);
576 // return result;
577 builder.setInsertionPointToEnd(returnBlock);
578 builder.create<func::ReturnOp>(returnBlock->getArgument(0));
580 return funcOp;
583 /// Convert FPowI into a call to a local function implementing
584 /// the power operation. The local function computes a scalar result,
585 /// so vector forms of FPowI are linearized.
586 LogicalResult
587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
588 PatternRewriter &rewriter) const {
589 if (dyn_cast<VectorType>(op.getType()))
590 return rewriter.notifyMatchFailure(op, "non-scalar operation");
592 FunctionType funcType = getElementalFuncTypeForOp(op);
594 // The outlined software implementation must have been already
595 // generated.
596 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
597 if (!elementFunc)
598 return rewriter.notifyMatchFailure(op, "missing software implementation");
600 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
601 return success();
604 /// Create function to implement the ctlz function the given \p elementType type
605 /// inside \p module. The \p elementType must be IntegerType, an the created
606 /// function has 'IntegerType (*)(IntegerType)' function type.
608 /// template <typename T>
609 /// T __mlir_math_ctlz_*(T x) {
610 /// bits = sizeof(x) * 8;
611 /// if (x == 0)
612 /// return bits;
614 /// uint32_t n = 0;
615 /// for (int i = 1; i < bits; ++i) {
616 /// if (x < 0) continue;
617 /// n++;
618 /// x <<= 1;
619 /// }
620 /// return n;
621 /// }
623 /// Converts to (for i32):
625 /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
626 /// %c_32 = arith.constant 32 : index
627 /// %c_0 = arith.constant 0 : i32
628 /// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
629 /// %out = scf.if %arg_eq_zero {
630 /// scf.yield %c_32 : i32
631 /// } else {
632 /// %c_1index = arith.constant 1 : index
633 /// %c_1i32 = arith.constant 1 : i32
634 /// %n = arith.constant 0 : i32
635 /// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
636 /// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
637 /// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
638 /// %yield_val = scf.if %cond {
639 /// scf.yield %arg_iter, %n_iter : i32, i32
640 /// } else {
641 /// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
642 /// %n_next = arith.addi %n_iter, %c_1i32 : i32
643 /// scf.yield %arg_next, %n_next : i32, i32
644 /// }
645 /// scf.yield %yield_val: i32, i32
646 /// }
647 /// scf.yield %n_out : i32
648 /// }
649 /// return %out: i32
650 /// }
651 static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
652 if (!isa<IntegerType>(elementType)) {
653 LLVM_DEBUG({
654 DBGS() << "non-integer element type for CtlzFunc; type was: ";
655 elementType.print(llvm::dbgs());
657 llvm_unreachable("non-integer element type");
659 int64_t bitWidth = elementType.getIntOrFloatBitWidth();
661 Location loc = module->getLoc();
662 ImplicitLocOpBuilder builder =
663 ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody());
665 std::string funcName("__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS << '_' << elementType;
668 FunctionType funcType =
669 FunctionType::get(builder.getContext(), {elementType}, elementType);
670 auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
672 // LinkonceODR ensures that there is only one implementation of this function
673 // across all math.ctlz functions that are lowered in this way.
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
675 Attribute linkage =
676 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
677 funcOp->setAttr("llvm.linkage", linkage);
678 funcOp.setPrivate();
680 // set the insertion point to the start of the function
681 Block *funcBody = funcOp.addEntryBlock();
682 builder.setInsertionPointToStart(funcBody);
684 Value arg = funcOp.getArgument(0);
685 Type indexType = builder.getIndexType();
686 Value bitWidthValue = builder.create<arith::ConstantOp>(
687 elementType, builder.getIntegerAttr(elementType, bitWidth));
688 Value zeroValue = builder.create<arith::ConstantOp>(
689 elementType, builder.getIntegerAttr(elementType, 0));
691 Value inputEqZero =
692 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
694 // if input == 0, return bit width, else enter loop.
695 scf::IfOp ifOp = builder.create<scf::IfOp>(
696 elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
697 ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
699 auto elseBuilder =
700 ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
702 Value oneIndex = elseBuilder.create<arith::ConstantOp>(
703 indexType, elseBuilder.getIndexAttr(1));
704 Value oneValue = elseBuilder.create<arith::ConstantOp>(
705 elementType, elseBuilder.getIntegerAttr(elementType, 1));
706 Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
707 indexType, elseBuilder.getIndexAttr(bitWidth));
708 Value nValue = elseBuilder.create<arith::ConstantOp>(
709 elementType, elseBuilder.getIntegerAttr(elementType, 0));
711 auto loop = elseBuilder.create<scf::ForOp>(
712 oneIndex, bitWidthIndex, oneIndex,
713 // Initial values for two loop induction variables, the arg which is being
714 // shifted left in each iteration, and the n value which tracks the count
715 // of leading zeros.
716 ValueRange{arg, nValue},
717 // Callback to build the body of the for loop
718 // if (arg < 0) {
719 // continue;
720 // } else {
721 // n++;
722 // arg <<= 1;
723 // }
724 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
725 Value argIter = args[0];
726 Value nIter = args[1];
728 Value argIsNonNegative = b.create<arith::CmpIOp>(
729 loc, arith::CmpIPredicate::slt, argIter, zeroValue);
730 scf::IfOp ifOp = b.create<scf::IfOp>(
731 loc, argIsNonNegative,
732 [&](OpBuilder &b, Location loc) {
733 // If arg is negative, continue (effectively, break)
734 b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
736 [&](OpBuilder &b, Location loc) {
737 // Otherwise, increment n and shift arg left.
738 Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
739 Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
740 b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
742 b.create<scf::YieldOp>(loc, ifOp.getResults());
744 elseBuilder.create<scf::YieldOp>(loop.getResult(1));
746 builder.create<func::ReturnOp>(ifOp.getResult(0));
747 return funcOp;
750 /// Convert ctlz into a call to a local function implementing the ctlz
751 /// operation.
752 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
753 PatternRewriter &rewriter) const {
754 if (dyn_cast<VectorType>(op.getType()))
755 return rewriter.notifyMatchFailure(op, "non-scalar operation");
757 Type type = getElementTypeOrSelf(op.getResult().getType());
758 func::FuncOp elementFunc = getFuncOpCallback(op, type);
759 if (!elementFunc)
760 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
761 diag << "Missing software implementation for op " << op->getName()
762 << " and type " << type;
765 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
766 return success();
769 namespace {
770 struct ConvertMathToFuncsPass
771 : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
772 ConvertMathToFuncsPass() = default;
773 ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
774 : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
776 void runOnOperation() override;
778 private:
779 // Return true, if this FPowI operation must be converted
780 // because the width of its exponent's type is greater than
781 // or equal to minWidthOfFPowIExponent option value.
782 bool isFPowIConvertible(math::FPowIOp op);
784 // Reture true, if operation is integer type.
785 bool isConvertible(Operation *op);
787 // Generate outlined implementations for power operations
788 // and store them in funcImpls map.
789 void generateOpImplementations();
791 // A map between pairs of (operation, type) deduced from operations that this
792 // pass will convert, and the corresponding outlined software implementations
793 // of these operations for the given type.
794 DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
796 } // namespace
798 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
799 auto expTy =
800 dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
801 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
804 bool ConvertMathToFuncsPass::isConvertible(Operation *op) {
805 return isa<IntegerType>(getElementTypeOrSelf(op->getResult(0).getType()));
808 void ConvertMathToFuncsPass::generateOpImplementations() {
809 ModuleOp module = getOperation();
811 module.walk([&](Operation *op) {
812 TypeSwitch<Operation *>(op)
813 .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
814 if (!convertCtlz || !isConvertible(op))
815 return;
816 Type resultType = getElementTypeOrSelf(op.getResult().getType());
818 // Generate the software implementation of this operation,
819 // if it has not been generated yet.
820 auto key = std::pair(op->getName(), resultType);
821 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
822 if (entry.second)
823 entry.first->second = createCtlzFunc(&module, resultType);
825 .Case<math::IPowIOp>([&](math::IPowIOp op) {
826 if (!isConvertible(op))
827 return;
829 Type resultType = getElementTypeOrSelf(op.getResult().getType());
831 // Generate the software implementation of this operation,
832 // if it has not been generated yet.
833 auto key = std::pair(op->getName(), resultType);
834 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
835 if (entry.second)
836 entry.first->second = createElementIPowIFunc(&module, resultType);
838 .Case<math::FPowIOp>([&](math::FPowIOp op) {
839 if (!isFPowIConvertible(op))
840 return;
842 FunctionType funcType = getElementalFuncTypeForOp(op);
844 // Generate the software implementation of this operation,
845 // if it has not been generated yet.
846 // FPowI implementations are mapped via the FunctionType
847 // created from the operation's result and operands.
848 auto key = std::pair(op->getName(), funcType);
849 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
850 if (entry.second)
851 entry.first->second = createElementFPowIFunc(&module, funcType);
856 void ConvertMathToFuncsPass::runOnOperation() {
857 ModuleOp module = getOperation();
859 // Create outlined implementations for power operations.
860 generateOpImplementations();
862 RewritePatternSet patterns(&getContext());
863 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
864 VecOpToScalarOp<math::CountLeadingZerosOp>>(
865 patterns.getContext());
867 // For the given Type Returns FuncOp stored in funcImpls map.
868 auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
869 auto it = funcImpls.find(std::pair(op->getName(), type));
870 if (it == funcImpls.end())
871 return {};
873 return it->second;
875 patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
876 getFuncOpByType);
878 if (convertCtlz)
879 patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
881 ConversionTarget target(getContext());
882 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
883 func::FuncDialect, scf::SCFDialect,
884 vector::VectorDialect>();
886 target.addDynamicallyLegalOp<math::IPowIOp>(
887 [this](math::IPowIOp op) { return !isConvertible(op); });
888 if (convertCtlz) {
889 target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
890 [this](math::CountLeadingZerosOp op) { return !isConvertible(op); });
892 target.addDynamicallyLegalOp<math::FPowIOp>(
893 [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
894 if (failed(applyPartialConversion(module, target, std::move(patterns))))
895 signalPassFailure();