[AMDGPU][AsmParser][NFC] Translate parsed MIMG instructions to MCInsts automatically.
[llvm-project.git] / llvm / lib / CodeGen / ComplexDeinterleavingPass.cpp
blob23827b9a2fd7073b0b577df7cc3beeadc8e5810a
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
60 //===----------------------------------------------------------------------===//
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/Statistic.h"
64 #include "llvm/Analysis/TargetLibraryInfo.h"
65 #include "llvm/Analysis/TargetTransformInfo.h"
66 #include "llvm/CodeGen/TargetLowering.h"
67 #include "llvm/CodeGen/TargetPassConfig.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
76 using namespace llvm;
77 using namespace PatternMatch;
79 #define DEBUG_TYPE "complex-deinterleaving"
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84 "enable-complex-deinterleaving",
85 cl::desc("Enable generation of complex instructions"), cl::init(true),
86 cl::Hidden);
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
103 namespace {
105 class ComplexDeinterleavingLegacyPass : public FunctionPass {
106 public:
107 static char ID;
109 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
110 : FunctionPass(ID), TM(TM) {
111 initializeComplexDeinterleavingLegacyPassPass(
112 *PassRegistry::getPassRegistry());
115 StringRef getPassName() const override {
116 return "Complex Deinterleaving Pass";
119 bool runOnFunction(Function &F) override;
120 void getAnalysisUsage(AnalysisUsage &AU) const override {
121 AU.addRequired<TargetLibraryInfoWrapperPass>();
122 AU.setPreservesCFG();
125 private:
126 const TargetMachine *TM;
129 class ComplexDeinterleavingGraph;
130 struct ComplexDeinterleavingCompositeNode {
132 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
133 Value *R, Value *I)
134 : Operation(Op), Real(R), Imag(I) {}
136 private:
137 friend class ComplexDeinterleavingGraph;
138 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
139 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
141 public:
142 ComplexDeinterleavingOperation Operation;
143 Value *Real;
144 Value *Imag;
146 // This two members are required exclusively for generating
147 // ComplexDeinterleavingOperation::Symmetric operations.
148 unsigned Opcode;
149 FastMathFlags Flags;
151 ComplexDeinterleavingRotation Rotation =
152 ComplexDeinterleavingRotation::Rotation_0;
153 SmallVector<RawNodePtr> Operands;
154 Value *ReplacementNode = nullptr;
156 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
158 void dump() { dump(dbgs()); }
159 void dump(raw_ostream &OS) {
160 auto PrintValue = [&](Value *V) {
161 if (V) {
162 OS << "\"";
163 V->print(OS, true);
164 OS << "\"\n";
165 } else
166 OS << "nullptr\n";
168 auto PrintNodeRef = [&](RawNodePtr Ptr) {
169 if (Ptr)
170 OS << Ptr << "\n";
171 else
172 OS << "nullptr\n";
175 OS << "- CompositeNode: " << this << "\n";
176 OS << " Real: ";
177 PrintValue(Real);
178 OS << " Imag: ";
179 PrintValue(Imag);
180 OS << " ReplacementNode: ";
181 PrintValue(ReplacementNode);
182 OS << " Operation: " << (int)Operation << "\n";
183 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
184 OS << " Operands: \n";
185 for (const auto &Op : Operands) {
186 OS << " - ";
187 PrintNodeRef(Op);
192 class ComplexDeinterleavingGraph {
193 public:
194 struct Product {
195 Value *Multiplier;
196 Value *Multiplicand;
197 bool IsPositive;
200 using Addend = std::pair<Value *, bool>;
201 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
202 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
204 // Helper struct for holding info about potential partial multiplication
205 // candidates
206 struct PartialMulCandidate {
207 Value *Common;
208 NodePtr Node;
209 unsigned RealIdx;
210 unsigned ImagIdx;
211 bool IsNodeInverted;
214 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
215 const TargetLibraryInfo *TLI)
216 : TL(TL), TLI(TLI) {}
218 private:
219 const TargetLowering *TL = nullptr;
220 const TargetLibraryInfo *TLI = nullptr;
221 SmallVector<NodePtr> CompositeNodes;
223 SmallPtrSet<Instruction *, 16> FinalInstructions;
225 /// Root instructions are instructions from which complex computation starts
226 std::map<Instruction *, NodePtr> RootToNode;
228 /// Topologically sorted root instructions
229 SmallVector<Instruction *, 1> OrderedRoots;
231 /// When examining a basic block for complex deinterleaving, if it is a simple
232 /// one-block loop, then the only incoming block is 'Incoming' and the
233 /// 'BackEdge' block is the block itself."
234 BasicBlock *BackEdge = nullptr;
235 BasicBlock *Incoming = nullptr;
237 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
238 /// %OutsideUser as it is shown in the IR:
240 /// vector.body:
241 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
242 /// [ %ReductionOp, %vector.body ]
243 /// ...
244 /// %ReductionOp = fadd i64 ...
245 /// ...
246 /// br i1 %condition, label %vector.body, %middle.block
248 /// middle.block:
249 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
251 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
252 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
253 std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
255 /// In the process of detecting a reduction, we consider a pair of
256 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
257 /// traverse the use-tree to detect complex operations. As this is a reduction
258 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
259 /// to the %ReductionOPs that we suspect to be complex.
260 /// RealPHI and ImagPHI are used by the identifyPHINode method.
261 PHINode *RealPHI = nullptr;
262 PHINode *ImagPHI = nullptr;
264 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
265 /// detection.
266 bool PHIsFound = false;
268 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
269 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
270 /// This mapping is populated during
271 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
272 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
273 /// replacement process.
274 std::map<PHINode *, PHINode *> OldToNewPHI;
276 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
277 Value *R, Value *I) {
278 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
279 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
280 (R && I)) &&
281 "Reduction related nodes must have Real and Imaginary parts");
282 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
286 NodePtr submitCompositeNode(NodePtr Node) {
287 CompositeNodes.push_back(Node);
288 return Node;
291 NodePtr getContainingComposite(Value *R, Value *I) {
292 for (const auto &CN : CompositeNodes) {
293 if (CN->Real == R && CN->Imag == I)
294 return CN;
296 return nullptr;
299 /// Identifies a complex partial multiply pattern and its rotation, based on
300 /// the following patterns
302 /// 0: r: cr + ar * br
303 /// i: ci + ar * bi
304 /// 90: r: cr - ai * bi
305 /// i: ci + ai * br
306 /// 180: r: cr - ar * br
307 /// i: ci - ar * bi
308 /// 270: r: cr + ai * bi
309 /// i: ci - ai * br
310 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
312 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
313 /// is partially known from identifyPartialMul, filling in the other half of
314 /// the complex pair.
315 NodePtr
316 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
317 std::pair<Value *, Value *> &CommonOperandI);
319 /// Identifies a complex add pattern and its rotation, based on the following
320 /// patterns.
322 /// 90: r: ar - bi
323 /// i: ai + br
324 /// 270: r: ar + bi
325 /// i: ai - br
326 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
327 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
329 NodePtr identifyNode(Value *R, Value *I);
331 /// Determine if a sum of complex numbers can be formed from \p RealAddends
332 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
333 /// Return nullptr if it is not possible to construct a complex number.
334 /// \p Flags are needed to generate symmetric Add and Sub operations.
335 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
336 std::list<Addend> &ImagAddends, FastMathFlags Flags,
337 NodePtr Accumulator);
339 /// Extract one addend that have both real and imaginary parts positive.
340 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
341 std::list<Addend> &ImagAddends);
343 /// Determine if sum of multiplications of complex numbers can be formed from
344 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
345 /// to it. Return nullptr if it is not possible to construct a complex number.
346 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
347 std::vector<Product> &ImagMuls,
348 NodePtr Accumulator);
350 /// Go through pairs of multiplication (one Real and one Imag) and find all
351 /// possible candidates for partial multiplication and put them into \p
352 /// Candidates. Returns true if all Product has pair with common operand
353 bool collectPartialMuls(const std::vector<Product> &RealMuls,
354 const std::vector<Product> &ImagMuls,
355 std::vector<PartialMulCandidate> &Candidates);
357 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
358 /// the order of complex computation operations may be significantly altered,
359 /// and the real and imaginary parts may not be executed in parallel. This
360 /// function takes this into consideration and employs a more general approach
361 /// to identify complex computations. Initially, it gathers all the addends
362 /// and multiplicands and then constructs a complex expression from them.
363 NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
365 NodePtr identifyRoot(Instruction *I);
367 /// Identifies the Deinterleave operation applied to a vector containing
368 /// complex numbers. There are two ways to represent the Deinterleave
369 /// operation:
370 /// * Using two shufflevectors with even indices for /pReal instruction and
371 /// odd indices for /pImag instructions (only for fixed-width vectors)
372 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
373 /// intrinsic (for both fixed and scalable vectors)
374 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
376 /// identifying the operation that represents a complex number repeated in a
377 /// Splat vector. There are two possible types of splats: ConstantExpr with
378 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
379 /// initialization mask with all values set to zero.
380 NodePtr identifySplat(Value *Real, Value *Imag);
382 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
384 /// Identifies SelectInsts in a loop that has reduction with predication masks
385 /// and/or predicated tail folding
386 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
388 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
390 /// Complete IR modifications after producing new reduction operation:
391 /// * Populate the PHINode generated for
392 /// ComplexDeinterleavingOperation::ReductionPHI
393 /// * Deinterleave the final value outside of the loop and repurpose original
394 /// reduction users
395 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
397 public:
398 void dump() { dump(dbgs()); }
399 void dump(raw_ostream &OS) {
400 for (const auto &Node : CompositeNodes)
401 Node->dump(OS);
404 /// Returns false if the deinterleaving operation should be cancelled for the
405 /// current graph.
406 bool identifyNodes(Instruction *RootI);
408 /// In case \pB is one-block loop, this function seeks potential reductions
409 /// and populates ReductionInfo. Returns true if any reductions were
410 /// identified.
411 bool collectPotentialReductions(BasicBlock *B);
413 void identifyReductionNodes();
415 /// Check that every instruction, from the roots to the leaves, has internal
416 /// uses.
417 bool checkNodes();
419 /// Perform the actual replacement of the underlying instruction graph.
420 void replaceNodes();
423 class ComplexDeinterleaving {
424 public:
425 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
426 : TL(tl), TLI(tli) {}
427 bool runOnFunction(Function &F);
429 private:
430 bool evaluateBasicBlock(BasicBlock *B);
432 const TargetLowering *TL = nullptr;
433 const TargetLibraryInfo *TLI = nullptr;
436 } // namespace
438 char ComplexDeinterleavingLegacyPass::ID = 0;
440 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
441 "Complex Deinterleaving", false, false)
442 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
443 "Complex Deinterleaving", false, false)
445 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
446 FunctionAnalysisManager &AM) {
447 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
448 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
449 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
450 return PreservedAnalyses::all();
452 PreservedAnalyses PA;
453 PA.preserve<FunctionAnalysisManagerModuleProxy>();
454 return PA;
457 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
458 return new ComplexDeinterleavingLegacyPass(TM);
461 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
462 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
463 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
464 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
467 bool ComplexDeinterleaving::runOnFunction(Function &F) {
468 if (!ComplexDeinterleavingEnabled) {
469 LLVM_DEBUG(
470 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
471 return false;
474 if (!TL->isComplexDeinterleavingSupported()) {
475 LLVM_DEBUG(
476 dbgs() << "Complex deinterleaving has been disabled, target does "
477 "not support lowering of complex number operations.\n");
478 return false;
481 bool Changed = false;
482 for (auto &B : F)
483 Changed |= evaluateBasicBlock(&B);
485 return Changed;
488 static bool isInterleavingMask(ArrayRef<int> Mask) {
489 // If the size is not even, it's not an interleaving mask
490 if ((Mask.size() & 1))
491 return false;
493 int HalfNumElements = Mask.size() / 2;
494 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
495 int MaskIdx = Idx * 2;
496 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
497 return false;
500 return true;
503 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
504 int Offset = Mask[0];
505 int HalfNumElements = Mask.size() / 2;
507 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
508 if (Mask[Idx] != (Idx * 2) + Offset)
509 return false;
512 return true;
515 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
516 ComplexDeinterleavingGraph Graph(TL, TLI);
517 if (Graph.collectPotentialReductions(B))
518 Graph.identifyReductionNodes();
520 for (auto &I : *B)
521 Graph.identifyNodes(&I);
523 if (Graph.checkNodes()) {
524 Graph.replaceNodes();
525 return true;
528 return false;
531 ComplexDeinterleavingGraph::NodePtr
532 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
533 Instruction *Real, Instruction *Imag,
534 std::pair<Value *, Value *> &PartialMatch) {
535 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
536 << "\n");
538 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
539 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
540 return nullptr;
543 if (Real->getOpcode() != Instruction::FMul ||
544 Imag->getOpcode() != Instruction::FMul) {
545 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n");
546 return nullptr;
549 Value *R0 = Real->getOperand(0);
550 Value *R1 = Real->getOperand(1);
551 Value *I0 = Imag->getOperand(0);
552 Value *I1 = Imag->getOperand(1);
554 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
555 // rotations and use the operand.
556 unsigned Negs = 0;
557 Value *Op;
558 if (match(R0, m_Neg(m_Value(Op)))) {
559 Negs |= 1;
560 R0 = Op;
561 } else if (match(R1, m_Neg(m_Value(Op)))) {
562 Negs |= 1;
563 R1 = Op;
566 if (match(I0, m_Neg(m_Value(Op)))) {
567 Negs |= 2;
568 Negs ^= 1;
569 I0 = Op;
570 } else if (match(I1, m_Neg(m_Value(Op)))) {
571 Negs |= 2;
572 Negs ^= 1;
573 I1 = Op;
576 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
578 Value *CommonOperand;
579 Value *UncommonRealOp;
580 Value *UncommonImagOp;
582 if (R0 == I0 || R0 == I1) {
583 CommonOperand = R0;
584 UncommonRealOp = R1;
585 } else if (R1 == I0 || R1 == I1) {
586 CommonOperand = R1;
587 UncommonRealOp = R0;
588 } else {
589 LLVM_DEBUG(dbgs() << " - No equal operand\n");
590 return nullptr;
593 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
594 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
595 Rotation == ComplexDeinterleavingRotation::Rotation_270)
596 std::swap(UncommonRealOp, UncommonImagOp);
598 // Between identifyPartialMul and here we need to have found a complete valid
599 // pair from the CommonOperand of each part.
600 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
601 Rotation == ComplexDeinterleavingRotation::Rotation_180)
602 PartialMatch.first = CommonOperand;
603 else
604 PartialMatch.second = CommonOperand;
606 if (!PartialMatch.first || !PartialMatch.second) {
607 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
608 return nullptr;
611 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
612 if (!CommonNode) {
613 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
614 return nullptr;
617 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
618 if (!UncommonNode) {
619 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
620 return nullptr;
623 NodePtr Node = prepareCompositeNode(
624 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
625 Node->Rotation = Rotation;
626 Node->addOperand(CommonNode);
627 Node->addOperand(UncommonNode);
628 return submitCompositeNode(Node);
631 ComplexDeinterleavingGraph::NodePtr
632 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
633 Instruction *Imag) {
634 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
635 << "\n");
636 // Determine rotation
637 ComplexDeinterleavingRotation Rotation;
638 if (Real->getOpcode() == Instruction::FAdd &&
639 Imag->getOpcode() == Instruction::FAdd)
640 Rotation = ComplexDeinterleavingRotation::Rotation_0;
641 else if (Real->getOpcode() == Instruction::FSub &&
642 Imag->getOpcode() == Instruction::FAdd)
643 Rotation = ComplexDeinterleavingRotation::Rotation_90;
644 else if (Real->getOpcode() == Instruction::FSub &&
645 Imag->getOpcode() == Instruction::FSub)
646 Rotation = ComplexDeinterleavingRotation::Rotation_180;
647 else if (Real->getOpcode() == Instruction::FAdd &&
648 Imag->getOpcode() == Instruction::FSub)
649 Rotation = ComplexDeinterleavingRotation::Rotation_270;
650 else {
651 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
652 return nullptr;
655 if (!Real->getFastMathFlags().allowContract() ||
656 !Imag->getFastMathFlags().allowContract()) {
657 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
658 return nullptr;
661 Value *CR = Real->getOperand(0);
662 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
663 if (!RealMulI)
664 return nullptr;
665 Value *CI = Imag->getOperand(0);
666 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
667 if (!ImagMulI)
668 return nullptr;
670 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
671 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
672 return nullptr;
675 Value *R0 = RealMulI->getOperand(0);
676 Value *R1 = RealMulI->getOperand(1);
677 Value *I0 = ImagMulI->getOperand(0);
678 Value *I1 = ImagMulI->getOperand(1);
680 Value *CommonOperand;
681 Value *UncommonRealOp;
682 Value *UncommonImagOp;
684 if (R0 == I0 || R0 == I1) {
685 CommonOperand = R0;
686 UncommonRealOp = R1;
687 } else if (R1 == I0 || R1 == I1) {
688 CommonOperand = R1;
689 UncommonRealOp = R0;
690 } else {
691 LLVM_DEBUG(dbgs() << " - No equal operand\n");
692 return nullptr;
695 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
696 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
697 Rotation == ComplexDeinterleavingRotation::Rotation_270)
698 std::swap(UncommonRealOp, UncommonImagOp);
700 std::pair<Value *, Value *> PartialMatch(
701 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
702 Rotation == ComplexDeinterleavingRotation::Rotation_180)
703 ? CommonOperand
704 : nullptr,
705 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
706 Rotation == ComplexDeinterleavingRotation::Rotation_270)
707 ? CommonOperand
708 : nullptr);
710 auto *CRInst = dyn_cast<Instruction>(CR);
711 auto *CIInst = dyn_cast<Instruction>(CI);
713 if (!CRInst || !CIInst) {
714 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
715 return nullptr;
718 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
719 if (!CNode) {
720 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
721 return nullptr;
724 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
725 if (!UncommonRes) {
726 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
727 return nullptr;
730 assert(PartialMatch.first && PartialMatch.second);
731 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
732 if (!CommonRes) {
733 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
734 return nullptr;
737 NodePtr Node = prepareCompositeNode(
738 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
739 Node->Rotation = Rotation;
740 Node->addOperand(CommonRes);
741 Node->addOperand(UncommonRes);
742 Node->addOperand(CNode);
743 return submitCompositeNode(Node);
746 ComplexDeinterleavingGraph::NodePtr
747 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
748 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
750 // Determine rotation
751 ComplexDeinterleavingRotation Rotation;
752 if ((Real->getOpcode() == Instruction::FSub &&
753 Imag->getOpcode() == Instruction::FAdd) ||
754 (Real->getOpcode() == Instruction::Sub &&
755 Imag->getOpcode() == Instruction::Add))
756 Rotation = ComplexDeinterleavingRotation::Rotation_90;
757 else if ((Real->getOpcode() == Instruction::FAdd &&
758 Imag->getOpcode() == Instruction::FSub) ||
759 (Real->getOpcode() == Instruction::Add &&
760 Imag->getOpcode() == Instruction::Sub))
761 Rotation = ComplexDeinterleavingRotation::Rotation_270;
762 else {
763 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
764 return nullptr;
767 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
768 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
769 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
770 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
772 if (!AR || !AI || !BR || !BI) {
773 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
774 return nullptr;
777 NodePtr ResA = identifyNode(AR, AI);
778 if (!ResA) {
779 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
780 return nullptr;
782 NodePtr ResB = identifyNode(BR, BI);
783 if (!ResB) {
784 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
785 return nullptr;
788 NodePtr Node =
789 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
790 Node->Rotation = Rotation;
791 Node->addOperand(ResA);
792 Node->addOperand(ResB);
793 return submitCompositeNode(Node);
796 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
797 unsigned OpcA = A->getOpcode();
798 unsigned OpcB = B->getOpcode();
800 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
801 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
802 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
803 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
806 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
807 auto Pattern =
808 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
810 return match(A, Pattern) && match(B, Pattern);
813 static bool isInstructionPotentiallySymmetric(Instruction *I) {
814 switch (I->getOpcode()) {
815 case Instruction::FAdd:
816 case Instruction::FSub:
817 case Instruction::FMul:
818 case Instruction::FNeg:
819 return true;
820 default:
821 return false;
825 ComplexDeinterleavingGraph::NodePtr
826 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
827 Instruction *Imag) {
828 if (Real->getOpcode() != Imag->getOpcode())
829 return nullptr;
831 if (!isInstructionPotentiallySymmetric(Real) ||
832 !isInstructionPotentiallySymmetric(Imag))
833 return nullptr;
835 auto *R0 = Real->getOperand(0);
836 auto *I0 = Imag->getOperand(0);
838 NodePtr Op0 = identifyNode(R0, I0);
839 NodePtr Op1 = nullptr;
840 if (Op0 == nullptr)
841 return nullptr;
843 if (Real->isBinaryOp()) {
844 auto *R1 = Real->getOperand(1);
845 auto *I1 = Imag->getOperand(1);
846 Op1 = identifyNode(R1, I1);
847 if (Op1 == nullptr)
848 return nullptr;
851 if (isa<FPMathOperator>(Real) &&
852 Real->getFastMathFlags() != Imag->getFastMathFlags())
853 return nullptr;
855 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
856 Real, Imag);
857 Node->Opcode = Real->getOpcode();
858 if (isa<FPMathOperator>(Real))
859 Node->Flags = Real->getFastMathFlags();
861 Node->addOperand(Op0);
862 if (Real->isBinaryOp())
863 Node->addOperand(Op1);
865 return submitCompositeNode(Node);
868 ComplexDeinterleavingGraph::NodePtr
869 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
870 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
871 assert(R->getType() == I->getType() &&
872 "Real and imaginary parts should not have different types");
873 if (NodePtr CN = getContainingComposite(R, I)) {
874 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
875 return CN;
878 if (NodePtr CN = identifySplat(R, I))
879 return CN;
881 auto *Real = dyn_cast<Instruction>(R);
882 auto *Imag = dyn_cast<Instruction>(I);
883 if (!Real || !Imag)
884 return nullptr;
886 if (NodePtr CN = identifyDeinterleave(Real, Imag))
887 return CN;
889 if (NodePtr CN = identifyPHINode(Real, Imag))
890 return CN;
892 if (NodePtr CN = identifySelectNode(Real, Imag))
893 return CN;
895 auto *VTy = cast<VectorType>(Real->getType());
896 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
898 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
899 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
900 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
901 ComplexDeinterleavingOperation::CAdd, NewVTy);
903 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
904 if (NodePtr CN = identifyPartialMul(Real, Imag))
905 return CN;
908 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
909 if (NodePtr CN = identifyAdd(Real, Imag))
910 return CN;
913 if (HasCMulSupport && HasCAddSupport) {
914 if (NodePtr CN = identifyReassocNodes(Real, Imag))
915 return CN;
918 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
919 return CN;
921 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
922 return nullptr;
925 ComplexDeinterleavingGraph::NodePtr
926 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
927 Instruction *Imag) {
929 if ((Real->getOpcode() != Instruction::FAdd &&
930 Real->getOpcode() != Instruction::FSub &&
931 Real->getOpcode() != Instruction::FNeg) ||
932 (Imag->getOpcode() != Instruction::FAdd &&
933 Imag->getOpcode() != Instruction::FSub &&
934 Imag->getOpcode() != Instruction::FNeg))
935 return nullptr;
937 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
938 LLVM_DEBUG(
939 dbgs()
940 << "The flags in Real and Imaginary instructions are not identical\n");
941 return nullptr;
944 FastMathFlags Flags = Real->getFastMathFlags();
945 if (!Flags.allowReassoc()) {
946 LLVM_DEBUG(
947 dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n");
948 return nullptr;
951 // Collect multiplications and addend instructions from the given instruction
952 // while traversing it operands. Additionally, verify that all instructions
953 // have the same fast math flags.
954 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
955 std::list<Addend> &Addends) -> bool {
956 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
957 SmallPtrSet<Value *, 8> Visited;
958 while (!Worklist.empty()) {
959 auto [V, IsPositive] = Worklist.back();
960 Worklist.pop_back();
961 if (!Visited.insert(V).second)
962 continue;
964 Instruction *I = dyn_cast<Instruction>(V);
965 if (!I) {
966 Addends.emplace_back(V, IsPositive);
967 continue;
970 // If an instruction has more than one user, it indicates that it either
971 // has an external user, which will be later checked by the checkNodes
972 // function, or it is a subexpression utilized by multiple expressions. In
973 // the latter case, we will attempt to separately identify the complex
974 // operation from here in order to create a shared
975 // ComplexDeinterleavingCompositeNode.
976 if (I != Insn && I->getNumUses() > 1) {
977 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
978 Addends.emplace_back(I, IsPositive);
979 continue;
982 if (I->getOpcode() == Instruction::FAdd) {
983 Worklist.emplace_back(I->getOperand(1), IsPositive);
984 Worklist.emplace_back(I->getOperand(0), IsPositive);
985 } else if (I->getOpcode() == Instruction::FSub) {
986 Worklist.emplace_back(I->getOperand(1), !IsPositive);
987 Worklist.emplace_back(I->getOperand(0), IsPositive);
988 } else if (I->getOpcode() == Instruction::FMul) {
989 Value *A, *B;
990 if (match(I->getOperand(0), m_FNeg(m_Value(A)))) {
991 IsPositive = !IsPositive;
992 } else {
993 A = I->getOperand(0);
996 if (match(I->getOperand(1), m_FNeg(m_Value(B)))) {
997 IsPositive = !IsPositive;
998 } else {
999 B = I->getOperand(1);
1001 Muls.push_back(Product{A, B, IsPositive});
1002 } else if (I->getOpcode() == Instruction::FNeg) {
1003 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1004 } else {
1005 Addends.emplace_back(I, IsPositive);
1006 continue;
1009 if (I->getFastMathFlags() != Flags) {
1010 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1011 "inconsistent with the root instructions' flags: "
1012 << *I << "\n");
1013 return false;
1016 return true;
1019 std::vector<Product> RealMuls, ImagMuls;
1020 std::list<Addend> RealAddends, ImagAddends;
1021 if (!Collect(Real, RealMuls, RealAddends) ||
1022 !Collect(Imag, ImagMuls, ImagAddends))
1023 return nullptr;
1025 if (RealAddends.size() != ImagAddends.size())
1026 return nullptr;
1028 NodePtr FinalNode;
1029 if (!RealMuls.empty() || !ImagMuls.empty()) {
1030 // If there are multiplicands, extract positive addend and use it as an
1031 // accumulator
1032 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1033 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1034 if (!FinalNode)
1035 return nullptr;
1038 // Identify and process remaining additions
1039 if (!RealAddends.empty() || !ImagAddends.empty()) {
1040 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1041 if (!FinalNode)
1042 return nullptr;
1044 assert(FinalNode && "FinalNode can not be nullptr here");
1045 // Set the Real and Imag fields of the final node and submit it
1046 FinalNode->Real = Real;
1047 FinalNode->Imag = Imag;
1048 submitCompositeNode(FinalNode);
1049 return FinalNode;
1052 bool ComplexDeinterleavingGraph::collectPartialMuls(
1053 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1054 std::vector<PartialMulCandidate> &PartialMulCandidates) {
1055 // Helper function to extract a common operand from two products
1056 auto FindCommonInstruction = [](const Product &Real,
1057 const Product &Imag) -> Value * {
1058 if (Real.Multiplicand == Imag.Multiplicand ||
1059 Real.Multiplicand == Imag.Multiplier)
1060 return Real.Multiplicand;
1062 if (Real.Multiplier == Imag.Multiplicand ||
1063 Real.Multiplier == Imag.Multiplier)
1064 return Real.Multiplier;
1066 return nullptr;
1069 // Iterating over real and imaginary multiplications to find common operands
1070 // If a common operand is found, a partial multiplication candidate is created
1071 // and added to the candidates vector The function returns false if no common
1072 // operands are found for any product
1073 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1074 bool FoundCommon = false;
1075 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1076 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1077 if (!Common)
1078 continue;
1080 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1081 : RealMuls[i].Multiplicand;
1082 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1083 : ImagMuls[j].Multiplicand;
1085 auto Node = identifyNode(A, B);
1086 if (Node) {
1087 FoundCommon = true;
1088 PartialMulCandidates.push_back({Common, Node, i, j, false});
1091 Node = identifyNode(B, A);
1092 if (Node) {
1093 FoundCommon = true;
1094 PartialMulCandidates.push_back({Common, Node, i, j, true});
1097 if (!FoundCommon)
1098 return false;
1100 return true;
1103 ComplexDeinterleavingGraph::NodePtr
1104 ComplexDeinterleavingGraph::identifyMultiplications(
1105 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1106 NodePtr Accumulator = nullptr) {
1107 if (RealMuls.size() != ImagMuls.size())
1108 return nullptr;
1110 std::vector<PartialMulCandidate> Info;
1111 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1112 return nullptr;
1114 // Map to store common instruction to node pointers
1115 std::map<Value *, NodePtr> CommonToNode;
1116 std::vector<bool> Processed(Info.size(), false);
1117 for (unsigned I = 0; I < Info.size(); ++I) {
1118 if (Processed[I])
1119 continue;
1121 PartialMulCandidate &InfoA = Info[I];
1122 for (unsigned J = I + 1; J < Info.size(); ++J) {
1123 if (Processed[J])
1124 continue;
1126 PartialMulCandidate &InfoB = Info[J];
1127 auto *InfoReal = &InfoA;
1128 auto *InfoImag = &InfoB;
1130 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1131 if (!NodeFromCommon) {
1132 std::swap(InfoReal, InfoImag);
1133 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1135 if (!NodeFromCommon)
1136 continue;
1138 CommonToNode[InfoReal->Common] = NodeFromCommon;
1139 CommonToNode[InfoImag->Common] = NodeFromCommon;
1140 Processed[I] = true;
1141 Processed[J] = true;
1145 std::vector<bool> ProcessedReal(RealMuls.size(), false);
1146 std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1147 NodePtr Result = Accumulator;
1148 for (auto &PMI : Info) {
1149 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1150 continue;
1152 auto It = CommonToNode.find(PMI.Common);
1153 // TODO: Process independent complex multiplications. Cases like this:
1154 // A.real() * B where both A and B are complex numbers.
1155 if (It == CommonToNode.end()) {
1156 LLVM_DEBUG({
1157 dbgs() << "Unprocessed independent partial multiplication:\n";
1158 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1159 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1160 << " multiplied by " << *Mul->Multiplicand << "\n";
1162 return nullptr;
1165 auto &RealMul = RealMuls[PMI.RealIdx];
1166 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1168 auto NodeA = It->second;
1169 auto NodeB = PMI.Node;
1170 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1171 // The following table illustrates the relationship between multiplications
1172 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1173 // can see:
1175 // Rotation | Real | Imag |
1176 // ---------+--------+--------+
1177 // 0 | x * u | x * v |
1178 // 90 | -y * v | y * u |
1179 // 180 | -x * u | -x * v |
1180 // 270 | y * v | -y * u |
1182 // Check if the candidate can indeed be represented by partial
1183 // multiplication
1184 // TODO: Add support for multiplication by complex one
1185 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1186 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1187 continue;
1189 // Determine the rotation based on the multiplications
1190 ComplexDeinterleavingRotation Rotation;
1191 if (IsMultiplicandReal) {
1192 // Detect 0 and 180 degrees rotation
1193 if (RealMul.IsPositive && ImagMul.IsPositive)
1194 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1195 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1196 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1197 else
1198 continue;
1200 } else {
1201 // Detect 90 and 270 degrees rotation
1202 if (!RealMul.IsPositive && ImagMul.IsPositive)
1203 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1204 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1205 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1206 else
1207 continue;
1210 LLVM_DEBUG({
1211 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1212 dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1213 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1214 dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1215 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1216 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1219 NodePtr NodeMul = prepareCompositeNode(
1220 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1221 NodeMul->Rotation = Rotation;
1222 NodeMul->addOperand(NodeA);
1223 NodeMul->addOperand(NodeB);
1224 if (Result)
1225 NodeMul->addOperand(Result);
1226 submitCompositeNode(NodeMul);
1227 Result = NodeMul;
1228 ProcessedReal[PMI.RealIdx] = true;
1229 ProcessedImag[PMI.ImagIdx] = true;
1232 // Ensure all products have been processed, if not return nullptr.
1233 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1234 !all_of(ProcessedImag, [](bool V) { return V; })) {
1236 // Dump debug information about which partial multiplications are not
1237 // processed.
1238 LLVM_DEBUG({
1239 dbgs() << "Unprocessed products (Real):\n";
1240 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1241 if (!ProcessedReal[i])
1242 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1243 << *RealMuls[i].Multiplier << " multiplied by "
1244 << *RealMuls[i].Multiplicand << "\n";
1246 dbgs() << "Unprocessed products (Imag):\n";
1247 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1248 if (!ProcessedImag[i])
1249 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1250 << *ImagMuls[i].Multiplier << " multiplied by "
1251 << *ImagMuls[i].Multiplicand << "\n";
1254 return nullptr;
1257 return Result;
1260 ComplexDeinterleavingGraph::NodePtr
1261 ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
1262 std::list<Addend> &ImagAddends,
1263 FastMathFlags Flags,
1264 NodePtr Accumulator = nullptr) {
1265 if (RealAddends.size() != ImagAddends.size())
1266 return nullptr;
1268 NodePtr Result;
1269 // If we have accumulator use it as first addend
1270 if (Accumulator)
1271 Result = Accumulator;
1272 // Otherwise find an element with both positive real and imaginary parts.
1273 else
1274 Result = extractPositiveAddend(RealAddends, ImagAddends);
1276 if (!Result)
1277 return nullptr;
1279 while (!RealAddends.empty()) {
1280 auto ItR = RealAddends.begin();
1281 auto [R, IsPositiveR] = *ItR;
1283 bool FoundImag = false;
1284 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1285 auto [I, IsPositiveI] = *ItI;
1286 ComplexDeinterleavingRotation Rotation;
1287 if (IsPositiveR && IsPositiveI)
1288 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1289 else if (!IsPositiveR && IsPositiveI)
1290 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1291 else if (!IsPositiveR && !IsPositiveI)
1292 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1293 else
1294 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1296 NodePtr AddNode;
1297 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1298 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1299 AddNode = identifyNode(R, I);
1300 } else {
1301 AddNode = identifyNode(I, R);
1303 if (AddNode) {
1304 LLVM_DEBUG({
1305 dbgs() << "Identified addition:\n";
1306 dbgs().indent(4) << "X: " << *R << "\n";
1307 dbgs().indent(4) << "Y: " << *I << "\n";
1308 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1311 NodePtr TmpNode;
1312 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1313 TmpNode = prepareCompositeNode(
1314 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1315 TmpNode->Opcode = Instruction::FAdd;
1316 TmpNode->Flags = Flags;
1317 } else if (Rotation ==
1318 llvm::ComplexDeinterleavingRotation::Rotation_180) {
1319 TmpNode = prepareCompositeNode(
1320 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1321 TmpNode->Opcode = Instruction::FSub;
1322 TmpNode->Flags = Flags;
1323 } else {
1324 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1325 nullptr, nullptr);
1326 TmpNode->Rotation = Rotation;
1329 TmpNode->addOperand(Result);
1330 TmpNode->addOperand(AddNode);
1331 submitCompositeNode(TmpNode);
1332 Result = TmpNode;
1333 RealAddends.erase(ItR);
1334 ImagAddends.erase(ItI);
1335 FoundImag = true;
1336 break;
1339 if (!FoundImag)
1340 return nullptr;
1342 return Result;
1345 ComplexDeinterleavingGraph::NodePtr
1346 ComplexDeinterleavingGraph::extractPositiveAddend(
1347 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1348 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1349 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1350 auto [R, IsPositiveR] = *ItR;
1351 auto [I, IsPositiveI] = *ItI;
1352 if (IsPositiveR && IsPositiveI) {
1353 auto Result = identifyNode(R, I);
1354 if (Result) {
1355 RealAddends.erase(ItR);
1356 ImagAddends.erase(ItI);
1357 return Result;
1362 return nullptr;
1365 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1366 // This potential root instruction might already have been recognized as
1367 // reduction. Because RootToNode maps both Real and Imaginary parts to
1368 // CompositeNode we should choose only one either Real or Imag instruction to
1369 // use as an anchor for generating complex instruction.
1370 auto It = RootToNode.find(RootI);
1371 if (It != RootToNode.end() && It->second->Real == RootI) {
1372 OrderedRoots.push_back(RootI);
1373 return true;
1376 auto RootNode = identifyRoot(RootI);
1377 if (!RootNode)
1378 return false;
1380 LLVM_DEBUG({
1381 Function *F = RootI->getFunction();
1382 BasicBlock *B = RootI->getParent();
1383 dbgs() << "Complex deinterleaving graph for " << F->getName()
1384 << "::" << B->getName() << ".\n";
1385 dump(dbgs());
1386 dbgs() << "\n";
1388 RootToNode[RootI] = RootNode;
1389 OrderedRoots.push_back(RootI);
1390 return true;
1393 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1394 bool FoundPotentialReduction = false;
1396 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1397 if (!Br || Br->getNumSuccessors() != 2)
1398 return false;
1400 // Identify simple one-block loop
1401 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1402 return false;
1404 SmallVector<PHINode *> PHIs;
1405 for (auto &PHI : B->phis()) {
1406 if (PHI.getNumIncomingValues() != 2)
1407 continue;
1409 if (!PHI.getType()->isVectorTy())
1410 continue;
1412 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1413 if (!ReductionOp)
1414 continue;
1416 // Check if final instruction is reduced outside of current block
1417 Instruction *FinalReduction = nullptr;
1418 auto NumUsers = 0u;
1419 for (auto *U : ReductionOp->users()) {
1420 ++NumUsers;
1421 if (U == &PHI)
1422 continue;
1423 FinalReduction = dyn_cast<Instruction>(U);
1426 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1427 isa<PHINode>(FinalReduction))
1428 continue;
1430 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1431 BackEdge = B;
1432 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1433 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1434 Incoming = PHI.getIncomingBlock(IncomingIdx);
1435 FoundPotentialReduction = true;
1437 // If the initial value of PHINode is an Instruction, consider it a leaf
1438 // value of a complex deinterleaving graph.
1439 if (auto *InitPHI =
1440 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1441 FinalInstructions.insert(InitPHI);
1443 return FoundPotentialReduction;
1446 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1447 SmallVector<bool> Processed(ReductionInfo.size(), false);
1448 SmallVector<Instruction *> OperationInstruction;
1449 for (auto &P : ReductionInfo)
1450 OperationInstruction.push_back(P.first);
1452 // Identify a complex computation by evaluating two reduction operations that
1453 // potentially could be involved
1454 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1455 if (Processed[i])
1456 continue;
1457 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1458 if (Processed[j])
1459 continue;
1461 auto *Real = OperationInstruction[i];
1462 auto *Imag = OperationInstruction[j];
1463 if (Real->getType() != Imag->getType())
1464 continue;
1466 RealPHI = ReductionInfo[Real].first;
1467 ImagPHI = ReductionInfo[Imag].first;
1468 PHIsFound = false;
1469 auto Node = identifyNode(Real, Imag);
1470 if (!Node) {
1471 std::swap(Real, Imag);
1472 std::swap(RealPHI, ImagPHI);
1473 Node = identifyNode(Real, Imag);
1476 // If a node is identified and reduction PHINode is used in the chain of
1477 // operations, mark its operation instructions as used to prevent
1478 // re-identification and attach the node to the real part
1479 if (Node && PHIsFound) {
1480 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1481 << *Real << " / " << *Imag << "\n");
1482 Processed[i] = true;
1483 Processed[j] = true;
1484 auto RootNode = prepareCompositeNode(
1485 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1486 RootNode->addOperand(Node);
1487 RootToNode[Real] = RootNode;
1488 RootToNode[Imag] = RootNode;
1489 submitCompositeNode(RootNode);
1490 break;
1495 RealPHI = nullptr;
1496 ImagPHI = nullptr;
1499 bool ComplexDeinterleavingGraph::checkNodes() {
1500 // Collect all instructions from roots to leaves
1501 SmallPtrSet<Instruction *, 16> AllInstructions;
1502 SmallVector<Instruction *, 8> Worklist;
1503 for (auto &Pair : RootToNode)
1504 Worklist.push_back(Pair.first);
1506 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1507 // chains
1508 while (!Worklist.empty()) {
1509 auto *I = Worklist.back();
1510 Worklist.pop_back();
1512 if (!AllInstructions.insert(I).second)
1513 continue;
1515 for (Value *Op : I->operands()) {
1516 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1517 if (!FinalInstructions.count(I))
1518 Worklist.emplace_back(OpI);
1523 // Find instructions that have users outside of chain
1524 SmallVector<Instruction *, 2> OuterInstructions;
1525 for (auto *I : AllInstructions) {
1526 // Skip root nodes
1527 if (RootToNode.count(I))
1528 continue;
1530 for (User *U : I->users()) {
1531 if (AllInstructions.count(cast<Instruction>(U)))
1532 continue;
1534 // Found an instruction that is not used by XCMLA/XCADD chain
1535 Worklist.emplace_back(I);
1536 break;
1540 // If any instructions are found to be used outside, find and remove roots
1541 // that somehow connect to those instructions.
1542 SmallPtrSet<Instruction *, 16> Visited;
1543 while (!Worklist.empty()) {
1544 auto *I = Worklist.back();
1545 Worklist.pop_back();
1546 if (!Visited.insert(I).second)
1547 continue;
1549 // Found an impacted root node. Removing it from the nodes to be
1550 // deinterleaved
1551 if (RootToNode.count(I)) {
1552 LLVM_DEBUG(dbgs() << "Instruction " << *I
1553 << " could be deinterleaved but its chain of complex "
1554 "operations have an outside user\n");
1555 RootToNode.erase(I);
1558 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1559 continue;
1561 for (User *U : I->users())
1562 Worklist.emplace_back(cast<Instruction>(U));
1564 for (Value *Op : I->operands()) {
1565 if (auto *OpI = dyn_cast<Instruction>(Op))
1566 Worklist.emplace_back(OpI);
1569 return !RootToNode.empty();
1572 ComplexDeinterleavingGraph::NodePtr
1573 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1574 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1575 if (Intrinsic->getIntrinsicID() !=
1576 Intrinsic::experimental_vector_interleave2)
1577 return nullptr;
1579 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1580 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1581 if (!Real || !Imag)
1582 return nullptr;
1584 return identifyNode(Real, Imag);
1587 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1588 if (!SVI)
1589 return nullptr;
1591 // Look for a shufflevector that takes separate vectors of the real and
1592 // imaginary components and recombines them into a single vector.
1593 if (!isInterleavingMask(SVI->getShuffleMask()))
1594 return nullptr;
1596 Instruction *Real;
1597 Instruction *Imag;
1598 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1599 return nullptr;
1601 return identifyNode(Real, Imag);
1604 ComplexDeinterleavingGraph::NodePtr
1605 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1606 Instruction *Imag) {
1607 Instruction *I = nullptr;
1608 Value *FinalValue = nullptr;
1609 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1610 match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1611 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1612 m_Value(FinalValue)))) {
1613 NodePtr PlaceholderNode = prepareCompositeNode(
1614 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1615 PlaceholderNode->ReplacementNode = FinalValue;
1616 FinalInstructions.insert(Real);
1617 FinalInstructions.insert(Imag);
1618 return submitCompositeNode(PlaceholderNode);
1621 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1622 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1623 if (!RealShuffle || !ImagShuffle) {
1624 if (RealShuffle || ImagShuffle)
1625 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1626 return nullptr;
1629 Value *RealOp1 = RealShuffle->getOperand(1);
1630 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1631 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1632 return nullptr;
1634 Value *ImagOp1 = ImagShuffle->getOperand(1);
1635 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1636 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1637 return nullptr;
1640 Value *RealOp0 = RealShuffle->getOperand(0);
1641 Value *ImagOp0 = ImagShuffle->getOperand(0);
1643 if (RealOp0 != ImagOp0) {
1644 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1645 return nullptr;
1648 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1649 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1650 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1651 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1652 return nullptr;
1655 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1656 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1657 return nullptr;
1660 // Type checking, the shuffle type should be a vector type of the same
1661 // scalar type, but half the size
1662 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1663 Value *Op = Shuffle->getOperand(0);
1664 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1665 auto *OpTy = cast<FixedVectorType>(Op->getType());
1667 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1668 return false;
1669 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1670 return false;
1672 return true;
1675 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1676 if (!CheckType(Shuffle))
1677 return false;
1679 ArrayRef<int> Mask = Shuffle->getShuffleMask();
1680 int Last = *Mask.rbegin();
1682 Value *Op = Shuffle->getOperand(0);
1683 auto *OpTy = cast<FixedVectorType>(Op->getType());
1684 int NumElements = OpTy->getNumElements();
1686 // Ensure that the deinterleaving shuffle only pulls from the first
1687 // shuffle operand.
1688 return Last < NumElements;
1691 if (RealShuffle->getType() != ImagShuffle->getType()) {
1692 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1693 return nullptr;
1695 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1696 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1697 return nullptr;
1699 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1700 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1701 return nullptr;
1704 NodePtr PlaceholderNode =
1705 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1706 RealShuffle, ImagShuffle);
1707 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1708 FinalInstructions.insert(RealShuffle);
1709 FinalInstructions.insert(ImagShuffle);
1710 return submitCompositeNode(PlaceholderNode);
1713 ComplexDeinterleavingGraph::NodePtr
1714 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1715 auto IsSplat = [](Value *V) -> bool {
1716 // Fixed-width vector with constants
1717 if (isa<ConstantDataVector>(V))
1718 return true;
1720 VectorType *VTy;
1721 ArrayRef<int> Mask;
1722 // Splats are represented differently depending on whether the repeated
1723 // value is a constant or an Instruction
1724 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1725 if (Const->getOpcode() != Instruction::ShuffleVector)
1726 return false;
1727 VTy = cast<VectorType>(Const->getType());
1728 Mask = Const->getShuffleMask();
1729 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1730 VTy = Shuf->getType();
1731 Mask = Shuf->getShuffleMask();
1732 } else {
1733 return false;
1736 // When the data type is <1 x Type>, it's not possible to differentiate
1737 // between the ComplexDeinterleaving::Deinterleave and
1738 // ComplexDeinterleaving::Splat operations.
1739 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1740 return false;
1742 return all_equal(Mask) && Mask[0] == 0;
1745 if (!IsSplat(R) || !IsSplat(I))
1746 return nullptr;
1748 auto *Real = dyn_cast<Instruction>(R);
1749 auto *Imag = dyn_cast<Instruction>(I);
1750 if ((!Real && Imag) || (Real && !Imag))
1751 return nullptr;
1753 if (Real && Imag) {
1754 // Non-constant splats should be in the same basic block
1755 if (Real->getParent() != Imag->getParent())
1756 return nullptr;
1758 FinalInstructions.insert(Real);
1759 FinalInstructions.insert(Imag);
1761 NodePtr PlaceholderNode =
1762 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1763 return submitCompositeNode(PlaceholderNode);
1766 ComplexDeinterleavingGraph::NodePtr
1767 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1768 Instruction *Imag) {
1769 if (Real != RealPHI || Imag != ImagPHI)
1770 return nullptr;
1772 PHIsFound = true;
1773 NodePtr PlaceholderNode = prepareCompositeNode(
1774 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1775 return submitCompositeNode(PlaceholderNode);
1778 ComplexDeinterleavingGraph::NodePtr
1779 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1780 Instruction *Imag) {
1781 auto *SelectReal = dyn_cast<SelectInst>(Real);
1782 auto *SelectImag = dyn_cast<SelectInst>(Imag);
1783 if (!SelectReal || !SelectImag)
1784 return nullptr;
1786 Instruction *MaskA, *MaskB;
1787 Instruction *AR, *AI, *RA, *BI;
1788 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1789 m_Instruction(RA))) ||
1790 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1791 m_Instruction(BI))))
1792 return nullptr;
1794 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1795 return nullptr;
1797 if (!MaskA->getType()->isVectorTy())
1798 return nullptr;
1800 auto NodeA = identifyNode(AR, AI);
1801 if (!NodeA)
1802 return nullptr;
1804 auto NodeB = identifyNode(RA, BI);
1805 if (!NodeB)
1806 return nullptr;
1808 NodePtr PlaceholderNode = prepareCompositeNode(
1809 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1810 PlaceholderNode->addOperand(NodeA);
1811 PlaceholderNode->addOperand(NodeB);
1812 FinalInstructions.insert(MaskA);
1813 FinalInstructions.insert(MaskB);
1814 return submitCompositeNode(PlaceholderNode);
1817 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1818 FastMathFlags Flags, Value *InputA,
1819 Value *InputB) {
1820 Value *I;
1821 switch (Opcode) {
1822 case Instruction::FNeg:
1823 I = B.CreateFNeg(InputA);
1824 break;
1825 case Instruction::FAdd:
1826 I = B.CreateFAdd(InputA, InputB);
1827 break;
1828 case Instruction::FSub:
1829 I = B.CreateFSub(InputA, InputB);
1830 break;
1831 case Instruction::FMul:
1832 I = B.CreateFMul(InputA, InputB);
1833 break;
1834 default:
1835 llvm_unreachable("Incorrect symmetric opcode");
1837 cast<Instruction>(I)->setFastMathFlags(Flags);
1838 return I;
1841 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1842 RawNodePtr Node) {
1843 if (Node->ReplacementNode)
1844 return Node->ReplacementNode;
1846 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1847 return Node->Operands.size() > Idx
1848 ? replaceNode(Builder, Node->Operands[Idx])
1849 : nullptr;
1852 Value *ReplacementNode;
1853 switch (Node->Operation) {
1854 case ComplexDeinterleavingOperation::CAdd:
1855 case ComplexDeinterleavingOperation::CMulPartial:
1856 case ComplexDeinterleavingOperation::Symmetric: {
1857 Value *Input0 = ReplaceOperandIfExist(Node, 0);
1858 Value *Input1 = ReplaceOperandIfExist(Node, 1);
1859 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1860 assert(!Input1 || (Input0->getType() == Input1->getType() &&
1861 "Node inputs need to be of the same type"));
1862 assert(!Accumulator ||
1863 (Input0->getType() == Accumulator->getType() &&
1864 "Accumulator and input need to be of the same type"));
1865 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1866 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1867 Input0, Input1);
1868 else
1869 ReplacementNode = TL->createComplexDeinterleavingIR(
1870 Builder, Node->Operation, Node->Rotation, Input0, Input1,
1871 Accumulator);
1872 break;
1874 case ComplexDeinterleavingOperation::Deinterleave:
1875 llvm_unreachable("Deinterleave node should already have ReplacementNode");
1876 break;
1877 case ComplexDeinterleavingOperation::Splat: {
1878 auto *NewTy = VectorType::getDoubleElementsVectorType(
1879 cast<VectorType>(Node->Real->getType()));
1880 auto *R = dyn_cast<Instruction>(Node->Real);
1881 auto *I = dyn_cast<Instruction>(Node->Imag);
1882 if (R && I) {
1883 // Splats that are not constant are interleaved where they are located
1884 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1885 IRBuilder<> IRB(InsertPoint);
1886 ReplacementNode =
1887 IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1888 {Node->Real, Node->Imag});
1889 } else {
1890 ReplacementNode =
1891 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1892 NewTy, {Node->Real, Node->Imag});
1894 break;
1896 case ComplexDeinterleavingOperation::ReductionPHI: {
1897 // If Operation is ReductionPHI, a new empty PHINode is created.
1898 // It is filled later when the ReductionOperation is processed.
1899 auto *VTy = cast<VectorType>(Node->Real->getType());
1900 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1901 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1902 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1903 ReplacementNode = NewPHI;
1904 break;
1906 case ComplexDeinterleavingOperation::ReductionOperation:
1907 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1908 processReductionOperation(ReplacementNode, Node);
1909 break;
1910 case ComplexDeinterleavingOperation::ReductionSelect: {
1911 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1912 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1913 auto *A = replaceNode(Builder, Node->Operands[0]);
1914 auto *B = replaceNode(Builder, Node->Operands[1]);
1915 auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1916 cast<VectorType>(MaskReal->getType()));
1917 auto *NewMask =
1918 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1919 NewMaskTy, {MaskReal, MaskImag});
1920 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1921 break;
1925 assert(ReplacementNode && "Target failed to create Intrinsic call.");
1926 NumComplexTransformations += 1;
1927 Node->ReplacementNode = ReplacementNode;
1928 return ReplacementNode;
1931 void ComplexDeinterleavingGraph::processReductionOperation(
1932 Value *OperationReplacement, RawNodePtr Node) {
1933 auto *Real = cast<Instruction>(Node->Real);
1934 auto *Imag = cast<Instruction>(Node->Imag);
1935 auto *OldPHIReal = ReductionInfo[Real].first;
1936 auto *OldPHIImag = ReductionInfo[Imag].first;
1937 auto *NewPHI = OldToNewPHI[OldPHIReal];
1939 auto *VTy = cast<VectorType>(Real->getType());
1940 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1942 // We have to interleave initial origin values coming from IncomingBlock
1943 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
1944 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
1946 IRBuilder<> Builder(Incoming->getTerminator());
1947 auto *NewInit = Builder.CreateIntrinsic(
1948 Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
1950 NewPHI->addIncoming(NewInit, Incoming);
1951 NewPHI->addIncoming(OperationReplacement, BackEdge);
1953 // Deinterleave complex vector outside of loop so that it can be finally
1954 // reduced
1955 auto *FinalReductionReal = ReductionInfo[Real].second;
1956 auto *FinalReductionImag = ReductionInfo[Imag].second;
1958 Builder.SetInsertPoint(
1959 &*FinalReductionReal->getParent()->getFirstInsertionPt());
1960 auto *Deinterleave = Builder.CreateIntrinsic(
1961 Intrinsic::experimental_vector_deinterleave2,
1962 OperationReplacement->getType(), OperationReplacement);
1964 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
1965 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
1967 Builder.SetInsertPoint(FinalReductionImag);
1968 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
1969 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
1972 void ComplexDeinterleavingGraph::replaceNodes() {
1973 SmallVector<Instruction *, 16> DeadInstrRoots;
1974 for (auto *RootInstruction : OrderedRoots) {
1975 // Check if this potential root went through check process and we can
1976 // deinterleave it
1977 if (!RootToNode.count(RootInstruction))
1978 continue;
1980 IRBuilder<> Builder(RootInstruction);
1981 auto RootNode = RootToNode[RootInstruction];
1982 Value *R = replaceNode(Builder, RootNode.get());
1984 if (RootNode->Operation ==
1985 ComplexDeinterleavingOperation::ReductionOperation) {
1986 auto *RootReal = cast<Instruction>(RootNode->Real);
1987 auto *RootImag = cast<Instruction>(RootNode->Imag);
1988 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
1989 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
1990 DeadInstrRoots.push_back(cast<Instruction>(RootReal));
1991 DeadInstrRoots.push_back(cast<Instruction>(RootImag));
1992 } else {
1993 assert(R && "Unable to find replacement for RootInstruction");
1994 DeadInstrRoots.push_back(RootInstruction);
1995 RootInstruction->replaceAllUsesWith(R);
1999 for (auto *I : DeadInstrRoots)
2000 RecursivelyDeleteTriviallyDeadInstructions(I, TLI);