1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file This file contains passes and utilities to lower llvm intrinsic call
10 /// to DXILOp function call.
11 //===----------------------------------------------------------------------===//
13 #include "DXILConstants.h"
14 #include "DXILIntrinsicExpansion.h"
15 #include "DXILOpBuilder.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Intrinsics.h"
22 #include "llvm/IR/IntrinsicsDirectX.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/IR/PassManager.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/ErrorHandling.h"
28 #define DEBUG_TYPE "dxil-op-lower"
31 using namespace llvm::dxil
;
33 static bool isVectorArgExpansion(Function
&F
) {
34 switch (F
.getIntrinsicID()) {
35 case Intrinsic::dx_dot2
:
36 case Intrinsic::dx_dot3
:
37 case Intrinsic::dx_dot4
:
43 static SmallVector
<Value
*> populateOperands(Value
*Arg
, IRBuilder
<> &Builder
) {
44 SmallVector
<Value
*, 4> ExtractedElements
;
45 auto *VecArg
= dyn_cast
<FixedVectorType
>(Arg
->getType());
46 for (unsigned I
= 0; I
< VecArg
->getNumElements(); ++I
) {
47 Value
*Index
= ConstantInt::get(Type::getInt32Ty(Arg
->getContext()), I
);
48 Value
*ExtractedElement
= Builder
.CreateExtractElement(Arg
, Index
);
49 ExtractedElements
.push_back(ExtractedElement
);
51 return ExtractedElements
;
54 static SmallVector
<Value
*> argVectorFlatten(CallInst
*Orig
,
55 IRBuilder
<> &Builder
) {
56 // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
57 unsigned NumOperands
= Orig
->getNumOperands() - 1;
58 assert(NumOperands
> 0);
59 Value
*Arg0
= Orig
->getOperand(0);
60 [[maybe_unused
]] auto *VecArg0
= dyn_cast
<FixedVectorType
>(Arg0
->getType());
62 SmallVector
<Value
*> NewOperands
= populateOperands(Arg0
, Builder
);
63 for (unsigned I
= 1; I
< NumOperands
; ++I
) {
64 Value
*Arg
= Orig
->getOperand(I
);
65 [[maybe_unused
]] auto *VecArg
= dyn_cast
<FixedVectorType
>(Arg
->getType());
67 assert(VecArg0
->getElementType() == VecArg
->getElementType());
68 assert(VecArg0
->getNumElements() == VecArg
->getNumElements());
69 auto NextOperandList
= populateOperands(Arg
, Builder
);
70 NewOperands
.append(NextOperandList
.begin(), NextOperandList
.end());
75 static void lowerIntrinsic(dxil::OpCode DXILOp
, Function
&F
, Module
&M
) {
76 IRBuilder
<> B(M
.getContext());
77 DXILOpBuilder
DXILB(M
, B
);
78 Type
*OverloadTy
= DXILB
.getOverloadTy(DXILOp
, F
.getFunctionType());
79 for (User
*U
: make_early_inc_range(F
.users())) {
80 CallInst
*CI
= dyn_cast
<CallInst
>(U
);
84 SmallVector
<Value
*> Args
;
85 Value
*DXILOpArg
= B
.getInt32(static_cast<unsigned>(DXILOp
));
86 Args
.emplace_back(DXILOpArg
);
88 if (isVectorArgExpansion(F
)) {
89 SmallVector
<Value
*> NewArgs
= argVectorFlatten(CI
, B
);
90 Args
.append(NewArgs
.begin(), NewArgs
.end());
92 Args
.append(CI
->arg_begin(), CI
->arg_end());
95 DXILB
.createDXILOpCall(DXILOp
, F
.getReturnType(), OverloadTy
, Args
);
97 CI
->replaceAllUsesWith(DXILCI
);
98 CI
->eraseFromParent();
104 static bool lowerIntrinsics(Module
&M
) {
105 bool Updated
= false;
107 #define DXIL_OP_INTRINSIC_MAP
108 #include "DXILOperation.inc"
109 #undef DXIL_OP_INTRINSIC_MAP
111 for (Function
&F
: make_early_inc_range(M
.functions())) {
112 if (!F
.isDeclaration())
114 Intrinsic::ID ID
= F
.getIntrinsicID();
115 if (ID
== Intrinsic::not_intrinsic
)
117 auto LowerIt
= LowerMap
.find(ID
);
118 if (LowerIt
== LowerMap
.end())
120 lowerIntrinsic(LowerIt
->second
, F
, M
);
127 /// A pass that transforms external global definitions into declarations.
128 class DXILOpLowering
: public PassInfoMixin
<DXILOpLowering
> {
130 PreservedAnalyses
run(Module
&M
, ModuleAnalysisManager
&) {
131 if (lowerIntrinsics(M
))
132 return PreservedAnalyses::none();
133 return PreservedAnalyses::all();
139 class DXILOpLoweringLegacy
: public ModulePass
{
141 bool runOnModule(Module
&M
) override
{ return lowerIntrinsics(M
); }
142 StringRef
getPassName() const override
{ return "DXIL Op Lowering"; }
143 DXILOpLoweringLegacy() : ModulePass(ID
) {}
145 static char ID
; // Pass identification.
146 void getAnalysisUsage(llvm::AnalysisUsage
&AU
) const override
{
147 // Specify the passes that your pass depends on
148 AU
.addRequired
<DXILIntrinsicExpansionLegacy
>();
151 char DXILOpLoweringLegacy::ID
= 0;
152 } // end anonymous namespace
154 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy
, DEBUG_TYPE
, "DXIL Op Lowering",
156 INITIALIZE_PASS_END(DXILOpLoweringLegacy
, DEBUG_TYPE
, "DXIL Op Lowering", false,
159 ModulePass
*llvm::createDXILOpLoweringLegacyPass() {
160 return new DXILOpLoweringLegacy();