1 //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass interleaves around sext/zext/trunc instructions. MVE does not have
10 // a single sext/zext or trunc instruction that takes the bottom half of a
11 // vector and extends to a full width, like NEON has with MOVL. Instead it is
12 // expected that this happens through top/bottom instructions. So the MVE
13 // equivalent VMOVLT/B instructions take either the even or odd elements of the
14 // input and extend them to the larger type, producing a vector with half the
15 // number of elements each of double the bitwidth. As there is no simple
16 // instruction, we often have to turn sext/zext/trunc into a series of lane
17 // moves (or stack loads/stores, which we do not do yet).
19 // This pass takes vector code that starts at truncs, looks for interconnected
20 // blobs of operations that end with sext/zext (or constants/splats) of the
22 // %sa = sext v8i16 %a to v8i32
23 // %sb = sext v8i16 %b to v8i32
24 // %add = add v8i32 %sa, %sb
25 // %r = trunc %add to v8i16
26 // And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
27 // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
28 // %sa = sext v8i16 %sha to v8i32
29 // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
30 // %sb = sext v8i16 %shb to v8i32
31 // %add = add v8i32 %sa, %sb
32 // %r = trunc %add to v8i16
33 // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
34 // Which can then be split and lowered to MVE instructions efficiently:
35 // %sa_b = VMOVLB.s16 %a
36 // %sa_t = VMOVLT.s16 %a
37 // %sb_b = VMOVLB.s16 %b
38 // %sb_t = VMOVLT.s16 %b
39 // %add_b = VADD.i32 %sa_b, %sb_b
40 // %add_t = VADD.i32 %sa_t, %sb_t
41 // %r = VMOVNT.i16 %add_b, %add_t
43 //===----------------------------------------------------------------------===//
46 #include "ARMBaseInstrInfo.h"
47 #include "ARMSubtarget.h"
48 #include "llvm/ADT/SetVector.h"
49 #include "llvm/Analysis/TargetTransformInfo.h"
50 #include "llvm/CodeGen/TargetLowering.h"
51 #include "llvm/CodeGen/TargetPassConfig.h"
52 #include "llvm/IR/BasicBlock.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/IRBuilder.h"
56 #include "llvm/IR/InstIterator.h"
57 #include "llvm/IR/InstrTypes.h"
58 #include "llvm/IR/Instruction.h"
59 #include "llvm/IR/Instructions.h"
60 #include "llvm/IR/IntrinsicInst.h"
61 #include "llvm/IR/Intrinsics.h"
62 #include "llvm/IR/Type.h"
63 #include "llvm/IR/Value.h"
64 #include "llvm/InitializePasses.h"
65 #include "llvm/Pass.h"
66 #include "llvm/Support/Casting.h"
71 #define DEBUG_TYPE "mve-laneinterleave"
73 cl::opt
<bool> EnableInterleave(
74 "enable-mve-interleave", cl::Hidden
, cl::init(true),
75 cl::desc("Enable interleave MVE vector operation lowering"));
79 class MVELaneInterleaving
: public FunctionPass
{
81 static char ID
; // Pass identification, replacement for typeid
83 explicit MVELaneInterleaving() : FunctionPass(ID
) {
84 initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
87 bool runOnFunction(Function
&F
) override
;
89 StringRef
getPassName() const override
{ return "MVE lane interleaving"; }
91 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
93 AU
.addRequired
<TargetPassConfig
>();
94 FunctionPass::getAnalysisUsage(AU
);
98 } // end anonymous namespace
100 char MVELaneInterleaving::ID
= 0;
102 INITIALIZE_PASS(MVELaneInterleaving
, DEBUG_TYPE
, "MVE lane interleaving", false,
105 Pass
*llvm::createMVELaneInterleavingPass() {
106 return new MVELaneInterleaving();
109 static bool isProfitableToInterleave(SmallSetVector
<Instruction
*, 4> &Exts
,
110 SmallSetVector
<Instruction
*, 4> &Truncs
) {
111 // This is not always beneficial to transform. Exts can be incorporated into
112 // loads, Truncs can be folded into stores.
113 // Truncs are usually the same number of instructions,
114 // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
115 // Exts are unfortunately more instructions in the general case:
116 // A=VLDRH.32; B=VLDRH.32;
117 // vs with interleaving:
118 // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
119 // But those VMOVL may be folded into a VMULL.
121 // But expensive extends/truncs are always good to remove. FPExts always
122 // involve extra VCVT's so are always considered to be beneficial to convert.
123 for (auto *E
: Exts
) {
124 if (isa
<FPExtInst
>(E
) || !isa
<LoadInst
>(E
->getOperand(0))) {
125 LLVM_DEBUG(dbgs() << "Beneficial due to " << *E
<< "\n");
129 for (auto *T
: Truncs
) {
130 if (T
->hasOneUse() && !isa
<StoreInst
>(*T
->user_begin())) {
131 LLVM_DEBUG(dbgs() << "Beneficial due to " << *T
<< "\n");
136 // Otherwise, we know we have a load(ext), see if any of the Extends are a
137 // vmull. This is a simple heuristic and certainly not perfect.
138 for (auto *E
: Exts
) {
139 if (!E
->hasOneUse() ||
140 cast
<Instruction
>(*E
->user_begin())->getOpcode() != Instruction::Mul
) {
141 LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E
<< "\n");
148 static bool tryInterleave(Instruction
*Start
,
149 SmallPtrSetImpl
<Instruction
*> &Visited
) {
150 LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start
<< "\n");
152 if (!isa
<Instruction
>(Start
->getOperand(0)))
155 // Look for connected operations starting from Ext's, terminating at Truncs.
156 std::vector
<Instruction
*> Worklist
;
157 Worklist
.push_back(Start
);
158 Worklist
.push_back(cast
<Instruction
>(Start
->getOperand(0)));
160 SmallSetVector
<Instruction
*, 4> Truncs
;
161 SmallSetVector
<Instruction
*, 4> Reducts
;
162 SmallSetVector
<Instruction
*, 4> Exts
;
163 SmallSetVector
<Use
*, 4> OtherLeafs
;
164 SmallSetVector
<Instruction
*, 4> Ops
;
166 while (!Worklist
.empty()) {
167 Instruction
*I
= Worklist
.back();
170 switch (I
->getOpcode()) {
172 case Instruction::Trunc
:
173 case Instruction::FPTrunc
:
174 if (!Truncs
.insert(I
))
180 case Instruction::SExt
:
181 case Instruction::ZExt
:
182 case Instruction::FPExt
:
185 for (auto *Use
: I
->users())
186 Worklist
.push_back(cast
<Instruction
>(Use
));
190 case Instruction::Call
: {
191 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(I
);
195 if (II
->getIntrinsicID() == Intrinsic::vector_reduce_add
) {
196 if (!Reducts
.insert(I
))
202 switch (II
->getIntrinsicID()) {
204 case Intrinsic::smin
:
205 case Intrinsic::smax
:
206 case Intrinsic::umin
:
207 case Intrinsic::umax
:
208 case Intrinsic::sadd_sat
:
209 case Intrinsic::ssub_sat
:
210 case Intrinsic::uadd_sat
:
211 case Intrinsic::usub_sat
:
212 case Intrinsic::minnum
:
213 case Intrinsic::maxnum
:
214 case Intrinsic::fabs
:
216 case Intrinsic::ceil
:
217 case Intrinsic::floor
:
218 case Intrinsic::rint
:
219 case Intrinsic::round
:
220 case Intrinsic::trunc
:
225 [[fallthrough
]]; // Fall through to treating these like an operator below.
227 // Binary/tertiary ops
228 case Instruction::Add
:
229 case Instruction::Sub
:
230 case Instruction::Mul
:
231 case Instruction::AShr
:
232 case Instruction::LShr
:
233 case Instruction::Shl
:
234 case Instruction::ICmp
:
235 case Instruction::FCmp
:
236 case Instruction::FAdd
:
237 case Instruction::FMul
:
238 case Instruction::Select
:
242 for (Use
&Op
: I
->operands()) {
243 if (!isa
<FixedVectorType
>(Op
->getType()))
245 if (isa
<Instruction
>(Op
))
246 Worklist
.push_back(cast
<Instruction
>(&Op
));
248 OtherLeafs
.insert(&Op
);
251 for (auto *Use
: I
->users())
252 Worklist
.push_back(cast
<Instruction
>(Use
));
255 case Instruction::ShuffleVector
:
256 // A shuffle of a splat is a splat.
257 if (cast
<ShuffleVectorInst
>(I
)->isZeroEltSplat())
262 LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I
<< "\n");
267 if (Exts
.empty() && OtherLeafs
.empty())
271 dbgs() << "Found group:\n Exts:\n";
273 dbgs() << " " << *I
<< "\n";
276 dbgs() << " " << *I
<< "\n";
277 dbgs() << " OtherLeafs:\n";
278 for (auto *I
: OtherLeafs
)
279 dbgs() << " " << *I
->get() << " of " << *I
->getUser() << "\n";
280 dbgs() << " Truncs:\n";
281 for (auto *I
: Truncs
)
282 dbgs() << " " << *I
<< "\n";
283 dbgs() << " Reducts:\n";
284 for (auto *I
: Reducts
)
285 dbgs() << " " << *I
<< "\n";
288 assert((!Truncs
.empty() || !Reducts
.empty()) &&
289 "Expected some truncs or reductions");
290 if (Truncs
.empty() && Exts
.empty())
293 auto *VT
= !Truncs
.empty()
294 ? cast
<FixedVectorType
>(Truncs
[0]->getType())
295 : cast
<FixedVectorType
>(Exts
[0]->getOperand(0)->getType());
296 LLVM_DEBUG(dbgs() << "Using VT:" << *VT
<< "\n");
299 unsigned NumElts
= VT
->getNumElements();
300 unsigned BaseElts
= VT
->getScalarSizeInBits() == 16
302 : (VT
->getScalarSizeInBits() == 8 ? 16 : 0);
303 if (BaseElts
== 0 || NumElts
% BaseElts
!= 0) {
304 LLVM_DEBUG(dbgs() << " Type is unsupported\n");
307 if (Start
->getOperand(0)->getType()->getScalarSizeInBits() !=
308 VT
->getScalarSizeInBits() * 2) {
309 LLVM_DEBUG(dbgs() << " Type not double sized\n");
312 for (Instruction
*I
: Exts
)
313 if (I
->getOperand(0)->getType() != VT
) {
314 LLVM_DEBUG(dbgs() << " Wrong type on " << *I
<< "\n");
317 for (Instruction
*I
: Truncs
)
318 if (I
->getType() != VT
) {
319 LLVM_DEBUG(dbgs() << " Wrong type on " << *I
<< "\n");
323 // Check that it looks beneficial
324 if (!isProfitableToInterleave(Exts
, Truncs
))
326 if (!Reducts
.empty() && (Ops
.empty() || all_of(Ops
, [](Instruction
*I
) {
327 return I
->getOpcode() == Instruction::Mul
||
328 I
->getOpcode() == Instruction::Select
||
329 I
->getOpcode() == Instruction::ICmp
;
331 LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n");
335 // Create new shuffles around the extends / truncs / other leaves.
336 IRBuilder
<> Builder(Start
);
338 SmallVector
<int, 16> LeafMask
;
339 SmallVector
<int, 16> TruncMask
;
340 // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15
341 // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15
342 for (unsigned Base
= 0; Base
< NumElts
; Base
+= BaseElts
) {
343 for (unsigned i
= 0; i
< BaseElts
/ 2; i
++)
344 LeafMask
.push_back(Base
+ i
* 2);
345 for (unsigned i
= 0; i
< BaseElts
/ 2; i
++)
346 LeafMask
.push_back(Base
+ i
* 2 + 1);
348 for (unsigned Base
= 0; Base
< NumElts
; Base
+= BaseElts
) {
349 for (unsigned i
= 0; i
< BaseElts
/ 2; i
++) {
350 TruncMask
.push_back(Base
+ i
);
351 TruncMask
.push_back(Base
+ i
+ BaseElts
/ 2);
355 for (Instruction
*I
: Exts
) {
356 LLVM_DEBUG(dbgs() << "Replacing ext " << *I
<< "\n");
357 Builder
.SetInsertPoint(I
);
358 Value
*Shuffle
= Builder
.CreateShuffleVector(I
->getOperand(0), LeafMask
);
359 bool FPext
= isa
<FPExtInst
>(I
);
360 bool Sext
= isa
<SExtInst
>(I
);
361 Value
*Ext
= FPext
? Builder
.CreateFPExt(Shuffle
, I
->getType())
362 : Sext
? Builder
.CreateSExt(Shuffle
, I
->getType())
363 : Builder
.CreateZExt(Shuffle
, I
->getType());
364 I
->replaceAllUsesWith(Ext
);
365 LLVM_DEBUG(dbgs() << " with " << *Shuffle
<< "\n");
368 for (Use
*I
: OtherLeafs
) {
369 LLVM_DEBUG(dbgs() << "Replacing leaf " << *I
<< "\n");
370 Builder
.SetInsertPoint(cast
<Instruction
>(I
->getUser()));
371 Value
*Shuffle
= Builder
.CreateShuffleVector(I
->get(), LeafMask
);
372 I
->getUser()->setOperand(I
->getOperandNo(), Shuffle
);
373 LLVM_DEBUG(dbgs() << " with " << *Shuffle
<< "\n");
376 for (Instruction
*I
: Truncs
) {
377 LLVM_DEBUG(dbgs() << "Replacing trunc " << *I
<< "\n");
379 Builder
.SetInsertPoint(I
->getParent(), ++I
->getIterator());
380 Value
*Shuf
= Builder
.CreateShuffleVector(I
, TruncMask
);
381 I
->replaceAllUsesWith(Shuf
);
382 cast
<Instruction
>(Shuf
)->setOperand(0, I
);
384 LLVM_DEBUG(dbgs() << " with " << *Shuf
<< "\n");
390 // Add reductions are fairly common and associative, meaning we can start the
391 // interleaving from them and don't need to emit a shuffle.
392 static bool isAddReduction(Instruction
&I
) {
393 if (auto *II
= dyn_cast
<IntrinsicInst
>(&I
))
394 return II
->getIntrinsicID() == Intrinsic::vector_reduce_add
;
398 bool MVELaneInterleaving::runOnFunction(Function
&F
) {
399 if (!EnableInterleave
)
401 auto &TPC
= getAnalysis
<TargetPassConfig
>();
402 auto &TM
= TPC
.getTM
<TargetMachine
>();
403 auto *ST
= &TM
.getSubtarget
<ARMSubtarget
>(F
);
404 if (!ST
->hasMVEIntegerOps())
407 bool Changed
= false;
409 SmallPtrSet
<Instruction
*, 16> Visited
;
410 for (Instruction
&I
: reverse(instructions(F
))) {
411 if (((I
.getType()->isVectorTy() &&
412 (isa
<TruncInst
>(I
) || isa
<FPTruncInst
>(I
))) ||
413 isAddReduction(I
)) &&
415 Changed
|= tryInterleave(&I
, Visited
);