1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
13 //===----------------------------------------------------------------------===//
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/CodeGen/TargetLowering.h"
22 #include "llvm/CodeGen/TargetPassConfig.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constant.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/InstrTypes.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/IntrinsicsARM.h"
36 #include "llvm/IR/IRBuilder.h"
37 #include "llvm/IR/PatternMatch.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/IR/Value.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Utils/Local.h"
48 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
50 cl::opt
<bool> EnableMaskedGatherScatters(
51 "enable-arm-maskedgatscat", cl::Hidden
, cl::init(true),
52 cl::desc("Enable the generation of masked gathers and scatters"));
56 class MVEGatherScatterLowering
: public FunctionPass
{
58 static char ID
; // Pass identification, replacement for typeid
60 explicit MVEGatherScatterLowering() : FunctionPass(ID
) {
61 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
64 bool runOnFunction(Function
&F
) override
;
66 StringRef
getPassName() const override
{
67 return "MVE gather/scatter lowering";
70 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
72 AU
.addRequired
<TargetPassConfig
>();
73 AU
.addRequired
<LoopInfoWrapperPass
>();
74 FunctionPass::getAnalysisUsage(AU
);
78 LoopInfo
*LI
= nullptr;
80 // Check this is a valid gather with correct alignment
81 bool isLegalTypeAndAlignment(unsigned NumElements
, unsigned ElemSize
,
83 // Check whether Ptr is hidden behind a bitcast and look through it
84 void lookThroughBitcast(Value
*&Ptr
);
85 // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
86 // scalar base and vector offsets, or else fallback to using a base of 0 and
87 // offset of Ptr where possible.
88 Value
*decomposePtr(Value
*Ptr
, Value
*&Offsets
, int &Scale
,
89 FixedVectorType
*Ty
, Type
*MemoryTy
,
90 IRBuilder
<> &Builder
);
91 // Check for a getelementptr and deduce base and offsets from it, on success
92 // returning the base directly and the offsets indirectly using the Offsets
94 Value
*decomposeGEP(Value
*&Offsets
, FixedVectorType
*Ty
,
95 GetElementPtrInst
*GEP
, IRBuilder
<> &Builder
);
96 // Compute the scale of this gather/scatter instruction
97 int computeScale(unsigned GEPElemSize
, unsigned MemoryElemSize
);
98 // If the value is a constant, or derived from constants via additions
99 // and multilications, return its numeric value
100 Optional
<int64_t> getIfConst(const Value
*V
);
101 // If Inst is an add instruction, check whether one summand is a
102 // constant. If so, scale this constant and return it together with
103 // the other summand.
104 std::pair
<Value
*, int64_t> getVarAndConst(Value
*Inst
, int TypeScale
);
106 Instruction
*lowerGather(IntrinsicInst
*I
);
107 // Create a gather from a base + vector of offsets
108 Instruction
*tryCreateMaskedGatherOffset(IntrinsicInst
*I
, Value
*Ptr
,
110 IRBuilder
<> &Builder
);
111 // Create a gather from a vector of pointers
112 Instruction
*tryCreateMaskedGatherBase(IntrinsicInst
*I
, Value
*Ptr
,
113 IRBuilder
<> &Builder
,
114 int64_t Increment
= 0);
115 // Create an incrementing gather from a vector of pointers
116 Instruction
*tryCreateMaskedGatherBaseWB(IntrinsicInst
*I
, Value
*Ptr
,
117 IRBuilder
<> &Builder
,
118 int64_t Increment
= 0);
120 Instruction
*lowerScatter(IntrinsicInst
*I
);
121 // Create a scatter to a base + vector of offsets
122 Instruction
*tryCreateMaskedScatterOffset(IntrinsicInst
*I
, Value
*Offsets
,
123 IRBuilder
<> &Builder
);
124 // Create a scatter to a vector of pointers
125 Instruction
*tryCreateMaskedScatterBase(IntrinsicInst
*I
, Value
*Ptr
,
126 IRBuilder
<> &Builder
,
127 int64_t Increment
= 0);
128 // Create an incrementing scatter from a vector of pointers
129 Instruction
*tryCreateMaskedScatterBaseWB(IntrinsicInst
*I
, Value
*Ptr
,
130 IRBuilder
<> &Builder
,
131 int64_t Increment
= 0);
133 // QI gathers and scatters can increment their offsets on their own if
134 // the increment is a constant value (digit)
135 Instruction
*tryCreateIncrementingGatScat(IntrinsicInst
*I
, Value
*Ptr
,
136 IRBuilder
<> &Builder
);
137 // QI gathers/scatters can increment their offsets on their own if the
138 // increment is a constant value (digit) - this creates a writeback QI
140 Instruction
*tryCreateIncrementingWBGatScat(IntrinsicInst
*I
, Value
*BasePtr
,
141 Value
*Ptr
, unsigned TypeScale
,
142 IRBuilder
<> &Builder
);
144 // Optimise the base and offsets of the given address
145 bool optimiseAddress(Value
*Address
, BasicBlock
*BB
, LoopInfo
*LI
);
146 // Try to fold consecutive geps together into one
147 Value
*foldGEP(GetElementPtrInst
*GEP
, Value
*&Offsets
, IRBuilder
<> &Builder
);
148 // Check whether these offsets could be moved out of the loop they're in
149 bool optimiseOffsets(Value
*Offsets
, BasicBlock
*BB
, LoopInfo
*LI
);
150 // Pushes the given add out of the loop
151 void pushOutAdd(PHINode
*&Phi
, Value
*OffsSecondOperand
, unsigned StartIndex
);
152 // Pushes the given mul out of the loop
153 void pushOutMul(PHINode
*&Phi
, Value
*IncrementPerRound
,
154 Value
*OffsSecondOperand
, unsigned LoopIncrement
,
155 IRBuilder
<> &Builder
);
158 } // end anonymous namespace
160 char MVEGatherScatterLowering::ID
= 0;
162 INITIALIZE_PASS(MVEGatherScatterLowering
, DEBUG_TYPE
,
163 "MVE gather/scattering lowering pass", false, false)
165 Pass
*llvm::createMVEGatherScatterLoweringPass() {
166 return new MVEGatherScatterLowering();
169 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements
,
172 if (((NumElements
== 4 &&
173 (ElemSize
== 32 || ElemSize
== 16 || ElemSize
== 8)) ||
174 (NumElements
== 8 && (ElemSize
== 16 || ElemSize
== 8)) ||
175 (NumElements
== 16 && ElemSize
== 8)) &&
176 Alignment
>= ElemSize
/ 8)
178 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
179 << "valid alignment or vector type \n");
183 static bool checkOffsetSize(Value
*Offsets
, unsigned TargetElemCount
) {
184 // Offsets that are not of type <N x i32> are sign extended by the
185 // getelementptr instruction, and MVE gathers/scatters treat the offset as
186 // unsigned. Thus, if the element size is smaller than 32, we can only allow
187 // positive offsets - i.e., the offsets are not allowed to be variables we
189 // Additionally, <N x i32> offsets have to either originate from a zext of a
190 // vector with element types smaller or equal the type of the gather we're
191 // looking at, or consist of constants that we can check are small enough
192 // to fit into the gather type.
193 // Thus we check that 0 < value < 2^TargetElemSize.
194 unsigned TargetElemSize
= 128 / TargetElemCount
;
195 unsigned OffsetElemSize
= cast
<FixedVectorType
>(Offsets
->getType())
197 ->getScalarSizeInBits();
198 if (OffsetElemSize
!= TargetElemSize
|| OffsetElemSize
!= 32) {
199 Constant
*ConstOff
= dyn_cast
<Constant
>(Offsets
);
202 int64_t TargetElemMaxSize
= (1ULL << TargetElemSize
);
203 auto CheckValueSize
= [TargetElemMaxSize
](Value
*OffsetElem
) {
204 ConstantInt
*OConst
= dyn_cast
<ConstantInt
>(OffsetElem
);
207 int SExtValue
= OConst
->getSExtValue();
208 if (SExtValue
>= TargetElemMaxSize
|| SExtValue
< 0)
212 if (isa
<FixedVectorType
>(ConstOff
->getType())) {
213 for (unsigned i
= 0; i
< TargetElemCount
; i
++) {
214 if (!CheckValueSize(ConstOff
->getAggregateElement(i
)))
218 if (!CheckValueSize(ConstOff
))
225 Value
*MVEGatherScatterLowering::decomposePtr(Value
*Ptr
, Value
*&Offsets
,
226 int &Scale
, FixedVectorType
*Ty
,
228 IRBuilder
<> &Builder
) {
229 if (auto *GEP
= dyn_cast
<GetElementPtrInst
>(Ptr
)) {
230 if (Value
*V
= decomposeGEP(Offsets
, Ty
, GEP
, Builder
)) {
232 computeScale(GEP
->getSourceElementType()->getPrimitiveSizeInBits(),
233 MemoryTy
->getScalarSizeInBits());
234 return Scale
== -1 ? nullptr : V
;
238 // If we couldn't use the GEP (or it doesn't exist), attempt to use a
239 // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
241 FixedVectorType
*PtrTy
= cast
<FixedVectorType
>(Ptr
->getType());
242 if (PtrTy
->getNumElements() != 4 || MemoryTy
->getScalarSizeInBits() == 32)
244 Value
*Zero
= ConstantInt::get(Builder
.getInt32Ty(), 0);
245 Value
*BasePtr
= Builder
.CreateIntToPtr(Zero
, Builder
.getInt8PtrTy());
246 Offsets
= Builder
.CreatePtrToInt(
247 Ptr
, FixedVectorType::get(Builder
.getInt32Ty(), 4));
252 Value
*MVEGatherScatterLowering::decomposeGEP(Value
*&Offsets
,
254 GetElementPtrInst
*GEP
,
255 IRBuilder
<> &Builder
) {
257 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
261 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
262 << " Looking at intrinsic for base + vector of offsets\n");
263 Value
*GEPPtr
= GEP
->getPointerOperand();
264 Offsets
= GEP
->getOperand(1);
265 if (GEPPtr
->getType()->isVectorTy() ||
266 !isa
<FixedVectorType
>(Offsets
->getType()))
269 if (GEP
->getNumOperands() != 2) {
270 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
271 << " operands. Expanding.\n");
274 Offsets
= GEP
->getOperand(1);
275 unsigned OffsetsElemCount
=
276 cast
<FixedVectorType
>(Offsets
->getType())->getNumElements();
277 // Paranoid check whether the number of parallel lanes is the same
278 assert(Ty
->getNumElements() == OffsetsElemCount
);
280 ZExtInst
*ZextOffs
= dyn_cast
<ZExtInst
>(Offsets
);
282 Offsets
= ZextOffs
->getOperand(0);
283 FixedVectorType
*OffsetType
= cast
<FixedVectorType
>(Offsets
->getType());
285 // If the offsets are already being zext-ed to <N x i32>, that relieves us of
286 // having to make sure that they won't overflow.
287 if (!ZextOffs
|| cast
<FixedVectorType
>(ZextOffs
->getDestTy())
289 ->getScalarSizeInBits() != 32)
290 if (!checkOffsetSize(Offsets
, OffsetsElemCount
))
293 // The offset sizes have been checked; if any truncating or zext-ing is
294 // required to fix them, do that now
295 if (Ty
!= Offsets
->getType()) {
296 if ((Ty
->getElementType()->getScalarSizeInBits() <
297 OffsetType
->getElementType()->getScalarSizeInBits())) {
298 Offsets
= Builder
.CreateTrunc(Offsets
, Ty
);
300 Offsets
= Builder
.CreateZExt(Offsets
, VectorType::getInteger(Ty
));
303 // If none of the checks failed, return the gep's base pointer
304 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
308 void MVEGatherScatterLowering::lookThroughBitcast(Value
*&Ptr
) {
309 // Look through bitcast instruction if #elements is the same
310 if (auto *BitCast
= dyn_cast
<BitCastInst
>(Ptr
)) {
311 auto *BCTy
= cast
<FixedVectorType
>(BitCast
->getType());
312 auto *BCSrcTy
= cast
<FixedVectorType
>(BitCast
->getOperand(0)->getType());
313 if (BCTy
->getNumElements() == BCSrcTy
->getNumElements()) {
314 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
316 Ptr
= BitCast
->getOperand(0);
321 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize
,
322 unsigned MemoryElemSize
) {
323 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
324 // or a 8bit, 16bit or 32bit load/store scaled by 1
325 if (GEPElemSize
== 32 && MemoryElemSize
== 32)
327 else if (GEPElemSize
== 16 && MemoryElemSize
== 16)
329 else if (GEPElemSize
== 8)
331 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
332 << "create intrinsic\n");
336 Optional
<int64_t> MVEGatherScatterLowering::getIfConst(const Value
*V
) {
337 const Constant
*C
= dyn_cast
<Constant
>(V
);
339 return Optional
<int64_t>{C
->getUniqueInteger().getSExtValue()};
340 if (!isa
<Instruction
>(V
))
341 return Optional
<int64_t>{};
343 const Instruction
*I
= cast
<Instruction
>(V
);
344 if (I
->getOpcode() == Instruction::Add
||
345 I
->getOpcode() == Instruction::Mul
) {
346 Optional
<int64_t> Op0
= getIfConst(I
->getOperand(0));
347 Optional
<int64_t> Op1
= getIfConst(I
->getOperand(1));
349 return Optional
<int64_t>{};
350 if (I
->getOpcode() == Instruction::Add
)
351 return Optional
<int64_t>{Op0
.getValue() + Op1
.getValue()};
352 if (I
->getOpcode() == Instruction::Mul
)
353 return Optional
<int64_t>{Op0
.getValue() * Op1
.getValue()};
355 return Optional
<int64_t>{};
358 std::pair
<Value
*, int64_t>
359 MVEGatherScatterLowering::getVarAndConst(Value
*Inst
, int TypeScale
) {
360 std::pair
<Value
*, int64_t> ReturnFalse
=
361 std::pair
<Value
*, int64_t>(nullptr, 0);
362 // At this point, the instruction we're looking at must be an add or we
364 Instruction
*Add
= dyn_cast
<Instruction
>(Inst
);
365 if (Add
== nullptr || Add
->getOpcode() != Instruction::Add
)
369 Optional
<int64_t> Const
;
370 // Find out which operand the value that is increased is
371 if ((Const
= getIfConst(Add
->getOperand(0))))
372 Summand
= Add
->getOperand(1);
373 else if ((Const
= getIfConst(Add
->getOperand(1))))
374 Summand
= Add
->getOperand(0);
378 // Check that the constant is small enough for an incrementing gather
379 int64_t Immediate
= Const
.getValue() << TypeScale
;
380 if (Immediate
> 512 || Immediate
< -512 || Immediate
% 4 != 0)
383 return std::pair
<Value
*, int64_t>(Summand
, Immediate
);
386 Instruction
*MVEGatherScatterLowering::lowerGather(IntrinsicInst
*I
) {
387 using namespace PatternMatch
;
388 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
391 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
392 // Attempt to turn the masked gather in I into a MVE intrinsic
393 // Potentially optimising the addressing modes as we do so.
394 auto *Ty
= cast
<FixedVectorType
>(I
->getType());
395 Value
*Ptr
= I
->getArgOperand(0);
396 Align Alignment
= cast
<ConstantInt
>(I
->getArgOperand(1))->getAlignValue();
397 Value
*Mask
= I
->getArgOperand(2);
398 Value
*PassThru
= I
->getArgOperand(3);
400 if (!isLegalTypeAndAlignment(Ty
->getNumElements(), Ty
->getScalarSizeInBits(),
403 lookThroughBitcast(Ptr
);
404 assert(Ptr
->getType()->isVectorTy() && "Unexpected pointer type");
406 IRBuilder
<> Builder(I
->getContext());
407 Builder
.SetInsertPoint(I
);
408 Builder
.SetCurrentDebugLocation(I
->getDebugLoc());
410 Instruction
*Root
= I
;
412 Instruction
*Load
= tryCreateIncrementingGatScat(I
, Ptr
, Builder
);
414 Load
= tryCreateMaskedGatherOffset(I
, Ptr
, Root
, Builder
);
416 Load
= tryCreateMaskedGatherBase(I
, Ptr
, Builder
);
420 if (!isa
<UndefValue
>(PassThru
) && !match(PassThru
, m_Zero())) {
421 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
422 << "creating select\n");
423 Load
= SelectInst::Create(Mask
, Load
, PassThru
);
424 Builder
.Insert(Load
);
427 Root
->replaceAllUsesWith(Load
);
428 Root
->eraseFromParent();
430 // If this was an extending gather, we need to get rid of the sext/zext
431 // sext/zext as well as of the gather itself
432 I
->eraseFromParent();
434 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
439 Instruction
*MVEGatherScatterLowering::tryCreateMaskedGatherBase(
440 IntrinsicInst
*I
, Value
*Ptr
, IRBuilder
<> &Builder
, int64_t Increment
) {
441 using namespace PatternMatch
;
442 auto *Ty
= cast
<FixedVectorType
>(I
->getType());
443 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
444 if (Ty
->getNumElements() != 4 || Ty
->getScalarSizeInBits() != 32)
445 // Can't build an intrinsic for this
447 Value
*Mask
= I
->getArgOperand(2);
448 if (match(Mask
, m_One()))
449 return Builder
.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base
,
450 {Ty
, Ptr
->getType()},
451 {Ptr
, Builder
.getInt32(Increment
)});
453 return Builder
.CreateIntrinsic(
454 Intrinsic::arm_mve_vldr_gather_base_predicated
,
455 {Ty
, Ptr
->getType(), Mask
->getType()},
456 {Ptr
, Builder
.getInt32(Increment
), Mask
});
459 Instruction
*MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
460 IntrinsicInst
*I
, Value
*Ptr
, IRBuilder
<> &Builder
, int64_t Increment
) {
461 using namespace PatternMatch
;
462 auto *Ty
= cast
<FixedVectorType
>(I
->getType());
463 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
465 if (Ty
->getNumElements() != 4 || Ty
->getScalarSizeInBits() != 32)
466 // Can't build an intrinsic for this
468 Value
*Mask
= I
->getArgOperand(2);
469 if (match(Mask
, m_One()))
470 return Builder
.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb
,
471 {Ty
, Ptr
->getType()},
472 {Ptr
, Builder
.getInt32(Increment
)});
474 return Builder
.CreateIntrinsic(
475 Intrinsic::arm_mve_vldr_gather_base_wb_predicated
,
476 {Ty
, Ptr
->getType(), Mask
->getType()},
477 {Ptr
, Builder
.getInt32(Increment
), Mask
});
480 Instruction
*MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
481 IntrinsicInst
*I
, Value
*Ptr
, Instruction
*&Root
, IRBuilder
<> &Builder
) {
482 using namespace PatternMatch
;
484 Type
*MemoryTy
= I
->getType();
485 Type
*ResultTy
= MemoryTy
;
487 unsigned Unsigned
= 1;
488 // The size of the gather was already checked in isLegalTypeAndAlignment;
489 // if it was not a full vector width an appropriate extend should follow.
491 bool TruncResult
= false;
492 if (MemoryTy
->getPrimitiveSizeInBits() < 128) {
493 if (I
->hasOneUse()) {
494 // If the gather has a single extend of the correct type, use an extending
495 // gather and replace the ext. In which case the correct root to replace
496 // is not the CallInst itself, but the instruction which extends it.
497 Instruction
* User
= cast
<Instruction
>(*I
->users().begin());
498 if (isa
<SExtInst
>(User
) &&
499 User
->getType()->getPrimitiveSizeInBits() == 128) {
500 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
503 ResultTy
= User
->getType();
505 } else if (isa
<ZExtInst
>(User
) &&
506 User
->getType()->getPrimitiveSizeInBits() == 128) {
507 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
508 << *ResultTy
<< "\n");
510 ResultTy
= User
->getType();
514 // If an extend hasn't been found and the type is an integer, create an
515 // extending gather and truncate back to the original type.
516 if (ResultTy
->getPrimitiveSizeInBits() < 128 &&
517 ResultTy
->isIntOrIntVectorTy()) {
518 ResultTy
= ResultTy
->getWithNewBitWidth(
519 128 / cast
<FixedVectorType
>(ResultTy
)->getNumElements());
521 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
522 << *ResultTy
<< "\n");
525 // The final size of the gather must be a full vector width
526 if (ResultTy
->getPrimitiveSizeInBits() != 128) {
527 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
528 "from the correct type. Expanding\n");
535 Value
*BasePtr
= decomposePtr(
536 Ptr
, Offsets
, Scale
, cast
<FixedVectorType
>(ResultTy
), MemoryTy
, Builder
);
541 Value
*Mask
= I
->getArgOperand(2);
542 Instruction
*Load
= nullptr;
543 if (!match(Mask
, m_One()))
544 Load
= Builder
.CreateIntrinsic(
545 Intrinsic::arm_mve_vldr_gather_offset_predicated
,
546 {ResultTy
, BasePtr
->getType(), Offsets
->getType(), Mask
->getType()},
547 {BasePtr
, Offsets
, Builder
.getInt32(MemoryTy
->getScalarSizeInBits()),
548 Builder
.getInt32(Scale
), Builder
.getInt32(Unsigned
), Mask
});
550 Load
= Builder
.CreateIntrinsic(
551 Intrinsic::arm_mve_vldr_gather_offset
,
552 {ResultTy
, BasePtr
->getType(), Offsets
->getType()},
553 {BasePtr
, Offsets
, Builder
.getInt32(MemoryTy
->getScalarSizeInBits()),
554 Builder
.getInt32(Scale
), Builder
.getInt32(Unsigned
)});
557 Load
= TruncInst::Create(Instruction::Trunc
, Load
, MemoryTy
);
558 Builder
.Insert(Load
);
563 Instruction
*MVEGatherScatterLowering::lowerScatter(IntrinsicInst
*I
) {
564 using namespace PatternMatch
;
565 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
568 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
569 // Attempt to turn the masked scatter in I into a MVE intrinsic
570 // Potentially optimising the addressing modes as we do so.
571 Value
*Input
= I
->getArgOperand(0);
572 Value
*Ptr
= I
->getArgOperand(1);
573 Align Alignment
= cast
<ConstantInt
>(I
->getArgOperand(2))->getAlignValue();
574 auto *Ty
= cast
<FixedVectorType
>(Input
->getType());
576 if (!isLegalTypeAndAlignment(Ty
->getNumElements(), Ty
->getScalarSizeInBits(),
580 lookThroughBitcast(Ptr
);
581 assert(Ptr
->getType()->isVectorTy() && "Unexpected pointer type");
583 IRBuilder
<> Builder(I
->getContext());
584 Builder
.SetInsertPoint(I
);
585 Builder
.SetCurrentDebugLocation(I
->getDebugLoc());
587 Instruction
*Store
= tryCreateIncrementingGatScat(I
, Ptr
, Builder
);
589 Store
= tryCreateMaskedScatterOffset(I
, Ptr
, Builder
);
591 Store
= tryCreateMaskedScatterBase(I
, Ptr
, Builder
);
595 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
597 I
->eraseFromParent();
601 Instruction
*MVEGatherScatterLowering::tryCreateMaskedScatterBase(
602 IntrinsicInst
*I
, Value
*Ptr
, IRBuilder
<> &Builder
, int64_t Increment
) {
603 using namespace PatternMatch
;
604 Value
*Input
= I
->getArgOperand(0);
605 auto *Ty
= cast
<FixedVectorType
>(Input
->getType());
606 // Only QR variants allow truncating
607 if (!(Ty
->getNumElements() == 4 && Ty
->getScalarSizeInBits() == 32)) {
608 // Can't build an intrinsic for this
611 Value
*Mask
= I
->getArgOperand(3);
612 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
613 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
614 if (match(Mask
, m_One()))
615 return Builder
.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base
,
616 {Ptr
->getType(), Input
->getType()},
617 {Ptr
, Builder
.getInt32(Increment
), Input
});
619 return Builder
.CreateIntrinsic(
620 Intrinsic::arm_mve_vstr_scatter_base_predicated
,
621 {Ptr
->getType(), Input
->getType(), Mask
->getType()},
622 {Ptr
, Builder
.getInt32(Increment
), Input
, Mask
});
625 Instruction
*MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
626 IntrinsicInst
*I
, Value
*Ptr
, IRBuilder
<> &Builder
, int64_t Increment
) {
627 using namespace PatternMatch
;
628 Value
*Input
= I
->getArgOperand(0);
629 auto *Ty
= cast
<FixedVectorType
>(Input
->getType());
630 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
631 << "with writeback\n");
632 if (Ty
->getNumElements() != 4 || Ty
->getScalarSizeInBits() != 32)
633 // Can't build an intrinsic for this
635 Value
*Mask
= I
->getArgOperand(3);
636 if (match(Mask
, m_One()))
637 return Builder
.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb
,
638 {Ptr
->getType(), Input
->getType()},
639 {Ptr
, Builder
.getInt32(Increment
), Input
});
641 return Builder
.CreateIntrinsic(
642 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated
,
643 {Ptr
->getType(), Input
->getType(), Mask
->getType()},
644 {Ptr
, Builder
.getInt32(Increment
), Input
, Mask
});
647 Instruction
*MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
648 IntrinsicInst
*I
, Value
*Ptr
, IRBuilder
<> &Builder
) {
649 using namespace PatternMatch
;
650 Value
*Input
= I
->getArgOperand(0);
651 Value
*Mask
= I
->getArgOperand(3);
652 Type
*InputTy
= Input
->getType();
653 Type
*MemoryTy
= InputTy
;
655 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
656 << " to base + vector of offsets\n");
657 // If the input has been truncated, try to integrate that trunc into the
658 // scatter instruction (we don't care about alignment here)
659 if (TruncInst
*Trunc
= dyn_cast
<TruncInst
>(Input
)) {
660 Value
*PreTrunc
= Trunc
->getOperand(0);
661 Type
*PreTruncTy
= PreTrunc
->getType();
662 if (PreTruncTy
->getPrimitiveSizeInBits() == 128) {
664 InputTy
= PreTruncTy
;
667 bool ExtendInput
= false;
668 if (InputTy
->getPrimitiveSizeInBits() < 128 &&
669 InputTy
->isIntOrIntVectorTy()) {
670 // If we can't find a trunc to incorporate into the instruction, create an
671 // implicit one with a zext, so that we can still create a scatter. We know
672 // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
673 // smaller than 128 bits will divide evenly into a 128bit vector.
674 InputTy
= InputTy
->getWithNewBitWidth(
675 128 / cast
<FixedVectorType
>(InputTy
)->getNumElements());
677 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
680 if (InputTy
->getPrimitiveSizeInBits() != 128) {
681 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
682 "non-standard input types. Expanding.\n");
688 Value
*BasePtr
= decomposePtr(
689 Ptr
, Offsets
, Scale
, cast
<FixedVectorType
>(InputTy
), MemoryTy
, Builder
);
694 Input
= Builder
.CreateZExt(Input
, InputTy
);
695 if (!match(Mask
, m_One()))
696 return Builder
.CreateIntrinsic(
697 Intrinsic::arm_mve_vstr_scatter_offset_predicated
,
698 {BasePtr
->getType(), Offsets
->getType(), Input
->getType(),
700 {BasePtr
, Offsets
, Input
,
701 Builder
.getInt32(MemoryTy
->getScalarSizeInBits()),
702 Builder
.getInt32(Scale
), Mask
});
704 return Builder
.CreateIntrinsic(
705 Intrinsic::arm_mve_vstr_scatter_offset
,
706 {BasePtr
->getType(), Offsets
->getType(), Input
->getType()},
707 {BasePtr
, Offsets
, Input
,
708 Builder
.getInt32(MemoryTy
->getScalarSizeInBits()),
709 Builder
.getInt32(Scale
)});
712 Instruction
*MVEGatherScatterLowering::tryCreateIncrementingGatScat(
713 IntrinsicInst
*I
, Value
*Ptr
, IRBuilder
<> &Builder
) {
715 if (I
->getIntrinsicID() == Intrinsic::masked_gather
)
716 Ty
= cast
<FixedVectorType
>(I
->getType());
718 Ty
= cast
<FixedVectorType
>(I
->getArgOperand(0)->getType());
720 // Incrementing gathers only exist for v4i32
721 if (Ty
->getNumElements() != 4 || Ty
->getScalarSizeInBits() != 32)
723 // Incrementing gathers are not beneficial outside of a loop
724 Loop
*L
= LI
->getLoopFor(I
->getParent());
728 // Decompose the GEP into Base and Offsets
729 GetElementPtrInst
*GEP
= dyn_cast
<GetElementPtrInst
>(Ptr
);
731 Value
*BasePtr
= decomposeGEP(Offsets
, Ty
, GEP
, Builder
);
735 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
736 "wb gather/scatter\n");
738 // The gep was in charge of making sure the offsets are scaled correctly
739 // - calculate that factor so it can be applied by hand
740 DataLayout DT
= I
->getParent()->getParent()->getParent()->getDataLayout();
742 computeScale(DT
.getTypeSizeInBits(GEP
->getOperand(0)->getType()),
743 DT
.getTypeSizeInBits(GEP
->getType()) /
744 cast
<FixedVectorType
>(GEP
->getType())->getNumElements());
748 if (GEP
->hasOneUse()) {
749 // Only in this case do we want to build a wb gather, because the wb will
750 // change the phi which does affect other users of the gep (which will still
751 // be using the phi in the old way)
752 if (auto *Load
= tryCreateIncrementingWBGatScat(I
, BasePtr
, Offsets
,
757 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
758 "non-wb gather/scatter\n");
760 std::pair
<Value
*, int64_t> Add
= getVarAndConst(Offsets
, TypeScale
);
761 if (Add
.first
== nullptr)
763 Value
*OffsetsIncoming
= Add
.first
;
764 int64_t Immediate
= Add
.second
;
766 // Make sure the offsets are scaled correctly
767 Instruction
*ScaledOffsets
= BinaryOperator::Create(
768 Instruction::Shl
, OffsetsIncoming
,
769 Builder
.CreateVectorSplat(Ty
->getNumElements(), Builder
.getInt32(TypeScale
)),
771 // Add the base to the offsets
772 OffsetsIncoming
= BinaryOperator::Create(
773 Instruction::Add
, ScaledOffsets
,
774 Builder
.CreateVectorSplat(
775 Ty
->getNumElements(),
776 Builder
.CreatePtrToInt(
778 cast
<VectorType
>(ScaledOffsets
->getType())->getElementType())),
781 if (I
->getIntrinsicID() == Intrinsic::masked_gather
)
782 return tryCreateMaskedGatherBase(I
, OffsetsIncoming
, Builder
, Immediate
);
784 return tryCreateMaskedScatterBase(I
, OffsetsIncoming
, Builder
, Immediate
);
787 Instruction
*MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
788 IntrinsicInst
*I
, Value
*BasePtr
, Value
*Offsets
, unsigned TypeScale
,
789 IRBuilder
<> &Builder
) {
790 // Check whether this gather's offset is incremented by a constant - if so,
791 // and the load is of the right type, we can merge this into a QI gather
792 Loop
*L
= LI
->getLoopFor(I
->getParent());
793 // Offsets that are worth merging into this instruction will be incremented
794 // by a constant, thus we're looking for an add of a phi and a constant
795 PHINode
*Phi
= dyn_cast
<PHINode
>(Offsets
);
796 if (Phi
== nullptr || Phi
->getNumIncomingValues() != 2 ||
797 Phi
->getParent() != L
->getHeader() || Phi
->getNumUses() != 2)
798 // No phi means no IV to write back to; if there is a phi, we expect it
799 // to have exactly two incoming values; the only phis we are interested in
800 // will be loop IV's and have exactly two uses, one in their increment and
801 // one in the gather's gep
804 unsigned IncrementIndex
=
805 Phi
->getIncomingBlock(0) == L
->getLoopLatch() ? 0 : 1;
806 // Look through the phi to the phi increment
807 Offsets
= Phi
->getIncomingValue(IncrementIndex
);
809 std::pair
<Value
*, int64_t> Add
= getVarAndConst(Offsets
, TypeScale
);
810 if (Add
.first
== nullptr)
812 Value
*OffsetsIncoming
= Add
.first
;
813 int64_t Immediate
= Add
.second
;
814 if (OffsetsIncoming
!= Phi
)
815 // Then the increment we are looking at is not an increment of the
816 // induction variable, and we don't want to do a writeback
819 Builder
.SetInsertPoint(&Phi
->getIncomingBlock(1 - IncrementIndex
)->back());
821 cast
<FixedVectorType
>(OffsetsIncoming
->getType())->getNumElements();
823 // Make sure the offsets are scaled correctly
824 Instruction
*ScaledOffsets
= BinaryOperator::Create(
825 Instruction::Shl
, Phi
->getIncomingValue(1 - IncrementIndex
),
826 Builder
.CreateVectorSplat(NumElems
, Builder
.getInt32(TypeScale
)),
827 "ScaledIndex", &Phi
->getIncomingBlock(1 - IncrementIndex
)->back());
828 // Add the base to the offsets
829 OffsetsIncoming
= BinaryOperator::Create(
830 Instruction::Add
, ScaledOffsets
,
831 Builder
.CreateVectorSplat(
833 Builder
.CreatePtrToInt(
835 cast
<VectorType
>(ScaledOffsets
->getType())->getElementType())),
836 "StartIndex", &Phi
->getIncomingBlock(1 - IncrementIndex
)->back());
837 // The gather is pre-incrementing
838 OffsetsIncoming
= BinaryOperator::Create(
839 Instruction::Sub
, OffsetsIncoming
,
840 Builder
.CreateVectorSplat(NumElems
, Builder
.getInt32(Immediate
)),
841 "PreIncrementStartIndex",
842 &Phi
->getIncomingBlock(1 - IncrementIndex
)->back());
843 Phi
->setIncomingValue(1 - IncrementIndex
, OffsetsIncoming
);
845 Builder
.SetInsertPoint(I
);
847 Instruction
*EndResult
;
848 Instruction
*NewInduction
;
849 if (I
->getIntrinsicID() == Intrinsic::masked_gather
) {
850 // Build the incrementing gather
851 Value
*Load
= tryCreateMaskedGatherBaseWB(I
, Phi
, Builder
, Immediate
);
852 // One value to be handed to whoever uses the gather, one is the loop
854 EndResult
= ExtractValueInst::Create(Load
, 0, "Gather");
855 NewInduction
= ExtractValueInst::Create(Load
, 1, "GatherIncrement");
856 Builder
.Insert(EndResult
);
857 Builder
.Insert(NewInduction
);
859 // Build the incrementing scatter
860 EndResult
= NewInduction
=
861 tryCreateMaskedScatterBaseWB(I
, Phi
, Builder
, Immediate
);
863 Instruction
*AddInst
= cast
<Instruction
>(Offsets
);
864 AddInst
->replaceAllUsesWith(NewInduction
);
865 AddInst
->eraseFromParent();
866 Phi
->setIncomingValue(IncrementIndex
, NewInduction
);
871 void MVEGatherScatterLowering::pushOutAdd(PHINode
*&Phi
,
872 Value
*OffsSecondOperand
,
873 unsigned StartIndex
) {
874 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
875 Instruction
*InsertionPoint
=
876 &cast
<Instruction
>(Phi
->getIncomingBlock(StartIndex
)->back());
877 // Initialize the phi with a vector that contains a sum of the constants
878 Instruction
*NewIndex
= BinaryOperator::Create(
879 Instruction::Add
, Phi
->getIncomingValue(StartIndex
), OffsSecondOperand
,
880 "PushedOutAdd", InsertionPoint
);
881 unsigned IncrementIndex
= StartIndex
== 0 ? 1 : 0;
883 // Order such that start index comes first (this reduces mov's)
884 Phi
->addIncoming(NewIndex
, Phi
->getIncomingBlock(StartIndex
));
885 Phi
->addIncoming(Phi
->getIncomingValue(IncrementIndex
),
886 Phi
->getIncomingBlock(IncrementIndex
));
887 Phi
->removeIncomingValue(IncrementIndex
);
888 Phi
->removeIncomingValue(StartIndex
);
891 void MVEGatherScatterLowering::pushOutMul(PHINode
*&Phi
,
892 Value
*IncrementPerRound
,
893 Value
*OffsSecondOperand
,
894 unsigned LoopIncrement
,
895 IRBuilder
<> &Builder
) {
896 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
898 // Create a new scalar add outside of the loop and transform it to a splat
899 // by which loop variable can be incremented
900 Instruction
*InsertionPoint
= &cast
<Instruction
>(
901 Phi
->getIncomingBlock(LoopIncrement
== 1 ? 0 : 1)->back());
903 // Create a new index
904 Value
*StartIndex
= BinaryOperator::Create(
905 Instruction::Mul
, Phi
->getIncomingValue(LoopIncrement
== 1 ? 0 : 1),
906 OffsSecondOperand
, "PushedOutMul", InsertionPoint
);
908 Instruction
*Product
=
909 BinaryOperator::Create(Instruction::Mul
, IncrementPerRound
,
910 OffsSecondOperand
, "Product", InsertionPoint
);
911 // Increment NewIndex by Product instead of the multiplication
912 Instruction
*NewIncrement
= BinaryOperator::Create(
913 Instruction::Add
, Phi
, Product
, "IncrementPushedOutMul",
914 cast
<Instruction
>(Phi
->getIncomingBlock(LoopIncrement
)->back())
917 Phi
->addIncoming(StartIndex
,
918 Phi
->getIncomingBlock(LoopIncrement
== 1 ? 0 : 1));
919 Phi
->addIncoming(NewIncrement
, Phi
->getIncomingBlock(LoopIncrement
));
920 Phi
->removeIncomingValue((unsigned)0);
921 Phi
->removeIncomingValue((unsigned)0);
924 // Check whether all usages of this instruction are as offsets of
925 // gathers/scatters or simple arithmetics only used by gathers/scatters
926 static bool hasAllGatScatUsers(Instruction
*I
) {
927 if (I
->hasNUses(0)) {
931 for (User
*U
: I
->users()) {
932 if (!isa
<Instruction
>(U
))
934 if (isa
<GetElementPtrInst
>(U
) ||
935 isGatherScatter(dyn_cast
<IntrinsicInst
>(U
))) {
938 unsigned OpCode
= cast
<Instruction
>(U
)->getOpcode();
939 if ((OpCode
== Instruction::Add
|| OpCode
== Instruction::Mul
) &&
940 hasAllGatScatUsers(cast
<Instruction
>(U
))) {
949 bool MVEGatherScatterLowering::optimiseOffsets(Value
*Offsets
, BasicBlock
*BB
,
951 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"
952 << *Offsets
<< "\n");
953 // Optimise the addresses of gathers/scatters by moving invariant
954 // calculations out of the loop
955 if (!isa
<Instruction
>(Offsets
))
957 Instruction
*Offs
= cast
<Instruction
>(Offsets
);
958 if (Offs
->getOpcode() != Instruction::Add
&&
959 Offs
->getOpcode() != Instruction::Mul
)
961 Loop
*L
= LI
->getLoopFor(BB
);
964 if (!Offs
->hasOneUse()) {
965 if (!hasAllGatScatUsers(Offs
))
969 // Find out which, if any, operand of the instruction
973 if (isa
<PHINode
>(Offs
->getOperand(0))) {
974 Phi
= cast
<PHINode
>(Offs
->getOperand(0));
976 } else if (isa
<PHINode
>(Offs
->getOperand(1))) {
977 Phi
= cast
<PHINode
>(Offs
->getOperand(1));
980 bool Changed
= false;
981 if (isa
<Instruction
>(Offs
->getOperand(0)) &&
982 L
->contains(cast
<Instruction
>(Offs
->getOperand(0))))
983 Changed
|= optimiseOffsets(Offs
->getOperand(0), BB
, LI
);
984 if (isa
<Instruction
>(Offs
->getOperand(1)) &&
985 L
->contains(cast
<Instruction
>(Offs
->getOperand(1))))
986 Changed
|= optimiseOffsets(Offs
->getOperand(1), BB
, LI
);
989 if (isa
<PHINode
>(Offs
->getOperand(0))) {
990 Phi
= cast
<PHINode
>(Offs
->getOperand(0));
992 } else if (isa
<PHINode
>(Offs
->getOperand(1))) {
993 Phi
= cast
<PHINode
>(Offs
->getOperand(1));
999 // A phi node we want to perform this function on should be from the
1001 if (Phi
->getParent() != L
->getHeader())
1004 // We're looking for a simple add recurrence.
1005 BinaryOperator
*IncInstruction
;
1006 Value
*Start
, *IncrementPerRound
;
1007 if (!matchSimpleRecurrence(Phi
, IncInstruction
, Start
, IncrementPerRound
) ||
1008 IncInstruction
->getOpcode() != Instruction::Add
)
1011 int IncrementingBlock
= Phi
->getIncomingValue(0) == IncInstruction
? 0 : 1;
1013 // Get the value that is added to/multiplied with the phi
1014 Value
*OffsSecondOperand
= Offs
->getOperand(OffsSecondOp
);
1016 if (IncrementPerRound
->getType() != OffsSecondOperand
->getType() ||
1017 !L
->isLoopInvariant(OffsSecondOperand
))
1018 // Something has gone wrong, abort
1021 // Only proceed if the increment per round is a constant or an instruction
1022 // which does not originate from within the loop
1023 if (!isa
<Constant
>(IncrementPerRound
) &&
1024 !(isa
<Instruction
>(IncrementPerRound
) &&
1025 !L
->contains(cast
<Instruction
>(IncrementPerRound
))))
1028 // If the phi is not used by anything else, we can just adapt it when
1029 // replacing the instruction; if it is, we'll have to duplicate it
1031 if (Phi
->getNumUses() == 2) {
1032 // No other users -> reuse existing phi (One user is the instruction
1033 // we're looking at, the other is the phi increment)
1034 if (IncInstruction
->getNumUses() != 1) {
1035 // If the incrementing instruction does have more users than
1036 // our phi, we need to copy it
1037 IncInstruction
= BinaryOperator::Create(
1038 Instruction::BinaryOps(IncInstruction
->getOpcode()), Phi
,
1039 IncrementPerRound
, "LoopIncrement", IncInstruction
);
1040 Phi
->setIncomingValue(IncrementingBlock
, IncInstruction
);
1044 // There are other users -> create a new phi
1045 NewPhi
= PHINode::Create(Phi
->getType(), 2, "NewPhi", Phi
);
1046 // Copy the incoming values of the old phi
1047 NewPhi
->addIncoming(Phi
->getIncomingValue(IncrementingBlock
== 1 ? 0 : 1),
1048 Phi
->getIncomingBlock(IncrementingBlock
== 1 ? 0 : 1));
1049 IncInstruction
= BinaryOperator::Create(
1050 Instruction::BinaryOps(IncInstruction
->getOpcode()), NewPhi
,
1051 IncrementPerRound
, "LoopIncrement", IncInstruction
);
1052 NewPhi
->addIncoming(IncInstruction
,
1053 Phi
->getIncomingBlock(IncrementingBlock
));
1054 IncrementingBlock
= 1;
1057 IRBuilder
<> Builder(BB
->getContext());
1058 Builder
.SetInsertPoint(Phi
);
1059 Builder
.SetCurrentDebugLocation(Offs
->getDebugLoc());
1061 switch (Offs
->getOpcode()) {
1062 case Instruction::Add
:
1063 pushOutAdd(NewPhi
, OffsSecondOperand
, IncrementingBlock
== 1 ? 0 : 1);
1065 case Instruction::Mul
:
1066 pushOutMul(NewPhi
, IncrementPerRound
, OffsSecondOperand
, IncrementingBlock
,
1072 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1075 // The instruction has now been "absorbed" into the phi value
1076 Offs
->replaceAllUsesWith(NewPhi
);
1077 if (Offs
->hasNUses(0))
1078 Offs
->eraseFromParent();
1079 // Clean up the old increment in case it's unused because we built a new
1081 if (IncInstruction
->hasNUses(0))
1082 IncInstruction
->eraseFromParent();
1087 static Value
*CheckAndCreateOffsetAdd(Value
*X
, Value
*Y
, Value
*GEP
,
1088 IRBuilder
<> &Builder
) {
1089 // Splat the non-vector value to a vector of the given type - if the value is
1090 // a constant (and its value isn't too big), we can even use this opportunity
1091 // to scale it to the size of the vector elements
1092 auto FixSummands
= [&Builder
](FixedVectorType
*&VT
, Value
*&NonVectorVal
) {
1094 if ((Const
= dyn_cast
<ConstantInt
>(NonVectorVal
)) &&
1095 VT
->getElementType() != NonVectorVal
->getType()) {
1096 unsigned TargetElemSize
= VT
->getElementType()->getPrimitiveSizeInBits();
1097 uint64_t N
= Const
->getZExtValue();
1098 if (N
< (unsigned)(1 << (TargetElemSize
- 1))) {
1099 NonVectorVal
= Builder
.CreateVectorSplat(
1100 VT
->getNumElements(), Builder
.getIntN(TargetElemSize
, N
));
1105 Builder
.CreateVectorSplat(VT
->getNumElements(), NonVectorVal
);
1108 FixedVectorType
*XElType
= dyn_cast
<FixedVectorType
>(X
->getType());
1109 FixedVectorType
*YElType
= dyn_cast
<FixedVectorType
>(Y
->getType());
1110 // If one of X, Y is not a vector, we have to splat it in order
1111 // to add the two of them.
1112 if (XElType
&& !YElType
) {
1113 FixSummands(XElType
, Y
);
1114 YElType
= cast
<FixedVectorType
>(Y
->getType());
1115 } else if (YElType
&& !XElType
) {
1116 FixSummands(YElType
, X
);
1117 XElType
= cast
<FixedVectorType
>(X
->getType());
1119 assert(XElType
&& YElType
&& "Unknown vector types");
1120 // Check that the summands are of compatible types
1121 if (XElType
!= YElType
) {
1122 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1126 if (XElType
->getElementType()->getScalarSizeInBits() != 32) {
1127 // Check that by adding the vectors we do not accidentally
1128 // create an overflow
1129 Constant
*ConstX
= dyn_cast
<Constant
>(X
);
1130 Constant
*ConstY
= dyn_cast
<Constant
>(Y
);
1131 if (!ConstX
|| !ConstY
)
1133 unsigned TargetElemSize
= 128 / XElType
->getNumElements();
1134 for (unsigned i
= 0; i
< XElType
->getNumElements(); i
++) {
1135 ConstantInt
*ConstXEl
=
1136 dyn_cast
<ConstantInt
>(ConstX
->getAggregateElement(i
));
1137 ConstantInt
*ConstYEl
=
1138 dyn_cast
<ConstantInt
>(ConstY
->getAggregateElement(i
));
1139 if (!ConstXEl
|| !ConstYEl
||
1140 ConstXEl
->getZExtValue() + ConstYEl
->getZExtValue() >=
1141 (unsigned)(1 << (TargetElemSize
- 1)))
1146 Value
*Add
= Builder
.CreateAdd(X
, Y
);
1148 FixedVectorType
*GEPType
= cast
<FixedVectorType
>(GEP
->getType());
1149 if (checkOffsetSize(Add
, GEPType
->getNumElements()))
1155 Value
*MVEGatherScatterLowering::foldGEP(GetElementPtrInst
*GEP
,
1157 IRBuilder
<> &Builder
) {
1158 Value
*GEPPtr
= GEP
->getPointerOperand();
1159 Offsets
= GEP
->getOperand(1);
1160 // We only merge geps with constant offsets, because only for those
1161 // we can make sure that we do not cause an overflow
1162 if (!isa
<Constant
>(Offsets
))
1164 GetElementPtrInst
*BaseGEP
;
1165 if ((BaseGEP
= dyn_cast
<GetElementPtrInst
>(GEPPtr
))) {
1166 // Merge the two geps into one
1167 Value
*BaseBasePtr
= foldGEP(BaseGEP
, Offsets
, Builder
);
1171 CheckAndCreateOffsetAdd(Offsets
, GEP
->getOperand(1), GEP
, Builder
);
1172 if (Offsets
== nullptr)
1179 bool MVEGatherScatterLowering::optimiseAddress(Value
*Address
, BasicBlock
*BB
,
1181 GetElementPtrInst
*GEP
= dyn_cast
<GetElementPtrInst
>(Address
);
1184 bool Changed
= false;
1185 if (GEP
->hasOneUse() &&
1186 dyn_cast
<GetElementPtrInst
>(GEP
->getPointerOperand())) {
1187 IRBuilder
<> Builder(GEP
->getContext());
1188 Builder
.SetInsertPoint(GEP
);
1189 Builder
.SetCurrentDebugLocation(GEP
->getDebugLoc());
1191 Value
*Base
= foldGEP(GEP
, Offsets
, Builder
);
1192 // We only want to merge the geps if there is a real chance that they can be
1193 // used by an MVE gather; thus the offset has to have the correct size
1194 // (always i32 if it is not of vector type) and the base has to be a
1196 if (Offsets
&& Base
&& Base
!= GEP
) {
1197 GetElementPtrInst
*NewAddress
= GetElementPtrInst::Create(
1198 GEP
->getSourceElementType(), Base
, Offsets
, "gep.merged", GEP
);
1199 GEP
->replaceAllUsesWith(NewAddress
);
1204 Changed
|= optimiseOffsets(GEP
->getOperand(1), GEP
->getParent(), LI
);
1208 bool MVEGatherScatterLowering::runOnFunction(Function
&F
) {
1209 if (!EnableMaskedGatherScatters
)
1211 auto &TPC
= getAnalysis
<TargetPassConfig
>();
1212 auto &TM
= TPC
.getTM
<TargetMachine
>();
1213 auto *ST
= &TM
.getSubtarget
<ARMSubtarget
>(F
);
1214 if (!ST
->hasMVEIntegerOps())
1216 LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
1217 SmallVector
<IntrinsicInst
*, 4> Gathers
;
1218 SmallVector
<IntrinsicInst
*, 4> Scatters
;
1220 bool Changed
= false;
1222 for (BasicBlock
&BB
: F
) {
1223 Changed
|= SimplifyInstructionsInBlock(&BB
);
1225 for (Instruction
&I
: BB
) {
1226 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(&I
);
1227 if (II
&& II
->getIntrinsicID() == Intrinsic::masked_gather
&&
1228 isa
<FixedVectorType
>(II
->getType())) {
1229 Gathers
.push_back(II
);
1230 Changed
|= optimiseAddress(II
->getArgOperand(0), II
->getParent(), LI
);
1231 } else if (II
&& II
->getIntrinsicID() == Intrinsic::masked_scatter
&&
1232 isa
<FixedVectorType
>(II
->getArgOperand(0)->getType())) {
1233 Scatters
.push_back(II
);
1234 Changed
|= optimiseAddress(II
->getArgOperand(1), II
->getParent(), LI
);
1238 for (unsigned i
= 0; i
< Gathers
.size(); i
++) {
1239 IntrinsicInst
*I
= Gathers
[i
];
1240 Instruction
*L
= lowerGather(I
);
1244 // Get rid of any now dead instructions
1245 SimplifyInstructionsInBlock(L
->getParent());
1249 for (unsigned i
= 0; i
< Scatters
.size(); i
++) {
1250 IntrinsicInst
*I
= Scatters
[i
];
1251 Instruction
*S
= lowerScatter(I
);
1255 // Get rid of any now dead instructions
1256 SimplifyInstructionsInBlock(S
->getParent());