1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
60 //===----------------------------------------------------------------------===//
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/Statistic.h"
64 #include "llvm/Analysis/TargetLibraryInfo.h"
65 #include "llvm/Analysis/TargetTransformInfo.h"
66 #include "llvm/CodeGen/TargetLowering.h"
67 #include "llvm/CodeGen/TargetPassConfig.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
77 using namespace PatternMatch
;
79 #define DEBUG_TYPE "complex-deinterleaving"
81 STATISTIC(NumComplexTransformations
, "Amount of complex patterns transformed");
83 static cl::opt
<bool> ComplexDeinterleavingEnabled(
84 "enable-complex-deinterleaving",
85 cl::desc("Enable generation of complex instructions"), cl::init(true),
88 /// Checks the given mask, and determines whether said mask is interleaving.
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef
<int> Mask
);
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101 static bool isDeinterleavingMask(ArrayRef
<int> Mask
);
105 class ComplexDeinterleavingLegacyPass
: public FunctionPass
{
109 ComplexDeinterleavingLegacyPass(const TargetMachine
*TM
= nullptr)
110 : FunctionPass(ID
), TM(TM
) {
111 initializeComplexDeinterleavingLegacyPassPass(
112 *PassRegistry::getPassRegistry());
115 StringRef
getPassName() const override
{
116 return "Complex Deinterleaving Pass";
119 bool runOnFunction(Function
&F
) override
;
120 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
121 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
122 AU
.setPreservesCFG();
126 const TargetMachine
*TM
;
129 class ComplexDeinterleavingGraph
;
130 struct ComplexDeinterleavingCompositeNode
{
132 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op
,
134 : Operation(Op
), Real(R
), Imag(I
) {}
137 friend class ComplexDeinterleavingGraph
;
138 using NodePtr
= std::shared_ptr
<ComplexDeinterleavingCompositeNode
>;
139 using RawNodePtr
= ComplexDeinterleavingCompositeNode
*;
142 ComplexDeinterleavingOperation Operation
;
146 // This two members are required exclusively for generating
147 // ComplexDeinterleavingOperation::Symmetric operations.
151 ComplexDeinterleavingRotation Rotation
=
152 ComplexDeinterleavingRotation::Rotation_0
;
153 SmallVector
<RawNodePtr
> Operands
;
154 Value
*ReplacementNode
= nullptr;
156 void addOperand(NodePtr Node
) { Operands
.push_back(Node
.get()); }
158 void dump() { dump(dbgs()); }
159 void dump(raw_ostream
&OS
) {
160 auto PrintValue
= [&](Value
*V
) {
168 auto PrintNodeRef
= [&](RawNodePtr Ptr
) {
175 OS
<< "- CompositeNode: " << this << "\n";
180 OS
<< " ReplacementNode: ";
181 PrintValue(ReplacementNode
);
182 OS
<< " Operation: " << (int)Operation
<< "\n";
183 OS
<< " Rotation: " << ((int)Rotation
* 90) << "\n";
184 OS
<< " Operands: \n";
185 for (const auto &Op
: Operands
) {
192 class ComplexDeinterleavingGraph
{
200 using Addend
= std::pair
<Value
*, bool>;
201 using NodePtr
= ComplexDeinterleavingCompositeNode::NodePtr
;
202 using RawNodePtr
= ComplexDeinterleavingCompositeNode::RawNodePtr
;
204 // Helper struct for holding info about potential partial multiplication
206 struct PartialMulCandidate
{
214 explicit ComplexDeinterleavingGraph(const TargetLowering
*TL
,
215 const TargetLibraryInfo
*TLI
)
216 : TL(TL
), TLI(TLI
) {}
219 const TargetLowering
*TL
= nullptr;
220 const TargetLibraryInfo
*TLI
= nullptr;
221 SmallVector
<NodePtr
> CompositeNodes
;
223 SmallPtrSet
<Instruction
*, 16> FinalInstructions
;
225 /// Root instructions are instructions from which complex computation starts
226 std::map
<Instruction
*, NodePtr
> RootToNode
;
228 /// Topologically sorted root instructions
229 SmallVector
<Instruction
*, 1> OrderedRoots
;
231 /// When examining a basic block for complex deinterleaving, if it is a simple
232 /// one-block loop, then the only incoming block is 'Incoming' and the
233 /// 'BackEdge' block is the block itself."
234 BasicBlock
*BackEdge
= nullptr;
235 BasicBlock
*Incoming
= nullptr;
237 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
238 /// %OutsideUser as it is shown in the IR:
241 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
242 /// [ %ReductionOp, %vector.body ]
244 /// %ReductionOp = fadd i64 ...
246 /// br i1 %condition, label %vector.body, %middle.block
249 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
251 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
252 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
253 std::map
<Instruction
*, std::pair
<PHINode
*, Instruction
*>> ReductionInfo
;
255 /// In the process of detecting a reduction, we consider a pair of
256 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
257 /// traverse the use-tree to detect complex operations. As this is a reduction
258 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
259 /// to the %ReductionOPs that we suspect to be complex.
260 /// RealPHI and ImagPHI are used by the identifyPHINode method.
261 PHINode
*RealPHI
= nullptr;
262 PHINode
*ImagPHI
= nullptr;
264 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
266 bool PHIsFound
= false;
268 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
269 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
270 /// This mapping is populated during
271 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
272 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
273 /// replacement process.
274 std::map
<PHINode
*, PHINode
*> OldToNewPHI
;
276 NodePtr
prepareCompositeNode(ComplexDeinterleavingOperation Operation
,
277 Value
*R
, Value
*I
) {
278 assert(((Operation
!= ComplexDeinterleavingOperation::ReductionPHI
&&
279 Operation
!= ComplexDeinterleavingOperation::ReductionOperation
) ||
281 "Reduction related nodes must have Real and Imaginary parts");
282 return std::make_shared
<ComplexDeinterleavingCompositeNode
>(Operation
, R
,
286 NodePtr
submitCompositeNode(NodePtr Node
) {
287 CompositeNodes
.push_back(Node
);
291 NodePtr
getContainingComposite(Value
*R
, Value
*I
) {
292 for (const auto &CN
: CompositeNodes
) {
293 if (CN
->Real
== R
&& CN
->Imag
== I
)
299 /// Identifies a complex partial multiply pattern and its rotation, based on
300 /// the following patterns
302 /// 0: r: cr + ar * br
304 /// 90: r: cr - ai * bi
306 /// 180: r: cr - ar * br
308 /// 270: r: cr + ai * bi
310 NodePtr
identifyPartialMul(Instruction
*Real
, Instruction
*Imag
);
312 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
313 /// is partially known from identifyPartialMul, filling in the other half of
314 /// the complex pair.
316 identifyNodeWithImplicitAdd(Instruction
*I
, Instruction
*J
,
317 std::pair
<Value
*, Value
*> &CommonOperandI
);
319 /// Identifies a complex add pattern and its rotation, based on the following
326 NodePtr
identifyAdd(Instruction
*Real
, Instruction
*Imag
);
327 NodePtr
identifySymmetricOperation(Instruction
*Real
, Instruction
*Imag
);
329 NodePtr
identifyNode(Value
*R
, Value
*I
);
331 /// Determine if a sum of complex numbers can be formed from \p RealAddends
332 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
333 /// Return nullptr if it is not possible to construct a complex number.
334 /// \p Flags are needed to generate symmetric Add and Sub operations.
335 NodePtr
identifyAdditions(std::list
<Addend
> &RealAddends
,
336 std::list
<Addend
> &ImagAddends
, FastMathFlags Flags
,
337 NodePtr Accumulator
);
339 /// Extract one addend that have both real and imaginary parts positive.
340 NodePtr
extractPositiveAddend(std::list
<Addend
> &RealAddends
,
341 std::list
<Addend
> &ImagAddends
);
343 /// Determine if sum of multiplications of complex numbers can be formed from
344 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
345 /// to it. Return nullptr if it is not possible to construct a complex number.
346 NodePtr
identifyMultiplications(std::vector
<Product
> &RealMuls
,
347 std::vector
<Product
> &ImagMuls
,
348 NodePtr Accumulator
);
350 /// Go through pairs of multiplication (one Real and one Imag) and find all
351 /// possible candidates for partial multiplication and put them into \p
352 /// Candidates. Returns true if all Product has pair with common operand
353 bool collectPartialMuls(const std::vector
<Product
> &RealMuls
,
354 const std::vector
<Product
> &ImagMuls
,
355 std::vector
<PartialMulCandidate
> &Candidates
);
357 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
358 /// the order of complex computation operations may be significantly altered,
359 /// and the real and imaginary parts may not be executed in parallel. This
360 /// function takes this into consideration and employs a more general approach
361 /// to identify complex computations. Initially, it gathers all the addends
362 /// and multiplicands and then constructs a complex expression from them.
363 NodePtr
identifyReassocNodes(Instruction
*I
, Instruction
*J
);
365 NodePtr
identifyRoot(Instruction
*I
);
367 /// Identifies the Deinterleave operation applied to a vector containing
368 /// complex numbers. There are two ways to represent the Deinterleave
370 /// * Using two shufflevectors with even indices for /pReal instruction and
371 /// odd indices for /pImag instructions (only for fixed-width vectors)
372 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
373 /// intrinsic (for both fixed and scalable vectors)
374 NodePtr
identifyDeinterleave(Instruction
*Real
, Instruction
*Imag
);
376 /// identifying the operation that represents a complex number repeated in a
377 /// Splat vector. There are two possible types of splats: ConstantExpr with
378 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
379 /// initialization mask with all values set to zero.
380 NodePtr
identifySplat(Value
*Real
, Value
*Imag
);
382 NodePtr
identifyPHINode(Instruction
*Real
, Instruction
*Imag
);
384 /// Identifies SelectInsts in a loop that has reduction with predication masks
385 /// and/or predicated tail folding
386 NodePtr
identifySelectNode(Instruction
*Real
, Instruction
*Imag
);
388 Value
*replaceNode(IRBuilderBase
&Builder
, RawNodePtr Node
);
390 /// Complete IR modifications after producing new reduction operation:
391 /// * Populate the PHINode generated for
392 /// ComplexDeinterleavingOperation::ReductionPHI
393 /// * Deinterleave the final value outside of the loop and repurpose original
395 void processReductionOperation(Value
*OperationReplacement
, RawNodePtr Node
);
398 void dump() { dump(dbgs()); }
399 void dump(raw_ostream
&OS
) {
400 for (const auto &Node
: CompositeNodes
)
404 /// Returns false if the deinterleaving operation should be cancelled for the
406 bool identifyNodes(Instruction
*RootI
);
408 /// In case \pB is one-block loop, this function seeks potential reductions
409 /// and populates ReductionInfo. Returns true if any reductions were
411 bool collectPotentialReductions(BasicBlock
*B
);
413 void identifyReductionNodes();
415 /// Check that every instruction, from the roots to the leaves, has internal
419 /// Perform the actual replacement of the underlying instruction graph.
423 class ComplexDeinterleaving
{
425 ComplexDeinterleaving(const TargetLowering
*tl
, const TargetLibraryInfo
*tli
)
426 : TL(tl
), TLI(tli
) {}
427 bool runOnFunction(Function
&F
);
430 bool evaluateBasicBlock(BasicBlock
*B
);
432 const TargetLowering
*TL
= nullptr;
433 const TargetLibraryInfo
*TLI
= nullptr;
438 char ComplexDeinterleavingLegacyPass::ID
= 0;
440 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass
, DEBUG_TYPE
,
441 "Complex Deinterleaving", false, false)
442 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass
, DEBUG_TYPE
,
443 "Complex Deinterleaving", false, false)
445 PreservedAnalyses
ComplexDeinterleavingPass::run(Function
&F
,
446 FunctionAnalysisManager
&AM
) {
447 const TargetLowering
*TL
= TM
->getSubtargetImpl(F
)->getTargetLowering();
448 auto &TLI
= AM
.getResult
<llvm::TargetLibraryAnalysis
>(F
);
449 if (!ComplexDeinterleaving(TL
, &TLI
).runOnFunction(F
))
450 return PreservedAnalyses::all();
452 PreservedAnalyses PA
;
453 PA
.preserve
<FunctionAnalysisManagerModuleProxy
>();
457 FunctionPass
*llvm::createComplexDeinterleavingPass(const TargetMachine
*TM
) {
458 return new ComplexDeinterleavingLegacyPass(TM
);
461 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function
&F
) {
462 const auto *TL
= TM
->getSubtargetImpl(F
)->getTargetLowering();
463 auto TLI
= getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
464 return ComplexDeinterleaving(TL
, &TLI
).runOnFunction(F
);
467 bool ComplexDeinterleaving::runOnFunction(Function
&F
) {
468 if (!ComplexDeinterleavingEnabled
) {
470 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
474 if (!TL
->isComplexDeinterleavingSupported()) {
476 dbgs() << "Complex deinterleaving has been disabled, target does "
477 "not support lowering of complex number operations.\n");
481 bool Changed
= false;
483 Changed
|= evaluateBasicBlock(&B
);
488 static bool isInterleavingMask(ArrayRef
<int> Mask
) {
489 // If the size is not even, it's not an interleaving mask
490 if ((Mask
.size() & 1))
493 int HalfNumElements
= Mask
.size() / 2;
494 for (int Idx
= 0; Idx
< HalfNumElements
; ++Idx
) {
495 int MaskIdx
= Idx
* 2;
496 if (Mask
[MaskIdx
] != Idx
|| Mask
[MaskIdx
+ 1] != (Idx
+ HalfNumElements
))
503 static bool isDeinterleavingMask(ArrayRef
<int> Mask
) {
504 int Offset
= Mask
[0];
505 int HalfNumElements
= Mask
.size() / 2;
507 for (int Idx
= 1; Idx
< HalfNumElements
; ++Idx
) {
508 if (Mask
[Idx
] != (Idx
* 2) + Offset
)
515 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock
*B
) {
516 ComplexDeinterleavingGraph
Graph(TL
, TLI
);
517 if (Graph
.collectPotentialReductions(B
))
518 Graph
.identifyReductionNodes();
521 Graph
.identifyNodes(&I
);
523 if (Graph
.checkNodes()) {
524 Graph
.replaceNodes();
531 ComplexDeinterleavingGraph::NodePtr
532 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
533 Instruction
*Real
, Instruction
*Imag
,
534 std::pair
<Value
*, Value
*> &PartialMatch
) {
535 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real
<< " / " << *Imag
538 if (!Real
->hasOneUse() || !Imag
->hasOneUse()) {
539 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
543 if (Real
->getOpcode() != Instruction::FMul
||
544 Imag
->getOpcode() != Instruction::FMul
) {
545 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
549 Value
*R0
= Real
->getOperand(0);
550 Value
*R1
= Real
->getOperand(1);
551 Value
*I0
= Imag
->getOperand(0);
552 Value
*I1
= Imag
->getOperand(1);
554 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
555 // rotations and use the operand.
558 if (match(R0
, m_Neg(m_Value(Op
)))) {
561 } else if (match(R1
, m_Neg(m_Value(Op
)))) {
566 if (match(I0
, m_Neg(m_Value(Op
)))) {
570 } else if (match(I1
, m_Neg(m_Value(Op
)))) {
576 ComplexDeinterleavingRotation Rotation
= (ComplexDeinterleavingRotation
)Negs
;
578 Value
*CommonOperand
;
579 Value
*UncommonRealOp
;
580 Value
*UncommonImagOp
;
582 if (R0
== I0
|| R0
== I1
) {
585 } else if (R1
== I0
|| R1
== I1
) {
589 LLVM_DEBUG(dbgs() << " - No equal operand\n");
593 UncommonImagOp
= (CommonOperand
== I0
) ? I1
: I0
;
594 if (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
595 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
596 std::swap(UncommonRealOp
, UncommonImagOp
);
598 // Between identifyPartialMul and here we need to have found a complete valid
599 // pair from the CommonOperand of each part.
600 if (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
601 Rotation
== ComplexDeinterleavingRotation::Rotation_180
)
602 PartialMatch
.first
= CommonOperand
;
604 PartialMatch
.second
= CommonOperand
;
606 if (!PartialMatch
.first
|| !PartialMatch
.second
) {
607 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
611 NodePtr CommonNode
= identifyNode(PartialMatch
.first
, PartialMatch
.second
);
613 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
617 NodePtr UncommonNode
= identifyNode(UncommonRealOp
, UncommonImagOp
);
619 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
623 NodePtr Node
= prepareCompositeNode(
624 ComplexDeinterleavingOperation::CMulPartial
, Real
, Imag
);
625 Node
->Rotation
= Rotation
;
626 Node
->addOperand(CommonNode
);
627 Node
->addOperand(UncommonNode
);
628 return submitCompositeNode(Node
);
631 ComplexDeinterleavingGraph::NodePtr
632 ComplexDeinterleavingGraph::identifyPartialMul(Instruction
*Real
,
634 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real
<< " / " << *Imag
636 // Determine rotation
637 ComplexDeinterleavingRotation Rotation
;
638 if (Real
->getOpcode() == Instruction::FAdd
&&
639 Imag
->getOpcode() == Instruction::FAdd
)
640 Rotation
= ComplexDeinterleavingRotation::Rotation_0
;
641 else if (Real
->getOpcode() == Instruction::FSub
&&
642 Imag
->getOpcode() == Instruction::FAdd
)
643 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
644 else if (Real
->getOpcode() == Instruction::FSub
&&
645 Imag
->getOpcode() == Instruction::FSub
)
646 Rotation
= ComplexDeinterleavingRotation::Rotation_180
;
647 else if (Real
->getOpcode() == Instruction::FAdd
&&
648 Imag
->getOpcode() == Instruction::FSub
)
649 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
651 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
655 if (!Real
->getFastMathFlags().allowContract() ||
656 !Imag
->getFastMathFlags().allowContract()) {
657 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
661 Value
*CR
= Real
->getOperand(0);
662 Instruction
*RealMulI
= dyn_cast
<Instruction
>(Real
->getOperand(1));
665 Value
*CI
= Imag
->getOperand(0);
666 Instruction
*ImagMulI
= dyn_cast
<Instruction
>(Imag
->getOperand(1));
670 if (!RealMulI
->hasOneUse() || !ImagMulI
->hasOneUse()) {
671 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
675 Value
*R0
= RealMulI
->getOperand(0);
676 Value
*R1
= RealMulI
->getOperand(1);
677 Value
*I0
= ImagMulI
->getOperand(0);
678 Value
*I1
= ImagMulI
->getOperand(1);
680 Value
*CommonOperand
;
681 Value
*UncommonRealOp
;
682 Value
*UncommonImagOp
;
684 if (R0
== I0
|| R0
== I1
) {
687 } else if (R1
== I0
|| R1
== I1
) {
691 LLVM_DEBUG(dbgs() << " - No equal operand\n");
695 UncommonImagOp
= (CommonOperand
== I0
) ? I1
: I0
;
696 if (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
697 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
698 std::swap(UncommonRealOp
, UncommonImagOp
);
700 std::pair
<Value
*, Value
*> PartialMatch(
701 (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
702 Rotation
== ComplexDeinterleavingRotation::Rotation_180
)
705 (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
706 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
710 auto *CRInst
= dyn_cast
<Instruction
>(CR
);
711 auto *CIInst
= dyn_cast
<Instruction
>(CI
);
713 if (!CRInst
|| !CIInst
) {
714 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
718 NodePtr CNode
= identifyNodeWithImplicitAdd(CRInst
, CIInst
, PartialMatch
);
720 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
724 NodePtr UncommonRes
= identifyNode(UncommonRealOp
, UncommonImagOp
);
726 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
730 assert(PartialMatch
.first
&& PartialMatch
.second
);
731 NodePtr CommonRes
= identifyNode(PartialMatch
.first
, PartialMatch
.second
);
733 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
737 NodePtr Node
= prepareCompositeNode(
738 ComplexDeinterleavingOperation::CMulPartial
, Real
, Imag
);
739 Node
->Rotation
= Rotation
;
740 Node
->addOperand(CommonRes
);
741 Node
->addOperand(UncommonRes
);
742 Node
->addOperand(CNode
);
743 return submitCompositeNode(Node
);
746 ComplexDeinterleavingGraph::NodePtr
747 ComplexDeinterleavingGraph::identifyAdd(Instruction
*Real
, Instruction
*Imag
) {
748 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real
<< " / " << *Imag
<< "\n");
750 // Determine rotation
751 ComplexDeinterleavingRotation Rotation
;
752 if ((Real
->getOpcode() == Instruction::FSub
&&
753 Imag
->getOpcode() == Instruction::FAdd
) ||
754 (Real
->getOpcode() == Instruction::Sub
&&
755 Imag
->getOpcode() == Instruction::Add
))
756 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
757 else if ((Real
->getOpcode() == Instruction::FAdd
&&
758 Imag
->getOpcode() == Instruction::FSub
) ||
759 (Real
->getOpcode() == Instruction::Add
&&
760 Imag
->getOpcode() == Instruction::Sub
))
761 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
763 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
767 auto *AR
= dyn_cast
<Instruction
>(Real
->getOperand(0));
768 auto *BI
= dyn_cast
<Instruction
>(Real
->getOperand(1));
769 auto *AI
= dyn_cast
<Instruction
>(Imag
->getOperand(0));
770 auto *BR
= dyn_cast
<Instruction
>(Imag
->getOperand(1));
772 if (!AR
|| !AI
|| !BR
|| !BI
) {
773 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
777 NodePtr ResA
= identifyNode(AR
, AI
);
779 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
782 NodePtr ResB
= identifyNode(BR
, BI
);
784 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
789 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd
, Real
, Imag
);
790 Node
->Rotation
= Rotation
;
791 Node
->addOperand(ResA
);
792 Node
->addOperand(ResB
);
793 return submitCompositeNode(Node
);
796 static bool isInstructionPairAdd(Instruction
*A
, Instruction
*B
) {
797 unsigned OpcA
= A
->getOpcode();
798 unsigned OpcB
= B
->getOpcode();
800 return (OpcA
== Instruction::FSub
&& OpcB
== Instruction::FAdd
) ||
801 (OpcA
== Instruction::FAdd
&& OpcB
== Instruction::FSub
) ||
802 (OpcA
== Instruction::Sub
&& OpcB
== Instruction::Add
) ||
803 (OpcA
== Instruction::Add
&& OpcB
== Instruction::Sub
);
806 static bool isInstructionPairMul(Instruction
*A
, Instruction
*B
) {
808 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
810 return match(A
, Pattern
) && match(B
, Pattern
);
813 static bool isInstructionPotentiallySymmetric(Instruction
*I
) {
814 switch (I
->getOpcode()) {
815 case Instruction::FAdd
:
816 case Instruction::FSub
:
817 case Instruction::FMul
:
818 case Instruction::FNeg
:
825 ComplexDeinterleavingGraph::NodePtr
826 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction
*Real
,
828 if (Real
->getOpcode() != Imag
->getOpcode())
831 if (!isInstructionPotentiallySymmetric(Real
) ||
832 !isInstructionPotentiallySymmetric(Imag
))
835 auto *R0
= Real
->getOperand(0);
836 auto *I0
= Imag
->getOperand(0);
838 NodePtr Op0
= identifyNode(R0
, I0
);
839 NodePtr Op1
= nullptr;
843 if (Real
->isBinaryOp()) {
844 auto *R1
= Real
->getOperand(1);
845 auto *I1
= Imag
->getOperand(1);
846 Op1
= identifyNode(R1
, I1
);
851 if (isa
<FPMathOperator
>(Real
) &&
852 Real
->getFastMathFlags() != Imag
->getFastMathFlags())
855 auto Node
= prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric
,
857 Node
->Opcode
= Real
->getOpcode();
858 if (isa
<FPMathOperator
>(Real
))
859 Node
->Flags
= Real
->getFastMathFlags();
861 Node
->addOperand(Op0
);
862 if (Real
->isBinaryOp())
863 Node
->addOperand(Op1
);
865 return submitCompositeNode(Node
);
868 ComplexDeinterleavingGraph::NodePtr
869 ComplexDeinterleavingGraph::identifyNode(Value
*R
, Value
*I
) {
870 LLVM_DEBUG(dbgs() << "identifyNode on " << *R
<< " / " << *I
<< "\n");
871 assert(R
->getType() == I
->getType() &&
872 "Real and imaginary parts should not have different types");
873 if (NodePtr CN
= getContainingComposite(R
, I
)) {
874 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
878 if (NodePtr CN
= identifySplat(R
, I
))
881 auto *Real
= dyn_cast
<Instruction
>(R
);
882 auto *Imag
= dyn_cast
<Instruction
>(I
);
886 if (NodePtr CN
= identifyDeinterleave(Real
, Imag
))
889 if (NodePtr CN
= identifyPHINode(Real
, Imag
))
892 if (NodePtr CN
= identifySelectNode(Real
, Imag
))
895 auto *VTy
= cast
<VectorType
>(Real
->getType());
896 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
898 bool HasCMulSupport
= TL
->isComplexDeinterleavingOperationSupported(
899 ComplexDeinterleavingOperation::CMulPartial
, NewVTy
);
900 bool HasCAddSupport
= TL
->isComplexDeinterleavingOperationSupported(
901 ComplexDeinterleavingOperation::CAdd
, NewVTy
);
903 if (HasCMulSupport
&& isInstructionPairMul(Real
, Imag
)) {
904 if (NodePtr CN
= identifyPartialMul(Real
, Imag
))
908 if (HasCAddSupport
&& isInstructionPairAdd(Real
, Imag
)) {
909 if (NodePtr CN
= identifyAdd(Real
, Imag
))
913 if (HasCMulSupport
&& HasCAddSupport
) {
914 if (NodePtr CN
= identifyReassocNodes(Real
, Imag
))
918 if (NodePtr CN
= identifySymmetricOperation(Real
, Imag
))
921 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
925 ComplexDeinterleavingGraph::NodePtr
926 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction
*Real
,
929 if ((Real
->getOpcode() != Instruction::FAdd
&&
930 Real
->getOpcode() != Instruction::FSub
&&
931 Real
->getOpcode() != Instruction::FNeg
) ||
932 (Imag
->getOpcode() != Instruction::FAdd
&&
933 Imag
->getOpcode() != Instruction::FSub
&&
934 Imag
->getOpcode() != Instruction::FNeg
))
937 if (Real
->getFastMathFlags() != Imag
->getFastMathFlags()) {
940 << "The flags in Real and Imaginary instructions are not identical\n");
944 FastMathFlags Flags
= Real
->getFastMathFlags();
945 if (!Flags
.allowReassoc()) {
947 dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n");
951 // Collect multiplications and addend instructions from the given instruction
952 // while traversing it operands. Additionally, verify that all instructions
953 // have the same fast math flags.
954 auto Collect
= [&Flags
](Instruction
*Insn
, std::vector
<Product
> &Muls
,
955 std::list
<Addend
> &Addends
) -> bool {
956 SmallVector
<PointerIntPair
<Value
*, 1, bool>> Worklist
= {{Insn
, true}};
957 SmallPtrSet
<Value
*, 8> Visited
;
958 while (!Worklist
.empty()) {
959 auto [V
, IsPositive
] = Worklist
.back();
961 if (!Visited
.insert(V
).second
)
964 Instruction
*I
= dyn_cast
<Instruction
>(V
);
966 Addends
.emplace_back(V
, IsPositive
);
970 // If an instruction has more than one user, it indicates that it either
971 // has an external user, which will be later checked by the checkNodes
972 // function, or it is a subexpression utilized by multiple expressions. In
973 // the latter case, we will attempt to separately identify the complex
974 // operation from here in order to create a shared
975 // ComplexDeinterleavingCompositeNode.
976 if (I
!= Insn
&& I
->getNumUses() > 1) {
977 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I
<< "\n");
978 Addends
.emplace_back(I
, IsPositive
);
982 if (I
->getOpcode() == Instruction::FAdd
) {
983 Worklist
.emplace_back(I
->getOperand(1), IsPositive
);
984 Worklist
.emplace_back(I
->getOperand(0), IsPositive
);
985 } else if (I
->getOpcode() == Instruction::FSub
) {
986 Worklist
.emplace_back(I
->getOperand(1), !IsPositive
);
987 Worklist
.emplace_back(I
->getOperand(0), IsPositive
);
988 } else if (I
->getOpcode() == Instruction::FMul
) {
990 if (match(I
->getOperand(0), m_FNeg(m_Value(A
)))) {
991 IsPositive
= !IsPositive
;
993 A
= I
->getOperand(0);
996 if (match(I
->getOperand(1), m_FNeg(m_Value(B
)))) {
997 IsPositive
= !IsPositive
;
999 B
= I
->getOperand(1);
1001 Muls
.push_back(Product
{A
, B
, IsPositive
});
1002 } else if (I
->getOpcode() == Instruction::FNeg
) {
1003 Worklist
.emplace_back(I
->getOperand(0), !IsPositive
);
1005 Addends
.emplace_back(I
, IsPositive
);
1009 if (I
->getFastMathFlags() != Flags
) {
1010 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1011 "inconsistent with the root instructions' flags: "
1019 std::vector
<Product
> RealMuls
, ImagMuls
;
1020 std::list
<Addend
> RealAddends
, ImagAddends
;
1021 if (!Collect(Real
, RealMuls
, RealAddends
) ||
1022 !Collect(Imag
, ImagMuls
, ImagAddends
))
1025 if (RealAddends
.size() != ImagAddends
.size())
1029 if (!RealMuls
.empty() || !ImagMuls
.empty()) {
1030 // If there are multiplicands, extract positive addend and use it as an
1032 FinalNode
= extractPositiveAddend(RealAddends
, ImagAddends
);
1033 FinalNode
= identifyMultiplications(RealMuls
, ImagMuls
, FinalNode
);
1038 // Identify and process remaining additions
1039 if (!RealAddends
.empty() || !ImagAddends
.empty()) {
1040 FinalNode
= identifyAdditions(RealAddends
, ImagAddends
, Flags
, FinalNode
);
1044 assert(FinalNode
&& "FinalNode can not be nullptr here");
1045 // Set the Real and Imag fields of the final node and submit it
1046 FinalNode
->Real
= Real
;
1047 FinalNode
->Imag
= Imag
;
1048 submitCompositeNode(FinalNode
);
1052 bool ComplexDeinterleavingGraph::collectPartialMuls(
1053 const std::vector
<Product
> &RealMuls
, const std::vector
<Product
> &ImagMuls
,
1054 std::vector
<PartialMulCandidate
> &PartialMulCandidates
) {
1055 // Helper function to extract a common operand from two products
1056 auto FindCommonInstruction
= [](const Product
&Real
,
1057 const Product
&Imag
) -> Value
* {
1058 if (Real
.Multiplicand
== Imag
.Multiplicand
||
1059 Real
.Multiplicand
== Imag
.Multiplier
)
1060 return Real
.Multiplicand
;
1062 if (Real
.Multiplier
== Imag
.Multiplicand
||
1063 Real
.Multiplier
== Imag
.Multiplier
)
1064 return Real
.Multiplier
;
1069 // Iterating over real and imaginary multiplications to find common operands
1070 // If a common operand is found, a partial multiplication candidate is created
1071 // and added to the candidates vector The function returns false if no common
1072 // operands are found for any product
1073 for (unsigned i
= 0; i
< RealMuls
.size(); ++i
) {
1074 bool FoundCommon
= false;
1075 for (unsigned j
= 0; j
< ImagMuls
.size(); ++j
) {
1076 auto *Common
= FindCommonInstruction(RealMuls
[i
], ImagMuls
[j
]);
1080 auto *A
= RealMuls
[i
].Multiplicand
== Common
? RealMuls
[i
].Multiplier
1081 : RealMuls
[i
].Multiplicand
;
1082 auto *B
= ImagMuls
[j
].Multiplicand
== Common
? ImagMuls
[j
].Multiplier
1083 : ImagMuls
[j
].Multiplicand
;
1085 auto Node
= identifyNode(A
, B
);
1088 PartialMulCandidates
.push_back({Common
, Node
, i
, j
, false});
1091 Node
= identifyNode(B
, A
);
1094 PartialMulCandidates
.push_back({Common
, Node
, i
, j
, true});
1103 ComplexDeinterleavingGraph::NodePtr
1104 ComplexDeinterleavingGraph::identifyMultiplications(
1105 std::vector
<Product
> &RealMuls
, std::vector
<Product
> &ImagMuls
,
1106 NodePtr Accumulator
= nullptr) {
1107 if (RealMuls
.size() != ImagMuls
.size())
1110 std::vector
<PartialMulCandidate
> Info
;
1111 if (!collectPartialMuls(RealMuls
, ImagMuls
, Info
))
1114 // Map to store common instruction to node pointers
1115 std::map
<Value
*, NodePtr
> CommonToNode
;
1116 std::vector
<bool> Processed(Info
.size(), false);
1117 for (unsigned I
= 0; I
< Info
.size(); ++I
) {
1121 PartialMulCandidate
&InfoA
= Info
[I
];
1122 for (unsigned J
= I
+ 1; J
< Info
.size(); ++J
) {
1126 PartialMulCandidate
&InfoB
= Info
[J
];
1127 auto *InfoReal
= &InfoA
;
1128 auto *InfoImag
= &InfoB
;
1130 auto NodeFromCommon
= identifyNode(InfoReal
->Common
, InfoImag
->Common
);
1131 if (!NodeFromCommon
) {
1132 std::swap(InfoReal
, InfoImag
);
1133 NodeFromCommon
= identifyNode(InfoReal
->Common
, InfoImag
->Common
);
1135 if (!NodeFromCommon
)
1138 CommonToNode
[InfoReal
->Common
] = NodeFromCommon
;
1139 CommonToNode
[InfoImag
->Common
] = NodeFromCommon
;
1140 Processed
[I
] = true;
1141 Processed
[J
] = true;
1145 std::vector
<bool> ProcessedReal(RealMuls
.size(), false);
1146 std::vector
<bool> ProcessedImag(ImagMuls
.size(), false);
1147 NodePtr Result
= Accumulator
;
1148 for (auto &PMI
: Info
) {
1149 if (ProcessedReal
[PMI
.RealIdx
] || ProcessedImag
[PMI
.ImagIdx
])
1152 auto It
= CommonToNode
.find(PMI
.Common
);
1153 // TODO: Process independent complex multiplications. Cases like this:
1154 // A.real() * B where both A and B are complex numbers.
1155 if (It
== CommonToNode
.end()) {
1157 dbgs() << "Unprocessed independent partial multiplication:\n";
1158 for (auto *Mul
: {&RealMuls
[PMI
.RealIdx
], &RealMuls
[PMI
.RealIdx
]})
1159 dbgs().indent(4) << (Mul
->IsPositive
? "+" : "-") << *Mul
->Multiplier
1160 << " multiplied by " << *Mul
->Multiplicand
<< "\n";
1165 auto &RealMul
= RealMuls
[PMI
.RealIdx
];
1166 auto &ImagMul
= ImagMuls
[PMI
.ImagIdx
];
1168 auto NodeA
= It
->second
;
1169 auto NodeB
= PMI
.Node
;
1170 auto IsMultiplicandReal
= PMI
.Common
== NodeA
->Real
;
1171 // The following table illustrates the relationship between multiplications
1172 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1175 // Rotation | Real | Imag |
1176 // ---------+--------+--------+
1177 // 0 | x * u | x * v |
1178 // 90 | -y * v | y * u |
1179 // 180 | -x * u | -x * v |
1180 // 270 | y * v | -y * u |
1182 // Check if the candidate can indeed be represented by partial
1184 // TODO: Add support for multiplication by complex one
1185 if ((IsMultiplicandReal
&& PMI
.IsNodeInverted
) ||
1186 (!IsMultiplicandReal
&& !PMI
.IsNodeInverted
))
1189 // Determine the rotation based on the multiplications
1190 ComplexDeinterleavingRotation Rotation
;
1191 if (IsMultiplicandReal
) {
1192 // Detect 0 and 180 degrees rotation
1193 if (RealMul
.IsPositive
&& ImagMul
.IsPositive
)
1194 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_0
;
1195 else if (!RealMul
.IsPositive
&& !ImagMul
.IsPositive
)
1196 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_180
;
1201 // Detect 90 and 270 degrees rotation
1202 if (!RealMul
.IsPositive
&& ImagMul
.IsPositive
)
1203 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_90
;
1204 else if (RealMul
.IsPositive
&& !ImagMul
.IsPositive
)
1205 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_270
;
1211 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1212 dbgs().indent(4) << "X: " << *NodeA
->Real
<< "\n";
1213 dbgs().indent(4) << "Y: " << *NodeA
->Imag
<< "\n";
1214 dbgs().indent(4) << "U: " << *NodeB
->Real
<< "\n";
1215 dbgs().indent(4) << "V: " << *NodeB
->Imag
<< "\n";
1216 dbgs().indent(4) << "Rotation - " << (int)Rotation
* 90 << "\n";
1219 NodePtr NodeMul
= prepareCompositeNode(
1220 ComplexDeinterleavingOperation::CMulPartial
, nullptr, nullptr);
1221 NodeMul
->Rotation
= Rotation
;
1222 NodeMul
->addOperand(NodeA
);
1223 NodeMul
->addOperand(NodeB
);
1225 NodeMul
->addOperand(Result
);
1226 submitCompositeNode(NodeMul
);
1228 ProcessedReal
[PMI
.RealIdx
] = true;
1229 ProcessedImag
[PMI
.ImagIdx
] = true;
1232 // Ensure all products have been processed, if not return nullptr.
1233 if (!all_of(ProcessedReal
, [](bool V
) { return V
; }) ||
1234 !all_of(ProcessedImag
, [](bool V
) { return V
; })) {
1236 // Dump debug information about which partial multiplications are not
1239 dbgs() << "Unprocessed products (Real):\n";
1240 for (size_t i
= 0; i
< ProcessedReal
.size(); ++i
) {
1241 if (!ProcessedReal
[i
])
1242 dbgs().indent(4) << (RealMuls
[i
].IsPositive
? "+" : "-")
1243 << *RealMuls
[i
].Multiplier
<< " multiplied by "
1244 << *RealMuls
[i
].Multiplicand
<< "\n";
1246 dbgs() << "Unprocessed products (Imag):\n";
1247 for (size_t i
= 0; i
< ProcessedImag
.size(); ++i
) {
1248 if (!ProcessedImag
[i
])
1249 dbgs().indent(4) << (ImagMuls
[i
].IsPositive
? "+" : "-")
1250 << *ImagMuls
[i
].Multiplier
<< " multiplied by "
1251 << *ImagMuls
[i
].Multiplicand
<< "\n";
1260 ComplexDeinterleavingGraph::NodePtr
1261 ComplexDeinterleavingGraph::identifyAdditions(std::list
<Addend
> &RealAddends
,
1262 std::list
<Addend
> &ImagAddends
,
1263 FastMathFlags Flags
,
1264 NodePtr Accumulator
= nullptr) {
1265 if (RealAddends
.size() != ImagAddends
.size())
1269 // If we have accumulator use it as first addend
1271 Result
= Accumulator
;
1272 // Otherwise find an element with both positive real and imaginary parts.
1274 Result
= extractPositiveAddend(RealAddends
, ImagAddends
);
1279 while (!RealAddends
.empty()) {
1280 auto ItR
= RealAddends
.begin();
1281 auto [R
, IsPositiveR
] = *ItR
;
1283 bool FoundImag
= false;
1284 for (auto ItI
= ImagAddends
.begin(); ItI
!= ImagAddends
.end(); ++ItI
) {
1285 auto [I
, IsPositiveI
] = *ItI
;
1286 ComplexDeinterleavingRotation Rotation
;
1287 if (IsPositiveR
&& IsPositiveI
)
1288 Rotation
= ComplexDeinterleavingRotation::Rotation_0
;
1289 else if (!IsPositiveR
&& IsPositiveI
)
1290 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
1291 else if (!IsPositiveR
&& !IsPositiveI
)
1292 Rotation
= ComplexDeinterleavingRotation::Rotation_180
;
1294 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
1297 if (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
1298 Rotation
== ComplexDeinterleavingRotation::Rotation_180
) {
1299 AddNode
= identifyNode(R
, I
);
1301 AddNode
= identifyNode(I
, R
);
1305 dbgs() << "Identified addition:\n";
1306 dbgs().indent(4) << "X: " << *R
<< "\n";
1307 dbgs().indent(4) << "Y: " << *I
<< "\n";
1308 dbgs().indent(4) << "Rotation - " << (int)Rotation
* 90 << "\n";
1312 if (Rotation
== llvm::ComplexDeinterleavingRotation::Rotation_0
) {
1313 TmpNode
= prepareCompositeNode(
1314 ComplexDeinterleavingOperation::Symmetric
, nullptr, nullptr);
1315 TmpNode
->Opcode
= Instruction::FAdd
;
1316 TmpNode
->Flags
= Flags
;
1317 } else if (Rotation
==
1318 llvm::ComplexDeinterleavingRotation::Rotation_180
) {
1319 TmpNode
= prepareCompositeNode(
1320 ComplexDeinterleavingOperation::Symmetric
, nullptr, nullptr);
1321 TmpNode
->Opcode
= Instruction::FSub
;
1322 TmpNode
->Flags
= Flags
;
1324 TmpNode
= prepareCompositeNode(ComplexDeinterleavingOperation::CAdd
,
1326 TmpNode
->Rotation
= Rotation
;
1329 TmpNode
->addOperand(Result
);
1330 TmpNode
->addOperand(AddNode
);
1331 submitCompositeNode(TmpNode
);
1333 RealAddends
.erase(ItR
);
1334 ImagAddends
.erase(ItI
);
1345 ComplexDeinterleavingGraph::NodePtr
1346 ComplexDeinterleavingGraph::extractPositiveAddend(
1347 std::list
<Addend
> &RealAddends
, std::list
<Addend
> &ImagAddends
) {
1348 for (auto ItR
= RealAddends
.begin(); ItR
!= RealAddends
.end(); ++ItR
) {
1349 for (auto ItI
= ImagAddends
.begin(); ItI
!= ImagAddends
.end(); ++ItI
) {
1350 auto [R
, IsPositiveR
] = *ItR
;
1351 auto [I
, IsPositiveI
] = *ItI
;
1352 if (IsPositiveR
&& IsPositiveI
) {
1353 auto Result
= identifyNode(R
, I
);
1355 RealAddends
.erase(ItR
);
1356 ImagAddends
.erase(ItI
);
1365 bool ComplexDeinterleavingGraph::identifyNodes(Instruction
*RootI
) {
1366 // This potential root instruction might already have been recognized as
1367 // reduction. Because RootToNode maps both Real and Imaginary parts to
1368 // CompositeNode we should choose only one either Real or Imag instruction to
1369 // use as an anchor for generating complex instruction.
1370 auto It
= RootToNode
.find(RootI
);
1371 if (It
!= RootToNode
.end() && It
->second
->Real
== RootI
) {
1372 OrderedRoots
.push_back(RootI
);
1376 auto RootNode
= identifyRoot(RootI
);
1381 Function
*F
= RootI
->getFunction();
1382 BasicBlock
*B
= RootI
->getParent();
1383 dbgs() << "Complex deinterleaving graph for " << F
->getName()
1384 << "::" << B
->getName() << ".\n";
1388 RootToNode
[RootI
] = RootNode
;
1389 OrderedRoots
.push_back(RootI
);
1393 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock
*B
) {
1394 bool FoundPotentialReduction
= false;
1396 auto *Br
= dyn_cast
<BranchInst
>(B
->getTerminator());
1397 if (!Br
|| Br
->getNumSuccessors() != 2)
1400 // Identify simple one-block loop
1401 if (Br
->getSuccessor(0) != B
&& Br
->getSuccessor(1) != B
)
1404 SmallVector
<PHINode
*> PHIs
;
1405 for (auto &PHI
: B
->phis()) {
1406 if (PHI
.getNumIncomingValues() != 2)
1409 if (!PHI
.getType()->isVectorTy())
1412 auto *ReductionOp
= dyn_cast
<Instruction
>(PHI
.getIncomingValueForBlock(B
));
1416 // Check if final instruction is reduced outside of current block
1417 Instruction
*FinalReduction
= nullptr;
1419 for (auto *U
: ReductionOp
->users()) {
1423 FinalReduction
= dyn_cast
<Instruction
>(U
);
1426 if (NumUsers
!= 2 || !FinalReduction
|| FinalReduction
->getParent() == B
||
1427 isa
<PHINode
>(FinalReduction
))
1430 ReductionInfo
[ReductionOp
] = {&PHI
, FinalReduction
};
1432 auto BackEdgeIdx
= PHI
.getBasicBlockIndex(B
);
1433 auto IncomingIdx
= BackEdgeIdx
== 0 ? 1 : 0;
1434 Incoming
= PHI
.getIncomingBlock(IncomingIdx
);
1435 FoundPotentialReduction
= true;
1437 // If the initial value of PHINode is an Instruction, consider it a leaf
1438 // value of a complex deinterleaving graph.
1440 dyn_cast
<Instruction
>(PHI
.getIncomingValueForBlock(Incoming
)))
1441 FinalInstructions
.insert(InitPHI
);
1443 return FoundPotentialReduction
;
1446 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1447 SmallVector
<bool> Processed(ReductionInfo
.size(), false);
1448 SmallVector
<Instruction
*> OperationInstruction
;
1449 for (auto &P
: ReductionInfo
)
1450 OperationInstruction
.push_back(P
.first
);
1452 // Identify a complex computation by evaluating two reduction operations that
1453 // potentially could be involved
1454 for (size_t i
= 0; i
< OperationInstruction
.size(); ++i
) {
1457 for (size_t j
= i
+ 1; j
< OperationInstruction
.size(); ++j
) {
1461 auto *Real
= OperationInstruction
[i
];
1462 auto *Imag
= OperationInstruction
[j
];
1463 if (Real
->getType() != Imag
->getType())
1466 RealPHI
= ReductionInfo
[Real
].first
;
1467 ImagPHI
= ReductionInfo
[Imag
].first
;
1469 auto Node
= identifyNode(Real
, Imag
);
1471 std::swap(Real
, Imag
);
1472 std::swap(RealPHI
, ImagPHI
);
1473 Node
= identifyNode(Real
, Imag
);
1476 // If a node is identified and reduction PHINode is used in the chain of
1477 // operations, mark its operation instructions as used to prevent
1478 // re-identification and attach the node to the real part
1479 if (Node
&& PHIsFound
) {
1480 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1481 << *Real
<< " / " << *Imag
<< "\n");
1482 Processed
[i
] = true;
1483 Processed
[j
] = true;
1484 auto RootNode
= prepareCompositeNode(
1485 ComplexDeinterleavingOperation::ReductionOperation
, Real
, Imag
);
1486 RootNode
->addOperand(Node
);
1487 RootToNode
[Real
] = RootNode
;
1488 RootToNode
[Imag
] = RootNode
;
1489 submitCompositeNode(RootNode
);
1499 bool ComplexDeinterleavingGraph::checkNodes() {
1500 // Collect all instructions from roots to leaves
1501 SmallPtrSet
<Instruction
*, 16> AllInstructions
;
1502 SmallVector
<Instruction
*, 8> Worklist
;
1503 for (auto &Pair
: RootToNode
)
1504 Worklist
.push_back(Pair
.first
);
1506 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1508 while (!Worklist
.empty()) {
1509 auto *I
= Worklist
.back();
1510 Worklist
.pop_back();
1512 if (!AllInstructions
.insert(I
).second
)
1515 for (Value
*Op
: I
->operands()) {
1516 if (auto *OpI
= dyn_cast
<Instruction
>(Op
)) {
1517 if (!FinalInstructions
.count(I
))
1518 Worklist
.emplace_back(OpI
);
1523 // Find instructions that have users outside of chain
1524 SmallVector
<Instruction
*, 2> OuterInstructions
;
1525 for (auto *I
: AllInstructions
) {
1527 if (RootToNode
.count(I
))
1530 for (User
*U
: I
->users()) {
1531 if (AllInstructions
.count(cast
<Instruction
>(U
)))
1534 // Found an instruction that is not used by XCMLA/XCADD chain
1535 Worklist
.emplace_back(I
);
1540 // If any instructions are found to be used outside, find and remove roots
1541 // that somehow connect to those instructions.
1542 SmallPtrSet
<Instruction
*, 16> Visited
;
1543 while (!Worklist
.empty()) {
1544 auto *I
= Worklist
.back();
1545 Worklist
.pop_back();
1546 if (!Visited
.insert(I
).second
)
1549 // Found an impacted root node. Removing it from the nodes to be
1551 if (RootToNode
.count(I
)) {
1552 LLVM_DEBUG(dbgs() << "Instruction " << *I
1553 << " could be deinterleaved but its chain of complex "
1554 "operations have an outside user\n");
1555 RootToNode
.erase(I
);
1558 if (!AllInstructions
.count(I
) || FinalInstructions
.count(I
))
1561 for (User
*U
: I
->users())
1562 Worklist
.emplace_back(cast
<Instruction
>(U
));
1564 for (Value
*Op
: I
->operands()) {
1565 if (auto *OpI
= dyn_cast
<Instruction
>(Op
))
1566 Worklist
.emplace_back(OpI
);
1569 return !RootToNode
.empty();
1572 ComplexDeinterleavingGraph::NodePtr
1573 ComplexDeinterleavingGraph::identifyRoot(Instruction
*RootI
) {
1574 if (auto *Intrinsic
= dyn_cast
<IntrinsicInst
>(RootI
)) {
1575 if (Intrinsic
->getIntrinsicID() !=
1576 Intrinsic::experimental_vector_interleave2
)
1579 auto *Real
= dyn_cast
<Instruction
>(Intrinsic
->getOperand(0));
1580 auto *Imag
= dyn_cast
<Instruction
>(Intrinsic
->getOperand(1));
1584 return identifyNode(Real
, Imag
);
1587 auto *SVI
= dyn_cast
<ShuffleVectorInst
>(RootI
);
1591 // Look for a shufflevector that takes separate vectors of the real and
1592 // imaginary components and recombines them into a single vector.
1593 if (!isInterleavingMask(SVI
->getShuffleMask()))
1598 if (!match(RootI
, m_Shuffle(m_Instruction(Real
), m_Instruction(Imag
))))
1601 return identifyNode(Real
, Imag
);
1604 ComplexDeinterleavingGraph::NodePtr
1605 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction
*Real
,
1606 Instruction
*Imag
) {
1607 Instruction
*I
= nullptr;
1608 Value
*FinalValue
= nullptr;
1609 if (match(Real
, m_ExtractValue
<0>(m_Instruction(I
))) &&
1610 match(Imag
, m_ExtractValue
<1>(m_Specific(I
))) &&
1611 match(I
, m_Intrinsic
<Intrinsic::experimental_vector_deinterleave2
>(
1612 m_Value(FinalValue
)))) {
1613 NodePtr PlaceholderNode
= prepareCompositeNode(
1614 llvm::ComplexDeinterleavingOperation::Deinterleave
, Real
, Imag
);
1615 PlaceholderNode
->ReplacementNode
= FinalValue
;
1616 FinalInstructions
.insert(Real
);
1617 FinalInstructions
.insert(Imag
);
1618 return submitCompositeNode(PlaceholderNode
);
1621 auto *RealShuffle
= dyn_cast
<ShuffleVectorInst
>(Real
);
1622 auto *ImagShuffle
= dyn_cast
<ShuffleVectorInst
>(Imag
);
1623 if (!RealShuffle
|| !ImagShuffle
) {
1624 if (RealShuffle
|| ImagShuffle
)
1625 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1629 Value
*RealOp1
= RealShuffle
->getOperand(1);
1630 if (!isa
<UndefValue
>(RealOp1
) && !isa
<ConstantAggregateZero
>(RealOp1
)) {
1631 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1634 Value
*ImagOp1
= ImagShuffle
->getOperand(1);
1635 if (!isa
<UndefValue
>(ImagOp1
) && !isa
<ConstantAggregateZero
>(ImagOp1
)) {
1636 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1640 Value
*RealOp0
= RealShuffle
->getOperand(0);
1641 Value
*ImagOp0
= ImagShuffle
->getOperand(0);
1643 if (RealOp0
!= ImagOp0
) {
1644 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1648 ArrayRef
<int> RealMask
= RealShuffle
->getShuffleMask();
1649 ArrayRef
<int> ImagMask
= ImagShuffle
->getShuffleMask();
1650 if (!isDeinterleavingMask(RealMask
) || !isDeinterleavingMask(ImagMask
)) {
1651 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1655 if (RealMask
[0] != 0 || ImagMask
[0] != 1) {
1656 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1660 // Type checking, the shuffle type should be a vector type of the same
1661 // scalar type, but half the size
1662 auto CheckType
= [&](ShuffleVectorInst
*Shuffle
) {
1663 Value
*Op
= Shuffle
->getOperand(0);
1664 auto *ShuffleTy
= cast
<FixedVectorType
>(Shuffle
->getType());
1665 auto *OpTy
= cast
<FixedVectorType
>(Op
->getType());
1667 if (OpTy
->getScalarType() != ShuffleTy
->getScalarType())
1669 if ((ShuffleTy
->getNumElements() * 2) != OpTy
->getNumElements())
1675 auto CheckDeinterleavingShuffle
= [&](ShuffleVectorInst
*Shuffle
) -> bool {
1676 if (!CheckType(Shuffle
))
1679 ArrayRef
<int> Mask
= Shuffle
->getShuffleMask();
1680 int Last
= *Mask
.rbegin();
1682 Value
*Op
= Shuffle
->getOperand(0);
1683 auto *OpTy
= cast
<FixedVectorType
>(Op
->getType());
1684 int NumElements
= OpTy
->getNumElements();
1686 // Ensure that the deinterleaving shuffle only pulls from the first
1688 return Last
< NumElements
;
1691 if (RealShuffle
->getType() != ImagShuffle
->getType()) {
1692 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1695 if (!CheckDeinterleavingShuffle(RealShuffle
)) {
1696 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1699 if (!CheckDeinterleavingShuffle(ImagShuffle
)) {
1700 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1704 NodePtr PlaceholderNode
=
1705 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave
,
1706 RealShuffle
, ImagShuffle
);
1707 PlaceholderNode
->ReplacementNode
= RealShuffle
->getOperand(0);
1708 FinalInstructions
.insert(RealShuffle
);
1709 FinalInstructions
.insert(ImagShuffle
);
1710 return submitCompositeNode(PlaceholderNode
);
1713 ComplexDeinterleavingGraph::NodePtr
1714 ComplexDeinterleavingGraph::identifySplat(Value
*R
, Value
*I
) {
1715 auto IsSplat
= [](Value
*V
) -> bool {
1716 // Fixed-width vector with constants
1717 if (isa
<ConstantDataVector
>(V
))
1722 // Splats are represented differently depending on whether the repeated
1723 // value is a constant or an Instruction
1724 if (auto *Const
= dyn_cast
<ConstantExpr
>(V
)) {
1725 if (Const
->getOpcode() != Instruction::ShuffleVector
)
1727 VTy
= cast
<VectorType
>(Const
->getType());
1728 Mask
= Const
->getShuffleMask();
1729 } else if (auto *Shuf
= dyn_cast
<ShuffleVectorInst
>(V
)) {
1730 VTy
= Shuf
->getType();
1731 Mask
= Shuf
->getShuffleMask();
1736 // When the data type is <1 x Type>, it's not possible to differentiate
1737 // between the ComplexDeinterleaving::Deinterleave and
1738 // ComplexDeinterleaving::Splat operations.
1739 if (!VTy
->isScalableTy() && VTy
->getElementCount().getKnownMinValue() == 1)
1742 return all_equal(Mask
) && Mask
[0] == 0;
1745 if (!IsSplat(R
) || !IsSplat(I
))
1748 auto *Real
= dyn_cast
<Instruction
>(R
);
1749 auto *Imag
= dyn_cast
<Instruction
>(I
);
1750 if ((!Real
&& Imag
) || (Real
&& !Imag
))
1754 // Non-constant splats should be in the same basic block
1755 if (Real
->getParent() != Imag
->getParent())
1758 FinalInstructions
.insert(Real
);
1759 FinalInstructions
.insert(Imag
);
1761 NodePtr PlaceholderNode
=
1762 prepareCompositeNode(ComplexDeinterleavingOperation::Splat
, R
, I
);
1763 return submitCompositeNode(PlaceholderNode
);
1766 ComplexDeinterleavingGraph::NodePtr
1767 ComplexDeinterleavingGraph::identifyPHINode(Instruction
*Real
,
1768 Instruction
*Imag
) {
1769 if (Real
!= RealPHI
|| Imag
!= ImagPHI
)
1773 NodePtr PlaceholderNode
= prepareCompositeNode(
1774 ComplexDeinterleavingOperation::ReductionPHI
, Real
, Imag
);
1775 return submitCompositeNode(PlaceholderNode
);
1778 ComplexDeinterleavingGraph::NodePtr
1779 ComplexDeinterleavingGraph::identifySelectNode(Instruction
*Real
,
1780 Instruction
*Imag
) {
1781 auto *SelectReal
= dyn_cast
<SelectInst
>(Real
);
1782 auto *SelectImag
= dyn_cast
<SelectInst
>(Imag
);
1783 if (!SelectReal
|| !SelectImag
)
1786 Instruction
*MaskA
, *MaskB
;
1787 Instruction
*AR
, *AI
, *RA
, *BI
;
1788 if (!match(Real
, m_Select(m_Instruction(MaskA
), m_Instruction(AR
),
1789 m_Instruction(RA
))) ||
1790 !match(Imag
, m_Select(m_Instruction(MaskB
), m_Instruction(AI
),
1791 m_Instruction(BI
))))
1794 if (MaskA
!= MaskB
&& !MaskA
->isIdenticalTo(MaskB
))
1797 if (!MaskA
->getType()->isVectorTy())
1800 auto NodeA
= identifyNode(AR
, AI
);
1804 auto NodeB
= identifyNode(RA
, BI
);
1808 NodePtr PlaceholderNode
= prepareCompositeNode(
1809 ComplexDeinterleavingOperation::ReductionSelect
, Real
, Imag
);
1810 PlaceholderNode
->addOperand(NodeA
);
1811 PlaceholderNode
->addOperand(NodeB
);
1812 FinalInstructions
.insert(MaskA
);
1813 FinalInstructions
.insert(MaskB
);
1814 return submitCompositeNode(PlaceholderNode
);
1817 static Value
*replaceSymmetricNode(IRBuilderBase
&B
, unsigned Opcode
,
1818 FastMathFlags Flags
, Value
*InputA
,
1822 case Instruction::FNeg
:
1823 I
= B
.CreateFNeg(InputA
);
1825 case Instruction::FAdd
:
1826 I
= B
.CreateFAdd(InputA
, InputB
);
1828 case Instruction::FSub
:
1829 I
= B
.CreateFSub(InputA
, InputB
);
1831 case Instruction::FMul
:
1832 I
= B
.CreateFMul(InputA
, InputB
);
1835 llvm_unreachable("Incorrect symmetric opcode");
1837 cast
<Instruction
>(I
)->setFastMathFlags(Flags
);
1841 Value
*ComplexDeinterleavingGraph::replaceNode(IRBuilderBase
&Builder
,
1843 if (Node
->ReplacementNode
)
1844 return Node
->ReplacementNode
;
1846 auto ReplaceOperandIfExist
= [&](RawNodePtr
&Node
, unsigned Idx
) -> Value
* {
1847 return Node
->Operands
.size() > Idx
1848 ? replaceNode(Builder
, Node
->Operands
[Idx
])
1852 Value
*ReplacementNode
;
1853 switch (Node
->Operation
) {
1854 case ComplexDeinterleavingOperation::CAdd
:
1855 case ComplexDeinterleavingOperation::CMulPartial
:
1856 case ComplexDeinterleavingOperation::Symmetric
: {
1857 Value
*Input0
= ReplaceOperandIfExist(Node
, 0);
1858 Value
*Input1
= ReplaceOperandIfExist(Node
, 1);
1859 Value
*Accumulator
= ReplaceOperandIfExist(Node
, 2);
1860 assert(!Input1
|| (Input0
->getType() == Input1
->getType() &&
1861 "Node inputs need to be of the same type"));
1862 assert(!Accumulator
||
1863 (Input0
->getType() == Accumulator
->getType() &&
1864 "Accumulator and input need to be of the same type"));
1865 if (Node
->Operation
== ComplexDeinterleavingOperation::Symmetric
)
1866 ReplacementNode
= replaceSymmetricNode(Builder
, Node
->Opcode
, Node
->Flags
,
1869 ReplacementNode
= TL
->createComplexDeinterleavingIR(
1870 Builder
, Node
->Operation
, Node
->Rotation
, Input0
, Input1
,
1874 case ComplexDeinterleavingOperation::Deinterleave
:
1875 llvm_unreachable("Deinterleave node should already have ReplacementNode");
1877 case ComplexDeinterleavingOperation::Splat
: {
1878 auto *NewTy
= VectorType::getDoubleElementsVectorType(
1879 cast
<VectorType
>(Node
->Real
->getType()));
1880 auto *R
= dyn_cast
<Instruction
>(Node
->Real
);
1881 auto *I
= dyn_cast
<Instruction
>(Node
->Imag
);
1883 // Splats that are not constant are interleaved where they are located
1884 Instruction
*InsertPoint
= (I
->comesBefore(R
) ? R
: I
)->getNextNode();
1885 IRBuilder
<> IRB(InsertPoint
);
1887 IRB
.CreateIntrinsic(Intrinsic::experimental_vector_interleave2
, NewTy
,
1888 {Node
->Real
, Node
->Imag
});
1891 Builder
.CreateIntrinsic(Intrinsic::experimental_vector_interleave2
,
1892 NewTy
, {Node
->Real
, Node
->Imag
});
1896 case ComplexDeinterleavingOperation::ReductionPHI
: {
1897 // If Operation is ReductionPHI, a new empty PHINode is created.
1898 // It is filled later when the ReductionOperation is processed.
1899 auto *VTy
= cast
<VectorType
>(Node
->Real
->getType());
1900 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
1901 auto *NewPHI
= PHINode::Create(NewVTy
, 0, "", BackEdge
->getFirstNonPHI());
1902 OldToNewPHI
[dyn_cast
<PHINode
>(Node
->Real
)] = NewPHI
;
1903 ReplacementNode
= NewPHI
;
1906 case ComplexDeinterleavingOperation::ReductionOperation
:
1907 ReplacementNode
= replaceNode(Builder
, Node
->Operands
[0]);
1908 processReductionOperation(ReplacementNode
, Node
);
1910 case ComplexDeinterleavingOperation::ReductionSelect
: {
1911 auto *MaskReal
= cast
<Instruction
>(Node
->Real
)->getOperand(0);
1912 auto *MaskImag
= cast
<Instruction
>(Node
->Imag
)->getOperand(0);
1913 auto *A
= replaceNode(Builder
, Node
->Operands
[0]);
1914 auto *B
= replaceNode(Builder
, Node
->Operands
[1]);
1915 auto *NewMaskTy
= VectorType::getDoubleElementsVectorType(
1916 cast
<VectorType
>(MaskReal
->getType()));
1918 Builder
.CreateIntrinsic(Intrinsic::experimental_vector_interleave2
,
1919 NewMaskTy
, {MaskReal
, MaskImag
});
1920 ReplacementNode
= Builder
.CreateSelect(NewMask
, A
, B
);
1925 assert(ReplacementNode
&& "Target failed to create Intrinsic call.");
1926 NumComplexTransformations
+= 1;
1927 Node
->ReplacementNode
= ReplacementNode
;
1928 return ReplacementNode
;
1931 void ComplexDeinterleavingGraph::processReductionOperation(
1932 Value
*OperationReplacement
, RawNodePtr Node
) {
1933 auto *Real
= cast
<Instruction
>(Node
->Real
);
1934 auto *Imag
= cast
<Instruction
>(Node
->Imag
);
1935 auto *OldPHIReal
= ReductionInfo
[Real
].first
;
1936 auto *OldPHIImag
= ReductionInfo
[Imag
].first
;
1937 auto *NewPHI
= OldToNewPHI
[OldPHIReal
];
1939 auto *VTy
= cast
<VectorType
>(Real
->getType());
1940 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
1942 // We have to interleave initial origin values coming from IncomingBlock
1943 Value
*InitReal
= OldPHIReal
->getIncomingValueForBlock(Incoming
);
1944 Value
*InitImag
= OldPHIImag
->getIncomingValueForBlock(Incoming
);
1946 IRBuilder
<> Builder(Incoming
->getTerminator());
1947 auto *NewInit
= Builder
.CreateIntrinsic(
1948 Intrinsic::experimental_vector_interleave2
, NewVTy
, {InitReal
, InitImag
});
1950 NewPHI
->addIncoming(NewInit
, Incoming
);
1951 NewPHI
->addIncoming(OperationReplacement
, BackEdge
);
1953 // Deinterleave complex vector outside of loop so that it can be finally
1955 auto *FinalReductionReal
= ReductionInfo
[Real
].second
;
1956 auto *FinalReductionImag
= ReductionInfo
[Imag
].second
;
1958 Builder
.SetInsertPoint(
1959 &*FinalReductionReal
->getParent()->getFirstInsertionPt());
1960 auto *Deinterleave
= Builder
.CreateIntrinsic(
1961 Intrinsic::experimental_vector_deinterleave2
,
1962 OperationReplacement
->getType(), OperationReplacement
);
1964 auto *NewReal
= Builder
.CreateExtractValue(Deinterleave
, (uint64_t)0);
1965 FinalReductionReal
->replaceUsesOfWith(Real
, NewReal
);
1967 Builder
.SetInsertPoint(FinalReductionImag
);
1968 auto *NewImag
= Builder
.CreateExtractValue(Deinterleave
, 1);
1969 FinalReductionImag
->replaceUsesOfWith(Imag
, NewImag
);
1972 void ComplexDeinterleavingGraph::replaceNodes() {
1973 SmallVector
<Instruction
*, 16> DeadInstrRoots
;
1974 for (auto *RootInstruction
: OrderedRoots
) {
1975 // Check if this potential root went through check process and we can
1977 if (!RootToNode
.count(RootInstruction
))
1980 IRBuilder
<> Builder(RootInstruction
);
1981 auto RootNode
= RootToNode
[RootInstruction
];
1982 Value
*R
= replaceNode(Builder
, RootNode
.get());
1984 if (RootNode
->Operation
==
1985 ComplexDeinterleavingOperation::ReductionOperation
) {
1986 auto *RootReal
= cast
<Instruction
>(RootNode
->Real
);
1987 auto *RootImag
= cast
<Instruction
>(RootNode
->Imag
);
1988 ReductionInfo
[RootReal
].first
->removeIncomingValue(BackEdge
);
1989 ReductionInfo
[RootImag
].first
->removeIncomingValue(BackEdge
);
1990 DeadInstrRoots
.push_back(cast
<Instruction
>(RootReal
));
1991 DeadInstrRoots
.push_back(cast
<Instruction
>(RootImag
));
1993 assert(R
&& "Unable to find replacement for RootInstruction");
1994 DeadInstrRoots
.push_back(RootInstruction
);
1995 RootInstruction
->replaceAllUsesWith(R
);
1999 for (auto *I
: DeadInstrRoots
)
2000 RecursivelyDeleteTriviallyDeadInstructions(I
, TLI
);