1 //===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Transforms/SROA.h"
10 #include "mlir/Analysis/DataLayoutAnalysis.h"
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Analysis/TopologicalSortUtils.h"
13 #include "mlir/Interfaces/MemorySlotInterfaces.h"
14 #include "mlir/Transforms/Passes.h"
17 #define GEN_PASS_DEF_SROA
18 #include "mlir/Transforms/Passes.h.inc"
21 #define DEBUG_TYPE "sroa"
27 /// Information computed by destructurable memory slot analysis used to perform
28 /// actual destructuring of the slot. This struct is only constructed if
29 /// destructuring is possible, and contains the necessary data to perform it.
30 struct MemorySlotDestructuringInfo
{
31 /// Set of the indices that are actually used when accessing the subelements.
32 SmallPtrSet
<Attribute
, 8> usedIndices
;
33 /// Blocking uses of a given user of the memory slot that must be eliminated.
34 DenseMap
<Operation
*, SmallPtrSet
<OpOperand
*, 4>> userToBlockingUses
;
35 /// List of potentially indirect accessors of the memory slot that need
37 SmallVector
<DestructurableAccessorOpInterface
> accessors
;
42 /// Computes information for slot destructuring. This will compute whether this
43 /// slot can be destructured and data to perform the destructuring. Returns
44 /// nothing if the slot cannot be destructured or if there is no useful work to
46 static std::optional
<MemorySlotDestructuringInfo
>
47 computeDestructuringInfo(DestructurableMemorySlot
&slot
,
48 const DataLayout
&dataLayout
) {
49 assert(isa
<DestructurableTypeInterface
>(slot
.elemType
));
51 if (slot
.ptr
.use_empty())
54 MemorySlotDestructuringInfo info
;
56 SmallVector
<MemorySlot
> usedSafelyWorklist
;
58 auto scheduleAsBlockingUse
= [&](OpOperand
&use
) {
59 SmallPtrSetImpl
<OpOperand
*> &blockingUses
=
60 info
.userToBlockingUses
[use
.getOwner()];
61 blockingUses
.insert(&use
);
64 // Initialize the analysis with the immediate users of the slot.
65 for (OpOperand
&use
: slot
.ptr
.getUses()) {
67 dyn_cast
<DestructurableAccessorOpInterface
>(use
.getOwner())) {
68 if (accessor
.canRewire(slot
, info
.usedIndices
, usedSafelyWorklist
,
70 info
.accessors
.push_back(accessor
);
75 // If it cannot be shown that the operation uses the slot safely, maybe it
76 // can be promoted out of using the slot?
77 scheduleAsBlockingUse(use
);
80 SmallPtrSet
<OpOperand
*, 16> visited
;
81 while (!usedSafelyWorklist
.empty()) {
82 MemorySlot mustBeUsedSafely
= usedSafelyWorklist
.pop_back_val();
83 for (OpOperand
&subslotUse
: mustBeUsedSafely
.ptr
.getUses()) {
84 if (!visited
.insert(&subslotUse
).second
)
86 Operation
*subslotUser
= subslotUse
.getOwner();
88 if (auto memOp
= dyn_cast
<SafeMemorySlotAccessOpInterface
>(subslotUser
))
89 if (succeeded(memOp
.ensureOnlySafeAccesses(
90 mustBeUsedSafely
, usedSafelyWorklist
, dataLayout
)))
93 // If it cannot be shown that the operation uses the slot safely, maybe it
94 // can be promoted out of using the slot?
95 scheduleAsBlockingUse(subslotUse
);
99 SetVector
<Operation
*> forwardSlice
;
100 mlir::getForwardSlice(slot
.ptr
, &forwardSlice
);
101 for (Operation
*user
: forwardSlice
) {
102 // If the next operation has no blocking uses, everything is fine.
103 auto it
= info
.userToBlockingUses
.find(user
);
104 if (it
== info
.userToBlockingUses
.end())
107 SmallPtrSet
<OpOperand
*, 4> &blockingUses
= it
->second
;
108 auto promotable
= dyn_cast
<PromotableOpInterface
>(user
);
110 // An operation that has blocking uses must be promoted. If it is not
111 // promotable, destructuring must fail.
115 SmallVector
<OpOperand
*> newBlockingUses
;
116 // If the operation decides it cannot deal with removing the blocking uses,
117 // destructuring must fail.
118 if (!promotable
.canUsesBeRemoved(blockingUses
, newBlockingUses
, dataLayout
))
121 // Then, register any new blocking uses for coming operations.
122 for (OpOperand
*blockingUse
: newBlockingUses
) {
123 assert(llvm::is_contained(user
->getResults(), blockingUse
->get()));
125 SmallPtrSetImpl
<OpOperand
*> &newUserBlockingUseSet
=
126 info
.userToBlockingUses
[blockingUse
->getOwner()];
127 newUserBlockingUseSet
.insert(blockingUse
);
134 /// Performs the destructuring of a destructible slot given associated
135 /// destructuring information. The provided slot will be destructured in
136 /// subslots as specified by its allocator.
137 static void destructureSlot(
138 DestructurableMemorySlot
&slot
,
139 DestructurableAllocationOpInterface allocator
, OpBuilder
&builder
,
140 const DataLayout
&dataLayout
, MemorySlotDestructuringInfo
&info
,
141 SmallVectorImpl
<DestructurableAllocationOpInterface
> &newAllocators
,
142 const SROAStatistics
&statistics
) {
143 OpBuilder::InsertionGuard
guard(builder
);
145 builder
.setInsertionPointToStart(slot
.ptr
.getParentBlock());
146 DenseMap
<Attribute
, MemorySlot
> subslots
=
147 allocator
.destructure(slot
, info
.usedIndices
, builder
, newAllocators
);
149 if (statistics
.slotsWithMemoryBenefit
&&
150 slot
.subelementTypes
.size() != info
.usedIndices
.size())
151 (*statistics
.slotsWithMemoryBenefit
)++;
153 if (statistics
.maxSubelementAmount
)
154 statistics
.maxSubelementAmount
->updateMax(slot
.subelementTypes
.size());
156 SetVector
<Operation
*> usersToRewire
;
157 for (Operation
*user
: llvm::make_first_range(info
.userToBlockingUses
))
158 usersToRewire
.insert(user
);
159 for (DestructurableAccessorOpInterface accessor
: info
.accessors
)
160 usersToRewire
.insert(accessor
);
161 usersToRewire
= mlir::topologicalSort(usersToRewire
);
163 llvm::SmallVector
<Operation
*> toErase
;
164 for (Operation
*toRewire
: llvm::reverse(usersToRewire
)) {
165 builder
.setInsertionPointAfter(toRewire
);
166 if (auto accessor
= dyn_cast
<DestructurableAccessorOpInterface
>(toRewire
)) {
167 if (accessor
.rewire(slot
, subslots
, builder
, dataLayout
) ==
168 DeletionKind::Delete
)
169 toErase
.push_back(accessor
);
173 auto promotable
= cast
<PromotableOpInterface
>(toRewire
);
174 if (promotable
.removeBlockingUses(info
.userToBlockingUses
[promotable
],
175 builder
) == DeletionKind::Delete
)
176 toErase
.push_back(promotable
);
179 for (Operation
*toEraseOp
: toErase
)
182 assert(slot
.ptr
.use_empty() && "after destructuring, the original slot "
183 "pointer should no longer be used");
185 LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot
.ptr
188 if (statistics
.destructuredAmount
)
189 (*statistics
.destructuredAmount
)++;
191 std::optional
<DestructurableAllocationOpInterface
> newAllocator
=
192 allocator
.handleDestructuringComplete(slot
, builder
);
193 // Add newly created allocators to the worklist for further processing.
195 newAllocators
.push_back(*newAllocator
);
198 LogicalResult
mlir::tryToDestructureMemorySlots(
199 ArrayRef
<DestructurableAllocationOpInterface
> allocators
,
200 OpBuilder
&builder
, const DataLayout
&dataLayout
,
201 SROAStatistics statistics
) {
202 bool destructuredAny
= false;
204 SmallVector
<DestructurableAllocationOpInterface
> workList(allocators
);
205 SmallVector
<DestructurableAllocationOpInterface
> newWorkList
;
206 newWorkList
.reserve(allocators
.size());
207 // Destructuring a slot can allow for further destructuring of other
208 // slots, destructuring is tried until no destructuring succeeds.
210 bool changesInThisRound
= false;
212 for (DestructurableAllocationOpInterface allocator
: workList
) {
213 bool destructuredAnySlot
= false;
214 for (DestructurableMemorySlot slot
: allocator
.getDestructurableSlots()) {
215 std::optional
<MemorySlotDestructuringInfo
> info
=
216 computeDestructuringInfo(slot
, dataLayout
);
220 destructureSlot(slot
, allocator
, builder
, dataLayout
, *info
,
221 newWorkList
, statistics
);
222 destructuredAnySlot
= true;
224 // A break is required, since destructuring a slot may invalidate the
225 // remaning slots of an allocator.
228 if (!destructuredAnySlot
)
229 newWorkList
.push_back(allocator
);
230 changesInThisRound
|= destructuredAnySlot
;
233 if (!changesInThisRound
)
235 destructuredAny
|= changesInThisRound
;
237 // Swap the vector's backing memory and clear the entries in newWorkList
238 // afterwards. This ensures that additional heap allocations can be avoided.
239 workList
.swap(newWorkList
);
243 return success(destructuredAny
);
248 struct SROA
: public impl::SROABase
<SROA
> {
249 using impl::SROABase
<SROA
>::SROABase
;
251 void runOnOperation() override
{
252 Operation
*scopeOp
= getOperation();
254 SROAStatistics statistics
{&destructuredAmount
, &slotsWithMemoryBenefit
,
255 &maxSubelementAmount
};
257 auto &dataLayoutAnalysis
= getAnalysis
<DataLayoutAnalysis
>();
258 const DataLayout
&dataLayout
= dataLayoutAnalysis
.getAtOrAbove(scopeOp
);
259 bool changed
= false;
261 for (Region
®ion
: scopeOp
->getRegions()) {
262 if (region
.getBlocks().empty())
265 OpBuilder
builder(®ion
.front(), region
.front().begin());
267 SmallVector
<DestructurableAllocationOpInterface
> allocators
;
268 // Build a list of allocators to attempt to destructure the slots of.
269 region
.walk([&](DestructurableAllocationOpInterface allocator
) {
270 allocators
.emplace_back(allocator
);
273 // Attempt to destructure as many slots as possible.
274 if (succeeded(tryToDestructureMemorySlots(allocators
, builder
, dataLayout
,
279 markAllAnalysesPreserved();