1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
60 //===----------------------------------------------------------------------===//
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/MapVector.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/TargetLibraryInfo.h"
66 #include "llvm/Analysis/TargetTransformInfo.h"
67 #include "llvm/CodeGen/TargetLowering.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
77 using namespace PatternMatch
;
79 #define DEBUG_TYPE "complex-deinterleaving"
81 STATISTIC(NumComplexTransformations
, "Amount of complex patterns transformed");
83 static cl::opt
<bool> ComplexDeinterleavingEnabled(
84 "enable-complex-deinterleaving",
85 cl::desc("Enable generation of complex instructions"), cl::init(true),
88 /// Checks the given mask, and determines whether said mask is interleaving.
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef
<int> Mask
);
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101 static bool isDeinterleavingMask(ArrayRef
<int> Mask
);
103 /// Returns true if the operation is a negation of V, and it works for both
104 /// integers and floats.
105 static bool isNeg(Value
*V
);
107 /// Returns the operand for negation operation.
108 static Value
*getNegOperand(Value
*V
);
111 template <typename T
, typename IterT
>
112 std::optional
<T
> findCommonBetweenCollections(IterT A
, IterT B
) {
113 auto Common
= llvm::find_if(A
, [B
](T I
) { return llvm::is_contained(B
, I
); });
114 if (Common
!= A
.end())
115 return std::make_optional(*Common
);
119 class ComplexDeinterleavingLegacyPass
: public FunctionPass
{
123 ComplexDeinterleavingLegacyPass(const TargetMachine
*TM
= nullptr)
124 : FunctionPass(ID
), TM(TM
) {
125 initializeComplexDeinterleavingLegacyPassPass(
126 *PassRegistry::getPassRegistry());
129 StringRef
getPassName() const override
{
130 return "Complex Deinterleaving Pass";
133 bool runOnFunction(Function
&F
) override
;
134 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
135 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
136 AU
.setPreservesCFG();
140 const TargetMachine
*TM
;
143 class ComplexDeinterleavingGraph
;
144 struct ComplexDeinterleavingCompositeNode
{
146 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op
,
148 : Operation(Op
), Real(R
), Imag(I
) {}
151 friend class ComplexDeinterleavingGraph
;
152 using NodePtr
= std::shared_ptr
<ComplexDeinterleavingCompositeNode
>;
153 using RawNodePtr
= ComplexDeinterleavingCompositeNode
*;
154 bool OperandsValid
= true;
157 ComplexDeinterleavingOperation Operation
;
161 // This two members are required exclusively for generating
162 // ComplexDeinterleavingOperation::Symmetric operations.
164 std::optional
<FastMathFlags
> Flags
;
166 ComplexDeinterleavingRotation Rotation
=
167 ComplexDeinterleavingRotation::Rotation_0
;
168 SmallVector
<RawNodePtr
> Operands
;
169 Value
*ReplacementNode
= nullptr;
171 void addOperand(NodePtr Node
) {
172 if (!Node
|| !Node
.get())
173 OperandsValid
= false;
174 Operands
.push_back(Node
.get());
177 void dump() { dump(dbgs()); }
178 void dump(raw_ostream
&OS
) {
179 auto PrintValue
= [&](Value
*V
) {
187 auto PrintNodeRef
= [&](RawNodePtr Ptr
) {
194 OS
<< "- CompositeNode: " << this << "\n";
199 OS
<< " ReplacementNode: ";
200 PrintValue(ReplacementNode
);
201 OS
<< " Operation: " << (int)Operation
<< "\n";
202 OS
<< " Rotation: " << ((int)Rotation
* 90) << "\n";
203 OS
<< " Operands: \n";
204 for (const auto &Op
: Operands
) {
210 bool areOperandsValid() { return OperandsValid
; }
213 class ComplexDeinterleavingGraph
{
221 using Addend
= std::pair
<Value
*, bool>;
222 using NodePtr
= ComplexDeinterleavingCompositeNode::NodePtr
;
223 using RawNodePtr
= ComplexDeinterleavingCompositeNode::RawNodePtr
;
225 // Helper struct for holding info about potential partial multiplication
227 struct PartialMulCandidate
{
235 explicit ComplexDeinterleavingGraph(const TargetLowering
*TL
,
236 const TargetLibraryInfo
*TLI
)
237 : TL(TL
), TLI(TLI
) {}
240 const TargetLowering
*TL
= nullptr;
241 const TargetLibraryInfo
*TLI
= nullptr;
242 SmallVector
<NodePtr
> CompositeNodes
;
243 DenseMap
<std::pair
<Value
*, Value
*>, NodePtr
> CachedResult
;
245 SmallPtrSet
<Instruction
*, 16> FinalInstructions
;
247 /// Root instructions are instructions from which complex computation starts
248 std::map
<Instruction
*, NodePtr
> RootToNode
;
250 /// Topologically sorted root instructions
251 SmallVector
<Instruction
*, 1> OrderedRoots
;
253 /// When examining a basic block for complex deinterleaving, if it is a simple
254 /// one-block loop, then the only incoming block is 'Incoming' and the
255 /// 'BackEdge' block is the block itself."
256 BasicBlock
*BackEdge
= nullptr;
257 BasicBlock
*Incoming
= nullptr;
259 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
260 /// %OutsideUser as it is shown in the IR:
263 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
264 /// [ %ReductionOp, %vector.body ]
266 /// %ReductionOp = fadd i64 ...
268 /// br i1 %condition, label %vector.body, %middle.block
271 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
273 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
274 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
275 MapVector
<Instruction
*, std::pair
<PHINode
*, Instruction
*>> ReductionInfo
;
277 /// In the process of detecting a reduction, we consider a pair of
278 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
279 /// traverse the use-tree to detect complex operations. As this is a reduction
280 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
281 /// to the %ReductionOPs that we suspect to be complex.
282 /// RealPHI and ImagPHI are used by the identifyPHINode method.
283 PHINode
*RealPHI
= nullptr;
284 PHINode
*ImagPHI
= nullptr;
286 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
288 bool PHIsFound
= false;
290 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
291 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
292 /// This mapping is populated during
293 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
294 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
295 /// replacement process.
296 std::map
<PHINode
*, PHINode
*> OldToNewPHI
;
298 NodePtr
prepareCompositeNode(ComplexDeinterleavingOperation Operation
,
299 Value
*R
, Value
*I
) {
300 assert(((Operation
!= ComplexDeinterleavingOperation::ReductionPHI
&&
301 Operation
!= ComplexDeinterleavingOperation::ReductionOperation
) ||
303 "Reduction related nodes must have Real and Imaginary parts");
304 return std::make_shared
<ComplexDeinterleavingCompositeNode
>(Operation
, R
,
308 NodePtr
submitCompositeNode(NodePtr Node
) {
309 CompositeNodes
.push_back(Node
);
311 CachedResult
[{Node
->Real
, Node
->Imag
}] = Node
;
315 /// Identifies a complex partial multiply pattern and its rotation, based on
316 /// the following patterns
318 /// 0: r: cr + ar * br
320 /// 90: r: cr - ai * bi
322 /// 180: r: cr - ar * br
324 /// 270: r: cr + ai * bi
326 NodePtr
identifyPartialMul(Instruction
*Real
, Instruction
*Imag
);
328 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
329 /// is partially known from identifyPartialMul, filling in the other half of
330 /// the complex pair.
332 identifyNodeWithImplicitAdd(Instruction
*I
, Instruction
*J
,
333 std::pair
<Value
*, Value
*> &CommonOperandI
);
335 /// Identifies a complex add pattern and its rotation, based on the following
342 NodePtr
identifyAdd(Instruction
*Real
, Instruction
*Imag
);
343 NodePtr
identifySymmetricOperation(Instruction
*Real
, Instruction
*Imag
);
344 NodePtr
identifyPartialReduction(Value
*R
, Value
*I
);
345 NodePtr
identifyDotProduct(Value
*Inst
);
347 NodePtr
identifyNode(Value
*R
, Value
*I
);
349 /// Determine if a sum of complex numbers can be formed from \p RealAddends
350 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
351 /// Return nullptr if it is not possible to construct a complex number.
352 /// \p Flags are needed to generate symmetric Add and Sub operations.
353 NodePtr
identifyAdditions(std::list
<Addend
> &RealAddends
,
354 std::list
<Addend
> &ImagAddends
,
355 std::optional
<FastMathFlags
> Flags
,
356 NodePtr Accumulator
);
358 /// Extract one addend that have both real and imaginary parts positive.
359 NodePtr
extractPositiveAddend(std::list
<Addend
> &RealAddends
,
360 std::list
<Addend
> &ImagAddends
);
362 /// Determine if sum of multiplications of complex numbers can be formed from
363 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
364 /// to it. Return nullptr if it is not possible to construct a complex number.
365 NodePtr
identifyMultiplications(std::vector
<Product
> &RealMuls
,
366 std::vector
<Product
> &ImagMuls
,
367 NodePtr Accumulator
);
369 /// Go through pairs of multiplication (one Real and one Imag) and find all
370 /// possible candidates for partial multiplication and put them into \p
371 /// Candidates. Returns true if all Product has pair with common operand
372 bool collectPartialMuls(const std::vector
<Product
> &RealMuls
,
373 const std::vector
<Product
> &ImagMuls
,
374 std::vector
<PartialMulCandidate
> &Candidates
);
376 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
377 /// the order of complex computation operations may be significantly altered,
378 /// and the real and imaginary parts may not be executed in parallel. This
379 /// function takes this into consideration and employs a more general approach
380 /// to identify complex computations. Initially, it gathers all the addends
381 /// and multiplicands and then constructs a complex expression from them.
382 NodePtr
identifyReassocNodes(Instruction
*I
, Instruction
*J
);
384 NodePtr
identifyRoot(Instruction
*I
);
386 /// Identifies the Deinterleave operation applied to a vector containing
387 /// complex numbers. There are two ways to represent the Deinterleave
389 /// * Using two shufflevectors with even indices for /pReal instruction and
390 /// odd indices for /pImag instructions (only for fixed-width vectors)
391 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
392 /// intrinsic (for both fixed and scalable vectors)
393 NodePtr
identifyDeinterleave(Instruction
*Real
, Instruction
*Imag
);
395 /// identifying the operation that represents a complex number repeated in a
396 /// Splat vector. There are two possible types of splats: ConstantExpr with
397 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
398 /// initialization mask with all values set to zero.
399 NodePtr
identifySplat(Value
*Real
, Value
*Imag
);
401 NodePtr
identifyPHINode(Instruction
*Real
, Instruction
*Imag
);
403 /// Identifies SelectInsts in a loop that has reduction with predication masks
404 /// and/or predicated tail folding
405 NodePtr
identifySelectNode(Instruction
*Real
, Instruction
*Imag
);
407 Value
*replaceNode(IRBuilderBase
&Builder
, RawNodePtr Node
);
409 /// Complete IR modifications after producing new reduction operation:
410 /// * Populate the PHINode generated for
411 /// ComplexDeinterleavingOperation::ReductionPHI
412 /// * Deinterleave the final value outside of the loop and repurpose original
414 void processReductionOperation(Value
*OperationReplacement
, RawNodePtr Node
);
415 void processReductionSingle(Value
*OperationReplacement
, RawNodePtr Node
);
418 void dump() { dump(dbgs()); }
419 void dump(raw_ostream
&OS
) {
420 for (const auto &Node
: CompositeNodes
)
424 /// Returns false if the deinterleaving operation should be cancelled for the
426 bool identifyNodes(Instruction
*RootI
);
428 /// In case \pB is one-block loop, this function seeks potential reductions
429 /// and populates ReductionInfo. Returns true if any reductions were
431 bool collectPotentialReductions(BasicBlock
*B
);
433 void identifyReductionNodes();
435 /// Check that every instruction, from the roots to the leaves, has internal
439 /// Perform the actual replacement of the underlying instruction graph.
443 class ComplexDeinterleaving
{
445 ComplexDeinterleaving(const TargetLowering
*tl
, const TargetLibraryInfo
*tli
)
446 : TL(tl
), TLI(tli
) {}
447 bool runOnFunction(Function
&F
);
450 bool evaluateBasicBlock(BasicBlock
*B
);
452 const TargetLowering
*TL
= nullptr;
453 const TargetLibraryInfo
*TLI
= nullptr;
458 char ComplexDeinterleavingLegacyPass::ID
= 0;
460 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass
, DEBUG_TYPE
,
461 "Complex Deinterleaving", false, false)
462 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass
, DEBUG_TYPE
,
463 "Complex Deinterleaving", false, false)
465 PreservedAnalyses
ComplexDeinterleavingPass::run(Function
&F
,
466 FunctionAnalysisManager
&AM
) {
467 const TargetLowering
*TL
= TM
->getSubtargetImpl(F
)->getTargetLowering();
468 auto &TLI
= AM
.getResult
<llvm::TargetLibraryAnalysis
>(F
);
469 if (!ComplexDeinterleaving(TL
, &TLI
).runOnFunction(F
))
470 return PreservedAnalyses::all();
472 PreservedAnalyses PA
;
473 PA
.preserve
<FunctionAnalysisManagerModuleProxy
>();
477 FunctionPass
*llvm::createComplexDeinterleavingPass(const TargetMachine
*TM
) {
478 return new ComplexDeinterleavingLegacyPass(TM
);
481 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function
&F
) {
482 const auto *TL
= TM
->getSubtargetImpl(F
)->getTargetLowering();
483 auto TLI
= getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
484 return ComplexDeinterleaving(TL
, &TLI
).runOnFunction(F
);
487 bool ComplexDeinterleaving::runOnFunction(Function
&F
) {
488 if (!ComplexDeinterleavingEnabled
) {
490 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
494 if (!TL
->isComplexDeinterleavingSupported()) {
496 dbgs() << "Complex deinterleaving has been disabled, target does "
497 "not support lowering of complex number operations.\n");
501 bool Changed
= false;
503 Changed
|= evaluateBasicBlock(&B
);
508 static bool isInterleavingMask(ArrayRef
<int> Mask
) {
509 // If the size is not even, it's not an interleaving mask
510 if ((Mask
.size() & 1))
513 int HalfNumElements
= Mask
.size() / 2;
514 for (int Idx
= 0; Idx
< HalfNumElements
; ++Idx
) {
515 int MaskIdx
= Idx
* 2;
516 if (Mask
[MaskIdx
] != Idx
|| Mask
[MaskIdx
+ 1] != (Idx
+ HalfNumElements
))
523 static bool isDeinterleavingMask(ArrayRef
<int> Mask
) {
524 int Offset
= Mask
[0];
525 int HalfNumElements
= Mask
.size() / 2;
527 for (int Idx
= 1; Idx
< HalfNumElements
; ++Idx
) {
528 if (Mask
[Idx
] != (Idx
* 2) + Offset
)
535 bool isNeg(Value
*V
) {
536 return match(V
, m_FNeg(m_Value())) || match(V
, m_Neg(m_Value()));
539 Value
*getNegOperand(Value
*V
) {
541 auto *I
= cast
<Instruction
>(V
);
542 if (I
->getOpcode() == Instruction::FNeg
)
543 return I
->getOperand(0);
545 return I
->getOperand(1);
548 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock
*B
) {
549 ComplexDeinterleavingGraph
Graph(TL
, TLI
);
550 if (Graph
.collectPotentialReductions(B
))
551 Graph
.identifyReductionNodes();
554 Graph
.identifyNodes(&I
);
556 if (Graph
.checkNodes()) {
557 Graph
.replaceNodes();
564 ComplexDeinterleavingGraph::NodePtr
565 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
566 Instruction
*Real
, Instruction
*Imag
,
567 std::pair
<Value
*, Value
*> &PartialMatch
) {
568 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real
<< " / " << *Imag
571 if (!Real
->hasOneUse() || !Imag
->hasOneUse()) {
572 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
576 if ((Real
->getOpcode() != Instruction::FMul
&&
577 Real
->getOpcode() != Instruction::Mul
) ||
578 (Imag
->getOpcode() != Instruction::FMul
&&
579 Imag
->getOpcode() != Instruction::Mul
)) {
581 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
585 Value
*R0
= Real
->getOperand(0);
586 Value
*R1
= Real
->getOperand(1);
587 Value
*I0
= Imag
->getOperand(0);
588 Value
*I1
= Imag
->getOperand(1);
590 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
591 // rotations and use the operand.
594 if (match(R0
, m_Neg(m_Value(Op
)))) {
597 } else if (match(R1
, m_Neg(m_Value(Op
)))) {
606 } else if (match(I1
, m_Neg(m_Value(Op
)))) {
612 ComplexDeinterleavingRotation Rotation
= (ComplexDeinterleavingRotation
)Negs
;
614 Value
*CommonOperand
;
615 Value
*UncommonRealOp
;
616 Value
*UncommonImagOp
;
618 if (R0
== I0
|| R0
== I1
) {
621 } else if (R1
== I0
|| R1
== I1
) {
625 LLVM_DEBUG(dbgs() << " - No equal operand\n");
629 UncommonImagOp
= (CommonOperand
== I0
) ? I1
: I0
;
630 if (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
631 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
632 std::swap(UncommonRealOp
, UncommonImagOp
);
634 // Between identifyPartialMul and here we need to have found a complete valid
635 // pair from the CommonOperand of each part.
636 if (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
637 Rotation
== ComplexDeinterleavingRotation::Rotation_180
)
638 PartialMatch
.first
= CommonOperand
;
640 PartialMatch
.second
= CommonOperand
;
642 if (!PartialMatch
.first
|| !PartialMatch
.second
) {
643 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
647 NodePtr CommonNode
= identifyNode(PartialMatch
.first
, PartialMatch
.second
);
649 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
653 NodePtr UncommonNode
= identifyNode(UncommonRealOp
, UncommonImagOp
);
655 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
659 NodePtr Node
= prepareCompositeNode(
660 ComplexDeinterleavingOperation::CMulPartial
, Real
, Imag
);
661 Node
->Rotation
= Rotation
;
662 Node
->addOperand(CommonNode
);
663 Node
->addOperand(UncommonNode
);
664 return submitCompositeNode(Node
);
667 ComplexDeinterleavingGraph::NodePtr
668 ComplexDeinterleavingGraph::identifyPartialMul(Instruction
*Real
,
670 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real
<< " / " << *Imag
672 // Determine rotation
673 auto IsAdd
= [](unsigned Op
) {
674 return Op
== Instruction::FAdd
|| Op
== Instruction::Add
;
676 auto IsSub
= [](unsigned Op
) {
677 return Op
== Instruction::FSub
|| Op
== Instruction::Sub
;
679 ComplexDeinterleavingRotation Rotation
;
680 if (IsAdd(Real
->getOpcode()) && IsAdd(Imag
->getOpcode()))
681 Rotation
= ComplexDeinterleavingRotation::Rotation_0
;
682 else if (IsSub(Real
->getOpcode()) && IsAdd(Imag
->getOpcode()))
683 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
684 else if (IsSub(Real
->getOpcode()) && IsSub(Imag
->getOpcode()))
685 Rotation
= ComplexDeinterleavingRotation::Rotation_180
;
686 else if (IsAdd(Real
->getOpcode()) && IsSub(Imag
->getOpcode()))
687 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
689 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
693 if (isa
<FPMathOperator
>(Real
) &&
694 (!Real
->getFastMathFlags().allowContract() ||
695 !Imag
->getFastMathFlags().allowContract())) {
696 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
700 Value
*CR
= Real
->getOperand(0);
701 Instruction
*RealMulI
= dyn_cast
<Instruction
>(Real
->getOperand(1));
704 Value
*CI
= Imag
->getOperand(0);
705 Instruction
*ImagMulI
= dyn_cast
<Instruction
>(Imag
->getOperand(1));
709 if (!RealMulI
->hasOneUse() || !ImagMulI
->hasOneUse()) {
710 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
714 Value
*R0
= RealMulI
->getOperand(0);
715 Value
*R1
= RealMulI
->getOperand(1);
716 Value
*I0
= ImagMulI
->getOperand(0);
717 Value
*I1
= ImagMulI
->getOperand(1);
719 Value
*CommonOperand
;
720 Value
*UncommonRealOp
;
721 Value
*UncommonImagOp
;
723 if (R0
== I0
|| R0
== I1
) {
726 } else if (R1
== I0
|| R1
== I1
) {
730 LLVM_DEBUG(dbgs() << " - No equal operand\n");
734 UncommonImagOp
= (CommonOperand
== I0
) ? I1
: I0
;
735 if (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
736 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
737 std::swap(UncommonRealOp
, UncommonImagOp
);
739 std::pair
<Value
*, Value
*> PartialMatch(
740 (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
741 Rotation
== ComplexDeinterleavingRotation::Rotation_180
)
744 (Rotation
== ComplexDeinterleavingRotation::Rotation_90
||
745 Rotation
== ComplexDeinterleavingRotation::Rotation_270
)
749 auto *CRInst
= dyn_cast
<Instruction
>(CR
);
750 auto *CIInst
= dyn_cast
<Instruction
>(CI
);
752 if (!CRInst
|| !CIInst
) {
753 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
757 NodePtr CNode
= identifyNodeWithImplicitAdd(CRInst
, CIInst
, PartialMatch
);
759 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
763 NodePtr UncommonRes
= identifyNode(UncommonRealOp
, UncommonImagOp
);
765 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
769 assert(PartialMatch
.first
&& PartialMatch
.second
);
770 NodePtr CommonRes
= identifyNode(PartialMatch
.first
, PartialMatch
.second
);
772 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
776 NodePtr Node
= prepareCompositeNode(
777 ComplexDeinterleavingOperation::CMulPartial
, Real
, Imag
);
778 Node
->Rotation
= Rotation
;
779 Node
->addOperand(CommonRes
);
780 Node
->addOperand(UncommonRes
);
781 Node
->addOperand(CNode
);
782 return submitCompositeNode(Node
);
785 ComplexDeinterleavingGraph::NodePtr
786 ComplexDeinterleavingGraph::identifyAdd(Instruction
*Real
, Instruction
*Imag
) {
787 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real
<< " / " << *Imag
<< "\n");
789 // Determine rotation
790 ComplexDeinterleavingRotation Rotation
;
791 if ((Real
->getOpcode() == Instruction::FSub
&&
792 Imag
->getOpcode() == Instruction::FAdd
) ||
793 (Real
->getOpcode() == Instruction::Sub
&&
794 Imag
->getOpcode() == Instruction::Add
))
795 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
796 else if ((Real
->getOpcode() == Instruction::FAdd
&&
797 Imag
->getOpcode() == Instruction::FSub
) ||
798 (Real
->getOpcode() == Instruction::Add
&&
799 Imag
->getOpcode() == Instruction::Sub
))
800 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
802 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
806 auto *AR
= dyn_cast
<Instruction
>(Real
->getOperand(0));
807 auto *BI
= dyn_cast
<Instruction
>(Real
->getOperand(1));
808 auto *AI
= dyn_cast
<Instruction
>(Imag
->getOperand(0));
809 auto *BR
= dyn_cast
<Instruction
>(Imag
->getOperand(1));
811 if (!AR
|| !AI
|| !BR
|| !BI
) {
812 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
816 NodePtr ResA
= identifyNode(AR
, AI
);
818 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
821 NodePtr ResB
= identifyNode(BR
, BI
);
823 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
828 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd
, Real
, Imag
);
829 Node
->Rotation
= Rotation
;
830 Node
->addOperand(ResA
);
831 Node
->addOperand(ResB
);
832 return submitCompositeNode(Node
);
835 static bool isInstructionPairAdd(Instruction
*A
, Instruction
*B
) {
836 unsigned OpcA
= A
->getOpcode();
837 unsigned OpcB
= B
->getOpcode();
839 return (OpcA
== Instruction::FSub
&& OpcB
== Instruction::FAdd
) ||
840 (OpcA
== Instruction::FAdd
&& OpcB
== Instruction::FSub
) ||
841 (OpcA
== Instruction::Sub
&& OpcB
== Instruction::Add
) ||
842 (OpcA
== Instruction::Add
&& OpcB
== Instruction::Sub
);
845 static bool isInstructionPairMul(Instruction
*A
, Instruction
*B
) {
847 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
849 return match(A
, Pattern
) && match(B
, Pattern
);
852 static bool isInstructionPotentiallySymmetric(Instruction
*I
) {
853 switch (I
->getOpcode()) {
854 case Instruction::FAdd
:
855 case Instruction::FSub
:
856 case Instruction::FMul
:
857 case Instruction::FNeg
:
858 case Instruction::Add
:
859 case Instruction::Sub
:
860 case Instruction::Mul
:
867 ComplexDeinterleavingGraph::NodePtr
868 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction
*Real
,
870 if (Real
->getOpcode() != Imag
->getOpcode())
873 if (!isInstructionPotentiallySymmetric(Real
) ||
874 !isInstructionPotentiallySymmetric(Imag
))
877 auto *R0
= Real
->getOperand(0);
878 auto *I0
= Imag
->getOperand(0);
880 NodePtr Op0
= identifyNode(R0
, I0
);
881 NodePtr Op1
= nullptr;
885 if (Real
->isBinaryOp()) {
886 auto *R1
= Real
->getOperand(1);
887 auto *I1
= Imag
->getOperand(1);
888 Op1
= identifyNode(R1
, I1
);
893 if (isa
<FPMathOperator
>(Real
) &&
894 Real
->getFastMathFlags() != Imag
->getFastMathFlags())
897 auto Node
= prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric
,
899 Node
->Opcode
= Real
->getOpcode();
900 if (isa
<FPMathOperator
>(Real
))
901 Node
->Flags
= Real
->getFastMathFlags();
903 Node
->addOperand(Op0
);
904 if (Real
->isBinaryOp())
905 Node
->addOperand(Op1
);
907 return submitCompositeNode(Node
);
910 ComplexDeinterleavingGraph::NodePtr
911 ComplexDeinterleavingGraph::identifyDotProduct(Value
*V
) {
913 if (!TL
->isComplexDeinterleavingOperationSupported(
914 ComplexDeinterleavingOperation::CDot
, V
->getType())) {
915 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
916 "operation CDot with the type "
917 << *V
->getType() << "\n");
921 auto *Inst
= cast
<Instruction
>(V
);
922 auto *RealUser
= cast
<Instruction
>(*Inst
->user_begin());
925 prepareCompositeNode(ComplexDeinterleavingOperation::CDot
, Inst
, nullptr);
929 const Intrinsic::ID PartialReduceInt
=
930 Intrinsic::experimental_vector_partial_reduce_add
;
932 Value
*AReal
= nullptr;
933 Value
*AImag
= nullptr;
934 Value
*BReal
= nullptr;
935 Value
*BImag
= nullptr;
936 Value
*Phi
= nullptr;
938 auto UnwrapCast
= [](Value
*V
) -> Value
* {
939 if (auto *CI
= dyn_cast
<CastInst
>(V
))
940 return CI
->getOperand(0);
944 auto PatternRot0
= m_Intrinsic
<PartialReduceInt
>(
945 m_Intrinsic
<PartialReduceInt
>(m_Value(Phi
),
946 m_Mul(m_Value(BReal
), m_Value(AReal
))),
947 m_Neg(m_Mul(m_Value(BImag
), m_Value(AImag
))));
949 auto PatternRot270
= m_Intrinsic
<PartialReduceInt
>(
950 m_Intrinsic
<PartialReduceInt
>(
951 m_Value(Phi
), m_Neg(m_Mul(m_Value(BReal
), m_Value(AImag
)))),
952 m_Mul(m_Value(BImag
), m_Value(AReal
)));
954 if (match(Inst
, PatternRot0
)) {
955 CN
->Rotation
= ComplexDeinterleavingRotation::Rotation_0
;
956 } else if (match(Inst
, PatternRot270
)) {
957 CN
->Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
960 // The rotations 90 and 180 share the same operation pattern, so inspect the
961 // order of the operands, identifying where the real and imaginary
962 // components of A go, to discern between the aforementioned rotations.
963 auto PatternRot90Rot180
= m_Intrinsic
<PartialReduceInt
>(
964 m_Intrinsic
<PartialReduceInt
>(m_Value(Phi
),
965 m_Mul(m_Value(BReal
), m_Value(A0
))),
966 m_Mul(m_Value(BImag
), m_Value(A1
)));
968 if (!match(Inst
, PatternRot90Rot180
))
974 // Test if A0 is real/A1 is imag
975 ANode
= identifyNode(A0
, A1
);
977 // Test if A0 is imag/A1 is real
978 ANode
= identifyNode(A1
, A0
);
979 // Unable to identify operand components, thus unable to identify rotation
982 CN
->Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
988 CN
->Rotation
= ComplexDeinterleavingRotation::Rotation_180
;
992 AReal
= UnwrapCast(AReal
);
993 AImag
= UnwrapCast(AImag
);
994 BReal
= UnwrapCast(BReal
);
995 BImag
= UnwrapCast(BImag
);
997 VectorType
*VTy
= cast
<VectorType
>(V
->getType());
998 Type
*ExpectedOperandTy
= VectorType::getSubdividedVectorType(VTy
, 2);
999 if (AReal
->getType() != ExpectedOperandTy
)
1001 if (AImag
->getType() != ExpectedOperandTy
)
1003 if (BReal
->getType() != ExpectedOperandTy
)
1005 if (BImag
->getType() != ExpectedOperandTy
)
1008 if (Phi
->getType() != VTy
&& RealUser
->getType() != VTy
)
1011 NodePtr Node
= identifyNode(AReal
, AImag
);
1013 // In the case that a node was identified to figure out the rotation, ensure
1014 // that trying to identify a node with AReal and AImag post-unwrap results in
1016 if (ANode
&& Node
!= ANode
) {
1019 << "Identified node is different from previously identified node. "
1020 "Unable to confidently generate a complex operation node\n");
1024 CN
->addOperand(Node
);
1025 CN
->addOperand(identifyNode(BReal
, BImag
));
1026 CN
->addOperand(identifyNode(Phi
, RealUser
));
1028 return submitCompositeNode(CN
);
1031 ComplexDeinterleavingGraph::NodePtr
1032 ComplexDeinterleavingGraph::identifyPartialReduction(Value
*R
, Value
*I
) {
1033 // Partial reductions don't support non-vector types, so check these first
1034 if (!isa
<VectorType
>(R
->getType()) || !isa
<VectorType
>(I
->getType()))
1038 findCommonBetweenCollections
<Value
*>(R
->users(), I
->users());
1042 auto *IInst
= dyn_cast
<IntrinsicInst
>(*CommonUser
);
1043 if (!IInst
|| IInst
->getIntrinsicID() !=
1044 Intrinsic::experimental_vector_partial_reduce_add
)
1047 if (NodePtr CN
= identifyDotProduct(IInst
))
1053 ComplexDeinterleavingGraph::NodePtr
1054 ComplexDeinterleavingGraph::identifyNode(Value
*R
, Value
*I
) {
1055 auto It
= CachedResult
.find({R
, I
});
1056 if (It
!= CachedResult
.end()) {
1057 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1061 if (NodePtr CN
= identifyPartialReduction(R
, I
))
1064 bool IsReduction
= RealPHI
== R
&& (!ImagPHI
|| ImagPHI
== I
);
1065 if (!IsReduction
&& R
->getType() != I
->getType())
1068 if (NodePtr CN
= identifySplat(R
, I
))
1071 auto *Real
= dyn_cast
<Instruction
>(R
);
1072 auto *Imag
= dyn_cast
<Instruction
>(I
);
1076 if (NodePtr CN
= identifyDeinterleave(Real
, Imag
))
1079 if (NodePtr CN
= identifyPHINode(Real
, Imag
))
1082 if (NodePtr CN
= identifySelectNode(Real
, Imag
))
1085 auto *VTy
= cast
<VectorType
>(Real
->getType());
1086 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
1088 bool HasCMulSupport
= TL
->isComplexDeinterleavingOperationSupported(
1089 ComplexDeinterleavingOperation::CMulPartial
, NewVTy
);
1090 bool HasCAddSupport
= TL
->isComplexDeinterleavingOperationSupported(
1091 ComplexDeinterleavingOperation::CAdd
, NewVTy
);
1093 if (HasCMulSupport
&& isInstructionPairMul(Real
, Imag
)) {
1094 if (NodePtr CN
= identifyPartialMul(Real
, Imag
))
1098 if (HasCAddSupport
&& isInstructionPairAdd(Real
, Imag
)) {
1099 if (NodePtr CN
= identifyAdd(Real
, Imag
))
1103 if (HasCMulSupport
&& HasCAddSupport
) {
1104 if (NodePtr CN
= identifyReassocNodes(Real
, Imag
))
1108 if (NodePtr CN
= identifySymmetricOperation(Real
, Imag
))
1111 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1112 CachedResult
[{R
, I
}] = nullptr;
1116 ComplexDeinterleavingGraph::NodePtr
1117 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction
*Real
,
1118 Instruction
*Imag
) {
1119 auto IsOperationSupported
= [](unsigned Opcode
) -> bool {
1120 return Opcode
== Instruction::FAdd
|| Opcode
== Instruction::FSub
||
1121 Opcode
== Instruction::FNeg
|| Opcode
== Instruction::Add
||
1122 Opcode
== Instruction::Sub
;
1125 if (!IsOperationSupported(Real
->getOpcode()) ||
1126 !IsOperationSupported(Imag
->getOpcode()))
1129 std::optional
<FastMathFlags
> Flags
;
1130 if (isa
<FPMathOperator
>(Real
)) {
1131 if (Real
->getFastMathFlags() != Imag
->getFastMathFlags()) {
1132 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1137 Flags
= Real
->getFastMathFlags();
1138 if (!Flags
->allowReassoc()) {
1141 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1146 // Collect multiplications and addend instructions from the given instruction
1147 // while traversing it operands. Additionally, verify that all instructions
1148 // have the same fast math flags.
1149 auto Collect
= [&Flags
](Instruction
*Insn
, std::vector
<Product
> &Muls
,
1150 std::list
<Addend
> &Addends
) -> bool {
1151 SmallVector
<PointerIntPair
<Value
*, 1, bool>> Worklist
= {{Insn
, true}};
1152 SmallPtrSet
<Value
*, 8> Visited
;
1153 while (!Worklist
.empty()) {
1154 auto [V
, IsPositive
] = Worklist
.back();
1155 Worklist
.pop_back();
1156 if (!Visited
.insert(V
).second
)
1159 Instruction
*I
= dyn_cast
<Instruction
>(V
);
1161 Addends
.emplace_back(V
, IsPositive
);
1165 // If an instruction has more than one user, it indicates that it either
1166 // has an external user, which will be later checked by the checkNodes
1167 // function, or it is a subexpression utilized by multiple expressions. In
1168 // the latter case, we will attempt to separately identify the complex
1169 // operation from here in order to create a shared
1170 // ComplexDeinterleavingCompositeNode.
1171 if (I
!= Insn
&& I
->getNumUses() > 1) {
1172 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I
<< "\n");
1173 Addends
.emplace_back(I
, IsPositive
);
1176 switch (I
->getOpcode()) {
1177 case Instruction::FAdd
:
1178 case Instruction::Add
:
1179 Worklist
.emplace_back(I
->getOperand(1), IsPositive
);
1180 Worklist
.emplace_back(I
->getOperand(0), IsPositive
);
1182 case Instruction::FSub
:
1183 Worklist
.emplace_back(I
->getOperand(1), !IsPositive
);
1184 Worklist
.emplace_back(I
->getOperand(0), IsPositive
);
1186 case Instruction::Sub
:
1188 Worklist
.emplace_back(getNegOperand(I
), !IsPositive
);
1190 Worklist
.emplace_back(I
->getOperand(1), !IsPositive
);
1191 Worklist
.emplace_back(I
->getOperand(0), IsPositive
);
1194 case Instruction::FMul
:
1195 case Instruction::Mul
: {
1197 if (isNeg(I
->getOperand(0))) {
1198 A
= getNegOperand(I
->getOperand(0));
1199 IsPositive
= !IsPositive
;
1201 A
= I
->getOperand(0);
1204 if (isNeg(I
->getOperand(1))) {
1205 B
= getNegOperand(I
->getOperand(1));
1206 IsPositive
= !IsPositive
;
1208 B
= I
->getOperand(1);
1210 Muls
.push_back(Product
{A
, B
, IsPositive
});
1213 case Instruction::FNeg
:
1214 Worklist
.emplace_back(I
->getOperand(0), !IsPositive
);
1217 Addends
.emplace_back(I
, IsPositive
);
1221 if (Flags
&& I
->getFastMathFlags() != *Flags
) {
1222 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1223 "inconsistent with the root instructions' flags: "
1231 std::vector
<Product
> RealMuls
, ImagMuls
;
1232 std::list
<Addend
> RealAddends
, ImagAddends
;
1233 if (!Collect(Real
, RealMuls
, RealAddends
) ||
1234 !Collect(Imag
, ImagMuls
, ImagAddends
))
1237 if (RealAddends
.size() != ImagAddends
.size())
1241 if (!RealMuls
.empty() || !ImagMuls
.empty()) {
1242 // If there are multiplicands, extract positive addend and use it as an
1244 FinalNode
= extractPositiveAddend(RealAddends
, ImagAddends
);
1245 FinalNode
= identifyMultiplications(RealMuls
, ImagMuls
, FinalNode
);
1250 // Identify and process remaining additions
1251 if (!RealAddends
.empty() || !ImagAddends
.empty()) {
1252 FinalNode
= identifyAdditions(RealAddends
, ImagAddends
, Flags
, FinalNode
);
1256 assert(FinalNode
&& "FinalNode can not be nullptr here");
1257 // Set the Real and Imag fields of the final node and submit it
1258 FinalNode
->Real
= Real
;
1259 FinalNode
->Imag
= Imag
;
1260 submitCompositeNode(FinalNode
);
1264 bool ComplexDeinterleavingGraph::collectPartialMuls(
1265 const std::vector
<Product
> &RealMuls
, const std::vector
<Product
> &ImagMuls
,
1266 std::vector
<PartialMulCandidate
> &PartialMulCandidates
) {
1267 // Helper function to extract a common operand from two products
1268 auto FindCommonInstruction
= [](const Product
&Real
,
1269 const Product
&Imag
) -> Value
* {
1270 if (Real
.Multiplicand
== Imag
.Multiplicand
||
1271 Real
.Multiplicand
== Imag
.Multiplier
)
1272 return Real
.Multiplicand
;
1274 if (Real
.Multiplier
== Imag
.Multiplicand
||
1275 Real
.Multiplier
== Imag
.Multiplier
)
1276 return Real
.Multiplier
;
1281 // Iterating over real and imaginary multiplications to find common operands
1282 // If a common operand is found, a partial multiplication candidate is created
1283 // and added to the candidates vector The function returns false if no common
1284 // operands are found for any product
1285 for (unsigned i
= 0; i
< RealMuls
.size(); ++i
) {
1286 bool FoundCommon
= false;
1287 for (unsigned j
= 0; j
< ImagMuls
.size(); ++j
) {
1288 auto *Common
= FindCommonInstruction(RealMuls
[i
], ImagMuls
[j
]);
1292 auto *A
= RealMuls
[i
].Multiplicand
== Common
? RealMuls
[i
].Multiplier
1293 : RealMuls
[i
].Multiplicand
;
1294 auto *B
= ImagMuls
[j
].Multiplicand
== Common
? ImagMuls
[j
].Multiplier
1295 : ImagMuls
[j
].Multiplicand
;
1297 auto Node
= identifyNode(A
, B
);
1300 PartialMulCandidates
.push_back({Common
, Node
, i
, j
, false});
1303 Node
= identifyNode(B
, A
);
1306 PartialMulCandidates
.push_back({Common
, Node
, i
, j
, true});
1315 ComplexDeinterleavingGraph::NodePtr
1316 ComplexDeinterleavingGraph::identifyMultiplications(
1317 std::vector
<Product
> &RealMuls
, std::vector
<Product
> &ImagMuls
,
1318 NodePtr Accumulator
= nullptr) {
1319 if (RealMuls
.size() != ImagMuls
.size())
1322 std::vector
<PartialMulCandidate
> Info
;
1323 if (!collectPartialMuls(RealMuls
, ImagMuls
, Info
))
1326 // Map to store common instruction to node pointers
1327 std::map
<Value
*, NodePtr
> CommonToNode
;
1328 std::vector
<bool> Processed(Info
.size(), false);
1329 for (unsigned I
= 0; I
< Info
.size(); ++I
) {
1333 PartialMulCandidate
&InfoA
= Info
[I
];
1334 for (unsigned J
= I
+ 1; J
< Info
.size(); ++J
) {
1338 PartialMulCandidate
&InfoB
= Info
[J
];
1339 auto *InfoReal
= &InfoA
;
1340 auto *InfoImag
= &InfoB
;
1342 auto NodeFromCommon
= identifyNode(InfoReal
->Common
, InfoImag
->Common
);
1343 if (!NodeFromCommon
) {
1344 std::swap(InfoReal
, InfoImag
);
1345 NodeFromCommon
= identifyNode(InfoReal
->Common
, InfoImag
->Common
);
1347 if (!NodeFromCommon
)
1350 CommonToNode
[InfoReal
->Common
] = NodeFromCommon
;
1351 CommonToNode
[InfoImag
->Common
] = NodeFromCommon
;
1352 Processed
[I
] = true;
1353 Processed
[J
] = true;
1357 std::vector
<bool> ProcessedReal(RealMuls
.size(), false);
1358 std::vector
<bool> ProcessedImag(ImagMuls
.size(), false);
1359 NodePtr Result
= Accumulator
;
1360 for (auto &PMI
: Info
) {
1361 if (ProcessedReal
[PMI
.RealIdx
] || ProcessedImag
[PMI
.ImagIdx
])
1364 auto It
= CommonToNode
.find(PMI
.Common
);
1365 // TODO: Process independent complex multiplications. Cases like this:
1366 // A.real() * B where both A and B are complex numbers.
1367 if (It
== CommonToNode
.end()) {
1369 dbgs() << "Unprocessed independent partial multiplication:\n";
1370 for (auto *Mul
: {&RealMuls
[PMI
.RealIdx
], &RealMuls
[PMI
.RealIdx
]})
1371 dbgs().indent(4) << (Mul
->IsPositive
? "+" : "-") << *Mul
->Multiplier
1372 << " multiplied by " << *Mul
->Multiplicand
<< "\n";
1377 auto &RealMul
= RealMuls
[PMI
.RealIdx
];
1378 auto &ImagMul
= ImagMuls
[PMI
.ImagIdx
];
1380 auto NodeA
= It
->second
;
1381 auto NodeB
= PMI
.Node
;
1382 auto IsMultiplicandReal
= PMI
.Common
== NodeA
->Real
;
1383 // The following table illustrates the relationship between multiplications
1384 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1387 // Rotation | Real | Imag |
1388 // ---------+--------+--------+
1389 // 0 | x * u | x * v |
1390 // 90 | -y * v | y * u |
1391 // 180 | -x * u | -x * v |
1392 // 270 | y * v | -y * u |
1394 // Check if the candidate can indeed be represented by partial
1396 // TODO: Add support for multiplication by complex one
1397 if ((IsMultiplicandReal
&& PMI
.IsNodeInverted
) ||
1398 (!IsMultiplicandReal
&& !PMI
.IsNodeInverted
))
1401 // Determine the rotation based on the multiplications
1402 ComplexDeinterleavingRotation Rotation
;
1403 if (IsMultiplicandReal
) {
1404 // Detect 0 and 180 degrees rotation
1405 if (RealMul
.IsPositive
&& ImagMul
.IsPositive
)
1406 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_0
;
1407 else if (!RealMul
.IsPositive
&& !ImagMul
.IsPositive
)
1408 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_180
;
1413 // Detect 90 and 270 degrees rotation
1414 if (!RealMul
.IsPositive
&& ImagMul
.IsPositive
)
1415 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_90
;
1416 else if (RealMul
.IsPositive
&& !ImagMul
.IsPositive
)
1417 Rotation
= llvm::ComplexDeinterleavingRotation::Rotation_270
;
1423 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1424 dbgs().indent(4) << "X: " << *NodeA
->Real
<< "\n";
1425 dbgs().indent(4) << "Y: " << *NodeA
->Imag
<< "\n";
1426 dbgs().indent(4) << "U: " << *NodeB
->Real
<< "\n";
1427 dbgs().indent(4) << "V: " << *NodeB
->Imag
<< "\n";
1428 dbgs().indent(4) << "Rotation - " << (int)Rotation
* 90 << "\n";
1431 NodePtr NodeMul
= prepareCompositeNode(
1432 ComplexDeinterleavingOperation::CMulPartial
, nullptr, nullptr);
1433 NodeMul
->Rotation
= Rotation
;
1434 NodeMul
->addOperand(NodeA
);
1435 NodeMul
->addOperand(NodeB
);
1437 NodeMul
->addOperand(Result
);
1438 submitCompositeNode(NodeMul
);
1440 ProcessedReal
[PMI
.RealIdx
] = true;
1441 ProcessedImag
[PMI
.ImagIdx
] = true;
1444 // Ensure all products have been processed, if not return nullptr.
1445 if (!all_of(ProcessedReal
, [](bool V
) { return V
; }) ||
1446 !all_of(ProcessedImag
, [](bool V
) { return V
; })) {
1448 // Dump debug information about which partial multiplications are not
1451 dbgs() << "Unprocessed products (Real):\n";
1452 for (size_t i
= 0; i
< ProcessedReal
.size(); ++i
) {
1453 if (!ProcessedReal
[i
])
1454 dbgs().indent(4) << (RealMuls
[i
].IsPositive
? "+" : "-")
1455 << *RealMuls
[i
].Multiplier
<< " multiplied by "
1456 << *RealMuls
[i
].Multiplicand
<< "\n";
1458 dbgs() << "Unprocessed products (Imag):\n";
1459 for (size_t i
= 0; i
< ProcessedImag
.size(); ++i
) {
1460 if (!ProcessedImag
[i
])
1461 dbgs().indent(4) << (ImagMuls
[i
].IsPositive
? "+" : "-")
1462 << *ImagMuls
[i
].Multiplier
<< " multiplied by "
1463 << *ImagMuls
[i
].Multiplicand
<< "\n";
1472 ComplexDeinterleavingGraph::NodePtr
1473 ComplexDeinterleavingGraph::identifyAdditions(
1474 std::list
<Addend
> &RealAddends
, std::list
<Addend
> &ImagAddends
,
1475 std::optional
<FastMathFlags
> Flags
, NodePtr Accumulator
= nullptr) {
1476 if (RealAddends
.size() != ImagAddends
.size())
1480 // If we have accumulator use it as first addend
1482 Result
= Accumulator
;
1483 // Otherwise find an element with both positive real and imaginary parts.
1485 Result
= extractPositiveAddend(RealAddends
, ImagAddends
);
1490 while (!RealAddends
.empty()) {
1491 auto ItR
= RealAddends
.begin();
1492 auto [R
, IsPositiveR
] = *ItR
;
1494 bool FoundImag
= false;
1495 for (auto ItI
= ImagAddends
.begin(); ItI
!= ImagAddends
.end(); ++ItI
) {
1496 auto [I
, IsPositiveI
] = *ItI
;
1497 ComplexDeinterleavingRotation Rotation
;
1498 if (IsPositiveR
&& IsPositiveI
)
1499 Rotation
= ComplexDeinterleavingRotation::Rotation_0
;
1500 else if (!IsPositiveR
&& IsPositiveI
)
1501 Rotation
= ComplexDeinterleavingRotation::Rotation_90
;
1502 else if (!IsPositiveR
&& !IsPositiveI
)
1503 Rotation
= ComplexDeinterleavingRotation::Rotation_180
;
1505 Rotation
= ComplexDeinterleavingRotation::Rotation_270
;
1508 if (Rotation
== ComplexDeinterleavingRotation::Rotation_0
||
1509 Rotation
== ComplexDeinterleavingRotation::Rotation_180
) {
1510 AddNode
= identifyNode(R
, I
);
1512 AddNode
= identifyNode(I
, R
);
1516 dbgs() << "Identified addition:\n";
1517 dbgs().indent(4) << "X: " << *R
<< "\n";
1518 dbgs().indent(4) << "Y: " << *I
<< "\n";
1519 dbgs().indent(4) << "Rotation - " << (int)Rotation
* 90 << "\n";
1523 if (Rotation
== llvm::ComplexDeinterleavingRotation::Rotation_0
) {
1524 TmpNode
= prepareCompositeNode(
1525 ComplexDeinterleavingOperation::Symmetric
, nullptr, nullptr);
1527 TmpNode
->Opcode
= Instruction::FAdd
;
1528 TmpNode
->Flags
= *Flags
;
1530 TmpNode
->Opcode
= Instruction::Add
;
1532 } else if (Rotation
==
1533 llvm::ComplexDeinterleavingRotation::Rotation_180
) {
1534 TmpNode
= prepareCompositeNode(
1535 ComplexDeinterleavingOperation::Symmetric
, nullptr, nullptr);
1537 TmpNode
->Opcode
= Instruction::FSub
;
1538 TmpNode
->Flags
= *Flags
;
1540 TmpNode
->Opcode
= Instruction::Sub
;
1543 TmpNode
= prepareCompositeNode(ComplexDeinterleavingOperation::CAdd
,
1545 TmpNode
->Rotation
= Rotation
;
1548 TmpNode
->addOperand(Result
);
1549 TmpNode
->addOperand(AddNode
);
1550 submitCompositeNode(TmpNode
);
1552 RealAddends
.erase(ItR
);
1553 ImagAddends
.erase(ItI
);
1564 ComplexDeinterleavingGraph::NodePtr
1565 ComplexDeinterleavingGraph::extractPositiveAddend(
1566 std::list
<Addend
> &RealAddends
, std::list
<Addend
> &ImagAddends
) {
1567 for (auto ItR
= RealAddends
.begin(); ItR
!= RealAddends
.end(); ++ItR
) {
1568 for (auto ItI
= ImagAddends
.begin(); ItI
!= ImagAddends
.end(); ++ItI
) {
1569 auto [R
, IsPositiveR
] = *ItR
;
1570 auto [I
, IsPositiveI
] = *ItI
;
1571 if (IsPositiveR
&& IsPositiveI
) {
1572 auto Result
= identifyNode(R
, I
);
1574 RealAddends
.erase(ItR
);
1575 ImagAddends
.erase(ItI
);
1584 bool ComplexDeinterleavingGraph::identifyNodes(Instruction
*RootI
) {
1585 // This potential root instruction might already have been recognized as
1586 // reduction. Because RootToNode maps both Real and Imaginary parts to
1587 // CompositeNode we should choose only one either Real or Imag instruction to
1588 // use as an anchor for generating complex instruction.
1589 auto It
= RootToNode
.find(RootI
);
1590 if (It
!= RootToNode
.end()) {
1591 auto RootNode
= It
->second
;
1592 assert(RootNode
->Operation
==
1593 ComplexDeinterleavingOperation::ReductionOperation
||
1594 RootNode
->Operation
==
1595 ComplexDeinterleavingOperation::ReductionSingle
);
1596 // Find out which part, Real or Imag, comes later, and only if we come to
1597 // the latest part, add it to OrderedRoots.
1598 auto *R
= cast
<Instruction
>(RootNode
->Real
);
1599 auto *I
= RootNode
->Imag
? cast
<Instruction
>(RootNode
->Imag
) : nullptr;
1601 Instruction
*ReplacementAnchor
;
1603 ReplacementAnchor
= R
->comesBefore(I
) ? I
: R
;
1605 ReplacementAnchor
= R
;
1607 if (ReplacementAnchor
!= RootI
)
1609 OrderedRoots
.push_back(RootI
);
1613 auto RootNode
= identifyRoot(RootI
);
1618 Function
*F
= RootI
->getFunction();
1619 BasicBlock
*B
= RootI
->getParent();
1620 dbgs() << "Complex deinterleaving graph for " << F
->getName()
1621 << "::" << B
->getName() << ".\n";
1625 RootToNode
[RootI
] = RootNode
;
1626 OrderedRoots
.push_back(RootI
);
1630 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock
*B
) {
1631 bool FoundPotentialReduction
= false;
1633 auto *Br
= dyn_cast
<BranchInst
>(B
->getTerminator());
1634 if (!Br
|| Br
->getNumSuccessors() != 2)
1637 // Identify simple one-block loop
1638 if (Br
->getSuccessor(0) != B
&& Br
->getSuccessor(1) != B
)
1641 SmallVector
<PHINode
*> PHIs
;
1642 for (auto &PHI
: B
->phis()) {
1643 if (PHI
.getNumIncomingValues() != 2)
1646 if (!PHI
.getType()->isVectorTy())
1649 auto *ReductionOp
= dyn_cast
<Instruction
>(PHI
.getIncomingValueForBlock(B
));
1653 // Check if final instruction is reduced outside of current block
1654 Instruction
*FinalReduction
= nullptr;
1656 for (auto *U
: ReductionOp
->users()) {
1660 FinalReduction
= dyn_cast
<Instruction
>(U
);
1663 if (NumUsers
!= 2 || !FinalReduction
|| FinalReduction
->getParent() == B
||
1664 isa
<PHINode
>(FinalReduction
))
1667 ReductionInfo
[ReductionOp
] = {&PHI
, FinalReduction
};
1669 auto BackEdgeIdx
= PHI
.getBasicBlockIndex(B
);
1670 auto IncomingIdx
= BackEdgeIdx
== 0 ? 1 : 0;
1671 Incoming
= PHI
.getIncomingBlock(IncomingIdx
);
1672 FoundPotentialReduction
= true;
1674 // If the initial value of PHINode is an Instruction, consider it a leaf
1675 // value of a complex deinterleaving graph.
1677 dyn_cast
<Instruction
>(PHI
.getIncomingValueForBlock(Incoming
)))
1678 FinalInstructions
.insert(InitPHI
);
1680 return FoundPotentialReduction
;
1683 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1684 SmallVector
<bool> Processed(ReductionInfo
.size(), false);
1685 SmallVector
<Instruction
*> OperationInstruction
;
1686 for (auto &P
: ReductionInfo
)
1687 OperationInstruction
.push_back(P
.first
);
1689 // Identify a complex computation by evaluating two reduction operations that
1690 // potentially could be involved
1691 for (size_t i
= 0; i
< OperationInstruction
.size(); ++i
) {
1694 for (size_t j
= i
+ 1; j
< OperationInstruction
.size(); ++j
) {
1697 auto *Real
= OperationInstruction
[i
];
1698 auto *Imag
= OperationInstruction
[j
];
1699 if (Real
->getType() != Imag
->getType())
1702 RealPHI
= ReductionInfo
[Real
].first
;
1703 ImagPHI
= ReductionInfo
[Imag
].first
;
1705 auto Node
= identifyNode(Real
, Imag
);
1707 std::swap(Real
, Imag
);
1708 std::swap(RealPHI
, ImagPHI
);
1709 Node
= identifyNode(Real
, Imag
);
1712 // If a node is identified and reduction PHINode is used in the chain of
1713 // operations, mark its operation instructions as used to prevent
1714 // re-identification and attach the node to the real part
1715 if (Node
&& PHIsFound
) {
1716 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1717 << *Real
<< " / " << *Imag
<< "\n");
1718 Processed
[i
] = true;
1719 Processed
[j
] = true;
1720 auto RootNode
= prepareCompositeNode(
1721 ComplexDeinterleavingOperation::ReductionOperation
, Real
, Imag
);
1722 RootNode
->addOperand(Node
);
1723 RootToNode
[Real
] = RootNode
;
1724 RootToNode
[Imag
] = RootNode
;
1725 submitCompositeNode(RootNode
);
1730 auto *Real
= OperationInstruction
[i
];
1731 // We want to check that we have 2 operands, but the function attributes
1732 // being counted as operands bloats this value.
1733 if (Processed
[i
] || Real
->getNumOperands() < 2)
1736 RealPHI
= ReductionInfo
[Real
].first
;
1739 auto Node
= identifyNode(Real
->getOperand(0), Real
->getOperand(1));
1740 if (Node
&& PHIsFound
) {
1742 dbgs() << "Identified single reduction starting from instruction: "
1743 << *Real
<< "/" << *ReductionInfo
[Real
].second
<< "\n");
1744 Processed
[i
] = true;
1745 auto RootNode
= prepareCompositeNode(
1746 ComplexDeinterleavingOperation::ReductionSingle
, Real
, nullptr);
1747 RootNode
->addOperand(Node
);
1748 RootToNode
[Real
] = RootNode
;
1749 submitCompositeNode(RootNode
);
1757 bool ComplexDeinterleavingGraph::checkNodes() {
1759 bool FoundDeinterleaveNode
= false;
1760 for (NodePtr N
: CompositeNodes
) {
1761 if (!N
->areOperandsValid())
1763 if (N
->Operation
== ComplexDeinterleavingOperation::Deinterleave
)
1764 FoundDeinterleaveNode
= true;
1767 // We need a deinterleave node in order to guarantee that we're working with
1769 if (!FoundDeinterleaveNode
) {
1771 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1772 "guarantee safety during graph transformation.\n");
1776 // Collect all instructions from roots to leaves
1777 SmallPtrSet
<Instruction
*, 16> AllInstructions
;
1778 SmallVector
<Instruction
*, 8> Worklist
;
1779 for (auto &Pair
: RootToNode
)
1780 Worklist
.push_back(Pair
.first
);
1782 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1784 while (!Worklist
.empty()) {
1785 auto *I
= Worklist
.back();
1786 Worklist
.pop_back();
1788 if (!AllInstructions
.insert(I
).second
)
1791 for (Value
*Op
: I
->operands()) {
1792 if (auto *OpI
= dyn_cast
<Instruction
>(Op
)) {
1793 if (!FinalInstructions
.count(I
))
1794 Worklist
.emplace_back(OpI
);
1799 // Find instructions that have users outside of chain
1800 SmallVector
<Instruction
*, 2> OuterInstructions
;
1801 for (auto *I
: AllInstructions
) {
1803 if (RootToNode
.count(I
))
1806 for (User
*U
: I
->users()) {
1807 if (AllInstructions
.count(cast
<Instruction
>(U
)))
1810 // Found an instruction that is not used by XCMLA/XCADD chain
1811 Worklist
.emplace_back(I
);
1816 // If any instructions are found to be used outside, find and remove roots
1817 // that somehow connect to those instructions.
1818 SmallPtrSet
<Instruction
*, 16> Visited
;
1819 while (!Worklist
.empty()) {
1820 auto *I
= Worklist
.back();
1821 Worklist
.pop_back();
1822 if (!Visited
.insert(I
).second
)
1825 // Found an impacted root node. Removing it from the nodes to be
1827 if (RootToNode
.count(I
)) {
1828 LLVM_DEBUG(dbgs() << "Instruction " << *I
1829 << " could be deinterleaved but its chain of complex "
1830 "operations have an outside user\n");
1831 RootToNode
.erase(I
);
1834 if (!AllInstructions
.count(I
) || FinalInstructions
.count(I
))
1837 for (User
*U
: I
->users())
1838 Worklist
.emplace_back(cast
<Instruction
>(U
));
1840 for (Value
*Op
: I
->operands()) {
1841 if (auto *OpI
= dyn_cast
<Instruction
>(Op
))
1842 Worklist
.emplace_back(OpI
);
1845 return !RootToNode
.empty();
1848 ComplexDeinterleavingGraph::NodePtr
1849 ComplexDeinterleavingGraph::identifyRoot(Instruction
*RootI
) {
1850 if (auto *Intrinsic
= dyn_cast
<IntrinsicInst
>(RootI
)) {
1851 if (Intrinsic
->getIntrinsicID() != Intrinsic::vector_interleave2
)
1854 auto *Real
= dyn_cast
<Instruction
>(Intrinsic
->getOperand(0));
1855 auto *Imag
= dyn_cast
<Instruction
>(Intrinsic
->getOperand(1));
1859 return identifyNode(Real
, Imag
);
1862 auto *SVI
= dyn_cast
<ShuffleVectorInst
>(RootI
);
1866 // Look for a shufflevector that takes separate vectors of the real and
1867 // imaginary components and recombines them into a single vector.
1868 if (!isInterleavingMask(SVI
->getShuffleMask()))
1873 if (!match(RootI
, m_Shuffle(m_Instruction(Real
), m_Instruction(Imag
))))
1876 return identifyNode(Real
, Imag
);
1879 ComplexDeinterleavingGraph::NodePtr
1880 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction
*Real
,
1881 Instruction
*Imag
) {
1882 Instruction
*I
= nullptr;
1883 Value
*FinalValue
= nullptr;
1884 if (match(Real
, m_ExtractValue
<0>(m_Instruction(I
))) &&
1885 match(Imag
, m_ExtractValue
<1>(m_Specific(I
))) &&
1886 match(I
, m_Intrinsic
<Intrinsic::vector_deinterleave2
>(
1887 m_Value(FinalValue
)))) {
1888 NodePtr PlaceholderNode
= prepareCompositeNode(
1889 llvm::ComplexDeinterleavingOperation::Deinterleave
, Real
, Imag
);
1890 PlaceholderNode
->ReplacementNode
= FinalValue
;
1891 FinalInstructions
.insert(Real
);
1892 FinalInstructions
.insert(Imag
);
1893 return submitCompositeNode(PlaceholderNode
);
1896 auto *RealShuffle
= dyn_cast
<ShuffleVectorInst
>(Real
);
1897 auto *ImagShuffle
= dyn_cast
<ShuffleVectorInst
>(Imag
);
1898 if (!RealShuffle
|| !ImagShuffle
) {
1899 if (RealShuffle
|| ImagShuffle
)
1900 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1904 Value
*RealOp1
= RealShuffle
->getOperand(1);
1905 if (!isa
<UndefValue
>(RealOp1
) && !isa
<ConstantAggregateZero
>(RealOp1
)) {
1906 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1909 Value
*ImagOp1
= ImagShuffle
->getOperand(1);
1910 if (!isa
<UndefValue
>(ImagOp1
) && !isa
<ConstantAggregateZero
>(ImagOp1
)) {
1911 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1915 Value
*RealOp0
= RealShuffle
->getOperand(0);
1916 Value
*ImagOp0
= ImagShuffle
->getOperand(0);
1918 if (RealOp0
!= ImagOp0
) {
1919 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1923 ArrayRef
<int> RealMask
= RealShuffle
->getShuffleMask();
1924 ArrayRef
<int> ImagMask
= ImagShuffle
->getShuffleMask();
1925 if (!isDeinterleavingMask(RealMask
) || !isDeinterleavingMask(ImagMask
)) {
1926 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1930 if (RealMask
[0] != 0 || ImagMask
[0] != 1) {
1931 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1935 // Type checking, the shuffle type should be a vector type of the same
1936 // scalar type, but half the size
1937 auto CheckType
= [&](ShuffleVectorInst
*Shuffle
) {
1938 Value
*Op
= Shuffle
->getOperand(0);
1939 auto *ShuffleTy
= cast
<FixedVectorType
>(Shuffle
->getType());
1940 auto *OpTy
= cast
<FixedVectorType
>(Op
->getType());
1942 if (OpTy
->getScalarType() != ShuffleTy
->getScalarType())
1944 if ((ShuffleTy
->getNumElements() * 2) != OpTy
->getNumElements())
1950 auto CheckDeinterleavingShuffle
= [&](ShuffleVectorInst
*Shuffle
) -> bool {
1951 if (!CheckType(Shuffle
))
1954 ArrayRef
<int> Mask
= Shuffle
->getShuffleMask();
1955 int Last
= *Mask
.rbegin();
1957 Value
*Op
= Shuffle
->getOperand(0);
1958 auto *OpTy
= cast
<FixedVectorType
>(Op
->getType());
1959 int NumElements
= OpTy
->getNumElements();
1961 // Ensure that the deinterleaving shuffle only pulls from the first
1963 return Last
< NumElements
;
1966 if (RealShuffle
->getType() != ImagShuffle
->getType()) {
1967 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1970 if (!CheckDeinterleavingShuffle(RealShuffle
)) {
1971 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1974 if (!CheckDeinterleavingShuffle(ImagShuffle
)) {
1975 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1979 NodePtr PlaceholderNode
=
1980 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave
,
1981 RealShuffle
, ImagShuffle
);
1982 PlaceholderNode
->ReplacementNode
= RealShuffle
->getOperand(0);
1983 FinalInstructions
.insert(RealShuffle
);
1984 FinalInstructions
.insert(ImagShuffle
);
1985 return submitCompositeNode(PlaceholderNode
);
1988 ComplexDeinterleavingGraph::NodePtr
1989 ComplexDeinterleavingGraph::identifySplat(Value
*R
, Value
*I
) {
1990 auto IsSplat
= [](Value
*V
) -> bool {
1991 // Fixed-width vector with constants
1992 if (isa
<ConstantDataVector
>(V
))
1997 // Splats are represented differently depending on whether the repeated
1998 // value is a constant or an Instruction
1999 if (auto *Const
= dyn_cast
<ConstantExpr
>(V
)) {
2000 if (Const
->getOpcode() != Instruction::ShuffleVector
)
2002 VTy
= cast
<VectorType
>(Const
->getType());
2003 Mask
= Const
->getShuffleMask();
2004 } else if (auto *Shuf
= dyn_cast
<ShuffleVectorInst
>(V
)) {
2005 VTy
= Shuf
->getType();
2006 Mask
= Shuf
->getShuffleMask();
2011 // When the data type is <1 x Type>, it's not possible to differentiate
2012 // between the ComplexDeinterleaving::Deinterleave and
2013 // ComplexDeinterleaving::Splat operations.
2014 if (!VTy
->isScalableTy() && VTy
->getElementCount().getKnownMinValue() == 1)
2017 return all_equal(Mask
) && Mask
[0] == 0;
2020 if (!IsSplat(R
) || !IsSplat(I
))
2023 auto *Real
= dyn_cast
<Instruction
>(R
);
2024 auto *Imag
= dyn_cast
<Instruction
>(I
);
2025 if ((!Real
&& Imag
) || (Real
&& !Imag
))
2029 // Non-constant splats should be in the same basic block
2030 if (Real
->getParent() != Imag
->getParent())
2033 FinalInstructions
.insert(Real
);
2034 FinalInstructions
.insert(Imag
);
2036 NodePtr PlaceholderNode
=
2037 prepareCompositeNode(ComplexDeinterleavingOperation::Splat
, R
, I
);
2038 return submitCompositeNode(PlaceholderNode
);
2041 ComplexDeinterleavingGraph::NodePtr
2042 ComplexDeinterleavingGraph::identifyPHINode(Instruction
*Real
,
2043 Instruction
*Imag
) {
2044 if (Real
!= RealPHI
|| (ImagPHI
&& Imag
!= ImagPHI
))
2048 NodePtr PlaceholderNode
= prepareCompositeNode(
2049 ComplexDeinterleavingOperation::ReductionPHI
, Real
, Imag
);
2050 return submitCompositeNode(PlaceholderNode
);
2053 ComplexDeinterleavingGraph::NodePtr
2054 ComplexDeinterleavingGraph::identifySelectNode(Instruction
*Real
,
2055 Instruction
*Imag
) {
2056 auto *SelectReal
= dyn_cast
<SelectInst
>(Real
);
2057 auto *SelectImag
= dyn_cast
<SelectInst
>(Imag
);
2058 if (!SelectReal
|| !SelectImag
)
2061 Instruction
*MaskA
, *MaskB
;
2062 Instruction
*AR
, *AI
, *RA
, *BI
;
2063 if (!match(Real
, m_Select(m_Instruction(MaskA
), m_Instruction(AR
),
2064 m_Instruction(RA
))) ||
2065 !match(Imag
, m_Select(m_Instruction(MaskB
), m_Instruction(AI
),
2066 m_Instruction(BI
))))
2069 if (MaskA
!= MaskB
&& !MaskA
->isIdenticalTo(MaskB
))
2072 if (!MaskA
->getType()->isVectorTy())
2075 auto NodeA
= identifyNode(AR
, AI
);
2079 auto NodeB
= identifyNode(RA
, BI
);
2083 NodePtr PlaceholderNode
= prepareCompositeNode(
2084 ComplexDeinterleavingOperation::ReductionSelect
, Real
, Imag
);
2085 PlaceholderNode
->addOperand(NodeA
);
2086 PlaceholderNode
->addOperand(NodeB
);
2087 FinalInstructions
.insert(MaskA
);
2088 FinalInstructions
.insert(MaskB
);
2089 return submitCompositeNode(PlaceholderNode
);
2092 static Value
*replaceSymmetricNode(IRBuilderBase
&B
, unsigned Opcode
,
2093 std::optional
<FastMathFlags
> Flags
,
2094 Value
*InputA
, Value
*InputB
) {
2097 case Instruction::FNeg
:
2098 I
= B
.CreateFNeg(InputA
);
2100 case Instruction::FAdd
:
2101 I
= B
.CreateFAdd(InputA
, InputB
);
2103 case Instruction::Add
:
2104 I
= B
.CreateAdd(InputA
, InputB
);
2106 case Instruction::FSub
:
2107 I
= B
.CreateFSub(InputA
, InputB
);
2109 case Instruction::Sub
:
2110 I
= B
.CreateSub(InputA
, InputB
);
2112 case Instruction::FMul
:
2113 I
= B
.CreateFMul(InputA
, InputB
);
2115 case Instruction::Mul
:
2116 I
= B
.CreateMul(InputA
, InputB
);
2119 llvm_unreachable("Incorrect symmetric opcode");
2122 cast
<Instruction
>(I
)->setFastMathFlags(*Flags
);
2126 Value
*ComplexDeinterleavingGraph::replaceNode(IRBuilderBase
&Builder
,
2128 if (Node
->ReplacementNode
)
2129 return Node
->ReplacementNode
;
2131 auto ReplaceOperandIfExist
= [&](RawNodePtr
&Node
, unsigned Idx
) -> Value
* {
2132 return Node
->Operands
.size() > Idx
2133 ? replaceNode(Builder
, Node
->Operands
[Idx
])
2137 Value
*ReplacementNode
;
2138 switch (Node
->Operation
) {
2139 case ComplexDeinterleavingOperation::CDot
: {
2140 Value
*Input0
= ReplaceOperandIfExist(Node
, 0);
2141 Value
*Input1
= ReplaceOperandIfExist(Node
, 1);
2142 Value
*Accumulator
= ReplaceOperandIfExist(Node
, 2);
2143 assert(!Input1
|| (Input0
->getType() == Input1
->getType() &&
2144 "Node inputs need to be of the same type"));
2145 ReplacementNode
= TL
->createComplexDeinterleavingIR(
2146 Builder
, Node
->Operation
, Node
->Rotation
, Input0
, Input1
, Accumulator
);
2149 case ComplexDeinterleavingOperation::CAdd
:
2150 case ComplexDeinterleavingOperation::CMulPartial
:
2151 case ComplexDeinterleavingOperation::Symmetric
: {
2152 Value
*Input0
= ReplaceOperandIfExist(Node
, 0);
2153 Value
*Input1
= ReplaceOperandIfExist(Node
, 1);
2154 Value
*Accumulator
= ReplaceOperandIfExist(Node
, 2);
2155 assert(!Input1
|| (Input0
->getType() == Input1
->getType() &&
2156 "Node inputs need to be of the same type"));
2157 assert(!Accumulator
||
2158 (Input0
->getType() == Accumulator
->getType() &&
2159 "Accumulator and input need to be of the same type"));
2160 if (Node
->Operation
== ComplexDeinterleavingOperation::Symmetric
)
2161 ReplacementNode
= replaceSymmetricNode(Builder
, Node
->Opcode
, Node
->Flags
,
2164 ReplacementNode
= TL
->createComplexDeinterleavingIR(
2165 Builder
, Node
->Operation
, Node
->Rotation
, Input0
, Input1
,
2169 case ComplexDeinterleavingOperation::Deinterleave
:
2170 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2172 case ComplexDeinterleavingOperation::Splat
: {
2173 auto *NewTy
= VectorType::getDoubleElementsVectorType(
2174 cast
<VectorType
>(Node
->Real
->getType()));
2175 auto *R
= dyn_cast
<Instruction
>(Node
->Real
);
2176 auto *I
= dyn_cast
<Instruction
>(Node
->Imag
);
2178 // Splats that are not constant are interleaved where they are located
2179 Instruction
*InsertPoint
= (I
->comesBefore(R
) ? R
: I
)->getNextNode();
2180 IRBuilder
<> IRB(InsertPoint
);
2181 ReplacementNode
= IRB
.CreateIntrinsic(Intrinsic::vector_interleave2
,
2182 NewTy
, {Node
->Real
, Node
->Imag
});
2184 ReplacementNode
= Builder
.CreateIntrinsic(
2185 Intrinsic::vector_interleave2
, NewTy
, {Node
->Real
, Node
->Imag
});
2189 case ComplexDeinterleavingOperation::ReductionPHI
: {
2190 // If Operation is ReductionPHI, a new empty PHINode is created.
2191 // It is filled later when the ReductionOperation is processed.
2192 auto *OldPHI
= cast
<PHINode
>(Node
->Real
);
2193 auto *VTy
= cast
<VectorType
>(Node
->Real
->getType());
2194 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
2195 auto *NewPHI
= PHINode::Create(NewVTy
, 0, "", BackEdge
->getFirstNonPHIIt());
2196 OldToNewPHI
[OldPHI
] = NewPHI
;
2197 ReplacementNode
= NewPHI
;
2200 case ComplexDeinterleavingOperation::ReductionSingle
:
2201 ReplacementNode
= replaceNode(Builder
, Node
->Operands
[0]);
2202 processReductionSingle(ReplacementNode
, Node
);
2204 case ComplexDeinterleavingOperation::ReductionOperation
:
2205 ReplacementNode
= replaceNode(Builder
, Node
->Operands
[0]);
2206 processReductionOperation(ReplacementNode
, Node
);
2208 case ComplexDeinterleavingOperation::ReductionSelect
: {
2209 auto *MaskReal
= cast
<Instruction
>(Node
->Real
)->getOperand(0);
2210 auto *MaskImag
= cast
<Instruction
>(Node
->Imag
)->getOperand(0);
2211 auto *A
= replaceNode(Builder
, Node
->Operands
[0]);
2212 auto *B
= replaceNode(Builder
, Node
->Operands
[1]);
2213 auto *NewMaskTy
= VectorType::getDoubleElementsVectorType(
2214 cast
<VectorType
>(MaskReal
->getType()));
2215 auto *NewMask
= Builder
.CreateIntrinsic(Intrinsic::vector_interleave2
,
2216 NewMaskTy
, {MaskReal
, MaskImag
});
2217 ReplacementNode
= Builder
.CreateSelect(NewMask
, A
, B
);
2222 assert(ReplacementNode
&& "Target failed to create Intrinsic call.");
2223 NumComplexTransformations
+= 1;
2224 Node
->ReplacementNode
= ReplacementNode
;
2225 return ReplacementNode
;
2228 void ComplexDeinterleavingGraph::processReductionSingle(
2229 Value
*OperationReplacement
, RawNodePtr Node
) {
2230 auto *Real
= cast
<Instruction
>(Node
->Real
);
2231 auto *OldPHI
= ReductionInfo
[Real
].first
;
2232 auto *NewPHI
= OldToNewPHI
[OldPHI
];
2233 auto *VTy
= cast
<VectorType
>(Real
->getType());
2234 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
2236 Value
*Init
= OldPHI
->getIncomingValueForBlock(Incoming
);
2238 IRBuilder
<> Builder(Incoming
->getTerminator());
2240 Value
*NewInit
= nullptr;
2241 if (auto *C
= dyn_cast
<Constant
>(Init
)) {
2242 if (C
->isZeroValue())
2243 NewInit
= Constant::getNullValue(NewVTy
);
2247 NewInit
= Builder
.CreateIntrinsic(Intrinsic::vector_interleave2
, NewVTy
,
2248 {Init
, Constant::getNullValue(VTy
)});
2250 NewPHI
->addIncoming(NewInit
, Incoming
);
2251 NewPHI
->addIncoming(OperationReplacement
, BackEdge
);
2253 auto *FinalReduction
= ReductionInfo
[Real
].second
;
2254 Builder
.SetInsertPoint(&*FinalReduction
->getParent()->getFirstInsertionPt());
2256 auto *AddReduce
= Builder
.CreateAddReduce(OperationReplacement
);
2257 FinalReduction
->replaceAllUsesWith(AddReduce
);
2260 void ComplexDeinterleavingGraph::processReductionOperation(
2261 Value
*OperationReplacement
, RawNodePtr Node
) {
2262 auto *Real
= cast
<Instruction
>(Node
->Real
);
2263 auto *Imag
= cast
<Instruction
>(Node
->Imag
);
2264 auto *OldPHIReal
= ReductionInfo
[Real
].first
;
2265 auto *OldPHIImag
= ReductionInfo
[Imag
].first
;
2266 auto *NewPHI
= OldToNewPHI
[OldPHIReal
];
2268 auto *VTy
= cast
<VectorType
>(Real
->getType());
2269 auto *NewVTy
= VectorType::getDoubleElementsVectorType(VTy
);
2271 // We have to interleave initial origin values coming from IncomingBlock
2272 Value
*InitReal
= OldPHIReal
->getIncomingValueForBlock(Incoming
);
2273 Value
*InitImag
= OldPHIImag
->getIncomingValueForBlock(Incoming
);
2275 IRBuilder
<> Builder(Incoming
->getTerminator());
2276 auto *NewInit
= Builder
.CreateIntrinsic(Intrinsic::vector_interleave2
, NewVTy
,
2277 {InitReal
, InitImag
});
2279 NewPHI
->addIncoming(NewInit
, Incoming
);
2280 NewPHI
->addIncoming(OperationReplacement
, BackEdge
);
2282 // Deinterleave complex vector outside of loop so that it can be finally
2284 auto *FinalReductionReal
= ReductionInfo
[Real
].second
;
2285 auto *FinalReductionImag
= ReductionInfo
[Imag
].second
;
2287 Builder
.SetInsertPoint(
2288 &*FinalReductionReal
->getParent()->getFirstInsertionPt());
2289 auto *Deinterleave
= Builder
.CreateIntrinsic(Intrinsic::vector_deinterleave2
,
2290 OperationReplacement
->getType(),
2291 OperationReplacement
);
2293 auto *NewReal
= Builder
.CreateExtractValue(Deinterleave
, (uint64_t)0);
2294 FinalReductionReal
->replaceUsesOfWith(Real
, NewReal
);
2296 Builder
.SetInsertPoint(FinalReductionImag
);
2297 auto *NewImag
= Builder
.CreateExtractValue(Deinterleave
, 1);
2298 FinalReductionImag
->replaceUsesOfWith(Imag
, NewImag
);
2301 void ComplexDeinterleavingGraph::replaceNodes() {
2302 SmallVector
<Instruction
*, 16> DeadInstrRoots
;
2303 for (auto *RootInstruction
: OrderedRoots
) {
2304 // Check if this potential root went through check process and we can
2306 if (!RootToNode
.count(RootInstruction
))
2309 IRBuilder
<> Builder(RootInstruction
);
2310 auto RootNode
= RootToNode
[RootInstruction
];
2311 Value
*R
= replaceNode(Builder
, RootNode
.get());
2313 if (RootNode
->Operation
==
2314 ComplexDeinterleavingOperation::ReductionOperation
) {
2315 auto *RootReal
= cast
<Instruction
>(RootNode
->Real
);
2316 auto *RootImag
= cast
<Instruction
>(RootNode
->Imag
);
2317 ReductionInfo
[RootReal
].first
->removeIncomingValue(BackEdge
);
2318 ReductionInfo
[RootImag
].first
->removeIncomingValue(BackEdge
);
2319 DeadInstrRoots
.push_back(RootReal
);
2320 DeadInstrRoots
.push_back(RootImag
);
2321 } else if (RootNode
->Operation
==
2322 ComplexDeinterleavingOperation::ReductionSingle
) {
2323 auto *RootInst
= cast
<Instruction
>(RootNode
->Real
);
2324 ReductionInfo
[RootInst
].first
->removeIncomingValue(BackEdge
);
2325 DeadInstrRoots
.push_back(ReductionInfo
[RootInst
].second
);
2327 assert(R
&& "Unable to find replacement for RootInstruction");
2328 DeadInstrRoots
.push_back(RootInstruction
);
2329 RootInstruction
->replaceAllUsesWith(R
);
2333 for (auto *I
: DeadInstrRoots
)
2334 RecursivelyDeleteTriviallyDeadInstructions(I
, TLI
);