1 //===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // General infrastructure for keeping track of the values that according to
10 // the SPIR-V binary layout should be global to the whole module.
12 //===----------------------------------------------------------------------===//
14 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
15 #define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
17 #include "MCTargetDesc/SPIRVBaseInfo.h"
18 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
22 #include "llvm/CodeGen/MachineModuleInfo.h"
24 #include <type_traits>
28 // NOTE: using MapVector instead of DenseMap because it helps getting
29 // everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
30 // memory and expensive removals which do not happen anyway.
31 class DTSortableEntry
: public MapVector
<const MachineFunction
*, Register
> {
32 SmallVector
<DTSortableEntry
*, 2> Deps
;
37 // NOTE: bit-field default init is a C++20 feature.
38 FlagsTy() : IsFunc(0), IsGV(0) {}
43 // Common hoisting utility doesn't support function, because their hoisting
44 // require hoisting of params as well.
45 bool getIsFunc() const { return Flags
.IsFunc
; }
46 bool getIsGV() const { return Flags
.IsGV
; }
47 void setIsFunc(bool V
) { Flags
.IsFunc
= V
; }
48 void setIsGV(bool V
) { Flags
.IsGV
= V
; }
50 const SmallVector
<DTSortableEntry
*, 2> &getDeps() const { return Deps
; }
51 void addDep(DTSortableEntry
*E
) { Deps
.push_back(E
); }
54 struct SpecialTypeDescriptor
{
55 enum SpecialTypeKind
{
69 SpecialTypeDescriptor() = delete;
70 SpecialTypeDescriptor(SpecialTypeKind K
) : Kind(K
) { Hash
= Kind
; }
72 unsigned getHash() const { return Hash
; }
74 virtual ~SpecialTypeDescriptor() {}
77 struct ImageTypeDescriptor
: public SpecialTypeDescriptor
{
85 unsigned ImageFormat
: 6;
91 ImageTypeDescriptor(const Type
*SampledTy
, unsigned Dim
, unsigned Depth
,
92 unsigned Arrayed
, unsigned MS
, unsigned Sampled
,
93 unsigned ImageFormat
, unsigned AQ
= 0)
94 : SpecialTypeDescriptor(SpecialTypeKind::STK_Image
) {
97 Attrs
.Flags
.Dim
= Dim
;
98 Attrs
.Flags
.Depth
= Depth
;
99 Attrs
.Flags
.Arrayed
= Arrayed
;
101 Attrs
.Flags
.Sampled
= Sampled
;
102 Attrs
.Flags
.ImageFormat
= ImageFormat
;
104 Hash
= (DenseMapInfo
<Type
*>().getHashValue(SampledTy
) & 0xffff) ^
105 ((Attrs
.Val
<< 8) | Kind
);
108 static bool classof(const SpecialTypeDescriptor
*TD
) {
109 return TD
->Kind
== SpecialTypeKind::STK_Image
;
113 struct SampledImageTypeDescriptor
: public SpecialTypeDescriptor
{
114 SampledImageTypeDescriptor(const Type
*SampledTy
, const MachineInstr
*ImageTy
)
115 : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage
) {
116 assert(ImageTy
->getOpcode() == SPIRV::OpTypeImage
);
117 ImageTypeDescriptor
TD(
118 SampledTy
, ImageTy
->getOperand(2).getImm(),
119 ImageTy
->getOperand(3).getImm(), ImageTy
->getOperand(4).getImm(),
120 ImageTy
->getOperand(5).getImm(), ImageTy
->getOperand(6).getImm(),
121 ImageTy
->getOperand(7).getImm(), ImageTy
->getOperand(8).getImm());
122 Hash
= TD
.getHash() ^ Kind
;
125 static bool classof(const SpecialTypeDescriptor
*TD
) {
126 return TD
->Kind
== SpecialTypeKind::STK_SampledImage
;
130 struct SamplerTypeDescriptor
: public SpecialTypeDescriptor
{
131 SamplerTypeDescriptor()
132 : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler
) {
136 static bool classof(const SpecialTypeDescriptor
*TD
) {
137 return TD
->Kind
== SpecialTypeKind::STK_Sampler
;
141 struct PipeTypeDescriptor
: public SpecialTypeDescriptor
{
143 PipeTypeDescriptor(uint8_t AQ
)
144 : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe
) {
145 Hash
= (AQ
<< 8) | Kind
;
148 static bool classof(const SpecialTypeDescriptor
*TD
) {
149 return TD
->Kind
== SpecialTypeKind::STK_Pipe
;
153 struct DeviceEventTypeDescriptor
: public SpecialTypeDescriptor
{
155 DeviceEventTypeDescriptor()
156 : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent
) {
160 static bool classof(const SpecialTypeDescriptor
*TD
) {
161 return TD
->Kind
== SpecialTypeKind::STK_DeviceEvent
;
165 struct PointerTypeDescriptor
: public SpecialTypeDescriptor
{
166 const Type
*ElementType
;
167 unsigned AddressSpace
;
169 PointerTypeDescriptor() = delete;
170 PointerTypeDescriptor(const Type
*ElementType
, unsigned AddressSpace
)
171 : SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer
),
172 ElementType(ElementType
), AddressSpace(AddressSpace
) {
173 Hash
= (DenseMapInfo
<Type
*>().getHashValue(ElementType
) & 0xffff) ^
174 ((AddressSpace
<< 8) | Kind
);
177 static bool classof(const SpecialTypeDescriptor
*TD
) {
178 return TD
->Kind
== SpecialTypeKind::STK_Pointer
;
183 template <> struct DenseMapInfo
<SPIRV::SpecialTypeDescriptor
> {
184 static inline SPIRV::SpecialTypeDescriptor
getEmptyKey() {
185 return SPIRV::SpecialTypeDescriptor(
186 SPIRV::SpecialTypeDescriptor::STK_Empty
);
188 static inline SPIRV::SpecialTypeDescriptor
getTombstoneKey() {
189 return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last
);
191 static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val
) {
192 return Val
.getHash();
194 static bool isEqual(SPIRV::SpecialTypeDescriptor LHS
,
195 SPIRV::SpecialTypeDescriptor RHS
) {
196 return getHashValue(LHS
) == getHashValue(RHS
);
200 template <typename KeyTy
> class SPIRVDuplicatesTrackerBase
{
202 // NOTE: using MapVector instead of DenseMap helps getting everything ordered
203 // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
204 // expensive removals which don't happen anyway.
205 using StorageTy
= MapVector
<KeyTy
, SPIRV::DTSortableEntry
>;
211 void add(KeyTy V
, const MachineFunction
*MF
, Register R
) {
212 if (find(V
, MF
).isValid())
216 if (std::is_same
<Function
,
217 typename
std::remove_const
<
218 typename
std::remove_pointer
<KeyTy
>::type
>::type
>() ||
219 std::is_same
<Argument
,
220 typename
std::remove_const
<
221 typename
std::remove_pointer
<KeyTy
>::type
>::type
>())
222 Storage
[V
].setIsFunc(true);
223 if (std::is_same
<GlobalVariable
,
224 typename
std::remove_const
<
225 typename
std::remove_pointer
<KeyTy
>::type
>::type
>())
226 Storage
[V
].setIsGV(true);
229 Register
find(KeyTy V
, const MachineFunction
*MF
) const {
230 auto iter
= Storage
.find(V
);
231 if (iter
!= Storage
.end()) {
232 auto Map
= iter
->second
;
233 auto iter2
= Map
.find(MF
);
234 if (iter2
!= Map
.end())
235 return iter2
->second
;
240 const StorageTy
&getAllUses() const { return Storage
; }
243 StorageTy
&getAllUses() { return Storage
; }
245 // The friend class needs to have access to the internal storage
246 // to be able to build dependency graph, can't declare only one
247 // function a 'friend' due to the incomplete declaration at this point
248 // and mutual dependency problems.
249 friend class SPIRVGeneralDuplicatesTracker
;
252 template <typename T
>
253 class SPIRVDuplicatesTracker
: public SPIRVDuplicatesTrackerBase
<const T
*> {};
256 class SPIRVDuplicatesTracker
<SPIRV::SpecialTypeDescriptor
>
257 : public SPIRVDuplicatesTrackerBase
<SPIRV::SpecialTypeDescriptor
> {};
259 class SPIRVGeneralDuplicatesTracker
{
260 SPIRVDuplicatesTracker
<Type
> TT
;
261 SPIRVDuplicatesTracker
<Constant
> CT
;
262 SPIRVDuplicatesTracker
<GlobalVariable
> GT
;
263 SPIRVDuplicatesTracker
<Function
> FT
;
264 SPIRVDuplicatesTracker
<Argument
> AT
;
265 SPIRVDuplicatesTracker
<SPIRV::SpecialTypeDescriptor
> ST
;
267 // NOTE: using MOs instead of regs to get rid of MF dependency to be able
268 // to use flat data structure.
269 // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness
270 // but makes LITs more stable, should prefer DenseMap still due to
271 // significant perf difference.
272 using SPIRVReg2EntryTy
=
273 MapVector
<MachineOperand
*, SPIRV::DTSortableEntry
*>;
275 template <typename T
>
276 void prebuildReg2Entry(SPIRVDuplicatesTracker
<T
> &DT
,
277 SPIRVReg2EntryTy
&Reg2Entry
);
280 void buildDepsGraph(std::vector
<SPIRV::DTSortableEntry
*> &Graph
,
281 MachineModuleInfo
*MMI
);
283 void add(const Type
*Ty
, const MachineFunction
*MF
, Register R
) {
287 void add(const Type
*PointerElementType
, unsigned AddressSpace
,
288 const MachineFunction
*MF
, Register R
) {
289 ST
.add(SPIRV::PointerTypeDescriptor(PointerElementType
, AddressSpace
), MF
,
293 void add(const Constant
*C
, const MachineFunction
*MF
, Register R
) {
297 void add(const GlobalVariable
*GV
, const MachineFunction
*MF
, Register R
) {
301 void add(const Function
*F
, const MachineFunction
*MF
, Register R
) {
305 void add(const Argument
*Arg
, const MachineFunction
*MF
, Register R
) {
309 void add(const SPIRV::SpecialTypeDescriptor
&TD
, const MachineFunction
*MF
,
314 Register
find(const Type
*Ty
, const MachineFunction
*MF
) {
315 return TT
.find(const_cast<Type
*>(Ty
), MF
);
318 Register
find(const Type
*PointerElementType
, unsigned AddressSpace
,
319 const MachineFunction
*MF
) {
321 SPIRV::PointerTypeDescriptor(PointerElementType
, AddressSpace
), MF
);
324 Register
find(const Constant
*C
, const MachineFunction
*MF
) {
325 return CT
.find(const_cast<Constant
*>(C
), MF
);
328 Register
find(const GlobalVariable
*GV
, const MachineFunction
*MF
) {
329 return GT
.find(const_cast<GlobalVariable
*>(GV
), MF
);
332 Register
find(const Function
*F
, const MachineFunction
*MF
) {
333 return FT
.find(const_cast<Function
*>(F
), MF
);
336 Register
find(const Argument
*Arg
, const MachineFunction
*MF
) {
337 return AT
.find(const_cast<Argument
*>(Arg
), MF
);
340 Register
find(const SPIRV::SpecialTypeDescriptor
&TD
,
341 const MachineFunction
*MF
) {
342 return ST
.find(TD
, MF
);
345 const SPIRVDuplicatesTracker
<Type
> *getTypes() { return &TT
; }
348 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H