[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / Transforms / Utils / LibCallsShrinkWrap.cpp
blob6220f850930969d4323037407ed33505f102104c
1 //===-- LibCallsShrinkWrap.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass shrink-wraps a call to function if the result is not used.
10 // The call can set errno but is otherwise side effect free. For example:
11 // sqrt(val);
12 // is transformed to
13 // if (val < 0)
14 // sqrt(val);
15 // Even if the result of library call is not being used, the compiler cannot
16 // safely delete the call because the function can set errno on error
17 // conditions.
18 // Note in many functions, the error condition solely depends on the incoming
19 // parameter. In this optimization, we can generate the condition can lead to
20 // the errno to shrink-wrap the call. Since the chances of hitting the error
21 // condition is low, the runtime call is effectively eliminated.
23 // These partially dead calls are usually results of C++ abstraction penalty
24 // exposed by inlining.
26 //===----------------------------------------------------------------------===//
28 #include "llvm/Transforms/Utils/LibCallsShrinkWrap.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/Statistic.h"
31 #include "llvm/Analysis/DomTreeUpdater.h"
32 #include "llvm/Analysis/GlobalsModRef.h"
33 #include "llvm/Analysis/TargetLibraryInfo.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/Dominators.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/IRBuilder.h"
38 #include "llvm/IR/InstVisitor.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/MDBuilder.h"
41 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include <cmath>
45 using namespace llvm;
47 #define DEBUG_TYPE "libcalls-shrinkwrap"
49 STATISTIC(NumWrappedOneCond, "Number of One-Condition Wrappers Inserted");
50 STATISTIC(NumWrappedTwoCond, "Number of Two-Condition Wrappers Inserted");
52 namespace {
53 class LibCallsShrinkWrap : public InstVisitor<LibCallsShrinkWrap> {
54 public:
55 LibCallsShrinkWrap(const TargetLibraryInfo &TLI, DomTreeUpdater &DTU)
56 : TLI(TLI), DTU(DTU){};
57 void visitCallInst(CallInst &CI) { checkCandidate(CI); }
58 bool perform() {
59 bool Changed = false;
60 for (auto &CI : WorkList) {
61 LLVM_DEBUG(dbgs() << "CDCE calls: " << CI->getCalledFunction()->getName()
62 << "\n");
63 if (perform(CI)) {
64 Changed = true;
65 LLVM_DEBUG(dbgs() << "Transformed\n");
68 return Changed;
71 private:
72 bool perform(CallInst *CI);
73 void checkCandidate(CallInst &CI);
74 void shrinkWrapCI(CallInst *CI, Value *Cond);
75 bool performCallDomainErrorOnly(CallInst *CI, const LibFunc &Func);
76 bool performCallErrors(CallInst *CI, const LibFunc &Func);
77 bool performCallRangeErrorOnly(CallInst *CI, const LibFunc &Func);
78 Value *generateOneRangeCond(CallInst *CI, const LibFunc &Func);
79 Value *generateTwoRangeCond(CallInst *CI, const LibFunc &Func);
80 Value *generateCondForPow(CallInst *CI, const LibFunc &Func);
82 // Create an OR of two conditions with given Arg and Arg2.
83 Value *createOrCond(CallInst *CI, Value *Arg, CmpInst::Predicate Cmp,
84 float Val, Value *Arg2, CmpInst::Predicate Cmp2,
85 float Val2) {
86 IRBuilder<> BBBuilder(CI);
87 auto Cond2 = createCond(BBBuilder, Arg2, Cmp2, Val2);
88 auto Cond1 = createCond(BBBuilder, Arg, Cmp, Val);
89 return BBBuilder.CreateOr(Cond1, Cond2);
92 // Create an OR of two conditions.
93 Value *createOrCond(CallInst *CI, CmpInst::Predicate Cmp, float Val,
94 CmpInst::Predicate Cmp2, float Val2) {
95 Value *Arg = CI->getArgOperand(0);
96 return createOrCond(CI, Arg, Cmp, Val, Arg, Cmp2, Val2);
99 // Create a single condition using IRBuilder.
100 Value *createCond(IRBuilder<> &BBBuilder, Value *Arg, CmpInst::Predicate Cmp,
101 float Val) {
102 Constant *V = ConstantFP::get(BBBuilder.getContext(), APFloat(Val));
103 if (!Arg->getType()->isFloatTy())
104 V = ConstantFoldCastInstruction(Instruction::FPExt, V, Arg->getType());
105 if (BBBuilder.GetInsertBlock()->getParent()->hasFnAttribute(Attribute::StrictFP))
106 BBBuilder.setIsFPConstrained(true);
107 return BBBuilder.CreateFCmp(Cmp, Arg, V);
110 // Create a single condition with given Arg.
111 Value *createCond(CallInst *CI, Value *Arg, CmpInst::Predicate Cmp,
112 float Val) {
113 IRBuilder<> BBBuilder(CI);
114 return createCond(BBBuilder, Arg, Cmp, Val);
117 // Create a single condition.
118 Value *createCond(CallInst *CI, CmpInst::Predicate Cmp, float Val) {
119 Value *Arg = CI->getArgOperand(0);
120 return createCond(CI, Arg, Cmp, Val);
123 const TargetLibraryInfo &TLI;
124 DomTreeUpdater &DTU;
125 SmallVector<CallInst *, 16> WorkList;
127 } // end anonymous namespace
129 // Perform the transformation to calls with errno set by domain error.
130 bool LibCallsShrinkWrap::performCallDomainErrorOnly(CallInst *CI,
131 const LibFunc &Func) {
132 Value *Cond = nullptr;
134 switch (Func) {
135 case LibFunc_acos: // DomainError: (x < -1 || x > 1)
136 case LibFunc_acosf: // Same as acos
137 case LibFunc_acosl: // Same as acos
138 case LibFunc_asin: // DomainError: (x < -1 || x > 1)
139 case LibFunc_asinf: // Same as asin
140 case LibFunc_asinl: // Same as asin
142 ++NumWrappedTwoCond;
143 Cond = createOrCond(CI, CmpInst::FCMP_OLT, -1.0f, CmpInst::FCMP_OGT, 1.0f);
144 break;
146 case LibFunc_cos: // DomainError: (x == +inf || x == -inf)
147 case LibFunc_cosf: // Same as cos
148 case LibFunc_cosl: // Same as cos
149 case LibFunc_sin: // DomainError: (x == +inf || x == -inf)
150 case LibFunc_sinf: // Same as sin
151 case LibFunc_sinl: // Same as sin
153 ++NumWrappedTwoCond;
154 Cond = createOrCond(CI, CmpInst::FCMP_OEQ, INFINITY, CmpInst::FCMP_OEQ,
155 -INFINITY);
156 break;
158 case LibFunc_acosh: // DomainError: (x < 1)
159 case LibFunc_acoshf: // Same as acosh
160 case LibFunc_acoshl: // Same as acosh
162 ++NumWrappedOneCond;
163 Cond = createCond(CI, CmpInst::FCMP_OLT, 1.0f);
164 break;
166 case LibFunc_sqrt: // DomainError: (x < 0)
167 case LibFunc_sqrtf: // Same as sqrt
168 case LibFunc_sqrtl: // Same as sqrt
170 ++NumWrappedOneCond;
171 Cond = createCond(CI, CmpInst::FCMP_OLT, 0.0f);
172 break;
174 default:
175 return false;
177 shrinkWrapCI(CI, Cond);
178 return true;
181 // Perform the transformation to calls with errno set by range error.
182 bool LibCallsShrinkWrap::performCallRangeErrorOnly(CallInst *CI,
183 const LibFunc &Func) {
184 Value *Cond = nullptr;
186 switch (Func) {
187 case LibFunc_cosh:
188 case LibFunc_coshf:
189 case LibFunc_coshl:
190 case LibFunc_exp:
191 case LibFunc_expf:
192 case LibFunc_expl:
193 case LibFunc_exp10:
194 case LibFunc_exp10f:
195 case LibFunc_exp10l:
196 case LibFunc_exp2:
197 case LibFunc_exp2f:
198 case LibFunc_exp2l:
199 case LibFunc_sinh:
200 case LibFunc_sinhf:
201 case LibFunc_sinhl: {
202 Cond = generateTwoRangeCond(CI, Func);
203 break;
205 case LibFunc_expm1: // RangeError: (709, inf)
206 case LibFunc_expm1f: // RangeError: (88, inf)
207 case LibFunc_expm1l: // RangeError: (11356, inf)
209 Cond = generateOneRangeCond(CI, Func);
210 break;
212 default:
213 return false;
215 shrinkWrapCI(CI, Cond);
216 return true;
219 // Perform the transformation to calls with errno set by combination of errors.
220 bool LibCallsShrinkWrap::performCallErrors(CallInst *CI,
221 const LibFunc &Func) {
222 Value *Cond = nullptr;
224 switch (Func) {
225 case LibFunc_atanh: // DomainError: (x < -1 || x > 1)
226 // PoleError: (x == -1 || x == 1)
227 // Overall Cond: (x <= -1 || x >= 1)
228 case LibFunc_atanhf: // Same as atanh
229 case LibFunc_atanhl: // Same as atanh
231 ++NumWrappedTwoCond;
232 Cond = createOrCond(CI, CmpInst::FCMP_OLE, -1.0f, CmpInst::FCMP_OGE, 1.0f);
233 break;
235 case LibFunc_log: // DomainError: (x < 0)
236 // PoleError: (x == 0)
237 // Overall Cond: (x <= 0)
238 case LibFunc_logf: // Same as log
239 case LibFunc_logl: // Same as log
240 case LibFunc_log10: // Same as log
241 case LibFunc_log10f: // Same as log
242 case LibFunc_log10l: // Same as log
243 case LibFunc_log2: // Same as log
244 case LibFunc_log2f: // Same as log
245 case LibFunc_log2l: // Same as log
246 case LibFunc_logb: // Same as log
247 case LibFunc_logbf: // Same as log
248 case LibFunc_logbl: // Same as log
250 ++NumWrappedOneCond;
251 Cond = createCond(CI, CmpInst::FCMP_OLE, 0.0f);
252 break;
254 case LibFunc_log1p: // DomainError: (x < -1)
255 // PoleError: (x == -1)
256 // Overall Cond: (x <= -1)
257 case LibFunc_log1pf: // Same as log1p
258 case LibFunc_log1pl: // Same as log1p
260 ++NumWrappedOneCond;
261 Cond = createCond(CI, CmpInst::FCMP_OLE, -1.0f);
262 break;
264 case LibFunc_pow: // DomainError: x < 0 and y is noninteger
265 // PoleError: x == 0 and y < 0
266 // RangeError: overflow or underflow
267 case LibFunc_powf:
268 case LibFunc_powl: {
269 Cond = generateCondForPow(CI, Func);
270 if (Cond == nullptr)
271 return false;
272 break;
274 default:
275 return false;
277 assert(Cond && "performCallErrors should not see an empty condition");
278 shrinkWrapCI(CI, Cond);
279 return true;
282 // Checks if CI is a candidate for shrinkwrapping and put it into work list if
283 // true.
284 void LibCallsShrinkWrap::checkCandidate(CallInst &CI) {
285 if (CI.isNoBuiltin())
286 return;
287 // A possible improvement is to handle the calls with the return value being
288 // used. If there is API for fast libcall implementation without setting
289 // errno, we can use the same framework to direct/wrap the call to the fast
290 // API in the error free path, and leave the original call in the slow path.
291 if (!CI.use_empty())
292 return;
294 LibFunc Func;
295 Function *Callee = CI.getCalledFunction();
296 if (!Callee)
297 return;
298 if (!TLI.getLibFunc(*Callee, Func) || !TLI.has(Func))
299 return;
301 if (CI.arg_empty())
302 return;
303 // TODO: Handle long double in other formats.
304 Type *ArgType = CI.getArgOperand(0)->getType();
305 if (!(ArgType->isFloatTy() || ArgType->isDoubleTy() ||
306 ArgType->isX86_FP80Ty()))
307 return;
309 WorkList.push_back(&CI);
312 // Generate the upper bound condition for RangeError.
313 Value *LibCallsShrinkWrap::generateOneRangeCond(CallInst *CI,
314 const LibFunc &Func) {
315 float UpperBound;
316 switch (Func) {
317 case LibFunc_expm1: // RangeError: (709, inf)
318 UpperBound = 709.0f;
319 break;
320 case LibFunc_expm1f: // RangeError: (88, inf)
321 UpperBound = 88.0f;
322 break;
323 case LibFunc_expm1l: // RangeError: (11356, inf)
324 UpperBound = 11356.0f;
325 break;
326 default:
327 llvm_unreachable("Unhandled library call!");
330 ++NumWrappedOneCond;
331 return createCond(CI, CmpInst::FCMP_OGT, UpperBound);
334 // Generate the lower and upper bound condition for RangeError.
335 Value *LibCallsShrinkWrap::generateTwoRangeCond(CallInst *CI,
336 const LibFunc &Func) {
337 float UpperBound, LowerBound;
338 switch (Func) {
339 case LibFunc_cosh: // RangeError: (x < -710 || x > 710)
340 case LibFunc_sinh: // Same as cosh
341 LowerBound = -710.0f;
342 UpperBound = 710.0f;
343 break;
344 case LibFunc_coshf: // RangeError: (x < -89 || x > 89)
345 case LibFunc_sinhf: // Same as coshf
346 LowerBound = -89.0f;
347 UpperBound = 89.0f;
348 break;
349 case LibFunc_coshl: // RangeError: (x < -11357 || x > 11357)
350 case LibFunc_sinhl: // Same as coshl
351 LowerBound = -11357.0f;
352 UpperBound = 11357.0f;
353 break;
354 case LibFunc_exp: // RangeError: (x < -745 || x > 709)
355 LowerBound = -745.0f;
356 UpperBound = 709.0f;
357 break;
358 case LibFunc_expf: // RangeError: (x < -103 || x > 88)
359 LowerBound = -103.0f;
360 UpperBound = 88.0f;
361 break;
362 case LibFunc_expl: // RangeError: (x < -11399 || x > 11356)
363 LowerBound = -11399.0f;
364 UpperBound = 11356.0f;
365 break;
366 case LibFunc_exp10: // RangeError: (x < -323 || x > 308)
367 LowerBound = -323.0f;
368 UpperBound = 308.0f;
369 break;
370 case LibFunc_exp10f: // RangeError: (x < -45 || x > 38)
371 LowerBound = -45.0f;
372 UpperBound = 38.0f;
373 break;
374 case LibFunc_exp10l: // RangeError: (x < -4950 || x > 4932)
375 LowerBound = -4950.0f;
376 UpperBound = 4932.0f;
377 break;
378 case LibFunc_exp2: // RangeError: (x < -1074 || x > 1023)
379 LowerBound = -1074.0f;
380 UpperBound = 1023.0f;
381 break;
382 case LibFunc_exp2f: // RangeError: (x < -149 || x > 127)
383 LowerBound = -149.0f;
384 UpperBound = 127.0f;
385 break;
386 case LibFunc_exp2l: // RangeError: (x < -16445 || x > 11383)
387 LowerBound = -16445.0f;
388 UpperBound = 11383.0f;
389 break;
390 default:
391 llvm_unreachable("Unhandled library call!");
394 ++NumWrappedTwoCond;
395 return createOrCond(CI, CmpInst::FCMP_OGT, UpperBound, CmpInst::FCMP_OLT,
396 LowerBound);
399 // For pow(x,y), We only handle the following cases:
400 // (1) x is a constant && (x >= 1) && (x < MaxUInt8)
401 // Cond is: (y > 127)
402 // (2) x is a value coming from an integer type.
403 // (2.1) if x's bit_size == 8
404 // Cond: (x <= 0 || y > 128)
405 // (2.2) if x's bit_size is 16
406 // Cond: (x <= 0 || y > 64)
407 // (2.3) if x's bit_size is 32
408 // Cond: (x <= 0 || y > 32)
409 // Support for powl(x,y) and powf(x,y) are TBD.
411 // Note that condition can be more conservative than the actual condition
412 // (i.e. we might invoke the calls that will not set the errno.).
414 Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI,
415 const LibFunc &Func) {
416 // FIXME: LibFunc_powf and powl TBD.
417 if (Func != LibFunc_pow) {
418 LLVM_DEBUG(dbgs() << "Not handled powf() and powl()\n");
419 return nullptr;
422 Value *Base = CI->getArgOperand(0);
423 Value *Exp = CI->getArgOperand(1);
425 // Constant Base case.
426 if (ConstantFP *CF = dyn_cast<ConstantFP>(Base)) {
427 double D = CF->getValueAPF().convertToDouble();
428 if (D < 1.0f || D > APInt::getMaxValue(8).getZExtValue()) {
429 LLVM_DEBUG(dbgs() << "Not handled pow(): constant base out of range\n");
430 return nullptr;
433 ++NumWrappedOneCond;
434 return createCond(CI, Exp, CmpInst::FCMP_OGT, 127.0f);
437 // If the Base value coming from an integer type.
438 Instruction *I = dyn_cast<Instruction>(Base);
439 if (!I) {
440 LLVM_DEBUG(dbgs() << "Not handled pow(): FP type base\n");
441 return nullptr;
443 unsigned Opcode = I->getOpcode();
444 if (Opcode == Instruction::UIToFP || Opcode == Instruction::SIToFP) {
445 unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
446 float UpperV = 0.0f;
447 if (BW == 8)
448 UpperV = 128.0f;
449 else if (BW == 16)
450 UpperV = 64.0f;
451 else if (BW == 32)
452 UpperV = 32.0f;
453 else {
454 LLVM_DEBUG(dbgs() << "Not handled pow(): type too wide\n");
455 return nullptr;
458 ++NumWrappedTwoCond;
459 return createOrCond(CI, Base, CmpInst::FCMP_OLE, 0.0f, Exp,
460 CmpInst::FCMP_OGT, UpperV);
462 LLVM_DEBUG(dbgs() << "Not handled pow(): base not from integer convert\n");
463 return nullptr;
466 // Wrap conditions that can potentially generate errno to the library call.
467 void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) {
468 assert(Cond != nullptr && "ShrinkWrapCI is not expecting an empty call inst");
469 MDNode *BranchWeights =
470 MDBuilder(CI->getContext()).createBranchWeights(1, 2000);
472 Instruction *NewInst =
473 SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights, &DTU);
474 BasicBlock *CallBB = NewInst->getParent();
475 CallBB->setName("cdce.call");
476 BasicBlock *SuccBB = CallBB->getSingleSuccessor();
477 assert(SuccBB && "The split block should have a single successor");
478 SuccBB->setName("cdce.end");
479 CI->removeFromParent();
480 CI->insertInto(CallBB, CallBB->getFirstInsertionPt());
481 LLVM_DEBUG(dbgs() << "== Basic Block After ==");
482 LLVM_DEBUG(dbgs() << *CallBB->getSinglePredecessor() << *CallBB
483 << *CallBB->getSingleSuccessor() << "\n");
486 // Perform the transformation to a single candidate.
487 bool LibCallsShrinkWrap::perform(CallInst *CI) {
488 LibFunc Func;
489 Function *Callee = CI->getCalledFunction();
490 assert(Callee && "perform() should apply to a non-empty callee");
491 TLI.getLibFunc(*Callee, Func);
492 assert(Func && "perform() is not expecting an empty function");
494 if (performCallDomainErrorOnly(CI, Func) || performCallRangeErrorOnly(CI, Func))
495 return true;
496 return performCallErrors(CI, Func);
499 static bool runImpl(Function &F, const TargetLibraryInfo &TLI,
500 DominatorTree *DT) {
501 if (F.hasFnAttribute(Attribute::OptimizeForSize))
502 return false;
503 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
504 LibCallsShrinkWrap CCDCE(TLI, DTU);
505 CCDCE.visit(F);
506 bool Changed = CCDCE.perform();
508 // Verify the dominator after we've updated it locally.
509 assert(!DT ||
510 DTU.getDomTree().verify(DominatorTree::VerificationLevel::Fast));
511 return Changed;
514 PreservedAnalyses LibCallsShrinkWrapPass::run(Function &F,
515 FunctionAnalysisManager &FAM) {
516 auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
517 auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F);
518 if (!runImpl(F, TLI, DT))
519 return PreservedAnalyses::all();
520 auto PA = PreservedAnalyses();
521 PA.preserve<DominatorTreeAnalysis>();
522 return PA;