1 //===-- X86PartialReduction.cpp -------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass looks for add instructions used by a horizontal reduction to see
10 // if we might be able to use pmaddwd or psadbw. Some cases of this require
11 // cross basic block knowledge and can't be done in SelectionDAG.
13 //===----------------------------------------------------------------------===//
16 #include "llvm/Analysis/ValueTracking.h"
17 #include "llvm/CodeGen/TargetPassConfig.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/IntrinsicsX86.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Operator.h"
23 #include "llvm/Pass.h"
24 #include "X86TargetMachine.h"
28 #define DEBUG_TYPE "x86-partial-reduction"
32 class X86PartialReduction
: public FunctionPass
{
34 const X86Subtarget
*ST
;
37 static char ID
; // Pass identification, replacement for typeid.
39 X86PartialReduction() : FunctionPass(ID
) { }
41 bool runOnFunction(Function
&Fn
) override
;
43 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
47 StringRef
getPassName() const override
{
48 return "X86 Partial Reduction";
52 bool tryMAddReplacement(Instruction
*Op
);
53 bool trySADReplacement(Instruction
*Op
);
57 FunctionPass
*llvm::createX86PartialReductionPass() {
58 return new X86PartialReduction();
61 char X86PartialReduction::ID
= 0;
63 INITIALIZE_PASS(X86PartialReduction
, DEBUG_TYPE
,
64 "X86 Partial Reduction", false, false)
66 bool X86PartialReduction::tryMAddReplacement(Instruction
*Op
) {
70 // Need at least 8 elements.
71 if (cast
<FixedVectorType
>(Op
->getType())->getNumElements() < 8)
74 // Element type should be i32.
75 if (!cast
<VectorType
>(Op
->getType())->getElementType()->isIntegerTy(32))
78 auto *Mul
= dyn_cast
<BinaryOperator
>(Op
);
79 if (!Mul
|| Mul
->getOpcode() != Instruction::Mul
)
82 Value
*LHS
= Mul
->getOperand(0);
83 Value
*RHS
= Mul
->getOperand(1);
85 // LHS and RHS should be only used once or if they are the same then only
86 // used twice. Only check this when SSE4.1 is enabled and we have zext/sext
87 // instructions, otherwise we use punpck to emulate zero extend in stages. The
88 // trunc/ we need to do likely won't introduce new instructions in that case.
91 if (!isa
<Constant
>(LHS
) && !LHS
->hasNUses(2))
94 if (!isa
<Constant
>(LHS
) && !LHS
->hasOneUse())
96 if (!isa
<Constant
>(RHS
) && !RHS
->hasOneUse())
101 auto CanShrinkOp
= [&](Value
*Op
) {
102 auto IsFreeTruncation
= [&](Value
*Op
) {
103 if (auto *Cast
= dyn_cast
<CastInst
>(Op
)) {
104 if (Cast
->getParent() == Mul
->getParent() &&
105 (Cast
->getOpcode() == Instruction::SExt
||
106 Cast
->getOpcode() == Instruction::ZExt
) &&
107 Cast
->getOperand(0)->getType()->getScalarSizeInBits() <= 16)
111 return isa
<Constant
>(Op
);
114 // If the operation can be freely truncated and has enough sign bits we
116 if (IsFreeTruncation(Op
) &&
117 ComputeNumSignBits(Op
, *DL
, 0, nullptr, Mul
) > 16)
120 // SelectionDAG has limited support for truncating through an add or sub if
121 // the inputs are freely truncatable.
122 if (auto *BO
= dyn_cast
<BinaryOperator
>(Op
)) {
123 if (BO
->getParent() == Mul
->getParent() &&
124 IsFreeTruncation(BO
->getOperand(0)) &&
125 IsFreeTruncation(BO
->getOperand(1)) &&
126 ComputeNumSignBits(Op
, *DL
, 0, nullptr, Mul
) > 16)
133 // Both Ops need to be shrinkable.
134 if (!CanShrinkOp(LHS
) && !CanShrinkOp(RHS
))
137 IRBuilder
<> Builder(Mul
);
139 auto *MulTy
= cast
<FixedVectorType
>(Op
->getType());
140 unsigned NumElts
= MulTy
->getNumElements();
142 // Extract even elements and odd elements and add them together. This will
143 // be pattern matched by SelectionDAG to pmaddwd. This instruction will be
144 // half the original width.
145 SmallVector
<int, 16> EvenMask(NumElts
/ 2);
146 SmallVector
<int, 16> OddMask(NumElts
/ 2);
147 for (int i
= 0, e
= NumElts
/ 2; i
!= e
; ++i
) {
149 OddMask
[i
] = i
* 2 + 1;
151 // Creating a new mul so the replaceAllUsesWith below doesn't replace the
152 // uses in the shuffles we're creating.
153 Value
*NewMul
= Builder
.CreateMul(Mul
->getOperand(0), Mul
->getOperand(1));
154 Value
*EvenElts
= Builder
.CreateShuffleVector(NewMul
, NewMul
, EvenMask
);
155 Value
*OddElts
= Builder
.CreateShuffleVector(NewMul
, NewMul
, OddMask
);
156 Value
*MAdd
= Builder
.CreateAdd(EvenElts
, OddElts
);
158 // Concatenate zeroes to extend back to the original type.
159 SmallVector
<int, 32> ConcatMask(NumElts
);
160 std::iota(ConcatMask
.begin(), ConcatMask
.end(), 0);
161 Value
*Zero
= Constant::getNullValue(MAdd
->getType());
162 Value
*Concat
= Builder
.CreateShuffleVector(MAdd
, Zero
, ConcatMask
);
164 Mul
->replaceAllUsesWith(Concat
);
165 Mul
->eraseFromParent();
170 bool X86PartialReduction::trySADReplacement(Instruction
*Op
) {
174 // TODO: There's nothing special about i32, any integer type above i16 should
175 // work just as well.
176 if (!cast
<VectorType
>(Op
->getType())->getElementType()->isIntegerTy(32))
179 // Operand should be a select.
180 auto *SI
= dyn_cast
<SelectInst
>(Op
);
184 // Select needs to implement absolute value.
186 auto SPR
= matchSelectPattern(SI
, LHS
, RHS
);
187 if (SPR
.Flavor
!= SPF_ABS
)
190 // Need a subtract of two values.
191 auto *Sub
= dyn_cast
<BinaryOperator
>(LHS
);
192 if (!Sub
|| Sub
->getOpcode() != Instruction::Sub
)
195 // Look for zero extend from i8.
196 auto getZeroExtendedVal
= [](Value
*Op
) -> Value
* {
197 if (auto *ZExt
= dyn_cast
<ZExtInst
>(Op
))
198 if (cast
<VectorType
>(ZExt
->getOperand(0)->getType())
201 return ZExt
->getOperand(0);
206 // Both operands of the subtract should be extends from vXi8.
207 Value
*Op0
= getZeroExtendedVal(Sub
->getOperand(0));
208 Value
*Op1
= getZeroExtendedVal(Sub
->getOperand(1));
212 IRBuilder
<> Builder(SI
);
214 auto *OpTy
= cast
<FixedVectorType
>(Op
->getType());
215 unsigned NumElts
= OpTy
->getNumElements();
217 unsigned IntrinsicNumElts
;
219 if (ST
->hasBWI() && NumElts
>= 64) {
220 IID
= Intrinsic::x86_avx512_psad_bw_512
;
221 IntrinsicNumElts
= 64;
222 } else if (ST
->hasAVX2() && NumElts
>= 32) {
223 IID
= Intrinsic::x86_avx2_psad_bw
;
224 IntrinsicNumElts
= 32;
226 IID
= Intrinsic::x86_sse2_psad_bw
;
227 IntrinsicNumElts
= 16;
230 Function
*PSADBWFn
= Intrinsic::getDeclaration(SI
->getModule(), IID
);
233 // Pad input with zeroes.
234 SmallVector
<int, 32> ConcatMask(16);
235 for (unsigned i
= 0; i
!= NumElts
; ++i
)
237 for (unsigned i
= NumElts
; i
!= 16; ++i
)
238 ConcatMask
[i
] = (i
% NumElts
) + NumElts
;
240 Value
*Zero
= Constant::getNullValue(Op0
->getType());
241 Op0
= Builder
.CreateShuffleVector(Op0
, Zero
, ConcatMask
);
242 Op1
= Builder
.CreateShuffleVector(Op1
, Zero
, ConcatMask
);
246 // Intrinsics produce vXi64 and need to be casted to vXi32.
248 FixedVectorType::get(Builder
.getInt32Ty(), IntrinsicNumElts
/ 4);
250 assert(NumElts
% IntrinsicNumElts
== 0 && "Unexpected number of elements!");
251 unsigned NumSplits
= NumElts
/ IntrinsicNumElts
;
253 // First collect the pieces we need.
254 SmallVector
<Value
*, 4> Ops(NumSplits
);
255 for (unsigned i
= 0; i
!= NumSplits
; ++i
) {
256 SmallVector
<int, 64> ExtractMask(IntrinsicNumElts
);
257 std::iota(ExtractMask
.begin(), ExtractMask
.end(), i
* IntrinsicNumElts
);
258 Value
*ExtractOp0
= Builder
.CreateShuffleVector(Op0
, Op0
, ExtractMask
);
259 Value
*ExtractOp1
= Builder
.CreateShuffleVector(Op1
, Op0
, ExtractMask
);
260 Ops
[i
] = Builder
.CreateCall(PSADBWFn
, {ExtractOp0
, ExtractOp1
});
261 Ops
[i
] = Builder
.CreateBitCast(Ops
[i
], I32Ty
);
264 assert(isPowerOf2_32(NumSplits
) && "Expected power of 2 splits");
265 unsigned Stages
= Log2_32(NumSplits
);
266 for (unsigned s
= Stages
; s
> 0; --s
) {
267 unsigned NumConcatElts
=
268 cast
<FixedVectorType
>(Ops
[0]->getType())->getNumElements() * 2;
269 for (unsigned i
= 0; i
!= 1U << (s
- 1); ++i
) {
270 SmallVector
<int, 64> ConcatMask(NumConcatElts
);
271 std::iota(ConcatMask
.begin(), ConcatMask
.end(), 0);
272 Ops
[i
] = Builder
.CreateShuffleVector(Ops
[i
*2], Ops
[i
*2+1], ConcatMask
);
276 // At this point the final value should be in Ops[0]. Now we need to adjust
277 // it to the final original type.
278 NumElts
= cast
<FixedVectorType
>(OpTy
)->getNumElements();
280 // Extract down to 2 elements.
281 Ops
[0] = Builder
.CreateShuffleVector(Ops
[0], Ops
[0], ArrayRef
<int>{0, 1});
282 } else if (NumElts
>= 8) {
283 SmallVector
<int, 32> ConcatMask(NumElts
);
285 cast
<FixedVectorType
>(Ops
[0]->getType())->getNumElements();
286 for (unsigned i
= 0; i
!= SubElts
; ++i
)
288 for (unsigned i
= SubElts
; i
!= NumElts
; ++i
)
289 ConcatMask
[i
] = (i
% SubElts
) + SubElts
;
291 Value
*Zero
= Constant::getNullValue(Ops
[0]->getType());
292 Ops
[0] = Builder
.CreateShuffleVector(Ops
[0], Zero
, ConcatMask
);
295 SI
->replaceAllUsesWith(Ops
[0]);
296 SI
->eraseFromParent();
301 // Walk backwards from the ExtractElementInst and determine if it is the end of
302 // a horizontal reduction. Return the input to the reduction if we find one.
303 static Value
*matchAddReduction(const ExtractElementInst
&EE
) {
304 // Make sure we're extracting index 0.
305 auto *Index
= dyn_cast
<ConstantInt
>(EE
.getIndexOperand());
306 if (!Index
|| !Index
->isNullValue())
309 const auto *BO
= dyn_cast
<BinaryOperator
>(EE
.getVectorOperand());
310 if (!BO
|| BO
->getOpcode() != Instruction::Add
|| !BO
->hasOneUse())
313 unsigned NumElems
= cast
<FixedVectorType
>(BO
->getType())->getNumElements();
314 // Ensure the reduction size is a power of 2.
315 if (!isPowerOf2_32(NumElems
))
318 const Value
*Op
= BO
;
319 unsigned Stages
= Log2_32(NumElems
);
320 for (unsigned i
= 0; i
!= Stages
; ++i
) {
321 const auto *BO
= dyn_cast
<BinaryOperator
>(Op
);
322 if (!BO
|| BO
->getOpcode() != Instruction::Add
)
325 // If this isn't the first add, then it should only have 2 users, the
326 // shuffle and another add which we checked in the previous iteration.
327 if (i
!= 0 && !BO
->hasNUses(2))
330 Value
*LHS
= BO
->getOperand(0);
331 Value
*RHS
= BO
->getOperand(1);
333 auto *Shuffle
= dyn_cast
<ShuffleVectorInst
>(LHS
);
337 Shuffle
= dyn_cast
<ShuffleVectorInst
>(RHS
);
341 // The first operand of the shuffle should be the same as the other operand
343 if (!Shuffle
|| Shuffle
->getOperand(0) != Op
)
346 // Verify the shuffle has the expected (at this stage of the pyramid) mask.
347 unsigned MaskEnd
= 1 << i
;
348 for (unsigned Index
= 0; Index
< MaskEnd
; ++Index
)
349 if (Shuffle
->getMaskValue(Index
) != (int)(MaskEnd
+ Index
))
353 return const_cast<Value
*>(Op
);
356 // See if this BO is reachable from this Phi by walking forward through single
357 // use BinaryOperators with the same opcode. If we get back then we know we've
358 // found a loop and it is safe to step through this Add to find more leaves.
359 static bool isReachableFromPHI(PHINode
*Phi
, BinaryOperator
*BO
) {
360 // The PHI itself should only have one use.
361 if (!Phi
->hasOneUse())
364 Instruction
*U
= cast
<Instruction
>(*Phi
->user_begin());
368 while (U
->hasOneUse() && U
->getOpcode() == BO
->getOpcode())
369 U
= cast
<Instruction
>(*U
->user_begin());
374 // Collect all the leaves of the tree of adds that feeds into the horizontal
375 // reduction. Root is the Value that is used by the horizontal reduction.
376 // We look through single use phis, single use adds, or adds that are used by
377 // a phi that forms a loop with the add.
378 static void collectLeaves(Value
*Root
, SmallVectorImpl
<Instruction
*> &Leaves
) {
379 SmallPtrSet
<Value
*, 8> Visited
;
380 SmallVector
<Value
*, 8> Worklist
;
381 Worklist
.push_back(Root
);
383 while (!Worklist
.empty()) {
384 Value
*V
= Worklist
.pop_back_val();
385 if (!Visited
.insert(V
).second
)
388 if (auto *PN
= dyn_cast
<PHINode
>(V
)) {
389 // PHI node should have single use unless it is the root node, then it
391 if (!PN
->hasNUses(PN
== Root
? 2 : 1))
394 // Push incoming values to the worklist.
395 append_range(Worklist
, PN
->incoming_values());
400 if (auto *BO
= dyn_cast
<BinaryOperator
>(V
)) {
401 if (BO
->getOpcode() == Instruction::Add
) {
402 // Simple case. Single use, just push its operands to the worklist.
403 if (BO
->hasNUses(BO
== Root
? 2 : 1)) {
404 append_range(Worklist
, BO
->operands());
408 // If there is additional use, make sure it is an unvisited phi that
409 // gets us back to this node.
410 if (BO
->hasNUses(BO
== Root
? 3 : 2)) {
411 PHINode
*PN
= nullptr;
412 for (auto *U
: Root
->users())
413 if (auto *P
= dyn_cast
<PHINode
>(U
))
414 if (!Visited
.count(P
))
417 // If we didn't find a 2-input PHI then this isn't a case we can
419 if (!PN
|| PN
->getNumIncomingValues() != 2)
422 // Walk forward from this phi to see if it reaches back to this add.
423 if (!isReachableFromPHI(PN
, BO
))
426 // The phi forms a loop with this Add, push its operands.
427 append_range(Worklist
, BO
->operands());
432 // Not an add or phi, make it a leaf.
433 if (auto *I
= dyn_cast
<Instruction
>(V
)) {
434 if (!V
->hasNUses(I
== Root
? 2 : 1))
437 // Add this as a leaf.
443 bool X86PartialReduction::runOnFunction(Function
&F
) {
447 auto *TPC
= getAnalysisIfAvailable
<TargetPassConfig
>();
451 auto &TM
= TPC
->getTM
<X86TargetMachine
>();
452 ST
= TM
.getSubtargetImpl(F
);
454 DL
= &F
.getParent()->getDataLayout();
456 bool MadeChange
= false;
459 auto *EE
= dyn_cast
<ExtractElementInst
>(&I
);
463 // First find a reduction tree.
464 // FIXME: Do we need to handle other opcodes than Add?
465 Value
*Root
= matchAddReduction(*EE
);
469 SmallVector
<Instruction
*, 8> Leaves
;
470 collectLeaves(Root
, Leaves
);
472 for (Instruction
*I
: Leaves
) {
473 if (tryMAddReplacement(I
)) {
478 // Don't do SAD matching on the root node. SelectionDAG already
479 // has support for that and currently generates better code.
480 if (I
!= Root
&& trySADReplacement(I
))