1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 // This file contains the AArch64 / Cortex-A57 specific register allocation
9 // constraints for use by the PBQP register allocator.
11 // It is essentially a transcription of what is contained in
12 // AArch64A57FPLoadBalancing, which tries to use a balanced
13 // mix of odd and even D-registers when performing a critical sequence of
14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15 //===----------------------------------------------------------------------===//
17 #include "AArch64PBQPRegAlloc.h"
19 #include "AArch64RegisterInfo.h"
20 #include "llvm/CodeGen/LiveIntervals.h"
21 #include "llvm/CodeGen/MachineBasicBlock.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/RegAllocPBQP.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/raw_ostream.h"
29 #define DEBUG_TYPE "aarch64-pbqp"
36 bool isFPReg(unsigned reg
) {
37 return AArch64::FPR32RegClass
.contains(reg
) ||
38 AArch64::FPR64RegClass
.contains(reg
) ||
39 AArch64::FPR128RegClass
.contains(reg
);
43 bool isOdd(unsigned reg
) {
46 llvm_unreachable("Register is not from the expected class !");
149 bool haveSameParity(unsigned reg1
, unsigned reg2
) {
150 assert(isFPReg(reg1
) && "Expecting an FP register for reg1");
151 assert(isFPReg(reg2
) && "Expecting an FP register for reg2");
153 return isOdd(reg1
) == isOdd(reg2
);
158 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph
&G
, unsigned Rd
,
163 LiveIntervals
&LIs
= G
.getMetadata().LIS
;
165 if (Register::isPhysicalRegister(Rd
) || Register::isPhysicalRegister(Ra
)) {
166 LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
167 << Register::isPhysicalRegister(Rd
) << '\n');
168 LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
169 << Register::isPhysicalRegister(Ra
) << '\n');
173 PBQPRAGraph::NodeId node1
= G
.getMetadata().getNodeIdForVReg(Rd
);
174 PBQPRAGraph::NodeId node2
= G
.getMetadata().getNodeIdForVReg(Ra
);
176 const PBQPRAGraph::NodeMetadata::AllowedRegVector
*vRdAllowed
=
177 &G
.getNodeMetadata(node1
).getAllowedRegs();
178 const PBQPRAGraph::NodeMetadata::AllowedRegVector
*vRaAllowed
=
179 &G
.getNodeMetadata(node2
).getAllowedRegs();
181 PBQPRAGraph::EdgeId edge
= G
.findEdge(node1
, node2
);
183 // The edge does not exist. Create one with the appropriate interference
185 if (edge
== G
.invalidEdgeId()) {
186 const LiveInterval
&ld
= LIs
.getInterval(Rd
);
187 const LiveInterval
&la
= LIs
.getInterval(Ra
);
188 bool livesOverlap
= ld
.overlaps(la
);
190 PBQPRAGraph::RawMatrix
costs(vRdAllowed
->size() + 1,
191 vRaAllowed
->size() + 1, 0);
192 for (unsigned i
= 0, ie
= vRdAllowed
->size(); i
!= ie
; ++i
) {
193 unsigned pRd
= (*vRdAllowed
)[i
];
194 for (unsigned j
= 0, je
= vRaAllowed
->size(); j
!= je
; ++j
) {
195 unsigned pRa
= (*vRaAllowed
)[j
];
196 if (livesOverlap
&& TRI
->regsOverlap(pRd
, pRa
))
197 costs
[i
+ 1][j
+ 1] = std::numeric_limits
<PBQP::PBQPNum
>::infinity();
199 costs
[i
+ 1][j
+ 1] = haveSameParity(pRd
, pRa
) ? 0.0 : 1.0;
202 G
.addEdge(node1
, node2
, std::move(costs
));
206 if (G
.getEdgeNode1Id(edge
) == node2
) {
207 std::swap(node1
, node2
);
208 std::swap(vRdAllowed
, vRaAllowed
);
211 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
212 PBQPRAGraph::RawMatrix
costs(G
.getEdgeCosts(edge
));
213 for (unsigned i
= 0, ie
= vRdAllowed
->size(); i
!= ie
; ++i
) {
214 unsigned pRd
= (*vRdAllowed
)[i
];
216 // Get the maximum cost (excluding unallocatable reg) for same parity
218 PBQP::PBQPNum sameParityMax
= std::numeric_limits
<PBQP::PBQPNum
>::min();
219 for (unsigned j
= 0, je
= vRaAllowed
->size(); j
!= je
; ++j
) {
220 unsigned pRa
= (*vRaAllowed
)[j
];
221 if (haveSameParity(pRd
, pRa
))
222 if (costs
[i
+ 1][j
+ 1] !=
223 std::numeric_limits
<PBQP::PBQPNum
>::infinity() &&
224 costs
[i
+ 1][j
+ 1] > sameParityMax
)
225 sameParityMax
= costs
[i
+ 1][j
+ 1];
228 // Ensure all registers with a different parity have a higher cost
229 // than sameParityMax
230 for (unsigned j
= 0, je
= vRaAllowed
->size(); j
!= je
; ++j
) {
231 unsigned pRa
= (*vRaAllowed
)[j
];
232 if (!haveSameParity(pRd
, pRa
))
233 if (sameParityMax
> costs
[i
+ 1][j
+ 1])
234 costs
[i
+ 1][j
+ 1] = sameParityMax
+ 1.0;
237 G
.updateEdgeCosts(edge
, std::move(costs
));
242 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph
&G
, unsigned Rd
,
244 LiveIntervals
&LIs
= G
.getMetadata().LIS
;
246 // Do some Chain management
247 if (Chains
.count(Ra
)) {
249 LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra
, TRI
)
250 << " to " << printReg(Rd
, TRI
) << '\n';);
255 LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd
, TRI
)
260 PBQPRAGraph::NodeId node1
= G
.getMetadata().getNodeIdForVReg(Rd
);
262 const LiveInterval
&ld
= LIs
.getInterval(Rd
);
263 for (auto r
: Chains
) {
268 const LiveInterval
&lr
= LIs
.getInterval(r
);
269 if (ld
.overlaps(lr
)) {
270 const PBQPRAGraph::NodeMetadata::AllowedRegVector
*vRdAllowed
=
271 &G
.getNodeMetadata(node1
).getAllowedRegs();
273 PBQPRAGraph::NodeId node2
= G
.getMetadata().getNodeIdForVReg(r
);
274 const PBQPRAGraph::NodeMetadata::AllowedRegVector
*vRrAllowed
=
275 &G
.getNodeMetadata(node2
).getAllowedRegs();
277 PBQPRAGraph::EdgeId edge
= G
.findEdge(node1
, node2
);
278 assert(edge
!= G
.invalidEdgeId() &&
279 "PBQP error ! The edge should exist !");
281 LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
283 if (G
.getEdgeNode1Id(edge
) == node2
) {
284 std::swap(node1
, node2
);
285 std::swap(vRdAllowed
, vRrAllowed
);
288 // Enforce that cost is higher with all other Chains of the same parity
289 PBQP::Matrix
costs(G
.getEdgeCosts(edge
));
290 for (unsigned i
= 0, ie
= vRdAllowed
->size(); i
!= ie
; ++i
) {
291 unsigned pRd
= (*vRdAllowed
)[i
];
293 // Get the maximum cost (excluding unallocatable reg) for all other
295 PBQP::PBQPNum sameParityMax
= std::numeric_limits
<PBQP::PBQPNum
>::min();
296 for (unsigned j
= 0, je
= vRrAllowed
->size(); j
!= je
; ++j
) {
297 unsigned pRa
= (*vRrAllowed
)[j
];
298 if (!haveSameParity(pRd
, pRa
))
299 if (costs
[i
+ 1][j
+ 1] !=
300 std::numeric_limits
<PBQP::PBQPNum
>::infinity() &&
301 costs
[i
+ 1][j
+ 1] > sameParityMax
)
302 sameParityMax
= costs
[i
+ 1][j
+ 1];
305 // Ensure all registers with same parity have a higher cost
306 // than sameParityMax
307 for (unsigned j
= 0, je
= vRrAllowed
->size(); j
!= je
; ++j
) {
308 unsigned pRa
= (*vRrAllowed
)[j
];
309 if (haveSameParity(pRd
, pRa
))
310 if (sameParityMax
> costs
[i
+ 1][j
+ 1])
311 costs
[i
+ 1][j
+ 1] = sameParityMax
+ 1.0;
314 G
.updateEdgeCosts(edge
, std::move(costs
));
319 static bool regJustKilledBefore(const LiveIntervals
&LIs
, unsigned reg
,
320 const MachineInstr
&MI
) {
321 const LiveInterval
&LI
= LIs
.getInterval(reg
);
322 SlotIndex SI
= LIs
.getInstructionIndex(MI
);
323 return LI
.expiredAt(SI
);
326 void A57ChainingConstraint::apply(PBQPRAGraph
&G
) {
327 const MachineFunction
&MF
= G
.getMetadata().MF
;
328 LiveIntervals
&LIs
= G
.getMetadata().LIS
;
330 TRI
= MF
.getSubtarget().getRegisterInfo();
331 LLVM_DEBUG(MF
.dump());
333 for (const auto &MBB
: MF
) {
334 Chains
.clear(); // FIXME: really needed ? Could not work at MF level ?
336 for (const auto &MI
: MBB
) {
338 // Forget Chains which have expired
339 for (auto r
: Chains
) {
340 SmallVector
<unsigned, 8> toDel
;
341 if(regJustKilledBefore(LIs
, r
, MI
)) {
342 LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r
, TRI
) << " at ";
347 while (!toDel
.empty()) {
348 Chains
.remove(toDel
.back());
353 switch (MI
.getOpcode()) {
354 case AArch64::FMSUBSrrr
:
355 case AArch64::FMADDSrrr
:
356 case AArch64::FNMSUBSrrr
:
357 case AArch64::FNMADDSrrr
:
358 case AArch64::FMSUBDrrr
:
359 case AArch64::FMADDDrrr
:
360 case AArch64::FNMSUBDrrr
:
361 case AArch64::FNMADDDrrr
: {
362 Register Rd
= MI
.getOperand(0).getReg();
363 Register Ra
= MI
.getOperand(3).getReg();
365 if (addIntraChainConstraint(G
, Rd
, Ra
))
366 addInterChainConstraint(G
, Rd
, Ra
);
370 case AArch64::FMLAv2f32
:
371 case AArch64::FMLSv2f32
: {
372 Register Rd
= MI
.getOperand(0).getReg();
373 addInterChainConstraint(G
, Rd
, Rd
);