1 //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass interleaves around sext/zext/trunc instructions. MVE does not have
10 // a single sext/zext or trunc instruction that takes the bottom half of a
11 // vector and extends to a full width, like NEON has with MOVL. Instead it is
12 // expected that this happens through top/bottom instructions. So the MVE
13 // equivalent VMOVLT/B instructions take either the even or odd elements of the
14 // input and extend them to the larger type, producing a vector with half the
15 // number of elements each of double the bitwidth. As there is no simple
16 // instruction, we often have to turn sext/zext/trunc into a series of lane
17 // moves (or stack loads/stores, which we do not do yet).
19 // This pass takes vector code that starts at truncs, looks for interconnected
20 // blobs of operations that end with sext/zext (or constants/splats) of the
22 // %sa = sext v8i16 %a to v8i32
23 // %sb = sext v8i16 %b to v8i32
24 // %add = add v8i32 %sa, %sb
25 // %r = trunc %add to v8i16
26 // And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
27 // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
28 // %sa = sext v8i16 %sha to v8i32
29 // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
30 // %sb = sext v8i16 %shb to v8i32
31 // %add = add v8i32 %sa, %sb
32 // %r = trunc %add to v8i16
33 // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
34 // Which can then be split and lowered to MVE instructions efficiently:
35 // %sa_b = VMOVLB.s16 %a
36 // %sa_t = VMOVLT.s16 %a
37 // %sb_b = VMOVLB.s16 %b
38 // %sb_t = VMOVLT.s16 %b
39 // %add_b = VADD.i32 %sa_b, %sb_b
40 // %add_t = VADD.i32 %sa_t, %sb_t
41 // %r = VMOVNT.i16 %add_b, %add_t
43 //===----------------------------------------------------------------------===//
46 #include "ARMBaseInstrInfo.h"
47 #include "ARMSubtarget.h"
48 #include "llvm/Analysis/TargetTransformInfo.h"
49 #include "llvm/CodeGen/TargetLowering.h"
50 #include "llvm/CodeGen/TargetPassConfig.h"
51 #include "llvm/CodeGen/TargetSubtargetInfo.h"
52 #include "llvm/IR/BasicBlock.h"
53 #include "llvm/IR/Constant.h"
54 #include "llvm/IR/Constants.h"
55 #include "llvm/IR/DerivedTypes.h"
56 #include "llvm/IR/Function.h"
57 #include "llvm/IR/IRBuilder.h"
58 #include "llvm/IR/InstIterator.h"
59 #include "llvm/IR/InstrTypes.h"
60 #include "llvm/IR/Instruction.h"
61 #include "llvm/IR/Instructions.h"
62 #include "llvm/IR/IntrinsicInst.h"
63 #include "llvm/IR/Intrinsics.h"
64 #include "llvm/IR/IntrinsicsARM.h"
65 #include "llvm/IR/PatternMatch.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/Value.h"
68 #include "llvm/InitializePasses.h"
69 #include "llvm/Pass.h"
70 #include "llvm/Support/Casting.h"
76 #define DEBUG_TYPE "mve-laneinterleave"
78 cl::opt
<bool> EnableInterleave(
79 "enable-mve-interleave", cl::Hidden
, cl::init(true),
80 cl::desc("Enable interleave MVE vector operation lowering"));
84 class MVELaneInterleaving
: public FunctionPass
{
86 static char ID
; // Pass identification, replacement for typeid
88 explicit MVELaneInterleaving() : FunctionPass(ID
) {
89 initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
92 bool runOnFunction(Function
&F
) override
;
94 StringRef
getPassName() const override
{ return "MVE lane interleaving"; }
96 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
98 AU
.addRequired
<TargetPassConfig
>();
99 FunctionPass::getAnalysisUsage(AU
);
103 } // end anonymous namespace
105 char MVELaneInterleaving::ID
= 0;
107 INITIALIZE_PASS(MVELaneInterleaving
, DEBUG_TYPE
, "MVE lane interleaving", false,
110 Pass
*llvm::createMVELaneInterleavingPass() {
111 return new MVELaneInterleaving();
114 static bool isProfitableToInterleave(SmallSetVector
<Instruction
*, 4> &Exts
,
115 SmallSetVector
<Instruction
*, 4> &Truncs
) {
116 // This is not always beneficial to transform. Exts can be incorporated into
117 // loads, Truncs can be folded into stores.
118 // Truncs are usually the same number of instructions,
119 // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
120 // Exts are unfortunately more instructions in the general case:
121 // A=VLDRH.32; B=VLDRH.32;
122 // vs with interleaving:
123 // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
124 // But those VMOVL may be folded into a VMULL.
126 // But expensive extends/truncs are always good to remove. FPExts always
127 // involve extra VCVT's so are always considered to be beneficial to convert.
128 for (auto *E
: Exts
) {
129 if (isa
<FPExtInst
>(E
) || !isa
<LoadInst
>(E
->getOperand(0))) {
130 LLVM_DEBUG(dbgs() << "Beneficial due to " << *E
<< "\n");
134 for (auto *T
: Truncs
) {
135 if (T
->hasOneUse() && !isa
<StoreInst
>(*T
->user_begin())) {
136 LLVM_DEBUG(dbgs() << "Beneficial due to " << *T
<< "\n");
141 // Otherwise, we know we have a load(ext), see if any of the Extends are a
142 // vmull. This is a simple heuristic and certainly not perfect.
143 for (auto *E
: Exts
) {
144 if (!E
->hasOneUse() ||
145 cast
<Instruction
>(*E
->user_begin())->getOpcode() != Instruction::Mul
) {
146 LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E
<< "\n");
153 static bool tryInterleave(Instruction
*Start
,
154 SmallPtrSetImpl
<Instruction
*> &Visited
) {
155 LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start
<< "\n");
156 auto *VT
= cast
<FixedVectorType
>(Start
->getType());
158 if (!isa
<Instruction
>(Start
->getOperand(0)))
161 // Look for connected operations starting from Ext's, terminating at Truncs.
162 std::vector
<Instruction
*> Worklist
;
163 Worklist
.push_back(Start
);
164 Worklist
.push_back(cast
<Instruction
>(Start
->getOperand(0)));
166 SmallSetVector
<Instruction
*, 4> Truncs
;
167 SmallSetVector
<Instruction
*, 4> Exts
;
168 SmallSetVector
<Use
*, 4> OtherLeafs
;
169 SmallSetVector
<Instruction
*, 4> Ops
;
171 while (!Worklist
.empty()) {
172 Instruction
*I
= Worklist
.back();
175 switch (I
->getOpcode()) {
177 case Instruction::Trunc
:
178 case Instruction::FPTrunc
:
186 case Instruction::SExt
:
187 case Instruction::ZExt
:
188 case Instruction::FPExt
:
191 for (auto *Use
: I
->users())
192 Worklist
.push_back(cast
<Instruction
>(Use
));
196 case Instruction::Call
: {
197 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(I
);
201 switch (II
->getIntrinsicID()) {
203 case Intrinsic::smin
:
204 case Intrinsic::smax
:
205 case Intrinsic::umin
:
206 case Intrinsic::umax
:
207 case Intrinsic::sadd_sat
:
208 case Intrinsic::ssub_sat
:
209 case Intrinsic::uadd_sat
:
210 case Intrinsic::usub_sat
:
211 case Intrinsic::minnum
:
212 case Intrinsic::maxnum
:
213 case Intrinsic::fabs
:
215 case Intrinsic::ceil
:
216 case Intrinsic::floor
:
217 case Intrinsic::rint
:
218 case Intrinsic::round
:
219 case Intrinsic::trunc
:
224 LLVM_FALLTHROUGH
; // Fall through to treating these like an operator below.
226 // Binary/tertiary ops
227 case Instruction::Add
:
228 case Instruction::Sub
:
229 case Instruction::Mul
:
230 case Instruction::AShr
:
231 case Instruction::LShr
:
232 case Instruction::Shl
:
233 case Instruction::ICmp
:
234 case Instruction::FCmp
:
235 case Instruction::FAdd
:
236 case Instruction::FMul
:
237 case Instruction::Select
:
242 for (Use
&Op
: I
->operands()) {
243 if (!isa
<FixedVectorType
>(Op
->getType()))
245 if (isa
<Instruction
>(Op
))
246 Worklist
.push_back(cast
<Instruction
>(&Op
));
248 OtherLeafs
.insert(&Op
);
251 for (auto *Use
: I
->users())
252 Worklist
.push_back(cast
<Instruction
>(Use
));
255 case Instruction::ShuffleVector
:
256 // A shuffle of a splat is a splat.
257 if (cast
<ShuffleVectorInst
>(I
)->isZeroEltSplat())
262 LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I
<< "\n");
267 if (Exts
.empty() && OtherLeafs
.empty())
271 dbgs() << "Found group:\n Exts:";
273 dbgs() << " " << *I
<< "\n";
276 dbgs() << " " << *I
<< "\n";
277 dbgs() << " OtherLeafs:";
278 for (auto *I
: OtherLeafs
)
279 dbgs() << " " << *I
->get() << " of " << *I
->getUser() << "\n";
281 for (auto *I
: Truncs
)
282 dbgs() << " " << *I
<< "\n";
285 assert(!Truncs
.empty() && "Expected some truncs");
288 unsigned NumElts
= VT
->getNumElements();
289 unsigned BaseElts
= VT
->getScalarSizeInBits() == 16
291 : (VT
->getScalarSizeInBits() == 8 ? 16 : 0);
292 if (BaseElts
== 0 || NumElts
% BaseElts
!= 0) {
293 LLVM_DEBUG(dbgs() << " Type is unsupported\n");
296 if (Start
->getOperand(0)->getType()->getScalarSizeInBits() !=
297 VT
->getScalarSizeInBits() * 2) {
298 LLVM_DEBUG(dbgs() << " Type not double sized\n");
301 for (Instruction
*I
: Exts
)
302 if (I
->getOperand(0)->getType() != VT
) {
303 LLVM_DEBUG(dbgs() << " Wrong type on " << *I
<< "\n");
306 for (Instruction
*I
: Truncs
)
307 if (I
->getType() != VT
) {
308 LLVM_DEBUG(dbgs() << " Wrong type on " << *I
<< "\n");
312 // Check that it looks beneficial
313 if (!isProfitableToInterleave(Exts
, Truncs
))
316 // Create new shuffles around the extends / truncs / other leaves.
317 IRBuilder
<> Builder(Start
);
319 SmallVector
<int, 16> LeafMask
;
320 SmallVector
<int, 16> TruncMask
;
321 // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15
322 // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15
323 for (unsigned Base
= 0; Base
< NumElts
; Base
+= BaseElts
) {
324 for (unsigned i
= 0; i
< BaseElts
/ 2; i
++)
325 LeafMask
.push_back(Base
+ i
* 2);
326 for (unsigned i
= 0; i
< BaseElts
/ 2; i
++)
327 LeafMask
.push_back(Base
+ i
* 2 + 1);
329 for (unsigned Base
= 0; Base
< NumElts
; Base
+= BaseElts
) {
330 for (unsigned i
= 0; i
< BaseElts
/ 2; i
++) {
331 TruncMask
.push_back(Base
+ i
);
332 TruncMask
.push_back(Base
+ i
+ BaseElts
/ 2);
336 for (Instruction
*I
: Exts
) {
337 LLVM_DEBUG(dbgs() << "Replacing ext " << *I
<< "\n");
338 Builder
.SetInsertPoint(I
);
339 Value
*Shuffle
= Builder
.CreateShuffleVector(I
->getOperand(0), LeafMask
);
340 bool FPext
= isa
<FPExtInst
>(I
);
341 bool Sext
= isa
<SExtInst
>(I
);
342 Value
*Ext
= FPext
? Builder
.CreateFPExt(Shuffle
, I
->getType())
343 : Sext
? Builder
.CreateSExt(Shuffle
, I
->getType())
344 : Builder
.CreateZExt(Shuffle
, I
->getType());
345 I
->replaceAllUsesWith(Ext
);
346 LLVM_DEBUG(dbgs() << " with " << *Shuffle
<< "\n");
349 for (Use
*I
: OtherLeafs
) {
350 LLVM_DEBUG(dbgs() << "Replacing leaf " << *I
<< "\n");
351 Builder
.SetInsertPoint(cast
<Instruction
>(I
->getUser()));
352 Value
*Shuffle
= Builder
.CreateShuffleVector(I
->get(), LeafMask
);
353 I
->getUser()->setOperand(I
->getOperandNo(), Shuffle
);
354 LLVM_DEBUG(dbgs() << " with " << *Shuffle
<< "\n");
357 for (Instruction
*I
: Truncs
) {
358 LLVM_DEBUG(dbgs() << "Replacing trunc " << *I
<< "\n");
360 Builder
.SetInsertPoint(I
->getParent(), ++I
->getIterator());
361 Value
*Shuf
= Builder
.CreateShuffleVector(I
, TruncMask
);
362 I
->replaceAllUsesWith(Shuf
);
363 cast
<Instruction
>(Shuf
)->setOperand(0, I
);
365 LLVM_DEBUG(dbgs() << " with " << *Shuf
<< "\n");
371 bool MVELaneInterleaving::runOnFunction(Function
&F
) {
372 if (!EnableInterleave
)
374 auto &TPC
= getAnalysis
<TargetPassConfig
>();
375 auto &TM
= TPC
.getTM
<TargetMachine
>();
376 auto *ST
= &TM
.getSubtarget
<ARMSubtarget
>(F
);
377 if (!ST
->hasMVEIntegerOps())
380 bool Changed
= false;
382 SmallPtrSet
<Instruction
*, 16> Visited
;
383 for (Instruction
&I
: reverse(instructions(F
))) {
384 if (I
.getType()->isVectorTy() &&
385 (isa
<TruncInst
>(I
) || isa
<FPTruncInst
>(I
)) && !Visited
.count(&I
))
386 Changed
|= tryInterleave(&I
, Visited
);