[IR] Replace of PointerType::get(Type) with opaque version (NFC) (#123617)
[llvm-project.git] / llvm / lib / Target / DirectX / DXILIntrinsicExpansion.cpp
blobcf142806bb1df6bc16ff229e3d9027f0ce6e934e
1 //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains DXIL intrinsic expansions for those that don't have
10 // opcodes in DirectX Intermediate Language (DXIL).
11 //===----------------------------------------------------------------------===//
13 #include "DXILIntrinsicExpansion.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstrTypes.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Intrinsics.h"
23 #include "llvm/IR/IntrinsicsDirectX.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/IR/Type.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/MathExtras.h"
31 #define DEBUG_TYPE "dxil-intrinsic-expansion"
33 using namespace llvm;
35 class DXILIntrinsicExpansionLegacy : public ModulePass {
37 public:
38 bool runOnModule(Module &M) override;
39 DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
41 static char ID; // Pass identification.
44 static bool isIntrinsicExpansion(Function &F) {
45 switch (F.getIntrinsicID()) {
46 case Intrinsic::abs:
47 case Intrinsic::atan2:
48 case Intrinsic::exp:
49 case Intrinsic::log:
50 case Intrinsic::log10:
51 case Intrinsic::pow:
52 case Intrinsic::dx_all:
53 case Intrinsic::dx_any:
54 case Intrinsic::dx_cross:
55 case Intrinsic::dx_uclamp:
56 case Intrinsic::dx_sclamp:
57 case Intrinsic::dx_nclamp:
58 case Intrinsic::dx_degrees:
59 case Intrinsic::dx_lerp:
60 case Intrinsic::dx_normalize:
61 case Intrinsic::dx_fdot:
62 case Intrinsic::dx_sdot:
63 case Intrinsic::dx_udot:
64 case Intrinsic::dx_sign:
65 case Intrinsic::dx_step:
66 case Intrinsic::dx_radians:
67 case Intrinsic::vector_reduce_add:
68 case Intrinsic::vector_reduce_fadd:
69 return true;
71 return false;
73 static Value *expandVecReduceAdd(CallInst *Orig, Intrinsic::ID IntrinsicId) {
74 assert(IntrinsicId == Intrinsic::vector_reduce_add ||
75 IntrinsicId == Intrinsic::vector_reduce_fadd);
77 IRBuilder<> Builder(Orig);
78 bool IsFAdd = (IntrinsicId == Intrinsic::vector_reduce_fadd);
80 Value *X = Orig->getOperand(IsFAdd ? 1 : 0);
81 Type *Ty = X->getType();
82 auto *XVec = dyn_cast<FixedVectorType>(Ty);
83 unsigned XVecSize = XVec->getNumElements();
84 Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
86 // Handle the initial start value for floating-point addition.
87 if (IsFAdd) {
88 Constant *StartValue = dyn_cast<Constant>(Orig->getOperand(0));
89 if (StartValue && !StartValue->isZeroValue())
90 Sum = Builder.CreateFAdd(Sum, StartValue);
93 // Accumulate the remaining vector elements.
94 for (unsigned I = 1; I < XVecSize; I++) {
95 Value *Elt = Builder.CreateExtractElement(X, I);
96 if (IsFAdd)
97 Sum = Builder.CreateFAdd(Sum, Elt);
98 else
99 Sum = Builder.CreateAdd(Sum, Elt);
102 return Sum;
105 static Value *expandAbs(CallInst *Orig) {
106 Value *X = Orig->getOperand(0);
107 IRBuilder<> Builder(Orig);
108 Type *Ty = X->getType();
109 Type *EltTy = Ty->getScalarType();
110 Constant *Zero = Ty->isVectorTy()
111 ? ConstantVector::getSplat(
112 ElementCount::getFixed(
113 cast<FixedVectorType>(Ty)->getNumElements()),
114 ConstantInt::get(EltTy, 0))
115 : ConstantInt::get(EltTy, 0);
116 auto *V = Builder.CreateSub(Zero, X);
117 return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
118 "dx.max");
121 static Value *expandCrossIntrinsic(CallInst *Orig) {
123 VectorType *VT = cast<VectorType>(Orig->getType());
124 if (cast<FixedVectorType>(VT)->getNumElements() != 3)
125 report_fatal_error(Twine("return vector must have exactly 3 elements"),
126 /* gen_crash_diag=*/false);
128 Value *op0 = Orig->getOperand(0);
129 Value *op1 = Orig->getOperand(1);
130 IRBuilder<> Builder(Orig);
132 Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0, "x0");
133 Value *op0_y = Builder.CreateExtractElement(op0, 1, "x1");
134 Value *op0_z = Builder.CreateExtractElement(op0, 2, "x2");
136 Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0, "y0");
137 Value *op1_y = Builder.CreateExtractElement(op1, 1, "y1");
138 Value *op1_z = Builder.CreateExtractElement(op1, 2, "y2");
140 auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
141 Value *xy = Builder.CreateFMul(x0, y1);
142 Value *yx = Builder.CreateFMul(y0, x1);
143 return Builder.CreateFSub(xy, yx, Orig->getName());
146 Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z);
147 Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x);
148 Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y);
150 Value *cross = UndefValue::get(VT);
151 cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0);
152 cross = Builder.CreateInsertElement(cross, zx_xz, 1);
153 cross = Builder.CreateInsertElement(cross, xy_yx, 2);
154 return cross;
157 // Create appropriate DXIL float dot intrinsic for the given A and B operands
158 // The appropriate opcode will be determined by the size of the operands
159 // The dot product is placed in the position indicated by Orig
160 static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
161 Type *ATy = A->getType();
162 [[maybe_unused]] Type *BTy = B->getType();
163 assert(ATy->isVectorTy() && BTy->isVectorTy());
165 IRBuilder<> Builder(Orig);
167 auto *AVec = dyn_cast<FixedVectorType>(ATy);
169 assert(ATy->getScalarType()->isFloatingPointTy());
171 Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
172 switch (AVec->getNumElements()) {
173 case 2:
174 DotIntrinsic = Intrinsic::dx_dot2;
175 break;
176 case 3:
177 DotIntrinsic = Intrinsic::dx_dot3;
178 break;
179 case 4:
180 DotIntrinsic = Intrinsic::dx_dot4;
181 break;
182 default:
183 report_fatal_error(
184 Twine("Invalid dot product input vector: length is outside 2-4"),
185 /* gen_crash_diag=*/false);
186 return nullptr;
188 return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
189 ArrayRef<Value *>{A, B}, nullptr, "dot");
192 // Create the appropriate DXIL float dot intrinsic for the operands of Orig
193 // The appropriate opcode will be determined by the size of the operands
194 // The dot product is placed in the position indicated by Orig
195 static Value *expandFloatDotIntrinsic(CallInst *Orig) {
196 return expandFloatDotIntrinsic(Orig, Orig->getOperand(0),
197 Orig->getOperand(1));
200 // Expand integer dot product to multiply and add ops
201 static Value *expandIntegerDotIntrinsic(CallInst *Orig,
202 Intrinsic::ID DotIntrinsic) {
203 assert(DotIntrinsic == Intrinsic::dx_sdot ||
204 DotIntrinsic == Intrinsic::dx_udot);
205 Value *A = Orig->getOperand(0);
206 Value *B = Orig->getOperand(1);
207 Type *ATy = A->getType();
208 [[maybe_unused]] Type *BTy = B->getType();
209 assert(ATy->isVectorTy() && BTy->isVectorTy());
211 IRBuilder<> Builder(Orig);
213 auto *AVec = dyn_cast<FixedVectorType>(ATy);
215 assert(ATy->getScalarType()->isIntegerTy());
217 Value *Result;
218 Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
219 ? Intrinsic::dx_imad
220 : Intrinsic::dx_umad;
221 Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
222 Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
223 Result = Builder.CreateMul(Elt0, Elt1);
224 for (unsigned I = 1; I < AVec->getNumElements(); I++) {
225 Elt0 = Builder.CreateExtractElement(A, I);
226 Elt1 = Builder.CreateExtractElement(B, I);
227 Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
228 ArrayRef<Value *>{Elt0, Elt1, Result},
229 nullptr, "dx.mad");
231 return Result;
234 static Value *expandExpIntrinsic(CallInst *Orig) {
235 Value *X = Orig->getOperand(0);
236 IRBuilder<> Builder(Orig);
237 Type *Ty = X->getType();
238 Type *EltTy = Ty->getScalarType();
239 Constant *Log2eConst =
240 Ty->isVectorTy() ? ConstantVector::getSplat(
241 ElementCount::getFixed(
242 cast<FixedVectorType>(Ty)->getNumElements()),
243 ConstantFP::get(EltTy, numbers::log2ef))
244 : ConstantFP::get(EltTy, numbers::log2ef);
245 Value *NewX = Builder.CreateFMul(Log2eConst, X);
246 auto *Exp2Call =
247 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
248 Exp2Call->setTailCall(Orig->isTailCall());
249 Exp2Call->setAttributes(Orig->getAttributes());
250 return Exp2Call;
253 static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
254 Intrinsic::ID intrinsicId) {
255 Value *X = Orig->getOperand(0);
256 IRBuilder<> Builder(Orig);
257 Type *Ty = X->getType();
258 Type *EltTy = Ty->getScalarType();
260 auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
261 Value *Elt) {
262 if (IntrinsicId == Intrinsic::dx_any)
263 return Builder.CreateOr(Result, Elt);
264 assert(IntrinsicId == Intrinsic::dx_all);
265 return Builder.CreateAnd(Result, Elt);
268 Value *Result = nullptr;
269 if (!Ty->isVectorTy()) {
270 Result = EltTy->isFloatingPointTy()
271 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
272 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
273 } else {
274 auto *XVec = dyn_cast<FixedVectorType>(Ty);
275 Value *Cond =
276 EltTy->isFloatingPointTy()
277 ? Builder.CreateFCmpUNE(
278 X, ConstantVector::getSplat(
279 ElementCount::getFixed(XVec->getNumElements()),
280 ConstantFP::get(EltTy, 0)))
281 : Builder.CreateICmpNE(
282 X, ConstantVector::getSplat(
283 ElementCount::getFixed(XVec->getNumElements()),
284 ConstantInt::get(EltTy, 0)));
285 Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
286 for (unsigned I = 1; I < XVec->getNumElements(); I++) {
287 Value *Elt = Builder.CreateExtractElement(Cond, I);
288 Result = ApplyOp(intrinsicId, Result, Elt);
291 return Result;
294 static Value *expandLerpIntrinsic(CallInst *Orig) {
295 Value *X = Orig->getOperand(0);
296 Value *Y = Orig->getOperand(1);
297 Value *S = Orig->getOperand(2);
298 IRBuilder<> Builder(Orig);
299 auto *V = Builder.CreateFSub(Y, X);
300 V = Builder.CreateFMul(S, V);
301 return Builder.CreateFAdd(X, V, "dx.lerp");
304 static Value *expandLogIntrinsic(CallInst *Orig,
305 float LogConstVal = numbers::ln2f) {
306 Value *X = Orig->getOperand(0);
307 IRBuilder<> Builder(Orig);
308 Type *Ty = X->getType();
309 Type *EltTy = Ty->getScalarType();
310 Constant *Ln2Const =
311 Ty->isVectorTy() ? ConstantVector::getSplat(
312 ElementCount::getFixed(
313 cast<FixedVectorType>(Ty)->getNumElements()),
314 ConstantFP::get(EltTy, LogConstVal))
315 : ConstantFP::get(EltTy, LogConstVal);
316 auto *Log2Call =
317 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
318 Log2Call->setTailCall(Orig->isTailCall());
319 Log2Call->setAttributes(Orig->getAttributes());
320 return Builder.CreateFMul(Ln2Const, Log2Call);
322 static Value *expandLog10Intrinsic(CallInst *Orig) {
323 return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
326 // Use dot product of vector operand with itself to calculate the length.
327 // Divide the vector by that length to normalize it.
328 static Value *expandNormalizeIntrinsic(CallInst *Orig) {
329 Value *X = Orig->getOperand(0);
330 Type *Ty = Orig->getType();
331 Type *EltTy = Ty->getScalarType();
332 IRBuilder<> Builder(Orig);
334 auto *XVec = dyn_cast<FixedVectorType>(Ty);
335 if (!XVec) {
336 if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
337 const APFloat &fpVal = constantFP->getValueAPF();
338 if (fpVal.isZero())
339 report_fatal_error(Twine("Invalid input scalar: length is zero"),
340 /* gen_crash_diag=*/false);
342 return Builder.CreateFDiv(X, X);
345 Value *DotProduct = expandFloatDotIntrinsic(Orig, X, X);
347 // verify that the length is non-zero
348 // (if the dot product is non-zero, then the length is non-zero)
349 if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
350 const APFloat &fpVal = constantFP->getValueAPF();
351 if (fpVal.isZero())
352 report_fatal_error(Twine("Invalid input vector: length is zero"),
353 /* gen_crash_diag=*/false);
356 Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
357 ArrayRef<Value *>{DotProduct},
358 nullptr, "dx.rsqrt");
360 Value *MultiplicandVec =
361 Builder.CreateVectorSplat(XVec->getNumElements(), Multiplicand);
362 return Builder.CreateFMul(X, MultiplicandVec);
365 static Value *expandAtan2Intrinsic(CallInst *Orig) {
366 Value *Y = Orig->getOperand(0);
367 Value *X = Orig->getOperand(1);
368 Type *Ty = X->getType();
369 IRBuilder<> Builder(Orig);
370 Builder.setFastMathFlags(Orig->getFastMathFlags());
372 Value *Tan = Builder.CreateFDiv(Y, X);
374 CallInst *Atan =
375 Builder.CreateIntrinsic(Ty, Intrinsic::atan, {Tan}, nullptr, "Elt.Atan");
376 Atan->setTailCall(Orig->isTailCall());
377 Atan->setAttributes(Orig->getAttributes());
379 // Modify atan result based on https://en.wikipedia.org/wiki/Atan2.
380 Constant *Pi = ConstantFP::get(Ty, llvm::numbers::pi);
381 Constant *HalfPi = ConstantFP::get(Ty, llvm::numbers::pi / 2);
382 Constant *NegHalfPi = ConstantFP::get(Ty, -llvm::numbers::pi / 2);
383 Constant *Zero = ConstantFP::get(Ty, 0);
384 Value *AtanAddPi = Builder.CreateFAdd(Atan, Pi);
385 Value *AtanSubPi = Builder.CreateFSub(Atan, Pi);
387 // x > 0 -> atan.
388 Value *Result = Atan;
389 Value *XLt0 = Builder.CreateFCmpOLT(X, Zero);
390 Value *XEq0 = Builder.CreateFCmpOEQ(X, Zero);
391 Value *YGe0 = Builder.CreateFCmpOGE(Y, Zero);
392 Value *YLt0 = Builder.CreateFCmpOLT(Y, Zero);
394 // x < 0, y >= 0 -> atan + pi.
395 Value *XLt0AndYGe0 = Builder.CreateAnd(XLt0, YGe0);
396 Result = Builder.CreateSelect(XLt0AndYGe0, AtanAddPi, Result);
398 // x < 0, y < 0 -> atan - pi.
399 Value *XLt0AndYLt0 = Builder.CreateAnd(XLt0, YLt0);
400 Result = Builder.CreateSelect(XLt0AndYLt0, AtanSubPi, Result);
402 // x == 0, y < 0 -> -pi/2
403 Value *XEq0AndYLt0 = Builder.CreateAnd(XEq0, YLt0);
404 Result = Builder.CreateSelect(XEq0AndYLt0, NegHalfPi, Result);
406 // x == 0, y > 0 -> pi/2
407 Value *XEq0AndYGe0 = Builder.CreateAnd(XEq0, YGe0);
408 Result = Builder.CreateSelect(XEq0AndYGe0, HalfPi, Result);
410 return Result;
413 static Value *expandPowIntrinsic(CallInst *Orig) {
415 Value *X = Orig->getOperand(0);
416 Value *Y = Orig->getOperand(1);
417 Type *Ty = X->getType();
418 IRBuilder<> Builder(Orig);
420 auto *Log2Call =
421 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
422 auto *Mul = Builder.CreateFMul(Log2Call, Y);
423 auto *Exp2Call =
424 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
425 Exp2Call->setTailCall(Orig->isTailCall());
426 Exp2Call->setAttributes(Orig->getAttributes());
427 return Exp2Call;
430 static Value *expandStepIntrinsic(CallInst *Orig) {
432 Value *X = Orig->getOperand(0);
433 Value *Y = Orig->getOperand(1);
434 Type *Ty = X->getType();
435 IRBuilder<> Builder(Orig);
437 Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
438 Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
439 Value *Cond = Builder.CreateFCmpOLT(Y, X);
441 if (Ty != Ty->getScalarType()) {
442 auto *XVec = dyn_cast<FixedVectorType>(Ty);
443 One = ConstantVector::getSplat(
444 ElementCount::getFixed(XVec->getNumElements()), One);
445 Zero = ConstantVector::getSplat(
446 ElementCount::getFixed(XVec->getNumElements()), Zero);
449 return Builder.CreateSelect(Cond, Zero, One);
452 static Value *expandRadiansIntrinsic(CallInst *Orig) {
453 Value *X = Orig->getOperand(0);
454 Type *Ty = X->getType();
455 IRBuilder<> Builder(Orig);
456 Value *PiOver180 = ConstantFP::get(Ty, llvm::numbers::pi / 180.0);
457 return Builder.CreateFMul(X, PiOver180);
460 static Intrinsic::ID getMaxForClamp(Intrinsic::ID ClampIntrinsic) {
461 if (ClampIntrinsic == Intrinsic::dx_uclamp)
462 return Intrinsic::umax;
463 if (ClampIntrinsic == Intrinsic::dx_sclamp)
464 return Intrinsic::smax;
465 assert(ClampIntrinsic == Intrinsic::dx_nclamp);
466 return Intrinsic::maxnum;
469 static Intrinsic::ID getMinForClamp(Intrinsic::ID ClampIntrinsic) {
470 if (ClampIntrinsic == Intrinsic::dx_uclamp)
471 return Intrinsic::umin;
472 if (ClampIntrinsic == Intrinsic::dx_sclamp)
473 return Intrinsic::smin;
474 assert(ClampIntrinsic == Intrinsic::dx_nclamp);
475 return Intrinsic::minnum;
478 static Value *expandClampIntrinsic(CallInst *Orig,
479 Intrinsic::ID ClampIntrinsic) {
480 Value *X = Orig->getOperand(0);
481 Value *Min = Orig->getOperand(1);
482 Value *Max = Orig->getOperand(2);
483 Type *Ty = X->getType();
484 IRBuilder<> Builder(Orig);
485 auto *MaxCall = Builder.CreateIntrinsic(Ty, getMaxForClamp(ClampIntrinsic),
486 {X, Min}, nullptr, "dx.max");
487 return Builder.CreateIntrinsic(Ty, getMinForClamp(ClampIntrinsic),
488 {MaxCall, Max}, nullptr, "dx.min");
491 static Value *expandDegreesIntrinsic(CallInst *Orig) {
492 Value *X = Orig->getOperand(0);
493 Type *Ty = X->getType();
494 IRBuilder<> Builder(Orig);
495 Value *DegreesRatio = ConstantFP::get(Ty, 180.0 * llvm::numbers::inv_pi);
496 return Builder.CreateFMul(X, DegreesRatio);
499 static Value *expandSignIntrinsic(CallInst *Orig) {
500 Value *X = Orig->getOperand(0);
501 Type *Ty = X->getType();
502 Type *ScalarTy = Ty->getScalarType();
503 Type *RetTy = Orig->getType();
504 Constant *Zero = Constant::getNullValue(Ty);
506 IRBuilder<> Builder(Orig);
508 Value *GT;
509 Value *LT;
510 if (ScalarTy->isFloatingPointTy()) {
511 GT = Builder.CreateFCmpOLT(Zero, X);
512 LT = Builder.CreateFCmpOLT(X, Zero);
513 } else {
514 assert(ScalarTy->isIntegerTy());
515 GT = Builder.CreateICmpSLT(Zero, X);
516 LT = Builder.CreateICmpSLT(X, Zero);
519 Value *ZextGT = Builder.CreateZExt(GT, RetTy);
520 Value *ZextLT = Builder.CreateZExt(LT, RetTy);
522 return Builder.CreateSub(ZextGT, ZextLT);
525 static bool expandIntrinsic(Function &F, CallInst *Orig) {
526 Value *Result = nullptr;
527 Intrinsic::ID IntrinsicId = F.getIntrinsicID();
528 switch (IntrinsicId) {
529 case Intrinsic::abs:
530 Result = expandAbs(Orig);
531 break;
532 case Intrinsic::atan2:
533 Result = expandAtan2Intrinsic(Orig);
534 break;
535 case Intrinsic::exp:
536 Result = expandExpIntrinsic(Orig);
537 break;
538 case Intrinsic::log:
539 Result = expandLogIntrinsic(Orig);
540 break;
541 case Intrinsic::log10:
542 Result = expandLog10Intrinsic(Orig);
543 break;
544 case Intrinsic::pow:
545 Result = expandPowIntrinsic(Orig);
546 break;
547 case Intrinsic::dx_all:
548 case Intrinsic::dx_any:
549 Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
550 break;
551 case Intrinsic::dx_cross:
552 Result = expandCrossIntrinsic(Orig);
553 break;
554 case Intrinsic::dx_uclamp:
555 case Intrinsic::dx_sclamp:
556 case Intrinsic::dx_nclamp:
557 Result = expandClampIntrinsic(Orig, IntrinsicId);
558 break;
559 case Intrinsic::dx_degrees:
560 Result = expandDegreesIntrinsic(Orig);
561 break;
562 case Intrinsic::dx_lerp:
563 Result = expandLerpIntrinsic(Orig);
564 break;
565 case Intrinsic::dx_normalize:
566 Result = expandNormalizeIntrinsic(Orig);
567 break;
568 case Intrinsic::dx_fdot:
569 Result = expandFloatDotIntrinsic(Orig);
570 break;
571 case Intrinsic::dx_sdot:
572 case Intrinsic::dx_udot:
573 Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
574 break;
575 case Intrinsic::dx_sign:
576 Result = expandSignIntrinsic(Orig);
577 break;
578 case Intrinsic::dx_step:
579 Result = expandStepIntrinsic(Orig);
580 break;
581 case Intrinsic::dx_radians:
582 Result = expandRadiansIntrinsic(Orig);
583 break;
584 case Intrinsic::vector_reduce_add:
585 case Intrinsic::vector_reduce_fadd:
586 Result = expandVecReduceAdd(Orig, IntrinsicId);
587 break;
589 if (Result) {
590 Orig->replaceAllUsesWith(Result);
591 Orig->eraseFromParent();
592 return true;
594 return false;
597 static bool expansionIntrinsics(Module &M) {
598 for (auto &F : make_early_inc_range(M.functions())) {
599 if (!isIntrinsicExpansion(F))
600 continue;
601 bool IntrinsicExpanded = false;
602 for (User *U : make_early_inc_range(F.users())) {
603 auto *IntrinsicCall = dyn_cast<CallInst>(U);
604 if (!IntrinsicCall)
605 continue;
606 IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
608 if (F.user_empty() && IntrinsicExpanded)
609 F.eraseFromParent();
611 return true;
614 PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
615 ModuleAnalysisManager &) {
616 if (expansionIntrinsics(M))
617 return PreservedAnalyses::none();
618 return PreservedAnalyses::all();
621 bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
622 return expansionIntrinsics(M);
625 char DXILIntrinsicExpansionLegacy::ID = 0;
627 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
628 "DXIL Intrinsic Expansion", false, false)
629 INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
630 "DXIL Intrinsic Expansion", false, false)
632 ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
633 return new DXILIntrinsicExpansionLegacy();