[ORC] Add std::tuple support to SimplePackedSerialization.
[llvm-project.git] / llvm / lib / Target / ARM / MVEGatherScatterLowering.cpp
blob4981b8051657abbcebbdd77e99d581cbdc00f9d7
1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
13 //===----------------------------------------------------------------------===//
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/CodeGen/TargetLowering.h"
22 #include "llvm/CodeGen/TargetPassConfig.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constant.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/InstrTypes.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/IntrinsicsARM.h"
36 #include "llvm/IR/IRBuilder.h"
37 #include "llvm/IR/PatternMatch.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/IR/Value.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include <algorithm>
44 #include <cassert>
46 using namespace llvm;
48 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
50 cl::opt<bool> EnableMaskedGatherScatters(
51 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
52 cl::desc("Enable the generation of masked gathers and scatters"));
54 namespace {
56 class MVEGatherScatterLowering : public FunctionPass {
57 public:
58 static char ID; // Pass identification, replacement for typeid
60 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
61 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
64 bool runOnFunction(Function &F) override;
66 StringRef getPassName() const override {
67 return "MVE gather/scatter lowering";
70 void getAnalysisUsage(AnalysisUsage &AU) const override {
71 AU.setPreservesCFG();
72 AU.addRequired<TargetPassConfig>();
73 AU.addRequired<LoopInfoWrapperPass>();
74 FunctionPass::getAnalysisUsage(AU);
77 private:
78 LoopInfo *LI = nullptr;
80 // Check this is a valid gather with correct alignment
81 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
82 Align Alignment);
83 // Check whether Ptr is hidden behind a bitcast and look through it
84 void lookThroughBitcast(Value *&Ptr);
85 // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
86 // scalar base and vector offsets, or else fallback to using a base of 0 and
87 // offset of Ptr where possible.
88 Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
89 FixedVectorType *Ty, Type *MemoryTy,
90 IRBuilder<> &Builder);
91 // Check for a getelementptr and deduce base and offsets from it, on success
92 // returning the base directly and the offsets indirectly using the Offsets
93 // argument
94 Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
95 GetElementPtrInst *GEP, IRBuilder<> &Builder);
96 // Compute the scale of this gather/scatter instruction
97 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
98 // If the value is a constant, or derived from constants via additions
99 // and multilications, return its numeric value
100 Optional<int64_t> getIfConst(const Value *V);
101 // If Inst is an add instruction, check whether one summand is a
102 // constant. If so, scale this constant and return it together with
103 // the other summand.
104 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
106 Instruction *lowerGather(IntrinsicInst *I);
107 // Create a gather from a base + vector of offsets
108 Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
109 Instruction *&Root,
110 IRBuilder<> &Builder);
111 // Create a gather from a vector of pointers
112 Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
113 IRBuilder<> &Builder,
114 int64_t Increment = 0);
115 // Create an incrementing gather from a vector of pointers
116 Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
117 IRBuilder<> &Builder,
118 int64_t Increment = 0);
120 Instruction *lowerScatter(IntrinsicInst *I);
121 // Create a scatter to a base + vector of offsets
122 Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
123 IRBuilder<> &Builder);
124 // Create a scatter to a vector of pointers
125 Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
126 IRBuilder<> &Builder,
127 int64_t Increment = 0);
128 // Create an incrementing scatter from a vector of pointers
129 Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
130 IRBuilder<> &Builder,
131 int64_t Increment = 0);
133 // QI gathers and scatters can increment their offsets on their own if
134 // the increment is a constant value (digit)
135 Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
136 IRBuilder<> &Builder);
137 // QI gathers/scatters can increment their offsets on their own if the
138 // increment is a constant value (digit) - this creates a writeback QI
139 // gather/scatter
140 Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
141 Value *Ptr, unsigned TypeScale,
142 IRBuilder<> &Builder);
144 // Optimise the base and offsets of the given address
145 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
146 // Try to fold consecutive geps together into one
147 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder);
148 // Check whether these offsets could be moved out of the loop they're in
149 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
150 // Pushes the given add out of the loop
151 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
152 // Pushes the given mul out of the loop
153 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
154 Value *OffsSecondOperand, unsigned LoopIncrement,
155 IRBuilder<> &Builder);
158 } // end anonymous namespace
160 char MVEGatherScatterLowering::ID = 0;
162 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
163 "MVE gather/scattering lowering pass", false, false)
165 Pass *llvm::createMVEGatherScatterLoweringPass() {
166 return new MVEGatherScatterLowering();
169 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
170 unsigned ElemSize,
171 Align Alignment) {
172 if (((NumElements == 4 &&
173 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
174 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
175 (NumElements == 16 && ElemSize == 8)) &&
176 Alignment >= ElemSize / 8)
177 return true;
178 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
179 << "valid alignment or vector type \n");
180 return false;
183 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
184 // Offsets that are not of type <N x i32> are sign extended by the
185 // getelementptr instruction, and MVE gathers/scatters treat the offset as
186 // unsigned. Thus, if the element size is smaller than 32, we can only allow
187 // positive offsets - i.e., the offsets are not allowed to be variables we
188 // can't look into.
189 // Additionally, <N x i32> offsets have to either originate from a zext of a
190 // vector with element types smaller or equal the type of the gather we're
191 // looking at, or consist of constants that we can check are small enough
192 // to fit into the gather type.
193 // Thus we check that 0 < value < 2^TargetElemSize.
194 unsigned TargetElemSize = 128 / TargetElemCount;
195 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
196 ->getElementType()
197 ->getScalarSizeInBits();
198 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
199 Constant *ConstOff = dyn_cast<Constant>(Offsets);
200 if (!ConstOff)
201 return false;
202 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
203 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
204 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
205 if (!OConst)
206 return false;
207 int SExtValue = OConst->getSExtValue();
208 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
209 return false;
210 return true;
212 if (isa<FixedVectorType>(ConstOff->getType())) {
213 for (unsigned i = 0; i < TargetElemCount; i++) {
214 if (!CheckValueSize(ConstOff->getAggregateElement(i)))
215 return false;
217 } else {
218 if (!CheckValueSize(ConstOff))
219 return false;
222 return true;
225 Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
226 int &Scale, FixedVectorType *Ty,
227 Type *MemoryTy,
228 IRBuilder<> &Builder) {
229 if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
230 if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
231 Scale =
232 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
233 MemoryTy->getScalarSizeInBits());
234 return Scale == -1 ? nullptr : V;
238 // If we couldn't use the GEP (or it doesn't exist), attempt to use a
239 // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
240 // elements.
241 FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
242 if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
243 return nullptr;
244 Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
245 Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getInt8PtrTy());
246 Offsets = Builder.CreatePtrToInt(
247 Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
248 Scale = 0;
249 return BasePtr;
252 Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
253 FixedVectorType *Ty,
254 GetElementPtrInst *GEP,
255 IRBuilder<> &Builder) {
256 if (!GEP) {
257 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
258 << "found\n");
259 return nullptr;
261 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
262 << " Looking at intrinsic for base + vector of offsets\n");
263 Value *GEPPtr = GEP->getPointerOperand();
264 Offsets = GEP->getOperand(1);
265 if (GEPPtr->getType()->isVectorTy() ||
266 !isa<FixedVectorType>(Offsets->getType()))
267 return nullptr;
269 if (GEP->getNumOperands() != 2) {
270 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
271 << " operands. Expanding.\n");
272 return nullptr;
274 Offsets = GEP->getOperand(1);
275 unsigned OffsetsElemCount =
276 cast<FixedVectorType>(Offsets->getType())->getNumElements();
277 // Paranoid check whether the number of parallel lanes is the same
278 assert(Ty->getNumElements() == OffsetsElemCount);
280 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
281 if (ZextOffs)
282 Offsets = ZextOffs->getOperand(0);
283 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
285 // If the offsets are already being zext-ed to <N x i32>, that relieves us of
286 // having to make sure that they won't overflow.
287 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
288 ->getElementType()
289 ->getScalarSizeInBits() != 32)
290 if (!checkOffsetSize(Offsets, OffsetsElemCount))
291 return nullptr;
293 // The offset sizes have been checked; if any truncating or zext-ing is
294 // required to fix them, do that now
295 if (Ty != Offsets->getType()) {
296 if ((Ty->getElementType()->getScalarSizeInBits() <
297 OffsetType->getElementType()->getScalarSizeInBits())) {
298 Offsets = Builder.CreateTrunc(Offsets, Ty);
299 } else {
300 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
303 // If none of the checks failed, return the gep's base pointer
304 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
305 return GEPPtr;
308 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
309 // Look through bitcast instruction if #elements is the same
310 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
311 auto *BCTy = cast<FixedVectorType>(BitCast->getType());
312 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
313 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
314 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
315 << "bitcast\n");
316 Ptr = BitCast->getOperand(0);
321 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
322 unsigned MemoryElemSize) {
323 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
324 // or a 8bit, 16bit or 32bit load/store scaled by 1
325 if (GEPElemSize == 32 && MemoryElemSize == 32)
326 return 2;
327 else if (GEPElemSize == 16 && MemoryElemSize == 16)
328 return 1;
329 else if (GEPElemSize == 8)
330 return 0;
331 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
332 << "create intrinsic\n");
333 return -1;
336 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
337 const Constant *C = dyn_cast<Constant>(V);
338 if (C != nullptr)
339 return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
340 if (!isa<Instruction>(V))
341 return Optional<int64_t>{};
343 const Instruction *I = cast<Instruction>(V);
344 if (I->getOpcode() == Instruction::Add ||
345 I->getOpcode() == Instruction::Mul) {
346 Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
347 Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
348 if (!Op0 || !Op1)
349 return Optional<int64_t>{};
350 if (I->getOpcode() == Instruction::Add)
351 return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
352 if (I->getOpcode() == Instruction::Mul)
353 return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
355 return Optional<int64_t>{};
358 std::pair<Value *, int64_t>
359 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
360 std::pair<Value *, int64_t> ReturnFalse =
361 std::pair<Value *, int64_t>(nullptr, 0);
362 // At this point, the instruction we're looking at must be an add or we
363 // bail out
364 Instruction *Add = dyn_cast<Instruction>(Inst);
365 if (Add == nullptr || Add->getOpcode() != Instruction::Add)
366 return ReturnFalse;
368 Value *Summand;
369 Optional<int64_t> Const;
370 // Find out which operand the value that is increased is
371 if ((Const = getIfConst(Add->getOperand(0))))
372 Summand = Add->getOperand(1);
373 else if ((Const = getIfConst(Add->getOperand(1))))
374 Summand = Add->getOperand(0);
375 else
376 return ReturnFalse;
378 // Check that the constant is small enough for an incrementing gather
379 int64_t Immediate = Const.getValue() << TypeScale;
380 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
381 return ReturnFalse;
383 return std::pair<Value *, int64_t>(Summand, Immediate);
386 Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
387 using namespace PatternMatch;
388 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
389 << *I << "\n");
391 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
392 // Attempt to turn the masked gather in I into a MVE intrinsic
393 // Potentially optimising the addressing modes as we do so.
394 auto *Ty = cast<FixedVectorType>(I->getType());
395 Value *Ptr = I->getArgOperand(0);
396 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
397 Value *Mask = I->getArgOperand(2);
398 Value *PassThru = I->getArgOperand(3);
400 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
401 Alignment))
402 return nullptr;
403 lookThroughBitcast(Ptr);
404 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
406 IRBuilder<> Builder(I->getContext());
407 Builder.SetInsertPoint(I);
408 Builder.SetCurrentDebugLocation(I->getDebugLoc());
410 Instruction *Root = I;
412 Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
413 if (!Load)
414 Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
415 if (!Load)
416 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
417 if (!Load)
418 return nullptr;
420 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
421 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
422 << "creating select\n");
423 Load = SelectInst::Create(Mask, Load, PassThru);
424 Builder.Insert(Load);
427 Root->replaceAllUsesWith(Load);
428 Root->eraseFromParent();
429 if (Root != I)
430 // If this was an extending gather, we need to get rid of the sext/zext
431 // sext/zext as well as of the gather itself
432 I->eraseFromParent();
434 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
435 << *Load << "\n");
436 return Load;
439 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
440 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
441 using namespace PatternMatch;
442 auto *Ty = cast<FixedVectorType>(I->getType());
443 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
444 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
445 // Can't build an intrinsic for this
446 return nullptr;
447 Value *Mask = I->getArgOperand(2);
448 if (match(Mask, m_One()))
449 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
450 {Ty, Ptr->getType()},
451 {Ptr, Builder.getInt32(Increment)});
452 else
453 return Builder.CreateIntrinsic(
454 Intrinsic::arm_mve_vldr_gather_base_predicated,
455 {Ty, Ptr->getType(), Mask->getType()},
456 {Ptr, Builder.getInt32(Increment), Mask});
459 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
460 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
461 using namespace PatternMatch;
462 auto *Ty = cast<FixedVectorType>(I->getType());
463 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
464 << "writeback\n");
465 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
466 // Can't build an intrinsic for this
467 return nullptr;
468 Value *Mask = I->getArgOperand(2);
469 if (match(Mask, m_One()))
470 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
471 {Ty, Ptr->getType()},
472 {Ptr, Builder.getInt32(Increment)});
473 else
474 return Builder.CreateIntrinsic(
475 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
476 {Ty, Ptr->getType(), Mask->getType()},
477 {Ptr, Builder.getInt32(Increment), Mask});
480 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
481 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
482 using namespace PatternMatch;
484 Type *MemoryTy = I->getType();
485 Type *ResultTy = MemoryTy;
487 unsigned Unsigned = 1;
488 // The size of the gather was already checked in isLegalTypeAndAlignment;
489 // if it was not a full vector width an appropriate extend should follow.
490 auto *Extend = Root;
491 bool TruncResult = false;
492 if (MemoryTy->getPrimitiveSizeInBits() < 128) {
493 if (I->hasOneUse()) {
494 // If the gather has a single extend of the correct type, use an extending
495 // gather and replace the ext. In which case the correct root to replace
496 // is not the CallInst itself, but the instruction which extends it.
497 Instruction* User = cast<Instruction>(*I->users().begin());
498 if (isa<SExtInst>(User) &&
499 User->getType()->getPrimitiveSizeInBits() == 128) {
500 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
501 << *User << "\n");
502 Extend = User;
503 ResultTy = User->getType();
504 Unsigned = 0;
505 } else if (isa<ZExtInst>(User) &&
506 User->getType()->getPrimitiveSizeInBits() == 128) {
507 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
508 << *ResultTy << "\n");
509 Extend = User;
510 ResultTy = User->getType();
514 // If an extend hasn't been found and the type is an integer, create an
515 // extending gather and truncate back to the original type.
516 if (ResultTy->getPrimitiveSizeInBits() < 128 &&
517 ResultTy->isIntOrIntVectorTy()) {
518 ResultTy = ResultTy->getWithNewBitWidth(
519 128 / cast<FixedVectorType>(ResultTy)->getNumElements());
520 TruncResult = true;
521 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
522 << *ResultTy << "\n");
525 // The final size of the gather must be a full vector width
526 if (ResultTy->getPrimitiveSizeInBits() != 128) {
527 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
528 "from the correct type. Expanding\n");
529 return nullptr;
533 Value *Offsets;
534 int Scale;
535 Value *BasePtr = decomposePtr(
536 Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
537 if (!BasePtr)
538 return nullptr;
540 Root = Extend;
541 Value *Mask = I->getArgOperand(2);
542 Instruction *Load = nullptr;
543 if (!match(Mask, m_One()))
544 Load = Builder.CreateIntrinsic(
545 Intrinsic::arm_mve_vldr_gather_offset_predicated,
546 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
547 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
548 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
549 else
550 Load = Builder.CreateIntrinsic(
551 Intrinsic::arm_mve_vldr_gather_offset,
552 {ResultTy, BasePtr->getType(), Offsets->getType()},
553 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
554 Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
556 if (TruncResult) {
557 Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
558 Builder.Insert(Load);
560 return Load;
563 Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
564 using namespace PatternMatch;
565 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
566 << *I << "\n");
568 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
569 // Attempt to turn the masked scatter in I into a MVE intrinsic
570 // Potentially optimising the addressing modes as we do so.
571 Value *Input = I->getArgOperand(0);
572 Value *Ptr = I->getArgOperand(1);
573 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
574 auto *Ty = cast<FixedVectorType>(Input->getType());
576 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
577 Alignment))
578 return nullptr;
580 lookThroughBitcast(Ptr);
581 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
583 IRBuilder<> Builder(I->getContext());
584 Builder.SetInsertPoint(I);
585 Builder.SetCurrentDebugLocation(I->getDebugLoc());
587 Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
588 if (!Store)
589 Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
590 if (!Store)
591 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
592 if (!Store)
593 return nullptr;
595 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
596 << *Store << "\n");
597 I->eraseFromParent();
598 return Store;
601 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
602 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
603 using namespace PatternMatch;
604 Value *Input = I->getArgOperand(0);
605 auto *Ty = cast<FixedVectorType>(Input->getType());
606 // Only QR variants allow truncating
607 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
608 // Can't build an intrinsic for this
609 return nullptr;
611 Value *Mask = I->getArgOperand(3);
612 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
613 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
614 if (match(Mask, m_One()))
615 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
616 {Ptr->getType(), Input->getType()},
617 {Ptr, Builder.getInt32(Increment), Input});
618 else
619 return Builder.CreateIntrinsic(
620 Intrinsic::arm_mve_vstr_scatter_base_predicated,
621 {Ptr->getType(), Input->getType(), Mask->getType()},
622 {Ptr, Builder.getInt32(Increment), Input, Mask});
625 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
626 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
627 using namespace PatternMatch;
628 Value *Input = I->getArgOperand(0);
629 auto *Ty = cast<FixedVectorType>(Input->getType());
630 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
631 << "with writeback\n");
632 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
633 // Can't build an intrinsic for this
634 return nullptr;
635 Value *Mask = I->getArgOperand(3);
636 if (match(Mask, m_One()))
637 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
638 {Ptr->getType(), Input->getType()},
639 {Ptr, Builder.getInt32(Increment), Input});
640 else
641 return Builder.CreateIntrinsic(
642 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
643 {Ptr->getType(), Input->getType(), Mask->getType()},
644 {Ptr, Builder.getInt32(Increment), Input, Mask});
647 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
648 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
649 using namespace PatternMatch;
650 Value *Input = I->getArgOperand(0);
651 Value *Mask = I->getArgOperand(3);
652 Type *InputTy = Input->getType();
653 Type *MemoryTy = InputTy;
655 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
656 << " to base + vector of offsets\n");
657 // If the input has been truncated, try to integrate that trunc into the
658 // scatter instruction (we don't care about alignment here)
659 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
660 Value *PreTrunc = Trunc->getOperand(0);
661 Type *PreTruncTy = PreTrunc->getType();
662 if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
663 Input = PreTrunc;
664 InputTy = PreTruncTy;
667 bool ExtendInput = false;
668 if (InputTy->getPrimitiveSizeInBits() < 128 &&
669 InputTy->isIntOrIntVectorTy()) {
670 // If we can't find a trunc to incorporate into the instruction, create an
671 // implicit one with a zext, so that we can still create a scatter. We know
672 // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
673 // smaller than 128 bits will divide evenly into a 128bit vector.
674 InputTy = InputTy->getWithNewBitWidth(
675 128 / cast<FixedVectorType>(InputTy)->getNumElements());
676 ExtendInput = true;
677 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
678 << *Input << "\n");
680 if (InputTy->getPrimitiveSizeInBits() != 128) {
681 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
682 "non-standard input types. Expanding.\n");
683 return nullptr;
686 Value *Offsets;
687 int Scale;
688 Value *BasePtr = decomposePtr(
689 Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
690 if (!BasePtr)
691 return nullptr;
693 if (ExtendInput)
694 Input = Builder.CreateZExt(Input, InputTy);
695 if (!match(Mask, m_One()))
696 return Builder.CreateIntrinsic(
697 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
698 {BasePtr->getType(), Offsets->getType(), Input->getType(),
699 Mask->getType()},
700 {BasePtr, Offsets, Input,
701 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
702 Builder.getInt32(Scale), Mask});
703 else
704 return Builder.CreateIntrinsic(
705 Intrinsic::arm_mve_vstr_scatter_offset,
706 {BasePtr->getType(), Offsets->getType(), Input->getType()},
707 {BasePtr, Offsets, Input,
708 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
709 Builder.getInt32(Scale)});
712 Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
713 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
714 FixedVectorType *Ty;
715 if (I->getIntrinsicID() == Intrinsic::masked_gather)
716 Ty = cast<FixedVectorType>(I->getType());
717 else
718 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
720 // Incrementing gathers only exist for v4i32
721 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
722 return nullptr;
723 // Incrementing gathers are not beneficial outside of a loop
724 Loop *L = LI->getLoopFor(I->getParent());
725 if (L == nullptr)
726 return nullptr;
728 // Decompose the GEP into Base and Offsets
729 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
730 Value *Offsets;
731 Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
732 if (!BasePtr)
733 return nullptr;
735 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
736 "wb gather/scatter\n");
738 // The gep was in charge of making sure the offsets are scaled correctly
739 // - calculate that factor so it can be applied by hand
740 DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
741 int TypeScale =
742 computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
743 DT.getTypeSizeInBits(GEP->getType()) /
744 cast<FixedVectorType>(GEP->getType())->getNumElements());
745 if (TypeScale == -1)
746 return nullptr;
748 if (GEP->hasOneUse()) {
749 // Only in this case do we want to build a wb gather, because the wb will
750 // change the phi which does affect other users of the gep (which will still
751 // be using the phi in the old way)
752 if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
753 TypeScale, Builder))
754 return Load;
757 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
758 "non-wb gather/scatter\n");
760 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
761 if (Add.first == nullptr)
762 return nullptr;
763 Value *OffsetsIncoming = Add.first;
764 int64_t Immediate = Add.second;
766 // Make sure the offsets are scaled correctly
767 Instruction *ScaledOffsets = BinaryOperator::Create(
768 Instruction::Shl, OffsetsIncoming,
769 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
770 "ScaledIndex", I);
771 // Add the base to the offsets
772 OffsetsIncoming = BinaryOperator::Create(
773 Instruction::Add, ScaledOffsets,
774 Builder.CreateVectorSplat(
775 Ty->getNumElements(),
776 Builder.CreatePtrToInt(
777 BasePtr,
778 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
779 "StartIndex", I);
781 if (I->getIntrinsicID() == Intrinsic::masked_gather)
782 return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
783 else
784 return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
787 Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
788 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
789 IRBuilder<> &Builder) {
790 // Check whether this gather's offset is incremented by a constant - if so,
791 // and the load is of the right type, we can merge this into a QI gather
792 Loop *L = LI->getLoopFor(I->getParent());
793 // Offsets that are worth merging into this instruction will be incremented
794 // by a constant, thus we're looking for an add of a phi and a constant
795 PHINode *Phi = dyn_cast<PHINode>(Offsets);
796 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
797 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
798 // No phi means no IV to write back to; if there is a phi, we expect it
799 // to have exactly two incoming values; the only phis we are interested in
800 // will be loop IV's and have exactly two uses, one in their increment and
801 // one in the gather's gep
802 return nullptr;
804 unsigned IncrementIndex =
805 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
806 // Look through the phi to the phi increment
807 Offsets = Phi->getIncomingValue(IncrementIndex);
809 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
810 if (Add.first == nullptr)
811 return nullptr;
812 Value *OffsetsIncoming = Add.first;
813 int64_t Immediate = Add.second;
814 if (OffsetsIncoming != Phi)
815 // Then the increment we are looking at is not an increment of the
816 // induction variable, and we don't want to do a writeback
817 return nullptr;
819 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
820 unsigned NumElems =
821 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
823 // Make sure the offsets are scaled correctly
824 Instruction *ScaledOffsets = BinaryOperator::Create(
825 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
826 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
827 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
828 // Add the base to the offsets
829 OffsetsIncoming = BinaryOperator::Create(
830 Instruction::Add, ScaledOffsets,
831 Builder.CreateVectorSplat(
832 NumElems,
833 Builder.CreatePtrToInt(
834 BasePtr,
835 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
836 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
837 // The gather is pre-incrementing
838 OffsetsIncoming = BinaryOperator::Create(
839 Instruction::Sub, OffsetsIncoming,
840 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
841 "PreIncrementStartIndex",
842 &Phi->getIncomingBlock(1 - IncrementIndex)->back());
843 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
845 Builder.SetInsertPoint(I);
847 Instruction *EndResult;
848 Instruction *NewInduction;
849 if (I->getIntrinsicID() == Intrinsic::masked_gather) {
850 // Build the incrementing gather
851 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
852 // One value to be handed to whoever uses the gather, one is the loop
853 // increment
854 EndResult = ExtractValueInst::Create(Load, 0, "Gather");
855 NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
856 Builder.Insert(EndResult);
857 Builder.Insert(NewInduction);
858 } else {
859 // Build the incrementing scatter
860 EndResult = NewInduction =
861 tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
863 Instruction *AddInst = cast<Instruction>(Offsets);
864 AddInst->replaceAllUsesWith(NewInduction);
865 AddInst->eraseFromParent();
866 Phi->setIncomingValue(IncrementIndex, NewInduction);
868 return EndResult;
871 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
872 Value *OffsSecondOperand,
873 unsigned StartIndex) {
874 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
875 Instruction *InsertionPoint =
876 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
877 // Initialize the phi with a vector that contains a sum of the constants
878 Instruction *NewIndex = BinaryOperator::Create(
879 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
880 "PushedOutAdd", InsertionPoint);
881 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
883 // Order such that start index comes first (this reduces mov's)
884 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
885 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
886 Phi->getIncomingBlock(IncrementIndex));
887 Phi->removeIncomingValue(IncrementIndex);
888 Phi->removeIncomingValue(StartIndex);
891 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
892 Value *IncrementPerRound,
893 Value *OffsSecondOperand,
894 unsigned LoopIncrement,
895 IRBuilder<> &Builder) {
896 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
898 // Create a new scalar add outside of the loop and transform it to a splat
899 // by which loop variable can be incremented
900 Instruction *InsertionPoint = &cast<Instruction>(
901 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
903 // Create a new index
904 Value *StartIndex = BinaryOperator::Create(
905 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
906 OffsSecondOperand, "PushedOutMul", InsertionPoint);
908 Instruction *Product =
909 BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
910 OffsSecondOperand, "Product", InsertionPoint);
911 // Increment NewIndex by Product instead of the multiplication
912 Instruction *NewIncrement = BinaryOperator::Create(
913 Instruction::Add, Phi, Product, "IncrementPushedOutMul",
914 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
915 .getPrevNode());
917 Phi->addIncoming(StartIndex,
918 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
919 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
920 Phi->removeIncomingValue((unsigned)0);
921 Phi->removeIncomingValue((unsigned)0);
924 // Check whether all usages of this instruction are as offsets of
925 // gathers/scatters or simple arithmetics only used by gathers/scatters
926 static bool hasAllGatScatUsers(Instruction *I) {
927 if (I->hasNUses(0)) {
928 return false;
930 bool Gatscat = true;
931 for (User *U : I->users()) {
932 if (!isa<Instruction>(U))
933 return false;
934 if (isa<GetElementPtrInst>(U) ||
935 isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
936 return Gatscat;
937 } else {
938 unsigned OpCode = cast<Instruction>(U)->getOpcode();
939 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
940 hasAllGatScatUsers(cast<Instruction>(U))) {
941 continue;
943 return false;
946 return Gatscat;
949 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
950 LoopInfo *LI) {
951 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"
952 << *Offsets << "\n");
953 // Optimise the addresses of gathers/scatters by moving invariant
954 // calculations out of the loop
955 if (!isa<Instruction>(Offsets))
956 return false;
957 Instruction *Offs = cast<Instruction>(Offsets);
958 if (Offs->getOpcode() != Instruction::Add &&
959 Offs->getOpcode() != Instruction::Mul)
960 return false;
961 Loop *L = LI->getLoopFor(BB);
962 if (L == nullptr)
963 return false;
964 if (!Offs->hasOneUse()) {
965 if (!hasAllGatScatUsers(Offs))
966 return false;
969 // Find out which, if any, operand of the instruction
970 // is a phi node
971 PHINode *Phi;
972 int OffsSecondOp;
973 if (isa<PHINode>(Offs->getOperand(0))) {
974 Phi = cast<PHINode>(Offs->getOperand(0));
975 OffsSecondOp = 1;
976 } else if (isa<PHINode>(Offs->getOperand(1))) {
977 Phi = cast<PHINode>(Offs->getOperand(1));
978 OffsSecondOp = 0;
979 } else {
980 bool Changed = false;
981 if (isa<Instruction>(Offs->getOperand(0)) &&
982 L->contains(cast<Instruction>(Offs->getOperand(0))))
983 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
984 if (isa<Instruction>(Offs->getOperand(1)) &&
985 L->contains(cast<Instruction>(Offs->getOperand(1))))
986 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
987 if (!Changed)
988 return false;
989 if (isa<PHINode>(Offs->getOperand(0))) {
990 Phi = cast<PHINode>(Offs->getOperand(0));
991 OffsSecondOp = 1;
992 } else if (isa<PHINode>(Offs->getOperand(1))) {
993 Phi = cast<PHINode>(Offs->getOperand(1));
994 OffsSecondOp = 0;
995 } else {
996 return false;
999 // A phi node we want to perform this function on should be from the
1000 // loop header.
1001 if (Phi->getParent() != L->getHeader())
1002 return false;
1004 // We're looking for a simple add recurrence.
1005 BinaryOperator *IncInstruction;
1006 Value *Start, *IncrementPerRound;
1007 if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
1008 IncInstruction->getOpcode() != Instruction::Add)
1009 return false;
1011 int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
1013 // Get the value that is added to/multiplied with the phi
1014 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
1016 if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
1017 !L->isLoopInvariant(OffsSecondOperand))
1018 // Something has gone wrong, abort
1019 return false;
1021 // Only proceed if the increment per round is a constant or an instruction
1022 // which does not originate from within the loop
1023 if (!isa<Constant>(IncrementPerRound) &&
1024 !(isa<Instruction>(IncrementPerRound) &&
1025 !L->contains(cast<Instruction>(IncrementPerRound))))
1026 return false;
1028 // If the phi is not used by anything else, we can just adapt it when
1029 // replacing the instruction; if it is, we'll have to duplicate it
1030 PHINode *NewPhi;
1031 if (Phi->getNumUses() == 2) {
1032 // No other users -> reuse existing phi (One user is the instruction
1033 // we're looking at, the other is the phi increment)
1034 if (IncInstruction->getNumUses() != 1) {
1035 // If the incrementing instruction does have more users than
1036 // our phi, we need to copy it
1037 IncInstruction = BinaryOperator::Create(
1038 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
1039 IncrementPerRound, "LoopIncrement", IncInstruction);
1040 Phi->setIncomingValue(IncrementingBlock, IncInstruction);
1042 NewPhi = Phi;
1043 } else {
1044 // There are other users -> create a new phi
1045 NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi);
1046 // Copy the incoming values of the old phi
1047 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
1048 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
1049 IncInstruction = BinaryOperator::Create(
1050 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
1051 IncrementPerRound, "LoopIncrement", IncInstruction);
1052 NewPhi->addIncoming(IncInstruction,
1053 Phi->getIncomingBlock(IncrementingBlock));
1054 IncrementingBlock = 1;
1057 IRBuilder<> Builder(BB->getContext());
1058 Builder.SetInsertPoint(Phi);
1059 Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1061 switch (Offs->getOpcode()) {
1062 case Instruction::Add:
1063 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1064 break;
1065 case Instruction::Mul:
1066 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1067 Builder);
1068 break;
1069 default:
1070 return false;
1072 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1073 << "add/mul\n");
1075 // The instruction has now been "absorbed" into the phi value
1076 Offs->replaceAllUsesWith(NewPhi);
1077 if (Offs->hasNUses(0))
1078 Offs->eraseFromParent();
1079 // Clean up the old increment in case it's unused because we built a new
1080 // one
1081 if (IncInstruction->hasNUses(0))
1082 IncInstruction->eraseFromParent();
1084 return true;
1087 static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP,
1088 IRBuilder<> &Builder) {
1089 // Splat the non-vector value to a vector of the given type - if the value is
1090 // a constant (and its value isn't too big), we can even use this opportunity
1091 // to scale it to the size of the vector elements
1092 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1093 ConstantInt *Const;
1094 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1095 VT->getElementType() != NonVectorVal->getType()) {
1096 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1097 uint64_t N = Const->getZExtValue();
1098 if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1099 NonVectorVal = Builder.CreateVectorSplat(
1100 VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1101 return;
1104 NonVectorVal =
1105 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1108 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1109 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1110 // If one of X, Y is not a vector, we have to splat it in order
1111 // to add the two of them.
1112 if (XElType && !YElType) {
1113 FixSummands(XElType, Y);
1114 YElType = cast<FixedVectorType>(Y->getType());
1115 } else if (YElType && !XElType) {
1116 FixSummands(YElType, X);
1117 XElType = cast<FixedVectorType>(X->getType());
1119 assert(XElType && YElType && "Unknown vector types");
1120 // Check that the summands are of compatible types
1121 if (XElType != YElType) {
1122 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1123 return nullptr;
1126 if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1127 // Check that by adding the vectors we do not accidentally
1128 // create an overflow
1129 Constant *ConstX = dyn_cast<Constant>(X);
1130 Constant *ConstY = dyn_cast<Constant>(Y);
1131 if (!ConstX || !ConstY)
1132 return nullptr;
1133 unsigned TargetElemSize = 128 / XElType->getNumElements();
1134 for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1135 ConstantInt *ConstXEl =
1136 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1137 ConstantInt *ConstYEl =
1138 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1139 if (!ConstXEl || !ConstYEl ||
1140 ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
1141 (unsigned)(1 << (TargetElemSize - 1)))
1142 return nullptr;
1146 Value *Add = Builder.CreateAdd(X, Y);
1148 FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
1149 if (checkOffsetSize(Add, GEPType->getNumElements()))
1150 return Add;
1151 else
1152 return nullptr;
1155 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1156 Value *&Offsets,
1157 IRBuilder<> &Builder) {
1158 Value *GEPPtr = GEP->getPointerOperand();
1159 Offsets = GEP->getOperand(1);
1160 // We only merge geps with constant offsets, because only for those
1161 // we can make sure that we do not cause an overflow
1162 if (!isa<Constant>(Offsets))
1163 return nullptr;
1164 GetElementPtrInst *BaseGEP;
1165 if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
1166 // Merge the two geps into one
1167 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
1168 if (!BaseBasePtr)
1169 return nullptr;
1170 Offsets =
1171 CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
1172 if (Offsets == nullptr)
1173 return nullptr;
1174 return BaseBasePtr;
1176 return GEPPtr;
1179 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1180 LoopInfo *LI) {
1181 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1182 if (!GEP)
1183 return false;
1184 bool Changed = false;
1185 if (GEP->hasOneUse() &&
1186 dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
1187 IRBuilder<> Builder(GEP->getContext());
1188 Builder.SetInsertPoint(GEP);
1189 Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1190 Value *Offsets;
1191 Value *Base = foldGEP(GEP, Offsets, Builder);
1192 // We only want to merge the geps if there is a real chance that they can be
1193 // used by an MVE gather; thus the offset has to have the correct size
1194 // (always i32 if it is not of vector type) and the base has to be a
1195 // pointer.
1196 if (Offsets && Base && Base != GEP) {
1197 GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
1198 GEP->getSourceElementType(), Base, Offsets, "gep.merged", GEP);
1199 GEP->replaceAllUsesWith(NewAddress);
1200 GEP = NewAddress;
1201 Changed = true;
1204 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1205 return Changed;
1208 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1209 if (!EnableMaskedGatherScatters)
1210 return false;
1211 auto &TPC = getAnalysis<TargetPassConfig>();
1212 auto &TM = TPC.getTM<TargetMachine>();
1213 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1214 if (!ST->hasMVEIntegerOps())
1215 return false;
1216 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1217 SmallVector<IntrinsicInst *, 4> Gathers;
1218 SmallVector<IntrinsicInst *, 4> Scatters;
1220 bool Changed = false;
1222 for (BasicBlock &BB : F) {
1223 Changed |= SimplifyInstructionsInBlock(&BB);
1225 for (Instruction &I : BB) {
1226 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1227 if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1228 isa<FixedVectorType>(II->getType())) {
1229 Gathers.push_back(II);
1230 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1231 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1232 isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1233 Scatters.push_back(II);
1234 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1238 for (unsigned i = 0; i < Gathers.size(); i++) {
1239 IntrinsicInst *I = Gathers[i];
1240 Instruction *L = lowerGather(I);
1241 if (L == nullptr)
1242 continue;
1244 // Get rid of any now dead instructions
1245 SimplifyInstructionsInBlock(L->getParent());
1246 Changed = true;
1249 for (unsigned i = 0; i < Scatters.size(); i++) {
1250 IntrinsicInst *I = Scatters[i];
1251 Instruction *S = lowerScatter(I);
1252 if (S == nullptr)
1253 continue;
1255 // Get rid of any now dead instructions
1256 SimplifyInstructionsInBlock(S->getParent());
1257 Changed = true;
1259 return Changed;