1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 // This file contains the AArch64 / Cortex-A57 specific register allocation
9 // constraints for use by the PBQP register allocator.
11 // It is essentially a transcription of what is contained in
12 // AArch64A57FPLoadBalancing, which tries to use a balanced
13 // mix of odd and even D-registers when performing a critical sequence of
14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15 //===----------------------------------------------------------------------===//
17 #include "AArch64PBQPRegAlloc.h"
18 #include "AArch64InstrInfo.h"
19 #include "AArch64RegisterInfo.h"
20 #include "llvm/CodeGen/LiveIntervals.h"
21 #include "llvm/CodeGen/MachineBasicBlock.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/RegAllocPBQP.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
28 #define DEBUG_TYPE "aarch64-pbqp"
34 bool isOdd(unsigned reg
) {
37 llvm_unreachable("Register is not from the expected class !");
140 bool haveSameParity(unsigned reg1
, unsigned reg2
) {
141 assert(AArch64InstrInfo::isFpOrNEON(reg1
) &&
142 "Expecting an FP register for reg1");
143 assert(AArch64InstrInfo::isFpOrNEON(reg2
) &&
144 "Expecting an FP register for reg2");
146 return isOdd(reg1
) == isOdd(reg2
);
151 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph
&G
, unsigned Rd
,
156 LiveIntervals
&LIs
= G
.getMetadata().LIS
;
158 if (Register::isPhysicalRegister(Rd
) || Register::isPhysicalRegister(Ra
)) {
159 LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
160 << Register::isPhysicalRegister(Rd
) << '\n');
161 LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
162 << Register::isPhysicalRegister(Ra
) << '\n');
166 PBQPRAGraph::NodeId node1
= G
.getMetadata().getNodeIdForVReg(Rd
);
167 PBQPRAGraph::NodeId node2
= G
.getMetadata().getNodeIdForVReg(Ra
);
169 const PBQPRAGraph::NodeMetadata::AllowedRegVector
*vRdAllowed
=
170 &G
.getNodeMetadata(node1
).getAllowedRegs();
171 const PBQPRAGraph::NodeMetadata::AllowedRegVector
*vRaAllowed
=
172 &G
.getNodeMetadata(node2
).getAllowedRegs();
174 PBQPRAGraph::EdgeId edge
= G
.findEdge(node1
, node2
);
176 // The edge does not exist. Create one with the appropriate interference
178 if (edge
== G
.invalidEdgeId()) {
179 const LiveInterval
&ld
= LIs
.getInterval(Rd
);
180 const LiveInterval
&la
= LIs
.getInterval(Ra
);
181 bool livesOverlap
= ld
.overlaps(la
);
183 PBQPRAGraph::RawMatrix
costs(vRdAllowed
->size() + 1,
184 vRaAllowed
->size() + 1, 0);
185 for (unsigned i
= 0, ie
= vRdAllowed
->size(); i
!= ie
; ++i
) {
186 unsigned pRd
= (*vRdAllowed
)[i
];
187 for (unsigned j
= 0, je
= vRaAllowed
->size(); j
!= je
; ++j
) {
188 unsigned pRa
= (*vRaAllowed
)[j
];
189 if (livesOverlap
&& TRI
->regsOverlap(pRd
, pRa
))
190 costs
[i
+ 1][j
+ 1] = std::numeric_limits
<PBQP::PBQPNum
>::infinity();
192 costs
[i
+ 1][j
+ 1] = haveSameParity(pRd
, pRa
) ? 0.0 : 1.0;
195 G
.addEdge(node1
, node2
, std::move(costs
));
199 if (G
.getEdgeNode1Id(edge
) == node2
) {
200 std::swap(node1
, node2
);
201 std::swap(vRdAllowed
, vRaAllowed
);
204 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
205 PBQPRAGraph::RawMatrix
costs(G
.getEdgeCosts(edge
));
206 for (unsigned i
= 0, ie
= vRdAllowed
->size(); i
!= ie
; ++i
) {
207 unsigned pRd
= (*vRdAllowed
)[i
];
209 // Get the maximum cost (excluding unallocatable reg) for same parity
211 PBQP::PBQPNum sameParityMax
= std::numeric_limits
<PBQP::PBQPNum
>::min();
212 for (unsigned j
= 0, je
= vRaAllowed
->size(); j
!= je
; ++j
) {
213 unsigned pRa
= (*vRaAllowed
)[j
];
214 if (haveSameParity(pRd
, pRa
))
215 if (costs
[i
+ 1][j
+ 1] !=
216 std::numeric_limits
<PBQP::PBQPNum
>::infinity() &&
217 costs
[i
+ 1][j
+ 1] > sameParityMax
)
218 sameParityMax
= costs
[i
+ 1][j
+ 1];
221 // Ensure all registers with a different parity have a higher cost
222 // than sameParityMax
223 for (unsigned j
= 0, je
= vRaAllowed
->size(); j
!= je
; ++j
) {
224 unsigned pRa
= (*vRaAllowed
)[j
];
225 if (!haveSameParity(pRd
, pRa
))
226 if (sameParityMax
> costs
[i
+ 1][j
+ 1])
227 costs
[i
+ 1][j
+ 1] = sameParityMax
+ 1.0;
230 G
.updateEdgeCosts(edge
, std::move(costs
));
235 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph
&G
, unsigned Rd
,
237 LiveIntervals
&LIs
= G
.getMetadata().LIS
;
239 // Do some Chain management
240 if (Chains
.count(Ra
)) {
242 LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra
, TRI
)
243 << " to " << printReg(Rd
, TRI
) << '\n';);
248 LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd
, TRI
)
253 PBQPRAGraph::NodeId node1
= G
.getMetadata().getNodeIdForVReg(Rd
);
255 const LiveInterval
&ld
= LIs
.getInterval(Rd
);
256 for (auto r
: Chains
) {
261 const LiveInterval
&lr
= LIs
.getInterval(r
);
262 if (ld
.overlaps(lr
)) {
263 const PBQPRAGraph::NodeMetadata::AllowedRegVector
*vRdAllowed
=
264 &G
.getNodeMetadata(node1
).getAllowedRegs();
266 PBQPRAGraph::NodeId node2
= G
.getMetadata().getNodeIdForVReg(r
);
267 const PBQPRAGraph::NodeMetadata::AllowedRegVector
*vRrAllowed
=
268 &G
.getNodeMetadata(node2
).getAllowedRegs();
270 PBQPRAGraph::EdgeId edge
= G
.findEdge(node1
, node2
);
271 assert(edge
!= G
.invalidEdgeId() &&
272 "PBQP error ! The edge should exist !");
274 LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
276 if (G
.getEdgeNode1Id(edge
) == node2
) {
277 std::swap(node1
, node2
);
278 std::swap(vRdAllowed
, vRrAllowed
);
281 // Enforce that cost is higher with all other Chains of the same parity
282 PBQP::Matrix
costs(G
.getEdgeCosts(edge
));
283 for (unsigned i
= 0, ie
= vRdAllowed
->size(); i
!= ie
; ++i
) {
284 unsigned pRd
= (*vRdAllowed
)[i
];
286 // Get the maximum cost (excluding unallocatable reg) for all other
288 PBQP::PBQPNum sameParityMax
= std::numeric_limits
<PBQP::PBQPNum
>::min();
289 for (unsigned j
= 0, je
= vRrAllowed
->size(); j
!= je
; ++j
) {
290 unsigned pRa
= (*vRrAllowed
)[j
];
291 if (!haveSameParity(pRd
, pRa
))
292 if (costs
[i
+ 1][j
+ 1] !=
293 std::numeric_limits
<PBQP::PBQPNum
>::infinity() &&
294 costs
[i
+ 1][j
+ 1] > sameParityMax
)
295 sameParityMax
= costs
[i
+ 1][j
+ 1];
298 // Ensure all registers with same parity have a higher cost
299 // than sameParityMax
300 for (unsigned j
= 0, je
= vRrAllowed
->size(); j
!= je
; ++j
) {
301 unsigned pRa
= (*vRrAllowed
)[j
];
302 if (haveSameParity(pRd
, pRa
))
303 if (sameParityMax
> costs
[i
+ 1][j
+ 1])
304 costs
[i
+ 1][j
+ 1] = sameParityMax
+ 1.0;
307 G
.updateEdgeCosts(edge
, std::move(costs
));
312 static bool regJustKilledBefore(const LiveIntervals
&LIs
, unsigned reg
,
313 const MachineInstr
&MI
) {
314 const LiveInterval
&LI
= LIs
.getInterval(reg
);
315 SlotIndex SI
= LIs
.getInstructionIndex(MI
);
316 return LI
.expiredAt(SI
);
319 void A57ChainingConstraint::apply(PBQPRAGraph
&G
) {
320 const MachineFunction
&MF
= G
.getMetadata().MF
;
321 LiveIntervals
&LIs
= G
.getMetadata().LIS
;
323 TRI
= MF
.getSubtarget().getRegisterInfo();
324 LLVM_DEBUG(MF
.dump());
326 for (const auto &MBB
: MF
) {
327 Chains
.clear(); // FIXME: really needed ? Could not work at MF level ?
329 for (const auto &MI
: MBB
) {
331 // Forget Chains which have expired
332 for (auto r
: Chains
) {
333 SmallVector
<unsigned, 8> toDel
;
334 if(regJustKilledBefore(LIs
, r
, MI
)) {
335 LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r
, TRI
) << " at ";
340 while (!toDel
.empty()) {
341 Chains
.remove(toDel
.back());
346 switch (MI
.getOpcode()) {
347 case AArch64::FMSUBSrrr
:
348 case AArch64::FMADDSrrr
:
349 case AArch64::FNMSUBSrrr
:
350 case AArch64::FNMADDSrrr
:
351 case AArch64::FMSUBDrrr
:
352 case AArch64::FMADDDrrr
:
353 case AArch64::FNMSUBDrrr
:
354 case AArch64::FNMADDDrrr
: {
355 Register Rd
= MI
.getOperand(0).getReg();
356 Register Ra
= MI
.getOperand(3).getReg();
358 if (addIntraChainConstraint(G
, Rd
, Ra
))
359 addInterChainConstraint(G
, Rd
, Ra
);
363 case AArch64::FMLAv2f32
:
364 case AArch64::FMLSv2f32
: {
365 Register Rd
= MI
.getOperand(0).getReg();
366 addInterChainConstraint(G
, Rd
, Rd
);