1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
60 //===----------------------------------------------------------------------===//
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/MapVector.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/TargetLibraryInfo.h"
66 #include "llvm/Analysis/TargetTransformInfo.h"
67 #include "llvm/CodeGen/TargetLowering.h"
68 #include "llvm/CodeGen/TargetPassConfig.h"
69 #include "llvm/CodeGen/TargetSubtargetInfo.h"
70 #include "llvm/IR/IRBuilder.h"
71 #include "llvm/IR/PatternMatch.h"
72 #include "llvm/InitializePasses.h"
73 #include "llvm/Target/TargetMachine.h"
74 #include "llvm/Transforms/Utils/Local.h"
78 using namespace PatternMatch
;
80 #define DEBUG_TYPE "complex-deinterleaving"
82 STATISTIC(NumComplexTransformations
, "Amount of complex patterns transformed");
84 static cl::opt
<bool> ComplexDeinterleavingEnabled(
85 "enable-complex-deinterleaving",
86 cl::desc("Enable generation of complex instructions"), cl::init(true),
89 /// Checks the given mask, and determines whether said mask is interleaving.
91 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
92 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
94 static bool isInterleavingMask(ArrayRef
<int> Mask
);
96 /// Checks the given mask, and determines whether said mask is deinterleaving.
98 /// To be deinterleaving, a mask must increment in steps of 2, and either start
100 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
102 static bool isDeinterleavingMask(ArrayRef
<int> Mask
);
104 /// Returns true if the operation is a negation of V, and it works for both
105 /// integers and floats.
106 static bool isNeg(Value
*V
);
108 /// Returns the operand for negation operation.
109 static Value
*getNegOperand(Value
*V
);
113 class ComplexDeinterleavingLegacyPass
: public FunctionPass
{
117 ComplexDeinterleavingLegacyPass(const TargetMachine
*TM
= nullptr)
118 : FunctionPass(ID
), TM(TM
) {
119 initializeComplexDeinterleavingLegacyPassPass(
120 *PassRegistry::getPassRegistry());
123 StringRef
getPassName() const override
{
124 return "Complex Deinterleaving Pass";
127 bool runOnFunction(Function
&F
) override
;
128 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
129 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
130 AU
.setPreservesCFG();
134 const TargetMachine
*TM
;
137 class ComplexDeinterleavingGraph
;
138 struct ComplexDeinterleavingCompositeNode
{
140 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op
,
142 : Operation(Op
), Real(R
), Imag(I
) {}
145 friend class ComplexDeinterleavingGraph
;
146 using NodePtr
= std::shared_ptr
<ComplexDeinterleavingCompositeNode
>;
147 using RawNodePtr
= ComplexDeinterleavingCompositeNode
*;
150 ComplexDeinterleavingOperation Operation
;
154 // This two members are required exclusively for generating
155 // ComplexDeinterleavingOperation::Symmetric operations.
157 std::optional
<FastMathFlags
> Flags
;
159 ComplexDeinterleavingRotation Rotation
=
160 ComplexDeinterleavingRotation::Rotation_0
;
161 SmallVector
<RawNodePtr
> Operands
;
162 Value
*ReplacementNode
= nullptr;
164 void addOperand(NodePtr Node
) { Operands
.push_back(Node
.get()); }
166 void dump() { dump(dbgs()); }
167 void dump(raw_ostream
&OS
) {
168 auto PrintValue
= [&](Value
*V
) {
176 auto PrintNodeRef
= [&](RawNodePtr Ptr
) {
183 OS
<< "- CompositeNode: " << this << "\n";
188 OS
<< " ReplacementNode: ";
189 PrintValue(ReplacementNode
);
190 OS
<< " Operation: " << (int)Operation
<< "\n";
191 OS
<< " Rotation: " << ((int)Rotation
* 90) << "\n";
192 OS
<< " Operands: \n";
193 for (const auto &Op
: Operands
) {
200 class ComplexDeinterleavingGraph
{
208 using Addend
= std::pair
<Value
*, bool>;
209 using NodePtr
= ComplexDeinterleavingCompositeNode::NodePtr
;
210 using RawNodePtr
= ComplexDeinterleavingCompositeNode::RawNodePtr
;
212 // Helper struct for holding info about potential partial multiplication
214 struct PartialMulCandidate
{
222 explicit ComplexDeinterleavingGraph(const TargetLowering
*TL
,
223 const TargetLibraryInfo
*TLI
)
224 : TL(TL
), TLI(TLI
) {}
227 const TargetLowering
*TL
= nullptr;
228 const TargetLibraryInfo
*TLI
= nullptr;
229 SmallVector
<NodePtr
> CompositeNodes
;
230 DenseMap
<std::pair
<Value
*, Value
*>, NodePtr
> CachedResult
;
232 SmallPtrSet
<Instruction
*, 16> FinalInstructions
;
234 /// Root instructions are instructions from which complex computation starts
235 std::map
<Instruction
*, NodePtr
> RootToNode
;
237 /// Topologically sorted root instructions
238 SmallVector
<Instruction
*, 1> OrderedRoots
;
240 /// When examining a basic block for complex deinterleaving, if it is a simple
241 /// one-block loop, then the only incoming block is 'Incoming' and the
242 /// 'BackEdge' block is the block itself."
243 BasicBlock
*BackEdge
= nullptr;
244 BasicBlock
*Incoming
= nullptr;
246 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247 /// %OutsideUser as it is shown in the IR:
250 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251 /// [ %ReductionOp, %vector.body ]
253 /// %ReductionOp = fadd i64 ...
255 /// br i1 %condition, label %vector.body, %middle.block
258 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
260 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262 MapVector
<Instruction
*, std::pair
<PHINode
*, Instruction
*>> ReductionInfo
;
264 /// In the process of detecting a reduction, we consider a pair of
265 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
266 /// traverse the use-tree to detect complex operations. As this is a reduction
267 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268 /// to the %ReductionOPs that we suspect to be complex.
269 /// RealPHI and ImagPHI are used by the identifyPHINode method.
270 PHINode
*RealPHI
= nullptr;
271 PHINode
*ImagPHI
= nullptr;
273 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
275 bool PHIsFound
= false;
277 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279 /// This mapping is populated during
280 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
282 /// replacement process.
283 std::map
<PHINode
*, PHINode
*> OldToNewPHI
;
285 NodePtr
prepareCompositeNode(ComplexDeinterleavingOperation Operation
,
286 Value
*R
, Value
*I
) {
287 assert(((Operation
!= ComplexDeinterleavingOperation::ReductionPHI
&&
288 Operation
!= ComplexDeinterleavingOperation::ReductionOperation
) ||
290 "Reduction related nodes must have Real and Imaginary parts");
291 return std::make_shared
<ComplexDeinterleavingCompositeNode
>(Operation
, R
,
295 NodePtr
submitCompositeNode(NodePtr Node
) {
296 CompositeNodes
.push_back(Node
);
297 if (Node
->Real
&& Node
->Imag
)
298 CachedResult
[{Node
->Real
, Node
->Imag
}] = Node
;
302 /// Identifies a complex partial multiply pattern and its rotation, based on
303 /// the following patterns
305 /// 0: r: cr + ar * br
307 /// 90: r: cr - ai * bi
309 /// 180: r: cr - ar * br
311 /// 270: r: cr + ai * bi
313 NodePtr
identifyPartialMul(Instruction
*Real
, Instruction
*Imag
);
315 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316 /// is partially known from identifyPartialMul, filling in the other half of
317 /// the complex pair.
319 identifyNodeWithImplicitAdd(Instruction
*I
, Instruction
*J
,
320 std::pair
<Value
*, Value
*> &CommonOperandI
);
322 /// Identifies a complex add pattern and its rotation, based on the following
329 NodePtr
identifyAdd(Instruction
*Real
, Instruction
*Imag
);
330 NodePtr
identifySymmetricOperation(Instruction
*Real
, Instruction
*Imag
);
332 NodePtr
identifyNode(Value
*R
, Value
*I
);
334 /// Determine if a sum of complex numbers can be formed from \p RealAddends
335 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336 /// Return nullptr if it is not possible to construct a complex number.
337 /// \p Flags are needed to generate symmetric Add and Sub operations.
338 NodePtr
identifyAdditions(std::list
<Addend
> &RealAddends
,
339 std::list
<Addend
> &ImagAddends
,
340 std::optional
<FastMathFlags
> Flags
,
341 NodePtr Accumulator
);
343 /// Extract one addend that have both real and imaginary parts positive.
344 NodePtr
extractPositiveAddend(std::list
<Addend
> &RealAddends
,
345 std::list
<Addend
> &ImagAddends
);
347 /// Determine if sum of multiplications of complex numbers can be formed from
348 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349 /// to it. Return nullptr if it is not possible to construct a complex number.
350 NodePtr
identifyMultiplications(std::vector
<Product
> &RealMuls
,
351 std::vector
<Product
> &ImagMuls
,
352 NodePtr Accumulator
);
354 /// Go through pairs of multiplication (one Real and one Imag) and find all
355 /// possible candidates for partial multiplication and put them into \p
356 /// Candidates. Returns true if all Product has pair with common operand
357 bool collectPartialMuls(const std::vector
<Product
> &RealMuls
,
358 const std::vector
<Product
> &ImagMuls
,
359 std::vector
<PartialMulCandidate
> &Candidates
);
361 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362 /// the order of complex computation operations may be significantly altered,
363 /// and the real and imaginary parts may not be executed in parallel. This
364 /// function takes this into consideration and employs a more general approach
365 /// to identify complex computations. Initially, it gathers all the addends
366 /// and multiplicands and then constructs a complex expression from them.
367 NodePtr
identifyReassocNodes(Instruction
*I
, Instruction
*J
);
369 NodePtr
identifyRoot(Instruction
*I
);
371 /// Identifies the Deinterleave operation applied to a vector containing
372 /// complex numbers. There are two ways to represent the Deinterleave
374 /// * Using two shufflevectors with even indices for /pReal instruction and
375 /// odd indices for /pImag instructions (only for fixed-width vectors)
376 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
377 /// intrinsic (for both fixed and scalable vectors)
378 NodePtr
identifyDeinterleave(Instruction
*Real
, Instruction
*Imag
);
380 /// identifying the operation that represents a complex number repeated in a
381 /// Splat vector. There are two possible types of splats: ConstantExpr with
382 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383 /// initialization mask with all values set to zero.
384 NodePtr
identifySplat(Value
*Real
, Value
*Imag
);
386 NodePtr
identifyPHINode(Instruction
*Real
, Instruction
*Imag
);
388 /// Identifies SelectInsts in a loop that has reduction with predication masks
389 /// and/or predicated tail folding
390 NodePtr
identifySelectNode(Instruction
*Real
, Instruction
*Imag
);
392 Value
*replaceNode(IRBuilderBase
&Builder
, RawNodePtr Node
);
394 /// Complete IR modifications after producing new reduction operation:
395 /// * Populate the PHINode generated for
396 /// ComplexDeinterleavingOperation::ReductionPHI
397 /// * Deinterleave the final value outside of the loop and repurpose original
399 void processReductionOperation(Value
*OperationReplacement
, RawNodePtr Node
);
402 void dump() { dump(dbgs()); }
403 void dump(raw_ostream
&OS
) {
404 for (const auto &Node
: CompositeNodes
)
408 /// Returns false if the deinterleaving operation should be cancelled for the
410 bool identifyNodes(Instruction
*RootI
);
412 /// In case \pB is one-block loop, this function seeks potential reductions
413 /// and populates ReductionInfo. Returns true if any reductions were
415 bool collectPotentialReductions(BasicBlock
*B
);
417 void identifyReductionNodes();
419 /// Check that every instruction, from the roots to the leaves, has internal
423 /// Perform the actual replacement of the underlying instruction graph.
427 class ComplexDeinterleaving
{
429 ComplexDeinterleaving(const TargetLowering
*tl
, const TargetLibraryInfo
*tli
)
430 : TL(tl
), TLI(tli
) {}
431 bool runOnFunction(Function
&F
);
434 bool evaluateBasicBlock(BasicBlock
*B
);
436 const TargetLowering
*TL
= nullptr;
437 const TargetLibraryInfo
*TLI
= nullptr;
442 char ComplexDeinterleavingLegacyPass::ID
= 0;
444 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass
, DEBUG_TYPE
,
445 "Complex Deinterleaving", false, false)
446 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass
, DEBUG_TYPE
,
447 "Complex Deinterleaving", false, false)
449 PreservedAnalyses
ComplexDeinterleavingPass::run(Function
&F
,
450 FunctionAnalysisManager
&AM
) {
451 const TargetLowering
*TL
= TM
->getSubtargetImpl(F
)->getTargetLowering();
452 auto &TLI
= AM
.getResult
<llvm::TargetLibraryAnalysis
>(F
);
453 if (!ComplexDeinterleaving(TL
, &TLI
).runOnFunction(F
))
454 return PreservedAnalyses::all();
456 PreservedAnalyses PA
;
457 PA
.preserve
<FunctionAnalysisManagerModuleProxy
>();
461 FunctionPass
*llvm::createComplexDeinterleavingPass(const TargetMachine
*TM
) {
462 return new ComplexDeinterleavingLegacyPass(TM
);
465 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function
&F
) {
466 const auto *TL
= TM
->getSubtargetImpl(F
)->getTargetLowering();
467 auto TLI
= getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
468 return ComplexDeinterleaving(TL
, &TLI
).runOnFunction(F
);
471 bool ComplexDeinterleaving::runOnFunction(Function
&F
) {
472 if (!ComplexDeinterleavingEnabled
) {
474 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
478 if (!TL
->isComplexDeinterleavingSupported()) {
480 dbgs() << "Complex deinterleaving has been disabled, target does "
481 "not support lowering of complex number operations.\n");
485 bool Changed
= false;
487 Changed
|= evaluateBasicBlock(&B
);
492 static bool isInterleavingMask(ArrayRef
<int> Mask
) {
493 // If the size is not even, it's not an interleaving mask
494 if ((Mask
.size() & 1))
497 int HalfNumElements
= Mask
.size() / 2;
498 for (int Idx
= 0; Idx
< HalfNumElements
; ++Idx
) {
499 int MaskIdx
= Idx
* 2;
500 if (Mask
[MaskIdx
] != Idx
|| Mask
[MaskIdx
+ 1] != (Idx
+ HalfNumElements
))
507 static bool isDeinterleavingMask(ArrayRef
<int> Mask
) {
508 int Offset
= Mask
[0];
509 int HalfNumElements
= Mask
.size() / 2;
511 for (int Idx
= 1; Idx
< HalfNumElements
; ++Idx
) {
512 if (Mask
[Idx
] != (Idx
* 2) + Offset
)
519 bool isNeg(Value
*V
) {
520 return match(V
, m_FNeg(m_Value())) || match(V
, m_Neg(m_Value()));
523 Value
*getNegOperand(Value
*V
) {
525 auto *I
= cast
<Instruction
>(V
);
526 if (I
->getOpcode() == Instruction::FNeg
)
527 return I
->getOperand(0);
529 return I
->getOperand(1);
532 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock
*B
) {
533 ComplexDeinterleavingGraph
Graph(TL
, TLI
);
534 if (Graph
.collectPotentialReductions(B
))
535 Graph
.identifyReductionNodes();
538 Graph
.identifyNodes(&I
);
540 if (Graph
.checkNodes()) {
541 Graph
.replaceNodes();
548 ComplexDeinterleavingGraph::NodePtr
549 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550 Instruction
*Real
, Instruction
*Imag
,
551 std::pair
<Value
*, Value
*> &PartialMatch
) {
552 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real
<< " / " << *Imag
555 if (!Real
->hasOneUse() || !Imag
->hasOneUse()) {
556 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
560 if ((Real
->getOpcode() != Instruction::FMul
&&
561 Real
->getOpcode() != Instruction::Mul
) ||
562 (Imag
->getOpcode() != Instruction::FMul
&&
563 Imag
->getOpcode() != Instruction::Mul
)) {
565 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
569 Value
*R0
= Real
->getOperand(0);
570 Value
*R1
= Real
->getOperand(1);
571 Value
*I0
= Imag
->getOperand(0);
572 Value
*I1
= Imag
->getOperand(1);
574 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575 // rotations and use the operand.
578 if (match(R0
, m_Neg(m_Value(Op
)))) {
581 } else if (match(R1
, m_Neg(m_Value(Op
)))) {
590 } else if (match(I1
, m_Neg(m_Value(Op
)))) {
596 ComplexDeinterleavingRotation Rotation
= (ComplexDeinterleavingRotation
)Negs
;
598 Value
*CommonOperand
;
599 Value
*UncommonRealOp
;
600 Value
*UncommonImagOp
;
602 if (R0
== I0
|| R0
== I1
) {
605 } else if (R1
== I0
|| R1
== I1
) {
609 LLVM_DEBUG(dbgs() << " - No equal operand\n");
613 UncommonImagOp
= (CommonOperand
== I0
) ? I1
: I0
;
614 if (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
615 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
616 std::swap(UncommonRealOp
, UncommonImagOp
);
618 // Between identifyPartialMul and here we need to have found a complete valid
619 // pair from the CommonOperand of each part.
620 if (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
621 Rotation
== ComplexDeinterleavingRotation::Rotation_180
)
622 PartialMatch
.first
= CommonOperand
;
624 PartialMatch
.second
= CommonOperand
;
626 if (!PartialMatch
.first
|| !PartialMatch
.second
) {
627 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
631 NodePtr CommonNode
= identifyNode(PartialMatch
.first
, PartialMatch
.second
);
633 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
637 NodePtr UncommonNode
= identifyNode(UncommonRealOp
, UncommonImagOp
);
639 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
643 NodePtr Node
= prepareCompositeNode(
644 ComplexDeinterleavingOperation::CMulPartial
, Real
, Imag
);
645 Node
->Rotation
= Rotation
;
646 Node
->addOperand(CommonNode
);
647 Node
->addOperand(UncommonNode
);
648 return submitCompositeNode(Node
);
651 ComplexDeinterleavingGraph::NodePtr
652 ComplexDeinterleavingGraph::identifyPartialMul(Instruction
*Real
,
654 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real
<< " / " << *Imag
656 // Determine rotation
657 auto IsAdd
= [](unsigned Op
) {
658 return Op
== Instruction::FAdd
|| Op
== Instruction::Add
;
660 auto IsSub
= [](unsigned Op
) {
661 return Op
== Instruction::FSub
|| Op
== Instruction::Sub
;
663 ComplexDeinterleavingRotation Rotation
;
664 if (IsAdd(Real
->getOpcode()) && IsAdd(Imag
->getOpcode()))
665 Rotation
= ComplexDeinterleavingRotation::Rotation_0
;
666 else if (IsSub(Real
->getOpcode()) && IsAdd(Imag
->getOpcode()))
667 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
668 else if (IsSub(Real
->getOpcode()) && IsSub(Imag
->getOpcode()))
669 Rotation
= ComplexDeinterleavingRotation::Rotation_180
;
670 else if (IsAdd(Real
->getOpcode()) && IsSub(Imag
->getOpcode()))
671 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
673 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
677 if (isa
<FPMathOperator
>(Real
) &&
678 (!Real
->getFastMathFlags().allowContract() ||
679 !Imag
->getFastMathFlags().allowContract())) {
680 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
684 Value
*CR
= Real
->getOperand(0);
685 Instruction
*RealMulI
= dyn_cast
<Instruction
>(Real
->getOperand(1));
688 Value
*CI
= Imag
->getOperand(0);
689 Instruction
*ImagMulI
= dyn_cast
<Instruction
>(Imag
->getOperand(1));
693 if (!RealMulI
->hasOneUse() || !ImagMulI
->hasOneUse()) {
694 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
698 Value
*R0
= RealMulI
->getOperand(0);
699 Value
*R1
= RealMulI
->getOperand(1);
700 Value
*I0
= ImagMulI
->getOperand(0);
701 Value
*I1
= ImagMulI
->getOperand(1);
703 Value
*CommonOperand
;
704 Value
*UncommonRealOp
;
705 Value
*UncommonImagOp
;
707 if (R0
== I0
|| R0
== I1
) {
710 } else if (R1
== I0
|| R1
== I1
) {
714 LLVM_DEBUG(dbgs() << " - No equal operand\n");
718 UncommonImagOp
= (CommonOperand
== I0
) ? I1
: I0
;
719 if (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
720 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
721 std::swap(UncommonRealOp
, UncommonImagOp
);
723 std::pair
<Value
*, Value
*> PartialMatch(
724 (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
725 Rotation
== ComplexDeinterleavingRotation::Rotation_180
)
728 (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
729 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
733 auto *CRInst
= dyn_cast
<Instruction
>(CR
);
734 auto *CIInst
= dyn_cast
<Instruction
>(CI
);
736 if (!CRInst
|| !CIInst
) {
737 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
741 NodePtr CNode
= identifyNodeWithImplicitAdd(CRInst
, CIInst
, PartialMatch
);
743 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
747 NodePtr UncommonRes
= identifyNode(UncommonRealOp
, UncommonImagOp
);
749 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
753 assert(PartialMatch
.first
&& PartialMatch
.second
);
754 NodePtr CommonRes
= identifyNode(PartialMatch
.first
, PartialMatch
.second
);
756 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
760 NodePtr Node
= prepareCompositeNode(
761 ComplexDeinterleavingOperation::CMulPartial
, Real
, Imag
);
762 Node
->Rotation
= Rotation
;
763 Node
->addOperand(CommonRes
);
764 Node
->addOperand(UncommonRes
);
765 Node
->addOperand(CNode
);
766 return submitCompositeNode(Node
);
769 ComplexDeinterleavingGraph::NodePtr
770 ComplexDeinterleavingGraph::identifyAdd(Instruction
*Real
, Instruction
*Imag
) {
771 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real
<< " / " << *Imag
<< "\n");
773 // Determine rotation
774 ComplexDeinterleavingRotation Rotation
;
775 if ((Real
->getOpcode() == Instruction::FSub
&&
776 Imag
->getOpcode() == Instruction::FAdd
) ||
777 (Real
->getOpcode() == Instruction::Sub
&&
778 Imag
->getOpcode() == Instruction::Add
))
779 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
780 else if ((Real
->getOpcode() == Instruction::FAdd
&&
781 Imag
->getOpcode() == Instruction::FSub
) ||
782 (Real
->getOpcode() == Instruction::Add
&&
783 Imag
->getOpcode() == Instruction::Sub
))
784 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
786 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
790 auto *AR
= dyn_cast
<Instruction
>(Real
->getOperand(0));
791 auto *BI
= dyn_cast
<Instruction
>(Real
->getOperand(1));
792 auto *AI
= dyn_cast
<Instruction
>(Imag
->getOperand(0));
793 auto *BR
= dyn_cast
<Instruction
>(Imag
->getOperand(1));
795 if (!AR
|| !AI
|| !BR
|| !BI
) {
796 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
800 NodePtr ResA
= identifyNode(AR
, AI
);
802 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
805 NodePtr ResB
= identifyNode(BR
, BI
);
807 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
812 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd
, Real
, Imag
);
813 Node
->Rotation
= Rotation
;
814 Node
->addOperand(ResA
);
815 Node
->addOperand(ResB
);
816 return submitCompositeNode(Node
);
819 static bool isInstructionPairAdd(Instruction
*A
, Instruction
*B
) {
820 unsigned OpcA
= A
->getOpcode();
821 unsigned OpcB
= B
->getOpcode();
823 return (OpcA
== Instruction::FSub
&& OpcB
== Instruction::FAdd
) ||
824 (OpcA
== Instruction::FAdd
&& OpcB
== Instruction::FSub
) ||
825 (OpcA
== Instruction::Sub
&& OpcB
== Instruction::Add
) ||
826 (OpcA
== Instruction::Add
&& OpcB
== Instruction::Sub
);
829 static bool isInstructionPairMul(Instruction
*A
, Instruction
*B
) {
831 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
833 return match(A
, Pattern
) && match(B
, Pattern
);
836 static bool isInstructionPotentiallySymmetric(Instruction
*I
) {
837 switch (I
->getOpcode()) {
838 case Instruction::FAdd
:
839 case Instruction::FSub
:
840 case Instruction::FMul
:
841 case Instruction::FNeg
:
842 case Instruction::Add
:
843 case Instruction::Sub
:
844 case Instruction::Mul
:
851 ComplexDeinterleavingGraph::NodePtr
852 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction
*Real
,
854 if (Real
->getOpcode() != Imag
->getOpcode())
857 if (!isInstructionPotentiallySymmetric(Real
) ||
858 !isInstructionPotentiallySymmetric(Imag
))
861 auto *R0
= Real
->getOperand(0);
862 auto *I0
= Imag
->getOperand(0);
864 NodePtr Op0
= identifyNode(R0
, I0
);
865 NodePtr Op1
= nullptr;
869 if (Real
->isBinaryOp()) {
870 auto *R1
= Real
->getOperand(1);
871 auto *I1
= Imag
->getOperand(1);
872 Op1
= identifyNode(R1
, I1
);
877 if (isa
<FPMathOperator
>(Real
) &&
878 Real
->getFastMathFlags() != Imag
->getFastMathFlags())
881 auto Node
= prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric
,
883 Node
->Opcode
= Real
->getOpcode();
884 if (isa
<FPMathOperator
>(Real
))
885 Node
->Flags
= Real
->getFastMathFlags();
887 Node
->addOperand(Op0
);
888 if (Real
->isBinaryOp())
889 Node
->addOperand(Op1
);
891 return submitCompositeNode(Node
);
894 ComplexDeinterleavingGraph::NodePtr
895 ComplexDeinterleavingGraph::identifyNode(Value
*R
, Value
*I
) {
896 LLVM_DEBUG(dbgs() << "identifyNode on " << *R
<< " / " << *I
<< "\n");
897 assert(R
->getType() == I
->getType() &&
898 "Real and imaginary parts should not have different types");
900 auto It
= CachedResult
.find({R
, I
});
901 if (It
!= CachedResult
.end()) {
902 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
906 if (NodePtr CN
= identifySplat(R
, I
))
909 auto *Real
= dyn_cast
<Instruction
>(R
);
910 auto *Imag
= dyn_cast
<Instruction
>(I
);
914 if (NodePtr CN
= identifyDeinterleave(Real
, Imag
))
917 if (NodePtr CN
= identifyPHINode(Real
, Imag
))
920 if (NodePtr CN
= identifySelectNode(Real
, Imag
))
923 auto *VTy
= cast
<VectorType
>(Real
->getType());
924 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
926 bool HasCMulSupport
= TL
->isComplexDeinterleavingOperationSupported(
927 ComplexDeinterleavingOperation::CMulPartial
, NewVTy
);
928 bool HasCAddSupport
= TL
->isComplexDeinterleavingOperationSupported(
929 ComplexDeinterleavingOperation::CAdd
, NewVTy
);
931 if (HasCMulSupport
&& isInstructionPairMul(Real
, Imag
)) {
932 if (NodePtr CN
= identifyPartialMul(Real
, Imag
))
936 if (HasCAddSupport
&& isInstructionPairAdd(Real
, Imag
)) {
937 if (NodePtr CN
= identifyAdd(Real
, Imag
))
941 if (HasCMulSupport
&& HasCAddSupport
) {
942 if (NodePtr CN
= identifyReassocNodes(Real
, Imag
))
946 if (NodePtr CN
= identifySymmetricOperation(Real
, Imag
))
949 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
950 CachedResult
[{R
, I
}] = nullptr;
954 ComplexDeinterleavingGraph::NodePtr
955 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction
*Real
,
957 auto IsOperationSupported
= [](unsigned Opcode
) -> bool {
958 return Opcode
== Instruction::FAdd
|| Opcode
== Instruction::FSub
||
959 Opcode
== Instruction::FNeg
|| Opcode
== Instruction::Add
||
960 Opcode
== Instruction::Sub
;
963 if (!IsOperationSupported(Real
->getOpcode()) ||
964 !IsOperationSupported(Imag
->getOpcode()))
967 std::optional
<FastMathFlags
> Flags
;
968 if (isa
<FPMathOperator
>(Real
)) {
969 if (Real
->getFastMathFlags() != Imag
->getFastMathFlags()) {
970 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
975 Flags
= Real
->getFastMathFlags();
976 if (!Flags
->allowReassoc()) {
979 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
984 // Collect multiplications and addend instructions from the given instruction
985 // while traversing it operands. Additionally, verify that all instructions
986 // have the same fast math flags.
987 auto Collect
= [&Flags
](Instruction
*Insn
, std::vector
<Product
> &Muls
,
988 std::list
<Addend
> &Addends
) -> bool {
989 SmallVector
<PointerIntPair
<Value
*, 1, bool>> Worklist
= {{Insn
, true}};
990 SmallPtrSet
<Value
*, 8> Visited
;
991 while (!Worklist
.empty()) {
992 auto [V
, IsPositive
] = Worklist
.back();
994 if (!Visited
.insert(V
).second
)
997 Instruction
*I
= dyn_cast
<Instruction
>(V
);
999 Addends
.emplace_back(V
, IsPositive
);
1003 // If an instruction has more than one user, it indicates that it either
1004 // has an external user, which will be later checked by the checkNodes
1005 // function, or it is a subexpression utilized by multiple expressions. In
1006 // the latter case, we will attempt to separately identify the complex
1007 // operation from here in order to create a shared
1008 // ComplexDeinterleavingCompositeNode.
1009 if (I
!= Insn
&& I
->getNumUses() > 1) {
1010 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I
<< "\n");
1011 Addends
.emplace_back(I
, IsPositive
);
1014 switch (I
->getOpcode()) {
1015 case Instruction::FAdd
:
1016 case Instruction::Add
:
1017 Worklist
.emplace_back(I
->getOperand(1), IsPositive
);
1018 Worklist
.emplace_back(I
->getOperand(0), IsPositive
);
1020 case Instruction::FSub
:
1021 Worklist
.emplace_back(I
->getOperand(1), !IsPositive
);
1022 Worklist
.emplace_back(I
->getOperand(0), IsPositive
);
1024 case Instruction::Sub
:
1026 Worklist
.emplace_back(getNegOperand(I
), !IsPositive
);
1028 Worklist
.emplace_back(I
->getOperand(1), !IsPositive
);
1029 Worklist
.emplace_back(I
->getOperand(0), IsPositive
);
1032 case Instruction::FMul
:
1033 case Instruction::Mul
: {
1035 if (isNeg(I
->getOperand(0))) {
1036 A
= getNegOperand(I
->getOperand(0));
1037 IsPositive
= !IsPositive
;
1039 A
= I
->getOperand(0);
1042 if (isNeg(I
->getOperand(1))) {
1043 B
= getNegOperand(I
->getOperand(1));
1044 IsPositive
= !IsPositive
;
1046 B
= I
->getOperand(1);
1048 Muls
.push_back(Product
{A
, B
, IsPositive
});
1051 case Instruction::FNeg
:
1052 Worklist
.emplace_back(I
->getOperand(0), !IsPositive
);
1055 Addends
.emplace_back(I
, IsPositive
);
1059 if (Flags
&& I
->getFastMathFlags() != *Flags
) {
1060 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061 "inconsistent with the root instructions' flags: "
1069 std::vector
<Product
> RealMuls
, ImagMuls
;
1070 std::list
<Addend
> RealAddends
, ImagAddends
;
1071 if (!Collect(Real
, RealMuls
, RealAddends
) ||
1072 !Collect(Imag
, ImagMuls
, ImagAddends
))
1075 if (RealAddends
.size() != ImagAddends
.size())
1079 if (!RealMuls
.empty() || !ImagMuls
.empty()) {
1080 // If there are multiplicands, extract positive addend and use it as an
1082 FinalNode
= extractPositiveAddend(RealAddends
, ImagAddends
);
1083 FinalNode
= identifyMultiplications(RealMuls
, ImagMuls
, FinalNode
);
1088 // Identify and process remaining additions
1089 if (!RealAddends
.empty() || !ImagAddends
.empty()) {
1090 FinalNode
= identifyAdditions(RealAddends
, ImagAddends
, Flags
, FinalNode
);
1094 assert(FinalNode
&& "FinalNode can not be nullptr here");
1095 // Set the Real and Imag fields of the final node and submit it
1096 FinalNode
->Real
= Real
;
1097 FinalNode
->Imag
= Imag
;
1098 submitCompositeNode(FinalNode
);
1102 bool ComplexDeinterleavingGraph::collectPartialMuls(
1103 const std::vector
<Product
> &RealMuls
, const std::vector
<Product
> &ImagMuls
,
1104 std::vector
<PartialMulCandidate
> &PartialMulCandidates
) {
1105 // Helper function to extract a common operand from two products
1106 auto FindCommonInstruction
= [](const Product
&Real
,
1107 const Product
&Imag
) -> Value
* {
1108 if (Real
.Multiplicand
== Imag
.Multiplicand
||
1109 Real
.Multiplicand
== Imag
.Multiplier
)
1110 return Real
.Multiplicand
;
1112 if (Real
.Multiplier
== Imag
.Multiplicand
||
1113 Real
.Multiplier
== Imag
.Multiplier
)
1114 return Real
.Multiplier
;
1119 // Iterating over real and imaginary multiplications to find common operands
1120 // If a common operand is found, a partial multiplication candidate is created
1121 // and added to the candidates vector The function returns false if no common
1122 // operands are found for any product
1123 for (unsigned i
= 0; i
< RealMuls
.size(); ++i
) {
1124 bool FoundCommon
= false;
1125 for (unsigned j
= 0; j
< ImagMuls
.size(); ++j
) {
1126 auto *Common
= FindCommonInstruction(RealMuls
[i
], ImagMuls
[j
]);
1130 auto *A
= RealMuls
[i
].Multiplicand
== Common
? RealMuls
[i
].Multiplier
1131 : RealMuls
[i
].Multiplicand
;
1132 auto *B
= ImagMuls
[j
].Multiplicand
== Common
? ImagMuls
[j
].Multiplier
1133 : ImagMuls
[j
].Multiplicand
;
1135 auto Node
= identifyNode(A
, B
);
1138 PartialMulCandidates
.push_back({Common
, Node
, i
, j
, false});
1141 Node
= identifyNode(B
, A
);
1144 PartialMulCandidates
.push_back({Common
, Node
, i
, j
, true});
1153 ComplexDeinterleavingGraph::NodePtr
1154 ComplexDeinterleavingGraph::identifyMultiplications(
1155 std::vector
<Product
> &RealMuls
, std::vector
<Product
> &ImagMuls
,
1156 NodePtr Accumulator
= nullptr) {
1157 if (RealMuls
.size() != ImagMuls
.size())
1160 std::vector
<PartialMulCandidate
> Info
;
1161 if (!collectPartialMuls(RealMuls
, ImagMuls
, Info
))
1164 // Map to store common instruction to node pointers
1165 std::map
<Value
*, NodePtr
> CommonToNode
;
1166 std::vector
<bool> Processed(Info
.size(), false);
1167 for (unsigned I
= 0; I
< Info
.size(); ++I
) {
1171 PartialMulCandidate
&InfoA
= Info
[I
];
1172 for (unsigned J
= I
+ 1; J
< Info
.size(); ++J
) {
1176 PartialMulCandidate
&InfoB
= Info
[J
];
1177 auto *InfoReal
= &InfoA
;
1178 auto *InfoImag
= &InfoB
;
1180 auto NodeFromCommon
= identifyNode(InfoReal
->Common
, InfoImag
->Common
);
1181 if (!NodeFromCommon
) {
1182 std::swap(InfoReal
, InfoImag
);
1183 NodeFromCommon
= identifyNode(InfoReal
->Common
, InfoImag
->Common
);
1185 if (!NodeFromCommon
)
1188 CommonToNode
[InfoReal
->Common
] = NodeFromCommon
;
1189 CommonToNode
[InfoImag
->Common
] = NodeFromCommon
;
1190 Processed
[I
] = true;
1191 Processed
[J
] = true;
1195 std::vector
<bool> ProcessedReal(RealMuls
.size(), false);
1196 std::vector
<bool> ProcessedImag(ImagMuls
.size(), false);
1197 NodePtr Result
= Accumulator
;
1198 for (auto &PMI
: Info
) {
1199 if (ProcessedReal
[PMI
.RealIdx
] || ProcessedImag
[PMI
.ImagIdx
])
1202 auto It
= CommonToNode
.find(PMI
.Common
);
1203 // TODO: Process independent complex multiplications. Cases like this:
1204 // A.real() * B where both A and B are complex numbers.
1205 if (It
== CommonToNode
.end()) {
1207 dbgs() << "Unprocessed independent partial multiplication:\n";
1208 for (auto *Mul
: {&RealMuls
[PMI
.RealIdx
], &RealMuls
[PMI
.RealIdx
]})
1209 dbgs().indent(4) << (Mul
->IsPositive
? "+" : "-") << *Mul
->Multiplier
1210 << " multiplied by " << *Mul
->Multiplicand
<< "\n";
1215 auto &RealMul
= RealMuls
[PMI
.RealIdx
];
1216 auto &ImagMul
= ImagMuls
[PMI
.ImagIdx
];
1218 auto NodeA
= It
->second
;
1219 auto NodeB
= PMI
.Node
;
1220 auto IsMultiplicandReal
= PMI
.Common
== NodeA
->Real
;
1221 // The following table illustrates the relationship between multiplications
1222 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1225 // Rotation | Real | Imag |
1226 // ---------+--------+--------+
1227 // 0 | x * u | x * v |
1228 // 90 | -y * v | y * u |
1229 // 180 | -x * u | -x * v |
1230 // 270 | y * v | -y * u |
1232 // Check if the candidate can indeed be represented by partial
1234 // TODO: Add support for multiplication by complex one
1235 if ((IsMultiplicandReal
&& PMI
.IsNodeInverted
) ||
1236 (!IsMultiplicandReal
&& !PMI
.IsNodeInverted
))
1239 // Determine the rotation based on the multiplications
1240 ComplexDeinterleavingRotation Rotation
;
1241 if (IsMultiplicandReal
) {
1242 // Detect 0 and 180 degrees rotation
1243 if (RealMul
.IsPositive
&& ImagMul
.IsPositive
)
1244 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_0
;
1245 else if (!RealMul
.IsPositive
&& !ImagMul
.IsPositive
)
1246 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_180
;
1251 // Detect 90 and 270 degrees rotation
1252 if (!RealMul
.IsPositive
&& ImagMul
.IsPositive
)
1253 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_90
;
1254 else if (RealMul
.IsPositive
&& !ImagMul
.IsPositive
)
1255 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_270
;
1261 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262 dbgs().indent(4) << "X: " << *NodeA
->Real
<< "\n";
1263 dbgs().indent(4) << "Y: " << *NodeA
->Imag
<< "\n";
1264 dbgs().indent(4) << "U: " << *NodeB
->Real
<< "\n";
1265 dbgs().indent(4) << "V: " << *NodeB
->Imag
<< "\n";
1266 dbgs().indent(4) << "Rotation - " << (int)Rotation
* 90 << "\n";
1269 NodePtr NodeMul
= prepareCompositeNode(
1270 ComplexDeinterleavingOperation::CMulPartial
, nullptr, nullptr);
1271 NodeMul
->Rotation
= Rotation
;
1272 NodeMul
->addOperand(NodeA
);
1273 NodeMul
->addOperand(NodeB
);
1275 NodeMul
->addOperand(Result
);
1276 submitCompositeNode(NodeMul
);
1278 ProcessedReal
[PMI
.RealIdx
] = true;
1279 ProcessedImag
[PMI
.ImagIdx
] = true;
1282 // Ensure all products have been processed, if not return nullptr.
1283 if (!all_of(ProcessedReal
, [](bool V
) { return V
; }) ||
1284 !all_of(ProcessedImag
, [](bool V
) { return V
; })) {
1286 // Dump debug information about which partial multiplications are not
1289 dbgs() << "Unprocessed products (Real):\n";
1290 for (size_t i
= 0; i
< ProcessedReal
.size(); ++i
) {
1291 if (!ProcessedReal
[i
])
1292 dbgs().indent(4) << (RealMuls
[i
].IsPositive
? "+" : "-")
1293 << *RealMuls
[i
].Multiplier
<< " multiplied by "
1294 << *RealMuls
[i
].Multiplicand
<< "\n";
1296 dbgs() << "Unprocessed products (Imag):\n";
1297 for (size_t i
= 0; i
< ProcessedImag
.size(); ++i
) {
1298 if (!ProcessedImag
[i
])
1299 dbgs().indent(4) << (ImagMuls
[i
].IsPositive
? "+" : "-")
1300 << *ImagMuls
[i
].Multiplier
<< " multiplied by "
1301 << *ImagMuls
[i
].Multiplicand
<< "\n";
1310 ComplexDeinterleavingGraph::NodePtr
1311 ComplexDeinterleavingGraph::identifyAdditions(
1312 std::list
<Addend
> &RealAddends
, std::list
<Addend
> &ImagAddends
,
1313 std::optional
<FastMathFlags
> Flags
, NodePtr Accumulator
= nullptr) {
1314 if (RealAddends
.size() != ImagAddends
.size())
1318 // If we have accumulator use it as first addend
1320 Result
= Accumulator
;
1321 // Otherwise find an element with both positive real and imaginary parts.
1323 Result
= extractPositiveAddend(RealAddends
, ImagAddends
);
1328 while (!RealAddends
.empty()) {
1329 auto ItR
= RealAddends
.begin();
1330 auto [R
, IsPositiveR
] = *ItR
;
1332 bool FoundImag
= false;
1333 for (auto ItI
= ImagAddends
.begin(); ItI
!= ImagAddends
.end(); ++ItI
) {
1334 auto [I
, IsPositiveI
] = *ItI
;
1335 ComplexDeinterleavingRotation Rotation
;
1336 if (IsPositiveR
&& IsPositiveI
)
1337 Rotation
= ComplexDeinterleavingRotation::Rotation_0
;
1338 else if (!IsPositiveR
&& IsPositiveI
)
1339 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
1340 else if (!IsPositiveR
&& !IsPositiveI
)
1341 Rotation
= ComplexDeinterleavingRotation::Rotation_180
;
1343 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
1346 if (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
1347 Rotation
== ComplexDeinterleavingRotation::Rotation_180
) {
1348 AddNode
= identifyNode(R
, I
);
1350 AddNode
= identifyNode(I
, R
);
1354 dbgs() << "Identified addition:\n";
1355 dbgs().indent(4) << "X: " << *R
<< "\n";
1356 dbgs().indent(4) << "Y: " << *I
<< "\n";
1357 dbgs().indent(4) << "Rotation - " << (int)Rotation
* 90 << "\n";
1361 if (Rotation
== llvm::ComplexDeinterleavingRotation::Rotation_0
) {
1362 TmpNode
= prepareCompositeNode(
1363 ComplexDeinterleavingOperation::Symmetric
, nullptr, nullptr);
1365 TmpNode
->Opcode
= Instruction::FAdd
;
1366 TmpNode
->Flags
= *Flags
;
1368 TmpNode
->Opcode
= Instruction::Add
;
1370 } else if (Rotation
==
1371 llvm::ComplexDeinterleavingRotation::Rotation_180
) {
1372 TmpNode
= prepareCompositeNode(
1373 ComplexDeinterleavingOperation::Symmetric
, nullptr, nullptr);
1375 TmpNode
->Opcode
= Instruction::FSub
;
1376 TmpNode
->Flags
= *Flags
;
1378 TmpNode
->Opcode
= Instruction::Sub
;
1381 TmpNode
= prepareCompositeNode(ComplexDeinterleavingOperation::CAdd
,
1383 TmpNode
->Rotation
= Rotation
;
1386 TmpNode
->addOperand(Result
);
1387 TmpNode
->addOperand(AddNode
);
1388 submitCompositeNode(TmpNode
);
1390 RealAddends
.erase(ItR
);
1391 ImagAddends
.erase(ItI
);
1402 ComplexDeinterleavingGraph::NodePtr
1403 ComplexDeinterleavingGraph::extractPositiveAddend(
1404 std::list
<Addend
> &RealAddends
, std::list
<Addend
> &ImagAddends
) {
1405 for (auto ItR
= RealAddends
.begin(); ItR
!= RealAddends
.end(); ++ItR
) {
1406 for (auto ItI
= ImagAddends
.begin(); ItI
!= ImagAddends
.end(); ++ItI
) {
1407 auto [R
, IsPositiveR
] = *ItR
;
1408 auto [I
, IsPositiveI
] = *ItI
;
1409 if (IsPositiveR
&& IsPositiveI
) {
1410 auto Result
= identifyNode(R
, I
);
1412 RealAddends
.erase(ItR
);
1413 ImagAddends
.erase(ItI
);
1422 bool ComplexDeinterleavingGraph::identifyNodes(Instruction
*RootI
) {
1423 // This potential root instruction might already have been recognized as
1424 // reduction. Because RootToNode maps both Real and Imaginary parts to
1425 // CompositeNode we should choose only one either Real or Imag instruction to
1426 // use as an anchor for generating complex instruction.
1427 auto It
= RootToNode
.find(RootI
);
1428 if (It
!= RootToNode
.end()) {
1429 auto RootNode
= It
->second
;
1430 assert(RootNode
->Operation
==
1431 ComplexDeinterleavingOperation::ReductionOperation
);
1432 // Find out which part, Real or Imag, comes later, and only if we come to
1433 // the latest part, add it to OrderedRoots.
1434 auto *R
= cast
<Instruction
>(RootNode
->Real
);
1435 auto *I
= cast
<Instruction
>(RootNode
->Imag
);
1436 auto *ReplacementAnchor
= R
->comesBefore(I
) ? I
: R
;
1437 if (ReplacementAnchor
!= RootI
)
1439 OrderedRoots
.push_back(RootI
);
1443 auto RootNode
= identifyRoot(RootI
);
1448 Function
*F
= RootI
->getFunction();
1449 BasicBlock
*B
= RootI
->getParent();
1450 dbgs() << "Complex deinterleaving graph for " << F
->getName()
1451 << "::" << B
->getName() << ".\n";
1455 RootToNode
[RootI
] = RootNode
;
1456 OrderedRoots
.push_back(RootI
);
1460 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock
*B
) {
1461 bool FoundPotentialReduction
= false;
1463 auto *Br
= dyn_cast
<BranchInst
>(B
->getTerminator());
1464 if (!Br
|| Br
->getNumSuccessors() != 2)
1467 // Identify simple one-block loop
1468 if (Br
->getSuccessor(0) != B
&& Br
->getSuccessor(1) != B
)
1471 SmallVector
<PHINode
*> PHIs
;
1472 for (auto &PHI
: B
->phis()) {
1473 if (PHI
.getNumIncomingValues() != 2)
1476 if (!PHI
.getType()->isVectorTy())
1479 auto *ReductionOp
= dyn_cast
<Instruction
>(PHI
.getIncomingValueForBlock(B
));
1483 // Check if final instruction is reduced outside of current block
1484 Instruction
*FinalReduction
= nullptr;
1486 for (auto *U
: ReductionOp
->users()) {
1490 FinalReduction
= dyn_cast
<Instruction
>(U
);
1493 if (NumUsers
!= 2 || !FinalReduction
|| FinalReduction
->getParent() == B
||
1494 isa
<PHINode
>(FinalReduction
))
1497 ReductionInfo
[ReductionOp
] = {&PHI
, FinalReduction
};
1499 auto BackEdgeIdx
= PHI
.getBasicBlockIndex(B
);
1500 auto IncomingIdx
= BackEdgeIdx
== 0 ? 1 : 0;
1501 Incoming
= PHI
.getIncomingBlock(IncomingIdx
);
1502 FoundPotentialReduction
= true;
1504 // If the initial value of PHINode is an Instruction, consider it a leaf
1505 // value of a complex deinterleaving graph.
1507 dyn_cast
<Instruction
>(PHI
.getIncomingValueForBlock(Incoming
)))
1508 FinalInstructions
.insert(InitPHI
);
1510 return FoundPotentialReduction
;
1513 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514 SmallVector
<bool> Processed(ReductionInfo
.size(), false);
1515 SmallVector
<Instruction
*> OperationInstruction
;
1516 for (auto &P
: ReductionInfo
)
1517 OperationInstruction
.push_back(P
.first
);
1519 // Identify a complex computation by evaluating two reduction operations that
1520 // potentially could be involved
1521 for (size_t i
= 0; i
< OperationInstruction
.size(); ++i
) {
1524 for (size_t j
= i
+ 1; j
< OperationInstruction
.size(); ++j
) {
1528 auto *Real
= OperationInstruction
[i
];
1529 auto *Imag
= OperationInstruction
[j
];
1530 if (Real
->getType() != Imag
->getType())
1533 RealPHI
= ReductionInfo
[Real
].first
;
1534 ImagPHI
= ReductionInfo
[Imag
].first
;
1536 auto Node
= identifyNode(Real
, Imag
);
1538 std::swap(Real
, Imag
);
1539 std::swap(RealPHI
, ImagPHI
);
1540 Node
= identifyNode(Real
, Imag
);
1543 // If a node is identified and reduction PHINode is used in the chain of
1544 // operations, mark its operation instructions as used to prevent
1545 // re-identification and attach the node to the real part
1546 if (Node
&& PHIsFound
) {
1547 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548 << *Real
<< " / " << *Imag
<< "\n");
1549 Processed
[i
] = true;
1550 Processed
[j
] = true;
1551 auto RootNode
= prepareCompositeNode(
1552 ComplexDeinterleavingOperation::ReductionOperation
, Real
, Imag
);
1553 RootNode
->addOperand(Node
);
1554 RootToNode
[Real
] = RootNode
;
1555 RootToNode
[Imag
] = RootNode
;
1556 submitCompositeNode(RootNode
);
1566 bool ComplexDeinterleavingGraph::checkNodes() {
1567 // Collect all instructions from roots to leaves
1568 SmallPtrSet
<Instruction
*, 16> AllInstructions
;
1569 SmallVector
<Instruction
*, 8> Worklist
;
1570 for (auto &Pair
: RootToNode
)
1571 Worklist
.push_back(Pair
.first
);
1573 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1575 while (!Worklist
.empty()) {
1576 auto *I
= Worklist
.back();
1577 Worklist
.pop_back();
1579 if (!AllInstructions
.insert(I
).second
)
1582 for (Value
*Op
: I
->operands()) {
1583 if (auto *OpI
= dyn_cast
<Instruction
>(Op
)) {
1584 if (!FinalInstructions
.count(I
))
1585 Worklist
.emplace_back(OpI
);
1590 // Find instructions that have users outside of chain
1591 SmallVector
<Instruction
*, 2> OuterInstructions
;
1592 for (auto *I
: AllInstructions
) {
1594 if (RootToNode
.count(I
))
1597 for (User
*U
: I
->users()) {
1598 if (AllInstructions
.count(cast
<Instruction
>(U
)))
1601 // Found an instruction that is not used by XCMLA/XCADD chain
1602 Worklist
.emplace_back(I
);
1607 // If any instructions are found to be used outside, find and remove roots
1608 // that somehow connect to those instructions.
1609 SmallPtrSet
<Instruction
*, 16> Visited
;
1610 while (!Worklist
.empty()) {
1611 auto *I
= Worklist
.back();
1612 Worklist
.pop_back();
1613 if (!Visited
.insert(I
).second
)
1616 // Found an impacted root node. Removing it from the nodes to be
1618 if (RootToNode
.count(I
)) {
1619 LLVM_DEBUG(dbgs() << "Instruction " << *I
1620 << " could be deinterleaved but its chain of complex "
1621 "operations have an outside user\n");
1622 RootToNode
.erase(I
);
1625 if (!AllInstructions
.count(I
) || FinalInstructions
.count(I
))
1628 for (User
*U
: I
->users())
1629 Worklist
.emplace_back(cast
<Instruction
>(U
));
1631 for (Value
*Op
: I
->operands()) {
1632 if (auto *OpI
= dyn_cast
<Instruction
>(Op
))
1633 Worklist
.emplace_back(OpI
);
1636 return !RootToNode
.empty();
1639 ComplexDeinterleavingGraph::NodePtr
1640 ComplexDeinterleavingGraph::identifyRoot(Instruction
*RootI
) {
1641 if (auto *Intrinsic
= dyn_cast
<IntrinsicInst
>(RootI
)) {
1642 if (Intrinsic
->getIntrinsicID() !=
1643 Intrinsic::experimental_vector_interleave2
)
1646 auto *Real
= dyn_cast
<Instruction
>(Intrinsic
->getOperand(0));
1647 auto *Imag
= dyn_cast
<Instruction
>(Intrinsic
->getOperand(1));
1651 return identifyNode(Real
, Imag
);
1654 auto *SVI
= dyn_cast
<ShuffleVectorInst
>(RootI
);
1658 // Look for a shufflevector that takes separate vectors of the real and
1659 // imaginary components and recombines them into a single vector.
1660 if (!isInterleavingMask(SVI
->getShuffleMask()))
1665 if (!match(RootI
, m_Shuffle(m_Instruction(Real
), m_Instruction(Imag
))))
1668 return identifyNode(Real
, Imag
);
1671 ComplexDeinterleavingGraph::NodePtr
1672 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction
*Real
,
1673 Instruction
*Imag
) {
1674 Instruction
*I
= nullptr;
1675 Value
*FinalValue
= nullptr;
1676 if (match(Real
, m_ExtractValue
<0>(m_Instruction(I
))) &&
1677 match(Imag
, m_ExtractValue
<1>(m_Specific(I
))) &&
1678 match(I
, m_Intrinsic
<Intrinsic::experimental_vector_deinterleave2
>(
1679 m_Value(FinalValue
)))) {
1680 NodePtr PlaceholderNode
= prepareCompositeNode(
1681 llvm::ComplexDeinterleavingOperation::Deinterleave
, Real
, Imag
);
1682 PlaceholderNode
->ReplacementNode
= FinalValue
;
1683 FinalInstructions
.insert(Real
);
1684 FinalInstructions
.insert(Imag
);
1685 return submitCompositeNode(PlaceholderNode
);
1688 auto *RealShuffle
= dyn_cast
<ShuffleVectorInst
>(Real
);
1689 auto *ImagShuffle
= dyn_cast
<ShuffleVectorInst
>(Imag
);
1690 if (!RealShuffle
|| !ImagShuffle
) {
1691 if (RealShuffle
|| ImagShuffle
)
1692 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1696 Value
*RealOp1
= RealShuffle
->getOperand(1);
1697 if (!isa
<UndefValue
>(RealOp1
) && !isa
<ConstantAggregateZero
>(RealOp1
)) {
1698 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1701 Value
*ImagOp1
= ImagShuffle
->getOperand(1);
1702 if (!isa
<UndefValue
>(ImagOp1
) && !isa
<ConstantAggregateZero
>(ImagOp1
)) {
1703 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1707 Value
*RealOp0
= RealShuffle
->getOperand(0);
1708 Value
*ImagOp0
= ImagShuffle
->getOperand(0);
1710 if (RealOp0
!= ImagOp0
) {
1711 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1715 ArrayRef
<int> RealMask
= RealShuffle
->getShuffleMask();
1716 ArrayRef
<int> ImagMask
= ImagShuffle
->getShuffleMask();
1717 if (!isDeinterleavingMask(RealMask
) || !isDeinterleavingMask(ImagMask
)) {
1718 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1722 if (RealMask
[0] != 0 || ImagMask
[0] != 1) {
1723 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1727 // Type checking, the shuffle type should be a vector type of the same
1728 // scalar type, but half the size
1729 auto CheckType
= [&](ShuffleVectorInst
*Shuffle
) {
1730 Value
*Op
= Shuffle
->getOperand(0);
1731 auto *ShuffleTy
= cast
<FixedVectorType
>(Shuffle
->getType());
1732 auto *OpTy
= cast
<FixedVectorType
>(Op
->getType());
1734 if (OpTy
->getScalarType() != ShuffleTy
->getScalarType())
1736 if ((ShuffleTy
->getNumElements() * 2) != OpTy
->getNumElements())
1742 auto CheckDeinterleavingShuffle
= [&](ShuffleVectorInst
*Shuffle
) -> bool {
1743 if (!CheckType(Shuffle
))
1746 ArrayRef
<int> Mask
= Shuffle
->getShuffleMask();
1747 int Last
= *Mask
.rbegin();
1749 Value
*Op
= Shuffle
->getOperand(0);
1750 auto *OpTy
= cast
<FixedVectorType
>(Op
->getType());
1751 int NumElements
= OpTy
->getNumElements();
1753 // Ensure that the deinterleaving shuffle only pulls from the first
1755 return Last
< NumElements
;
1758 if (RealShuffle
->getType() != ImagShuffle
->getType()) {
1759 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1762 if (!CheckDeinterleavingShuffle(RealShuffle
)) {
1763 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1766 if (!CheckDeinterleavingShuffle(ImagShuffle
)) {
1767 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1771 NodePtr PlaceholderNode
=
1772 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave
,
1773 RealShuffle
, ImagShuffle
);
1774 PlaceholderNode
->ReplacementNode
= RealShuffle
->getOperand(0);
1775 FinalInstructions
.insert(RealShuffle
);
1776 FinalInstructions
.insert(ImagShuffle
);
1777 return submitCompositeNode(PlaceholderNode
);
1780 ComplexDeinterleavingGraph::NodePtr
1781 ComplexDeinterleavingGraph::identifySplat(Value
*R
, Value
*I
) {
1782 auto IsSplat
= [](Value
*V
) -> bool {
1783 // Fixed-width vector with constants
1784 if (isa
<ConstantDataVector
>(V
))
1789 // Splats are represented differently depending on whether the repeated
1790 // value is a constant or an Instruction
1791 if (auto *Const
= dyn_cast
<ConstantExpr
>(V
)) {
1792 if (Const
->getOpcode() != Instruction::ShuffleVector
)
1794 VTy
= cast
<VectorType
>(Const
->getType());
1795 Mask
= Const
->getShuffleMask();
1796 } else if (auto *Shuf
= dyn_cast
<ShuffleVectorInst
>(V
)) {
1797 VTy
= Shuf
->getType();
1798 Mask
= Shuf
->getShuffleMask();
1803 // When the data type is <1 x Type>, it's not possible to differentiate
1804 // between the ComplexDeinterleaving::Deinterleave and
1805 // ComplexDeinterleaving::Splat operations.
1806 if (!VTy
->isScalableTy() && VTy
->getElementCount().getKnownMinValue() == 1)
1809 return all_equal(Mask
) && Mask
[0] == 0;
1812 if (!IsSplat(R
) || !IsSplat(I
))
1815 auto *Real
= dyn_cast
<Instruction
>(R
);
1816 auto *Imag
= dyn_cast
<Instruction
>(I
);
1817 if ((!Real
&& Imag
) || (Real
&& !Imag
))
1821 // Non-constant splats should be in the same basic block
1822 if (Real
->getParent() != Imag
->getParent())
1825 FinalInstructions
.insert(Real
);
1826 FinalInstructions
.insert(Imag
);
1828 NodePtr PlaceholderNode
=
1829 prepareCompositeNode(ComplexDeinterleavingOperation::Splat
, R
, I
);
1830 return submitCompositeNode(PlaceholderNode
);
1833 ComplexDeinterleavingGraph::NodePtr
1834 ComplexDeinterleavingGraph::identifyPHINode(Instruction
*Real
,
1835 Instruction
*Imag
) {
1836 if (Real
!= RealPHI
|| Imag
!= ImagPHI
)
1840 NodePtr PlaceholderNode
= prepareCompositeNode(
1841 ComplexDeinterleavingOperation::ReductionPHI
, Real
, Imag
);
1842 return submitCompositeNode(PlaceholderNode
);
1845 ComplexDeinterleavingGraph::NodePtr
1846 ComplexDeinterleavingGraph::identifySelectNode(Instruction
*Real
,
1847 Instruction
*Imag
) {
1848 auto *SelectReal
= dyn_cast
<SelectInst
>(Real
);
1849 auto *SelectImag
= dyn_cast
<SelectInst
>(Imag
);
1850 if (!SelectReal
|| !SelectImag
)
1853 Instruction
*MaskA
, *MaskB
;
1854 Instruction
*AR
, *AI
, *RA
, *BI
;
1855 if (!match(Real
, m_Select(m_Instruction(MaskA
), m_Instruction(AR
),
1856 m_Instruction(RA
))) ||
1857 !match(Imag
, m_Select(m_Instruction(MaskB
), m_Instruction(AI
),
1858 m_Instruction(BI
))))
1861 if (MaskA
!= MaskB
&& !MaskA
->isIdenticalTo(MaskB
))
1864 if (!MaskA
->getType()->isVectorTy())
1867 auto NodeA
= identifyNode(AR
, AI
);
1871 auto NodeB
= identifyNode(RA
, BI
);
1875 NodePtr PlaceholderNode
= prepareCompositeNode(
1876 ComplexDeinterleavingOperation::ReductionSelect
, Real
, Imag
);
1877 PlaceholderNode
->addOperand(NodeA
);
1878 PlaceholderNode
->addOperand(NodeB
);
1879 FinalInstructions
.insert(MaskA
);
1880 FinalInstructions
.insert(MaskB
);
1881 return submitCompositeNode(PlaceholderNode
);
1884 static Value
*replaceSymmetricNode(IRBuilderBase
&B
, unsigned Opcode
,
1885 std::optional
<FastMathFlags
> Flags
,
1886 Value
*InputA
, Value
*InputB
) {
1889 case Instruction::FNeg
:
1890 I
= B
.CreateFNeg(InputA
);
1892 case Instruction::FAdd
:
1893 I
= B
.CreateFAdd(InputA
, InputB
);
1895 case Instruction::Add
:
1896 I
= B
.CreateAdd(InputA
, InputB
);
1898 case Instruction::FSub
:
1899 I
= B
.CreateFSub(InputA
, InputB
);
1901 case Instruction::Sub
:
1902 I
= B
.CreateSub(InputA
, InputB
);
1904 case Instruction::FMul
:
1905 I
= B
.CreateFMul(InputA
, InputB
);
1907 case Instruction::Mul
:
1908 I
= B
.CreateMul(InputA
, InputB
);
1911 llvm_unreachable("Incorrect symmetric opcode");
1914 cast
<Instruction
>(I
)->setFastMathFlags(*Flags
);
1918 Value
*ComplexDeinterleavingGraph::replaceNode(IRBuilderBase
&Builder
,
1920 if (Node
->ReplacementNode
)
1921 return Node
->ReplacementNode
;
1923 auto ReplaceOperandIfExist
= [&](RawNodePtr
&Node
, unsigned Idx
) -> Value
* {
1924 return Node
->Operands
.size() > Idx
1925 ? replaceNode(Builder
, Node
->Operands
[Idx
])
1929 Value
*ReplacementNode
;
1930 switch (Node
->Operation
) {
1931 case ComplexDeinterleavingOperation::CAdd
:
1932 case ComplexDeinterleavingOperation::CMulPartial
:
1933 case ComplexDeinterleavingOperation::Symmetric
: {
1934 Value
*Input0
= ReplaceOperandIfExist(Node
, 0);
1935 Value
*Input1
= ReplaceOperandIfExist(Node
, 1);
1936 Value
*Accumulator
= ReplaceOperandIfExist(Node
, 2);
1937 assert(!Input1
|| (Input0
->getType() == Input1
->getType() &&
1938 "Node inputs need to be of the same type"));
1939 assert(!Accumulator
||
1940 (Input0
->getType() == Accumulator
->getType() &&
1941 "Accumulator and input need to be of the same type"));
1942 if (Node
->Operation
== ComplexDeinterleavingOperation::Symmetric
)
1943 ReplacementNode
= replaceSymmetricNode(Builder
, Node
->Opcode
, Node
->Flags
,
1946 ReplacementNode
= TL
->createComplexDeinterleavingIR(
1947 Builder
, Node
->Operation
, Node
->Rotation
, Input0
, Input1
,
1951 case ComplexDeinterleavingOperation::Deinterleave
:
1952 llvm_unreachable("Deinterleave node should already have ReplacementNode");
1954 case ComplexDeinterleavingOperation::Splat
: {
1955 auto *NewTy
= VectorType::getDoubleElementsVectorType(
1956 cast
<VectorType
>(Node
->Real
->getType()));
1957 auto *R
= dyn_cast
<Instruction
>(Node
->Real
);
1958 auto *I
= dyn_cast
<Instruction
>(Node
->Imag
);
1960 // Splats that are not constant are interleaved where they are located
1961 Instruction
*InsertPoint
= (I
->comesBefore(R
) ? R
: I
)->getNextNode();
1962 IRBuilder
<> IRB(InsertPoint
);
1964 IRB
.CreateIntrinsic(Intrinsic::experimental_vector_interleave2
, NewTy
,
1965 {Node
->Real
, Node
->Imag
});
1968 Builder
.CreateIntrinsic(Intrinsic::experimental_vector_interleave2
,
1969 NewTy
, {Node
->Real
, Node
->Imag
});
1973 case ComplexDeinterleavingOperation::ReductionPHI
: {
1974 // If Operation is ReductionPHI, a new empty PHINode is created.
1975 // It is filled later when the ReductionOperation is processed.
1976 auto *VTy
= cast
<VectorType
>(Node
->Real
->getType());
1977 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
1978 auto *NewPHI
= PHINode::Create(NewVTy
, 0, "", BackEdge
->getFirstNonPHI());
1979 OldToNewPHI
[dyn_cast
<PHINode
>(Node
->Real
)] = NewPHI
;
1980 ReplacementNode
= NewPHI
;
1983 case ComplexDeinterleavingOperation::ReductionOperation
:
1984 ReplacementNode
= replaceNode(Builder
, Node
->Operands
[0]);
1985 processReductionOperation(ReplacementNode
, Node
);
1987 case ComplexDeinterleavingOperation::ReductionSelect
: {
1988 auto *MaskReal
= cast
<Instruction
>(Node
->Real
)->getOperand(0);
1989 auto *MaskImag
= cast
<Instruction
>(Node
->Imag
)->getOperand(0);
1990 auto *A
= replaceNode(Builder
, Node
->Operands
[0]);
1991 auto *B
= replaceNode(Builder
, Node
->Operands
[1]);
1992 auto *NewMaskTy
= VectorType::getDoubleElementsVectorType(
1993 cast
<VectorType
>(MaskReal
->getType()));
1995 Builder
.CreateIntrinsic(Intrinsic::experimental_vector_interleave2
,
1996 NewMaskTy
, {MaskReal
, MaskImag
});
1997 ReplacementNode
= Builder
.CreateSelect(NewMask
, A
, B
);
2002 assert(ReplacementNode
&& "Target failed to create Intrinsic call.");
2003 NumComplexTransformations
+= 1;
2004 Node
->ReplacementNode
= ReplacementNode
;
2005 return ReplacementNode
;
2008 void ComplexDeinterleavingGraph::processReductionOperation(
2009 Value
*OperationReplacement
, RawNodePtr Node
) {
2010 auto *Real
= cast
<Instruction
>(Node
->Real
);
2011 auto *Imag
= cast
<Instruction
>(Node
->Imag
);
2012 auto *OldPHIReal
= ReductionInfo
[Real
].first
;
2013 auto *OldPHIImag
= ReductionInfo
[Imag
].first
;
2014 auto *NewPHI
= OldToNewPHI
[OldPHIReal
];
2016 auto *VTy
= cast
<VectorType
>(Real
->getType());
2017 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
2019 // We have to interleave initial origin values coming from IncomingBlock
2020 Value
*InitReal
= OldPHIReal
->getIncomingValueForBlock(Incoming
);
2021 Value
*InitImag
= OldPHIImag
->getIncomingValueForBlock(Incoming
);
2023 IRBuilder
<> Builder(Incoming
->getTerminator());
2024 auto *NewInit
= Builder
.CreateIntrinsic(
2025 Intrinsic::experimental_vector_interleave2
, NewVTy
, {InitReal
, InitImag
});
2027 NewPHI
->addIncoming(NewInit
, Incoming
);
2028 NewPHI
->addIncoming(OperationReplacement
, BackEdge
);
2030 // Deinterleave complex vector outside of loop so that it can be finally
2032 auto *FinalReductionReal
= ReductionInfo
[Real
].second
;
2033 auto *FinalReductionImag
= ReductionInfo
[Imag
].second
;
2035 Builder
.SetInsertPoint(
2036 &*FinalReductionReal
->getParent()->getFirstInsertionPt());
2037 auto *Deinterleave
= Builder
.CreateIntrinsic(
2038 Intrinsic::experimental_vector_deinterleave2
,
2039 OperationReplacement
->getType(), OperationReplacement
);
2041 auto *NewReal
= Builder
.CreateExtractValue(Deinterleave
, (uint64_t)0);
2042 FinalReductionReal
->replaceUsesOfWith(Real
, NewReal
);
2044 Builder
.SetInsertPoint(FinalReductionImag
);
2045 auto *NewImag
= Builder
.CreateExtractValue(Deinterleave
, 1);
2046 FinalReductionImag
->replaceUsesOfWith(Imag
, NewImag
);
2049 void ComplexDeinterleavingGraph::replaceNodes() {
2050 SmallVector
<Instruction
*, 16> DeadInstrRoots
;
2051 for (auto *RootInstruction
: OrderedRoots
) {
2052 // Check if this potential root went through check process and we can
2054 if (!RootToNode
.count(RootInstruction
))
2057 IRBuilder
<> Builder(RootInstruction
);
2058 auto RootNode
= RootToNode
[RootInstruction
];
2059 Value
*R
= replaceNode(Builder
, RootNode
.get());
2061 if (RootNode
->Operation
==
2062 ComplexDeinterleavingOperation::ReductionOperation
) {
2063 auto *RootReal
= cast
<Instruction
>(RootNode
->Real
);
2064 auto *RootImag
= cast
<Instruction
>(RootNode
->Imag
);
2065 ReductionInfo
[RootReal
].first
->removeIncomingValue(BackEdge
);
2066 ReductionInfo
[RootImag
].first
->removeIncomingValue(BackEdge
);
2067 DeadInstrRoots
.push_back(cast
<Instruction
>(RootReal
));
2068 DeadInstrRoots
.push_back(cast
<Instruction
>(RootImag
));
2070 assert(R
&& "Unable to find replacement for RootInstruction");
2071 DeadInstrRoots
.push_back(RootInstruction
);
2072 RootInstruction
->replaceAllUsesWith(R
);
2076 for (auto *I
: DeadInstrRoots
)
2077 RecursivelyDeleteTriviallyDeadInstructions(I
, TLI
);