1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
12 //===----------------------------------------------------------------------===//
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/InstSimplifyFolder.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/Analysis/ValueTracking.h"
19 #include "llvm/Analysis/VectorUtils.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/GetElementPtrTypeIterator.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/IntrinsicsRISCV.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/Transforms/Utils/Local.h"
30 using namespace PatternMatch
;
32 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
36 class RISCVGatherScatterLowering
: public FunctionPass
{
37 const RISCVSubtarget
*ST
= nullptr;
38 const RISCVTargetLowering
*TLI
= nullptr;
39 LoopInfo
*LI
= nullptr;
40 const DataLayout
*DL
= nullptr;
42 SmallVector
<WeakTrackingVH
> MaybeDeadPHIs
;
44 // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
45 // used by multiple gathers/scatters, this allow us to reuse the scalar
46 // instructions we created for the first gather/scatter for the others.
47 DenseMap
<GetElementPtrInst
*, std::pair
<Value
*, Value
*>> StridedAddrs
;
50 static char ID
; // Pass identification, replacement for typeid
52 RISCVGatherScatterLowering() : FunctionPass(ID
) {}
54 bool runOnFunction(Function
&F
) override
;
56 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
58 AU
.addRequired
<TargetPassConfig
>();
59 AU
.addRequired
<LoopInfoWrapperPass
>();
62 StringRef
getPassName() const override
{
63 return "RISC-V gather/scatter lowering";
67 bool tryCreateStridedLoadStore(IntrinsicInst
*II
, Type
*DataType
, Value
*Ptr
,
70 std::pair
<Value
*, Value
*> determineBaseAndStride(Instruction
*Ptr
,
71 IRBuilderBase
&Builder
);
73 bool matchStridedRecurrence(Value
*Index
, Loop
*L
, Value
*&Stride
,
74 PHINode
*&BasePtr
, BinaryOperator
*&Inc
,
75 IRBuilderBase
&Builder
);
78 } // end anonymous namespace
80 char RISCVGatherScatterLowering::ID
= 0;
82 INITIALIZE_PASS(RISCVGatherScatterLowering
, DEBUG_TYPE
,
83 "RISC-V gather/scatter lowering pass", false, false)
85 FunctionPass
*llvm::createRISCVGatherScatterLoweringPass() {
86 return new RISCVGatherScatterLowering();
89 // TODO: Should we consider the mask when looking for a stride?
90 static std::pair
<Value
*, Value
*> matchStridedConstant(Constant
*StartC
) {
91 if (!isa
<FixedVectorType
>(StartC
->getType()))
92 return std::make_pair(nullptr, nullptr);
94 unsigned NumElts
= cast
<FixedVectorType
>(StartC
->getType())->getNumElements();
96 // Check that the start value is a strided constant.
98 dyn_cast_or_null
<ConstantInt
>(StartC
->getAggregateElement((unsigned)0));
100 return std::make_pair(nullptr, nullptr);
101 APInt
StrideVal(StartVal
->getValue().getBitWidth(), 0);
102 ConstantInt
*Prev
= StartVal
;
103 for (unsigned i
= 1; i
!= NumElts
; ++i
) {
104 auto *C
= dyn_cast_or_null
<ConstantInt
>(StartC
->getAggregateElement(i
));
106 return std::make_pair(nullptr, nullptr);
108 APInt LocalStride
= C
->getValue() - Prev
->getValue();
110 StrideVal
= LocalStride
;
111 else if (StrideVal
!= LocalStride
)
112 return std::make_pair(nullptr, nullptr);
117 Value
*Stride
= ConstantInt::get(StartVal
->getType(), StrideVal
);
119 return std::make_pair(StartVal
, Stride
);
122 static std::pair
<Value
*, Value
*> matchStridedStart(Value
*Start
,
123 IRBuilderBase
&Builder
) {
124 // Base case, start is a strided constant.
125 auto *StartC
= dyn_cast
<Constant
>(Start
);
127 return matchStridedConstant(StartC
);
129 // Base case, start is a stepvector
130 if (match(Start
, m_Intrinsic
<Intrinsic::experimental_stepvector
>())) {
131 auto *Ty
= Start
->getType()->getScalarType();
132 return std::make_pair(ConstantInt::get(Ty
, 0), ConstantInt::get(Ty
, 1));
135 // Not a constant, maybe it's a strided constant with a splat added or
137 auto *BO
= dyn_cast
<BinaryOperator
>(Start
);
138 if (!BO
|| (BO
->getOpcode() != Instruction::Add
&&
139 BO
->getOpcode() != Instruction::Or
&&
140 BO
->getOpcode() != Instruction::Shl
&&
141 BO
->getOpcode() != Instruction::Mul
))
142 return std::make_pair(nullptr, nullptr);
144 if (BO
->getOpcode() == Instruction::Or
&&
145 !cast
<PossiblyDisjointInst
>(BO
)->isDisjoint())
146 return std::make_pair(nullptr, nullptr);
148 // Look for an operand that is splatted.
149 unsigned OtherIndex
= 0;
150 Value
*Splat
= getSplatValue(BO
->getOperand(1));
151 if (!Splat
&& Instruction::isCommutative(BO
->getOpcode())) {
152 Splat
= getSplatValue(BO
->getOperand(0));
156 return std::make_pair(nullptr, nullptr);
159 std::tie(Start
, Stride
) = matchStridedStart(BO
->getOperand(OtherIndex
),
162 return std::make_pair(nullptr, nullptr);
164 Builder
.SetInsertPoint(BO
);
165 Builder
.SetCurrentDebugLocation(DebugLoc());
166 // Add the splat value to the start or multiply the start and stride by the
168 switch (BO
->getOpcode()) {
170 llvm_unreachable("Unexpected opcode");
171 case Instruction::Or
:
172 // TODO: We'd be better off creating disjoint or here, but we don't yet
173 // have an IRBuilder API for that.
175 case Instruction::Add
:
176 Start
= Builder
.CreateAdd(Start
, Splat
);
178 case Instruction::Mul
:
179 Start
= Builder
.CreateMul(Start
, Splat
);
180 Stride
= Builder
.CreateMul(Stride
, Splat
);
182 case Instruction::Shl
:
183 Start
= Builder
.CreateShl(Start
, Splat
);
184 Stride
= Builder
.CreateShl(Stride
, Splat
);
188 return std::make_pair(Start
, Stride
);
191 // Recursively, walk about the use-def chain until we find a Phi with a strided
192 // start value. Build and update a scalar recurrence as we unwind the recursion.
193 // We also update the Stride as we unwind. Our goal is to move all of the
194 // arithmetic out of the loop.
195 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value
*Index
, Loop
*L
,
198 BinaryOperator
*&Inc
,
199 IRBuilderBase
&Builder
) {
200 // Our base case is a Phi.
201 if (auto *Phi
= dyn_cast
<PHINode
>(Index
)) {
202 // A phi node we want to perform this function on should be from the
204 if (Phi
->getParent() != L
->getHeader())
208 if (!matchSimpleRecurrence(Phi
, Inc
, Start
, Step
) ||
209 Inc
->getOpcode() != Instruction::Add
)
211 assert(Phi
->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
212 unsigned IncrementingBlock
= Phi
->getIncomingValue(0) == Inc
? 0 : 1;
213 assert(Phi
->getIncomingValue(IncrementingBlock
) == Inc
&&
214 "Expected one operand of phi to be Inc");
216 // Only proceed if the step is loop invariant.
217 if (!L
->isLoopInvariant(Step
))
220 // Step should be a splat.
221 Step
= getSplatValue(Step
);
225 std::tie(Start
, Stride
) = matchStridedStart(Start
, Builder
);
228 assert(Stride
!= nullptr);
230 // Build scalar phi and increment.
232 PHINode::Create(Start
->getType(), 2, Phi
->getName() + ".scalar", Phi
->getIterator());
233 Inc
= BinaryOperator::CreateAdd(BasePtr
, Step
, Inc
->getName() + ".scalar",
235 BasePtr
->addIncoming(Start
, Phi
->getIncomingBlock(1 - IncrementingBlock
));
236 BasePtr
->addIncoming(Inc
, Phi
->getIncomingBlock(IncrementingBlock
));
238 // Note that this Phi might be eligible for removal.
239 MaybeDeadPHIs
.push_back(Phi
);
243 // Otherwise look for binary operator.
244 auto *BO
= dyn_cast
<BinaryOperator
>(Index
);
248 switch (BO
->getOpcode()) {
251 case Instruction::Or
:
252 // We need to be able to treat Or as Add.
253 if (!cast
<PossiblyDisjointInst
>(BO
)->isDisjoint())
256 case Instruction::Add
:
258 case Instruction::Shl
:
260 case Instruction::Mul
:
264 // We should have one operand in the loop and one splat.
266 if (isa
<Instruction
>(BO
->getOperand(0)) &&
267 L
->contains(cast
<Instruction
>(BO
->getOperand(0)))) {
268 Index
= cast
<Instruction
>(BO
->getOperand(0));
269 OtherOp
= BO
->getOperand(1);
270 } else if (isa
<Instruction
>(BO
->getOperand(1)) &&
271 L
->contains(cast
<Instruction
>(BO
->getOperand(1))) &&
272 Instruction::isCommutative(BO
->getOpcode())) {
273 Index
= cast
<Instruction
>(BO
->getOperand(1));
274 OtherOp
= BO
->getOperand(0);
279 // Make sure other op is loop invariant.
280 if (!L
->isLoopInvariant(OtherOp
))
283 // Make sure we have a splat.
284 Value
*SplatOp
= getSplatValue(OtherOp
);
288 // Recurse up the use-def chain.
289 if (!matchStridedRecurrence(Index
, L
, Stride
, BasePtr
, Inc
, Builder
))
292 // Locate the Step and Start values from the recurrence.
293 unsigned StepIndex
= Inc
->getOperand(0) == BasePtr
? 1 : 0;
294 unsigned StartBlock
= BasePtr
->getOperand(0) == Inc
? 1 : 0;
295 Value
*Step
= Inc
->getOperand(StepIndex
);
296 Value
*Start
= BasePtr
->getOperand(StartBlock
);
298 // We need to adjust the start value in the preheader.
299 Builder
.SetInsertPoint(
300 BasePtr
->getIncomingBlock(StartBlock
)->getTerminator());
301 Builder
.SetCurrentDebugLocation(DebugLoc());
303 switch (BO
->getOpcode()) {
305 llvm_unreachable("Unexpected opcode!");
306 case Instruction::Add
:
307 case Instruction::Or
: {
308 // An add only affects the start value. It's ok to do this for Or because
309 // we already checked that there are no common set bits.
310 Start
= Builder
.CreateAdd(Start
, SplatOp
, "start");
313 case Instruction::Mul
: {
314 Start
= Builder
.CreateMul(Start
, SplatOp
, "start");
315 Step
= Builder
.CreateMul(Step
, SplatOp
, "step");
316 Stride
= Builder
.CreateMul(Stride
, SplatOp
, "stride");
319 case Instruction::Shl
: {
320 Start
= Builder
.CreateShl(Start
, SplatOp
, "start");
321 Step
= Builder
.CreateShl(Step
, SplatOp
, "step");
322 Stride
= Builder
.CreateShl(Stride
, SplatOp
, "stride");
327 Inc
->setOperand(StepIndex
, Step
);
328 BasePtr
->setIncomingValue(StartBlock
, Start
);
332 std::pair
<Value
*, Value
*>
333 RISCVGatherScatterLowering::determineBaseAndStride(Instruction
*Ptr
,
334 IRBuilderBase
&Builder
) {
336 // A gather/scatter of a splat is a zero strided load/store.
337 if (auto *BasePtr
= getSplatValue(Ptr
)) {
338 Type
*IntPtrTy
= DL
->getIntPtrType(BasePtr
->getType());
339 return std::make_pair(BasePtr
, ConstantInt::get(IntPtrTy
, 0));
342 auto *GEP
= dyn_cast
<GetElementPtrInst
>(Ptr
);
344 return std::make_pair(nullptr, nullptr);
346 auto I
= StridedAddrs
.find(GEP
);
347 if (I
!= StridedAddrs
.end())
350 SmallVector
<Value
*, 2> Ops(GEP
->operands());
352 // Base pointer needs to be a scalar.
353 Value
*ScalarBase
= Ops
[0];
354 if (ScalarBase
->getType()->isVectorTy()) {
355 ScalarBase
= getSplatValue(ScalarBase
);
357 return std::make_pair(nullptr, nullptr);
360 std::optional
<unsigned> VecOperand
;
361 unsigned TypeScale
= 0;
363 // Look for a vector operand and scale.
364 gep_type_iterator GTI
= gep_type_begin(GEP
);
365 for (unsigned i
= 1, e
= GEP
->getNumOperands(); i
!= e
; ++i
, ++GTI
) {
366 if (!Ops
[i
]->getType()->isVectorTy())
370 return std::make_pair(nullptr, nullptr);
374 TypeSize TS
= GTI
.getSequentialElementStride(*DL
);
376 return std::make_pair(nullptr, nullptr);
378 TypeScale
= TS
.getFixedValue();
381 // We need to find a vector index to simplify.
383 return std::make_pair(nullptr, nullptr);
385 // We can't extract the stride if the arithmetic is done at a different size
386 // than the pointer type. Adding the stride later may not wrap correctly.
387 // Technically we could handle wider indices, but I don't expect that in
388 // practice. Handle one special case here - constants. This simplifies
389 // writing test cases.
390 Value
*VecIndex
= Ops
[*VecOperand
];
391 Type
*VecIntPtrTy
= DL
->getIntPtrType(GEP
->getType());
392 if (VecIndex
->getType() != VecIntPtrTy
) {
393 auto *VecIndexC
= dyn_cast
<Constant
>(VecIndex
);
395 return std::make_pair(nullptr, nullptr);
396 if (VecIndex
->getType()->getScalarSizeInBits() > VecIntPtrTy
->getScalarSizeInBits())
397 VecIndex
= ConstantFoldCastInstruction(Instruction::Trunc
, VecIndexC
, VecIntPtrTy
);
399 VecIndex
= ConstantFoldCastInstruction(Instruction::SExt
, VecIndexC
, VecIntPtrTy
);
402 // Handle the non-recursive case. This is what we see if the vectorizer
403 // decides to use a scalar IV + vid on demand instead of a vector IV.
404 auto [Start
, Stride
] = matchStridedStart(VecIndex
, Builder
);
407 Builder
.SetInsertPoint(GEP
);
409 // Replace the vector index with the scalar start and build a scalar GEP.
410 Ops
[*VecOperand
] = Start
;
411 Type
*SourceTy
= GEP
->getSourceElementType();
413 Builder
.CreateGEP(SourceTy
, ScalarBase
, ArrayRef(Ops
).drop_front());
415 // Convert stride to pointer size if needed.
416 Type
*IntPtrTy
= DL
->getIntPtrType(BasePtr
->getType());
417 assert(Stride
->getType() == IntPtrTy
&& "Unexpected type");
419 // Scale the stride by the size of the indexed type.
421 Stride
= Builder
.CreateMul(Stride
, ConstantInt::get(IntPtrTy
, TypeScale
));
423 auto P
= std::make_pair(BasePtr
, Stride
);
424 StridedAddrs
[GEP
] = P
;
428 // Make sure we're in a loop and that has a pre-header and a single latch.
429 Loop
*L
= LI
->getLoopFor(GEP
->getParent());
430 if (!L
|| !L
->getLoopPreheader() || !L
->getLoopLatch())
431 return std::make_pair(nullptr, nullptr);
435 if (!matchStridedRecurrence(VecIndex
, L
, Stride
, BasePhi
, Inc
, Builder
))
436 return std::make_pair(nullptr, nullptr);
438 assert(BasePhi
->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
439 unsigned IncrementingBlock
= BasePhi
->getOperand(0) == Inc
? 0 : 1;
440 assert(BasePhi
->getIncomingValue(IncrementingBlock
) == Inc
&&
441 "Expected one operand of phi to be Inc");
443 Builder
.SetInsertPoint(GEP
);
445 // Replace the vector index with the scalar phi and build a scalar GEP.
446 Ops
[*VecOperand
] = BasePhi
;
447 Type
*SourceTy
= GEP
->getSourceElementType();
449 Builder
.CreateGEP(SourceTy
, ScalarBase
, ArrayRef(Ops
).drop_front());
451 // Final adjustments to stride should go in the start block.
452 Builder
.SetInsertPoint(
453 BasePhi
->getIncomingBlock(1 - IncrementingBlock
)->getTerminator());
455 // Convert stride to pointer size if needed.
456 Type
*IntPtrTy
= DL
->getIntPtrType(BasePtr
->getType());
457 assert(Stride
->getType() == IntPtrTy
&& "Unexpected type");
459 // Scale the stride by the size of the indexed type.
461 Stride
= Builder
.CreateMul(Stride
, ConstantInt::get(IntPtrTy
, TypeScale
));
463 auto P
= std::make_pair(BasePtr
, Stride
);
464 StridedAddrs
[GEP
] = P
;
468 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst
*II
,
472 // Make sure the operation will be supported by the backend.
473 MaybeAlign MA
= cast
<ConstantInt
>(AlignOp
)->getMaybeAlignValue();
474 EVT DataTypeVT
= TLI
->getValueType(*DL
, DataType
);
475 if (!MA
|| !TLI
->isLegalStridedLoadStore(DataTypeVT
, *MA
))
478 // FIXME: Let the backend type legalize by splitting/widening?
479 if (!TLI
->isTypeLegal(DataTypeVT
))
482 // Pointer should be an instruction.
483 auto *PtrI
= dyn_cast
<Instruction
>(Ptr
);
487 LLVMContext
&Ctx
= PtrI
->getContext();
488 IRBuilder
<InstSimplifyFolder
> Builder(Ctx
, *DL
);
489 Builder
.SetInsertPoint(PtrI
);
491 Value
*BasePtr
, *Stride
;
492 std::tie(BasePtr
, Stride
) = determineBaseAndStride(PtrI
, Builder
);
495 assert(Stride
!= nullptr);
497 Builder
.SetInsertPoint(II
);
500 if (II
->getIntrinsicID() == Intrinsic::masked_gather
)
501 Call
= Builder
.CreateIntrinsic(
502 Intrinsic::riscv_masked_strided_load
,
503 {DataType
, BasePtr
->getType(), Stride
->getType()},
504 {II
->getArgOperand(3), BasePtr
, Stride
, II
->getArgOperand(2)});
506 Call
= Builder
.CreateIntrinsic(
507 Intrinsic::riscv_masked_strided_store
,
508 {DataType
, BasePtr
->getType(), Stride
->getType()},
509 {II
->getArgOperand(0), BasePtr
, Stride
, II
->getArgOperand(3)});
512 II
->replaceAllUsesWith(Call
);
513 II
->eraseFromParent();
515 if (PtrI
->use_empty())
516 RecursivelyDeleteTriviallyDeadInstructions(PtrI
);
521 bool RISCVGatherScatterLowering::runOnFunction(Function
&F
) {
525 auto &TPC
= getAnalysis
<TargetPassConfig
>();
526 auto &TM
= TPC
.getTM
<RISCVTargetMachine
>();
527 ST
= &TM
.getSubtarget
<RISCVSubtarget
>(F
);
528 if (!ST
->hasVInstructions() || !ST
->useRVVForFixedLengthVectors())
531 TLI
= ST
->getTargetLowering();
532 DL
= &F
.getParent()->getDataLayout();
533 LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
535 StridedAddrs
.clear();
537 SmallVector
<IntrinsicInst
*, 4> Gathers
;
538 SmallVector
<IntrinsicInst
*, 4> Scatters
;
540 bool Changed
= false;
542 for (BasicBlock
&BB
: F
) {
543 for (Instruction
&I
: BB
) {
544 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(&I
);
545 if (II
&& II
->getIntrinsicID() == Intrinsic::masked_gather
) {
546 Gathers
.push_back(II
);
547 } else if (II
&& II
->getIntrinsicID() == Intrinsic::masked_scatter
) {
548 Scatters
.push_back(II
);
553 // Rewrite gather/scatter to form strided load/store if possible.
554 for (auto *II
: Gathers
)
555 Changed
|= tryCreateStridedLoadStore(
556 II
, II
->getType(), II
->getArgOperand(0), II
->getArgOperand(1));
557 for (auto *II
: Scatters
)
559 tryCreateStridedLoadStore(II
, II
->getArgOperand(0)->getType(),
560 II
->getArgOperand(1), II
->getArgOperand(2));
562 // Remove any dead phis.
563 while (!MaybeDeadPHIs
.empty()) {
564 if (auto *Phi
= dyn_cast_or_null
<PHINode
>(MaybeDeadPHIs
.pop_back_val()))
565 RecursivelyDeleteDeadPHINode(Phi
);