[MachineScheduler] Fix physreg dependencies of ExitSU (#123541)
[llvm-project.git] / llvm / lib / CodeGen / ComplexDeinterleavingPass.cpp
blob92053ed5619010077efececed0f36080ffea9b0b
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
60 //===----------------------------------------------------------------------===//
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/MapVector.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/TargetLibraryInfo.h"
66 #include "llvm/Analysis/TargetTransformInfo.h"
67 #include "llvm/CodeGen/TargetLowering.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
76 using namespace llvm;
77 using namespace PatternMatch;
79 #define DEBUG_TYPE "complex-deinterleaving"
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84 "enable-complex-deinterleaving",
85 cl::desc("Enable generation of complex instructions"), cl::init(true),
86 cl::Hidden);
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
103 /// Returns true if the operation is a negation of V, and it works for both
104 /// integers and floats.
105 static bool isNeg(Value *V);
107 /// Returns the operand for negation operation.
108 static Value *getNegOperand(Value *V);
110 namespace {
111 template <typename T, typename IterT>
112 std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
113 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
114 if (Common != A.end())
115 return std::make_optional(*Common);
116 return std::nullopt;
119 class ComplexDeinterleavingLegacyPass : public FunctionPass {
120 public:
121 static char ID;
123 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
124 : FunctionPass(ID), TM(TM) {
125 initializeComplexDeinterleavingLegacyPassPass(
126 *PassRegistry::getPassRegistry());
129 StringRef getPassName() const override {
130 return "Complex Deinterleaving Pass";
133 bool runOnFunction(Function &F) override;
134 void getAnalysisUsage(AnalysisUsage &AU) const override {
135 AU.addRequired<TargetLibraryInfoWrapperPass>();
136 AU.setPreservesCFG();
139 private:
140 const TargetMachine *TM;
143 class ComplexDeinterleavingGraph;
144 struct ComplexDeinterleavingCompositeNode {
146 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
147 Value *R, Value *I)
148 : Operation(Op), Real(R), Imag(I) {}
150 private:
151 friend class ComplexDeinterleavingGraph;
152 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
153 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
154 bool OperandsValid = true;
156 public:
157 ComplexDeinterleavingOperation Operation;
158 Value *Real;
159 Value *Imag;
161 // This two members are required exclusively for generating
162 // ComplexDeinterleavingOperation::Symmetric operations.
163 unsigned Opcode;
164 std::optional<FastMathFlags> Flags;
166 ComplexDeinterleavingRotation Rotation =
167 ComplexDeinterleavingRotation::Rotation_0;
168 SmallVector<RawNodePtr> Operands;
169 Value *ReplacementNode = nullptr;
171 void addOperand(NodePtr Node) {
172 if (!Node || !Node.get())
173 OperandsValid = false;
174 Operands.push_back(Node.get());
177 void dump() { dump(dbgs()); }
178 void dump(raw_ostream &OS) {
179 auto PrintValue = [&](Value *V) {
180 if (V) {
181 OS << "\"";
182 V->print(OS, true);
183 OS << "\"\n";
184 } else
185 OS << "nullptr\n";
187 auto PrintNodeRef = [&](RawNodePtr Ptr) {
188 if (Ptr)
189 OS << Ptr << "\n";
190 else
191 OS << "nullptr\n";
194 OS << "- CompositeNode: " << this << "\n";
195 OS << " Real: ";
196 PrintValue(Real);
197 OS << " Imag: ";
198 PrintValue(Imag);
199 OS << " ReplacementNode: ";
200 PrintValue(ReplacementNode);
201 OS << " Operation: " << (int)Operation << "\n";
202 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
203 OS << " Operands: \n";
204 for (const auto &Op : Operands) {
205 OS << " - ";
206 PrintNodeRef(Op);
210 bool areOperandsValid() { return OperandsValid; }
213 class ComplexDeinterleavingGraph {
214 public:
215 struct Product {
216 Value *Multiplier;
217 Value *Multiplicand;
218 bool IsPositive;
221 using Addend = std::pair<Value *, bool>;
222 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
223 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
225 // Helper struct for holding info about potential partial multiplication
226 // candidates
227 struct PartialMulCandidate {
228 Value *Common;
229 NodePtr Node;
230 unsigned RealIdx;
231 unsigned ImagIdx;
232 bool IsNodeInverted;
235 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
236 const TargetLibraryInfo *TLI)
237 : TL(TL), TLI(TLI) {}
239 private:
240 const TargetLowering *TL = nullptr;
241 const TargetLibraryInfo *TLI = nullptr;
242 SmallVector<NodePtr> CompositeNodes;
243 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
245 SmallPtrSet<Instruction *, 16> FinalInstructions;
247 /// Root instructions are instructions from which complex computation starts
248 std::map<Instruction *, NodePtr> RootToNode;
250 /// Topologically sorted root instructions
251 SmallVector<Instruction *, 1> OrderedRoots;
253 /// When examining a basic block for complex deinterleaving, if it is a simple
254 /// one-block loop, then the only incoming block is 'Incoming' and the
255 /// 'BackEdge' block is the block itself."
256 BasicBlock *BackEdge = nullptr;
257 BasicBlock *Incoming = nullptr;
259 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
260 /// %OutsideUser as it is shown in the IR:
262 /// vector.body:
263 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
264 /// [ %ReductionOp, %vector.body ]
265 /// ...
266 /// %ReductionOp = fadd i64 ...
267 /// ...
268 /// br i1 %condition, label %vector.body, %middle.block
270 /// middle.block:
271 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
273 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
274 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
275 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
277 /// In the process of detecting a reduction, we consider a pair of
278 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
279 /// traverse the use-tree to detect complex operations. As this is a reduction
280 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
281 /// to the %ReductionOPs that we suspect to be complex.
282 /// RealPHI and ImagPHI are used by the identifyPHINode method.
283 PHINode *RealPHI = nullptr;
284 PHINode *ImagPHI = nullptr;
286 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
287 /// detection.
288 bool PHIsFound = false;
290 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
291 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
292 /// This mapping is populated during
293 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
294 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
295 /// replacement process.
296 std::map<PHINode *, PHINode *> OldToNewPHI;
298 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
299 Value *R, Value *I) {
300 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
301 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
302 (R && I)) &&
303 "Reduction related nodes must have Real and Imaginary parts");
304 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
308 NodePtr submitCompositeNode(NodePtr Node) {
309 CompositeNodes.push_back(Node);
310 if (Node->Real)
311 CachedResult[{Node->Real, Node->Imag}] = Node;
312 return Node;
315 /// Identifies a complex partial multiply pattern and its rotation, based on
316 /// the following patterns
318 /// 0: r: cr + ar * br
319 /// i: ci + ar * bi
320 /// 90: r: cr - ai * bi
321 /// i: ci + ai * br
322 /// 180: r: cr - ar * br
323 /// i: ci - ar * bi
324 /// 270: r: cr + ai * bi
325 /// i: ci - ai * br
326 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
328 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
329 /// is partially known from identifyPartialMul, filling in the other half of
330 /// the complex pair.
331 NodePtr
332 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
333 std::pair<Value *, Value *> &CommonOperandI);
335 /// Identifies a complex add pattern and its rotation, based on the following
336 /// patterns.
338 /// 90: r: ar - bi
339 /// i: ai + br
340 /// 270: r: ar + bi
341 /// i: ai - br
342 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
343 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
344 NodePtr identifyPartialReduction(Value *R, Value *I);
345 NodePtr identifyDotProduct(Value *Inst);
347 NodePtr identifyNode(Value *R, Value *I);
349 /// Determine if a sum of complex numbers can be formed from \p RealAddends
350 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
351 /// Return nullptr if it is not possible to construct a complex number.
352 /// \p Flags are needed to generate symmetric Add and Sub operations.
353 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
354 std::list<Addend> &ImagAddends,
355 std::optional<FastMathFlags> Flags,
356 NodePtr Accumulator);
358 /// Extract one addend that have both real and imaginary parts positive.
359 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
360 std::list<Addend> &ImagAddends);
362 /// Determine if sum of multiplications of complex numbers can be formed from
363 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
364 /// to it. Return nullptr if it is not possible to construct a complex number.
365 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
366 std::vector<Product> &ImagMuls,
367 NodePtr Accumulator);
369 /// Go through pairs of multiplication (one Real and one Imag) and find all
370 /// possible candidates for partial multiplication and put them into \p
371 /// Candidates. Returns true if all Product has pair with common operand
372 bool collectPartialMuls(const std::vector<Product> &RealMuls,
373 const std::vector<Product> &ImagMuls,
374 std::vector<PartialMulCandidate> &Candidates);
376 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
377 /// the order of complex computation operations may be significantly altered,
378 /// and the real and imaginary parts may not be executed in parallel. This
379 /// function takes this into consideration and employs a more general approach
380 /// to identify complex computations. Initially, it gathers all the addends
381 /// and multiplicands and then constructs a complex expression from them.
382 NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
384 NodePtr identifyRoot(Instruction *I);
386 /// Identifies the Deinterleave operation applied to a vector containing
387 /// complex numbers. There are two ways to represent the Deinterleave
388 /// operation:
389 /// * Using two shufflevectors with even indices for /pReal instruction and
390 /// odd indices for /pImag instructions (only for fixed-width vectors)
391 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
392 /// intrinsic (for both fixed and scalable vectors)
393 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
395 /// identifying the operation that represents a complex number repeated in a
396 /// Splat vector. There are two possible types of splats: ConstantExpr with
397 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
398 /// initialization mask with all values set to zero.
399 NodePtr identifySplat(Value *Real, Value *Imag);
401 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
403 /// Identifies SelectInsts in a loop that has reduction with predication masks
404 /// and/or predicated tail folding
405 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
407 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
409 /// Complete IR modifications after producing new reduction operation:
410 /// * Populate the PHINode generated for
411 /// ComplexDeinterleavingOperation::ReductionPHI
412 /// * Deinterleave the final value outside of the loop and repurpose original
413 /// reduction users
414 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
415 void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
417 public:
418 void dump() { dump(dbgs()); }
419 void dump(raw_ostream &OS) {
420 for (const auto &Node : CompositeNodes)
421 Node->dump(OS);
424 /// Returns false if the deinterleaving operation should be cancelled for the
425 /// current graph.
426 bool identifyNodes(Instruction *RootI);
428 /// In case \pB is one-block loop, this function seeks potential reductions
429 /// and populates ReductionInfo. Returns true if any reductions were
430 /// identified.
431 bool collectPotentialReductions(BasicBlock *B);
433 void identifyReductionNodes();
435 /// Check that every instruction, from the roots to the leaves, has internal
436 /// uses.
437 bool checkNodes();
439 /// Perform the actual replacement of the underlying instruction graph.
440 void replaceNodes();
443 class ComplexDeinterleaving {
444 public:
445 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
446 : TL(tl), TLI(tli) {}
447 bool runOnFunction(Function &F);
449 private:
450 bool evaluateBasicBlock(BasicBlock *B);
452 const TargetLowering *TL = nullptr;
453 const TargetLibraryInfo *TLI = nullptr;
456 } // namespace
458 char ComplexDeinterleavingLegacyPass::ID = 0;
460 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
461 "Complex Deinterleaving", false, false)
462 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
463 "Complex Deinterleaving", false, false)
465 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
466 FunctionAnalysisManager &AM) {
467 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
468 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
469 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
470 return PreservedAnalyses::all();
472 PreservedAnalyses PA;
473 PA.preserve<FunctionAnalysisManagerModuleProxy>();
474 return PA;
477 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
478 return new ComplexDeinterleavingLegacyPass(TM);
481 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
482 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
483 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
484 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
487 bool ComplexDeinterleaving::runOnFunction(Function &F) {
488 if (!ComplexDeinterleavingEnabled) {
489 LLVM_DEBUG(
490 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
491 return false;
494 if (!TL->isComplexDeinterleavingSupported()) {
495 LLVM_DEBUG(
496 dbgs() << "Complex deinterleaving has been disabled, target does "
497 "not support lowering of complex number operations.\n");
498 return false;
501 bool Changed = false;
502 for (auto &B : F)
503 Changed |= evaluateBasicBlock(&B);
505 return Changed;
508 static bool isInterleavingMask(ArrayRef<int> Mask) {
509 // If the size is not even, it's not an interleaving mask
510 if ((Mask.size() & 1))
511 return false;
513 int HalfNumElements = Mask.size() / 2;
514 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
515 int MaskIdx = Idx * 2;
516 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
517 return false;
520 return true;
523 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
524 int Offset = Mask[0];
525 int HalfNumElements = Mask.size() / 2;
527 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
528 if (Mask[Idx] != (Idx * 2) + Offset)
529 return false;
532 return true;
535 bool isNeg(Value *V) {
536 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
539 Value *getNegOperand(Value *V) {
540 assert(isNeg(V));
541 auto *I = cast<Instruction>(V);
542 if (I->getOpcode() == Instruction::FNeg)
543 return I->getOperand(0);
545 return I->getOperand(1);
548 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
549 ComplexDeinterleavingGraph Graph(TL, TLI);
550 if (Graph.collectPotentialReductions(B))
551 Graph.identifyReductionNodes();
553 for (auto &I : *B)
554 Graph.identifyNodes(&I);
556 if (Graph.checkNodes()) {
557 Graph.replaceNodes();
558 return true;
561 return false;
564 ComplexDeinterleavingGraph::NodePtr
565 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
566 Instruction *Real, Instruction *Imag,
567 std::pair<Value *, Value *> &PartialMatch) {
568 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
569 << "\n");
571 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
572 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
573 return nullptr;
576 if ((Real->getOpcode() != Instruction::FMul &&
577 Real->getOpcode() != Instruction::Mul) ||
578 (Imag->getOpcode() != Instruction::FMul &&
579 Imag->getOpcode() != Instruction::Mul)) {
580 LLVM_DEBUG(
581 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
582 return nullptr;
585 Value *R0 = Real->getOperand(0);
586 Value *R1 = Real->getOperand(1);
587 Value *I0 = Imag->getOperand(0);
588 Value *I1 = Imag->getOperand(1);
590 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
591 // rotations and use the operand.
592 unsigned Negs = 0;
593 Value *Op;
594 if (match(R0, m_Neg(m_Value(Op)))) {
595 Negs |= 1;
596 R0 = Op;
597 } else if (match(R1, m_Neg(m_Value(Op)))) {
598 Negs |= 1;
599 R1 = Op;
602 if (isNeg(I0)) {
603 Negs |= 2;
604 Negs ^= 1;
605 I0 = Op;
606 } else if (match(I1, m_Neg(m_Value(Op)))) {
607 Negs |= 2;
608 Negs ^= 1;
609 I1 = Op;
612 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
614 Value *CommonOperand;
615 Value *UncommonRealOp;
616 Value *UncommonImagOp;
618 if (R0 == I0 || R0 == I1) {
619 CommonOperand = R0;
620 UncommonRealOp = R1;
621 } else if (R1 == I0 || R1 == I1) {
622 CommonOperand = R1;
623 UncommonRealOp = R0;
624 } else {
625 LLVM_DEBUG(dbgs() << " - No equal operand\n");
626 return nullptr;
629 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
630 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
631 Rotation == ComplexDeinterleavingRotation::Rotation_270)
632 std::swap(UncommonRealOp, UncommonImagOp);
634 // Between identifyPartialMul and here we need to have found a complete valid
635 // pair from the CommonOperand of each part.
636 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
637 Rotation == ComplexDeinterleavingRotation::Rotation_180)
638 PartialMatch.first = CommonOperand;
639 else
640 PartialMatch.second = CommonOperand;
642 if (!PartialMatch.first || !PartialMatch.second) {
643 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
644 return nullptr;
647 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
648 if (!CommonNode) {
649 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
650 return nullptr;
653 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
654 if (!UncommonNode) {
655 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
656 return nullptr;
659 NodePtr Node = prepareCompositeNode(
660 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
661 Node->Rotation = Rotation;
662 Node->addOperand(CommonNode);
663 Node->addOperand(UncommonNode);
664 return submitCompositeNode(Node);
667 ComplexDeinterleavingGraph::NodePtr
668 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
669 Instruction *Imag) {
670 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
671 << "\n");
672 // Determine rotation
673 auto IsAdd = [](unsigned Op) {
674 return Op == Instruction::FAdd || Op == Instruction::Add;
676 auto IsSub = [](unsigned Op) {
677 return Op == Instruction::FSub || Op == Instruction::Sub;
679 ComplexDeinterleavingRotation Rotation;
680 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
681 Rotation = ComplexDeinterleavingRotation::Rotation_0;
682 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
683 Rotation = ComplexDeinterleavingRotation::Rotation_90;
684 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
685 Rotation = ComplexDeinterleavingRotation::Rotation_180;
686 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
687 Rotation = ComplexDeinterleavingRotation::Rotation_270;
688 else {
689 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
690 return nullptr;
693 if (isa<FPMathOperator>(Real) &&
694 (!Real->getFastMathFlags().allowContract() ||
695 !Imag->getFastMathFlags().allowContract())) {
696 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
697 return nullptr;
700 Value *CR = Real->getOperand(0);
701 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
702 if (!RealMulI)
703 return nullptr;
704 Value *CI = Imag->getOperand(0);
705 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
706 if (!ImagMulI)
707 return nullptr;
709 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
710 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
711 return nullptr;
714 Value *R0 = RealMulI->getOperand(0);
715 Value *R1 = RealMulI->getOperand(1);
716 Value *I0 = ImagMulI->getOperand(0);
717 Value *I1 = ImagMulI->getOperand(1);
719 Value *CommonOperand;
720 Value *UncommonRealOp;
721 Value *UncommonImagOp;
723 if (R0 == I0 || R0 == I1) {
724 CommonOperand = R0;
725 UncommonRealOp = R1;
726 } else if (R1 == I0 || R1 == I1) {
727 CommonOperand = R1;
728 UncommonRealOp = R0;
729 } else {
730 LLVM_DEBUG(dbgs() << " - No equal operand\n");
731 return nullptr;
734 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
735 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
736 Rotation == ComplexDeinterleavingRotation::Rotation_270)
737 std::swap(UncommonRealOp, UncommonImagOp);
739 std::pair<Value *, Value *> PartialMatch(
740 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
741 Rotation == ComplexDeinterleavingRotation::Rotation_180)
742 ? CommonOperand
743 : nullptr,
744 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
745 Rotation == ComplexDeinterleavingRotation::Rotation_270)
746 ? CommonOperand
747 : nullptr);
749 auto *CRInst = dyn_cast<Instruction>(CR);
750 auto *CIInst = dyn_cast<Instruction>(CI);
752 if (!CRInst || !CIInst) {
753 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
754 return nullptr;
757 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
758 if (!CNode) {
759 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
760 return nullptr;
763 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
764 if (!UncommonRes) {
765 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
766 return nullptr;
769 assert(PartialMatch.first && PartialMatch.second);
770 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
771 if (!CommonRes) {
772 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
773 return nullptr;
776 NodePtr Node = prepareCompositeNode(
777 ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
778 Node->Rotation = Rotation;
779 Node->addOperand(CommonRes);
780 Node->addOperand(UncommonRes);
781 Node->addOperand(CNode);
782 return submitCompositeNode(Node);
785 ComplexDeinterleavingGraph::NodePtr
786 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
787 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
789 // Determine rotation
790 ComplexDeinterleavingRotation Rotation;
791 if ((Real->getOpcode() == Instruction::FSub &&
792 Imag->getOpcode() == Instruction::FAdd) ||
793 (Real->getOpcode() == Instruction::Sub &&
794 Imag->getOpcode() == Instruction::Add))
795 Rotation = ComplexDeinterleavingRotation::Rotation_90;
796 else if ((Real->getOpcode() == Instruction::FAdd &&
797 Imag->getOpcode() == Instruction::FSub) ||
798 (Real->getOpcode() == Instruction::Add &&
799 Imag->getOpcode() == Instruction::Sub))
800 Rotation = ComplexDeinterleavingRotation::Rotation_270;
801 else {
802 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
803 return nullptr;
806 auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
807 auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
808 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
809 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
811 if (!AR || !AI || !BR || !BI) {
812 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
813 return nullptr;
816 NodePtr ResA = identifyNode(AR, AI);
817 if (!ResA) {
818 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
819 return nullptr;
821 NodePtr ResB = identifyNode(BR, BI);
822 if (!ResB) {
823 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
824 return nullptr;
827 NodePtr Node =
828 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
829 Node->Rotation = Rotation;
830 Node->addOperand(ResA);
831 Node->addOperand(ResB);
832 return submitCompositeNode(Node);
835 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
836 unsigned OpcA = A->getOpcode();
837 unsigned OpcB = B->getOpcode();
839 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
840 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
841 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
842 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
845 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
846 auto Pattern =
847 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
849 return match(A, Pattern) && match(B, Pattern);
852 static bool isInstructionPotentiallySymmetric(Instruction *I) {
853 switch (I->getOpcode()) {
854 case Instruction::FAdd:
855 case Instruction::FSub:
856 case Instruction::FMul:
857 case Instruction::FNeg:
858 case Instruction::Add:
859 case Instruction::Sub:
860 case Instruction::Mul:
861 return true;
862 default:
863 return false;
867 ComplexDeinterleavingGraph::NodePtr
868 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
869 Instruction *Imag) {
870 if (Real->getOpcode() != Imag->getOpcode())
871 return nullptr;
873 if (!isInstructionPotentiallySymmetric(Real) ||
874 !isInstructionPotentiallySymmetric(Imag))
875 return nullptr;
877 auto *R0 = Real->getOperand(0);
878 auto *I0 = Imag->getOperand(0);
880 NodePtr Op0 = identifyNode(R0, I0);
881 NodePtr Op1 = nullptr;
882 if (Op0 == nullptr)
883 return nullptr;
885 if (Real->isBinaryOp()) {
886 auto *R1 = Real->getOperand(1);
887 auto *I1 = Imag->getOperand(1);
888 Op1 = identifyNode(R1, I1);
889 if (Op1 == nullptr)
890 return nullptr;
893 if (isa<FPMathOperator>(Real) &&
894 Real->getFastMathFlags() != Imag->getFastMathFlags())
895 return nullptr;
897 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
898 Real, Imag);
899 Node->Opcode = Real->getOpcode();
900 if (isa<FPMathOperator>(Real))
901 Node->Flags = Real->getFastMathFlags();
903 Node->addOperand(Op0);
904 if (Real->isBinaryOp())
905 Node->addOperand(Op1);
907 return submitCompositeNode(Node);
910 ComplexDeinterleavingGraph::NodePtr
911 ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
913 if (!TL->isComplexDeinterleavingOperationSupported(
914 ComplexDeinterleavingOperation::CDot, V->getType())) {
915 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
916 "operation CDot with the type "
917 << *V->getType() << "\n");
918 return nullptr;
921 auto *Inst = cast<Instruction>(V);
922 auto *RealUser = cast<Instruction>(*Inst->user_begin());
924 NodePtr CN =
925 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
927 NodePtr ANode;
929 const Intrinsic::ID PartialReduceInt =
930 Intrinsic::experimental_vector_partial_reduce_add;
932 Value *AReal = nullptr;
933 Value *AImag = nullptr;
934 Value *BReal = nullptr;
935 Value *BImag = nullptr;
936 Value *Phi = nullptr;
938 auto UnwrapCast = [](Value *V) -> Value * {
939 if (auto *CI = dyn_cast<CastInst>(V))
940 return CI->getOperand(0);
941 return V;
944 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
945 m_Intrinsic<PartialReduceInt>(m_Value(Phi),
946 m_Mul(m_Value(BReal), m_Value(AReal))),
947 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
949 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
950 m_Intrinsic<PartialReduceInt>(
951 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
952 m_Mul(m_Value(BImag), m_Value(AReal)));
954 if (match(Inst, PatternRot0)) {
955 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
956 } else if (match(Inst, PatternRot270)) {
957 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
958 } else {
959 Value *A0, *A1;
960 // The rotations 90 and 180 share the same operation pattern, so inspect the
961 // order of the operands, identifying where the real and imaginary
962 // components of A go, to discern between the aforementioned rotations.
963 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
964 m_Intrinsic<PartialReduceInt>(m_Value(Phi),
965 m_Mul(m_Value(BReal), m_Value(A0))),
966 m_Mul(m_Value(BImag), m_Value(A1)));
968 if (!match(Inst, PatternRot90Rot180))
969 return nullptr;
971 A0 = UnwrapCast(A0);
972 A1 = UnwrapCast(A1);
974 // Test if A0 is real/A1 is imag
975 ANode = identifyNode(A0, A1);
976 if (!ANode) {
977 // Test if A0 is imag/A1 is real
978 ANode = identifyNode(A1, A0);
979 // Unable to identify operand components, thus unable to identify rotation
980 if (!ANode)
981 return nullptr;
982 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
983 AReal = A1;
984 AImag = A0;
985 } else {
986 AReal = A0;
987 AImag = A1;
988 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
992 AReal = UnwrapCast(AReal);
993 AImag = UnwrapCast(AImag);
994 BReal = UnwrapCast(BReal);
995 BImag = UnwrapCast(BImag);
997 VectorType *VTy = cast<VectorType>(V->getType());
998 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
999 if (AReal->getType() != ExpectedOperandTy)
1000 return nullptr;
1001 if (AImag->getType() != ExpectedOperandTy)
1002 return nullptr;
1003 if (BReal->getType() != ExpectedOperandTy)
1004 return nullptr;
1005 if (BImag->getType() != ExpectedOperandTy)
1006 return nullptr;
1008 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1009 return nullptr;
1011 NodePtr Node = identifyNode(AReal, AImag);
1013 // In the case that a node was identified to figure out the rotation, ensure
1014 // that trying to identify a node with AReal and AImag post-unwrap results in
1015 // the same node
1016 if (ANode && Node != ANode) {
1017 LLVM_DEBUG(
1018 dbgs()
1019 << "Identified node is different from previously identified node. "
1020 "Unable to confidently generate a complex operation node\n");
1021 return nullptr;
1024 CN->addOperand(Node);
1025 CN->addOperand(identifyNode(BReal, BImag));
1026 CN->addOperand(identifyNode(Phi, RealUser));
1028 return submitCompositeNode(CN);
1031 ComplexDeinterleavingGraph::NodePtr
1032 ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1033 // Partial reductions don't support non-vector types, so check these first
1034 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1035 return nullptr;
1037 auto CommonUser =
1038 findCommonBetweenCollections<Value *>(R->users(), I->users());
1039 if (!CommonUser)
1040 return nullptr;
1042 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1043 if (!IInst || IInst->getIntrinsicID() !=
1044 Intrinsic::experimental_vector_partial_reduce_add)
1045 return nullptr;
1047 if (NodePtr CN = identifyDotProduct(IInst))
1048 return CN;
1050 return nullptr;
1053 ComplexDeinterleavingGraph::NodePtr
1054 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
1055 auto It = CachedResult.find({R, I});
1056 if (It != CachedResult.end()) {
1057 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1058 return It->second;
1061 if (NodePtr CN = identifyPartialReduction(R, I))
1062 return CN;
1064 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1065 if (!IsReduction && R->getType() != I->getType())
1066 return nullptr;
1068 if (NodePtr CN = identifySplat(R, I))
1069 return CN;
1071 auto *Real = dyn_cast<Instruction>(R);
1072 auto *Imag = dyn_cast<Instruction>(I);
1073 if (!Real || !Imag)
1074 return nullptr;
1076 if (NodePtr CN = identifyDeinterleave(Real, Imag))
1077 return CN;
1079 if (NodePtr CN = identifyPHINode(Real, Imag))
1080 return CN;
1082 if (NodePtr CN = identifySelectNode(Real, Imag))
1083 return CN;
1085 auto *VTy = cast<VectorType>(Real->getType());
1086 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1088 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1089 ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1090 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1091 ComplexDeinterleavingOperation::CAdd, NewVTy);
1093 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1094 if (NodePtr CN = identifyPartialMul(Real, Imag))
1095 return CN;
1098 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1099 if (NodePtr CN = identifyAdd(Real, Imag))
1100 return CN;
1103 if (HasCMulSupport && HasCAddSupport) {
1104 if (NodePtr CN = identifyReassocNodes(Real, Imag))
1105 return CN;
1108 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
1109 return CN;
1111 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1112 CachedResult[{R, I}] = nullptr;
1113 return nullptr;
1116 ComplexDeinterleavingGraph::NodePtr
1117 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1118 Instruction *Imag) {
1119 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1120 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1121 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1122 Opcode == Instruction::Sub;
1125 if (!IsOperationSupported(Real->getOpcode()) ||
1126 !IsOperationSupported(Imag->getOpcode()))
1127 return nullptr;
1129 std::optional<FastMathFlags> Flags;
1130 if (isa<FPMathOperator>(Real)) {
1131 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1132 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1133 "not identical\n");
1134 return nullptr;
1137 Flags = Real->getFastMathFlags();
1138 if (!Flags->allowReassoc()) {
1139 LLVM_DEBUG(
1140 dbgs()
1141 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1142 return nullptr;
1146 // Collect multiplications and addend instructions from the given instruction
1147 // while traversing it operands. Additionally, verify that all instructions
1148 // have the same fast math flags.
1149 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
1150 std::list<Addend> &Addends) -> bool {
1151 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1152 SmallPtrSet<Value *, 8> Visited;
1153 while (!Worklist.empty()) {
1154 auto [V, IsPositive] = Worklist.back();
1155 Worklist.pop_back();
1156 if (!Visited.insert(V).second)
1157 continue;
1159 Instruction *I = dyn_cast<Instruction>(V);
1160 if (!I) {
1161 Addends.emplace_back(V, IsPositive);
1162 continue;
1165 // If an instruction has more than one user, it indicates that it either
1166 // has an external user, which will be later checked by the checkNodes
1167 // function, or it is a subexpression utilized by multiple expressions. In
1168 // the latter case, we will attempt to separately identify the complex
1169 // operation from here in order to create a shared
1170 // ComplexDeinterleavingCompositeNode.
1171 if (I != Insn && I->getNumUses() > 1) {
1172 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1173 Addends.emplace_back(I, IsPositive);
1174 continue;
1176 switch (I->getOpcode()) {
1177 case Instruction::FAdd:
1178 case Instruction::Add:
1179 Worklist.emplace_back(I->getOperand(1), IsPositive);
1180 Worklist.emplace_back(I->getOperand(0), IsPositive);
1181 break;
1182 case Instruction::FSub:
1183 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1184 Worklist.emplace_back(I->getOperand(0), IsPositive);
1185 break;
1186 case Instruction::Sub:
1187 if (isNeg(I)) {
1188 Worklist.emplace_back(getNegOperand(I), !IsPositive);
1189 } else {
1190 Worklist.emplace_back(I->getOperand(1), !IsPositive);
1191 Worklist.emplace_back(I->getOperand(0), IsPositive);
1193 break;
1194 case Instruction::FMul:
1195 case Instruction::Mul: {
1196 Value *A, *B;
1197 if (isNeg(I->getOperand(0))) {
1198 A = getNegOperand(I->getOperand(0));
1199 IsPositive = !IsPositive;
1200 } else {
1201 A = I->getOperand(0);
1204 if (isNeg(I->getOperand(1))) {
1205 B = getNegOperand(I->getOperand(1));
1206 IsPositive = !IsPositive;
1207 } else {
1208 B = I->getOperand(1);
1210 Muls.push_back(Product{A, B, IsPositive});
1211 break;
1213 case Instruction::FNeg:
1214 Worklist.emplace_back(I->getOperand(0), !IsPositive);
1215 break;
1216 default:
1217 Addends.emplace_back(I, IsPositive);
1218 continue;
1221 if (Flags && I->getFastMathFlags() != *Flags) {
1222 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1223 "inconsistent with the root instructions' flags: "
1224 << *I << "\n");
1225 return false;
1228 return true;
1231 std::vector<Product> RealMuls, ImagMuls;
1232 std::list<Addend> RealAddends, ImagAddends;
1233 if (!Collect(Real, RealMuls, RealAddends) ||
1234 !Collect(Imag, ImagMuls, ImagAddends))
1235 return nullptr;
1237 if (RealAddends.size() != ImagAddends.size())
1238 return nullptr;
1240 NodePtr FinalNode;
1241 if (!RealMuls.empty() || !ImagMuls.empty()) {
1242 // If there are multiplicands, extract positive addend and use it as an
1243 // accumulator
1244 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1245 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1246 if (!FinalNode)
1247 return nullptr;
1250 // Identify and process remaining additions
1251 if (!RealAddends.empty() || !ImagAddends.empty()) {
1252 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1253 if (!FinalNode)
1254 return nullptr;
1256 assert(FinalNode && "FinalNode can not be nullptr here");
1257 // Set the Real and Imag fields of the final node and submit it
1258 FinalNode->Real = Real;
1259 FinalNode->Imag = Imag;
1260 submitCompositeNode(FinalNode);
1261 return FinalNode;
1264 bool ComplexDeinterleavingGraph::collectPartialMuls(
1265 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1266 std::vector<PartialMulCandidate> &PartialMulCandidates) {
1267 // Helper function to extract a common operand from two products
1268 auto FindCommonInstruction = [](const Product &Real,
1269 const Product &Imag) -> Value * {
1270 if (Real.Multiplicand == Imag.Multiplicand ||
1271 Real.Multiplicand == Imag.Multiplier)
1272 return Real.Multiplicand;
1274 if (Real.Multiplier == Imag.Multiplicand ||
1275 Real.Multiplier == Imag.Multiplier)
1276 return Real.Multiplier;
1278 return nullptr;
1281 // Iterating over real and imaginary multiplications to find common operands
1282 // If a common operand is found, a partial multiplication candidate is created
1283 // and added to the candidates vector The function returns false if no common
1284 // operands are found for any product
1285 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1286 bool FoundCommon = false;
1287 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1288 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1289 if (!Common)
1290 continue;
1292 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1293 : RealMuls[i].Multiplicand;
1294 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1295 : ImagMuls[j].Multiplicand;
1297 auto Node = identifyNode(A, B);
1298 if (Node) {
1299 FoundCommon = true;
1300 PartialMulCandidates.push_back({Common, Node, i, j, false});
1303 Node = identifyNode(B, A);
1304 if (Node) {
1305 FoundCommon = true;
1306 PartialMulCandidates.push_back({Common, Node, i, j, true});
1309 if (!FoundCommon)
1310 return false;
1312 return true;
1315 ComplexDeinterleavingGraph::NodePtr
1316 ComplexDeinterleavingGraph::identifyMultiplications(
1317 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1318 NodePtr Accumulator = nullptr) {
1319 if (RealMuls.size() != ImagMuls.size())
1320 return nullptr;
1322 std::vector<PartialMulCandidate> Info;
1323 if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1324 return nullptr;
1326 // Map to store common instruction to node pointers
1327 std::map<Value *, NodePtr> CommonToNode;
1328 std::vector<bool> Processed(Info.size(), false);
1329 for (unsigned I = 0; I < Info.size(); ++I) {
1330 if (Processed[I])
1331 continue;
1333 PartialMulCandidate &InfoA = Info[I];
1334 for (unsigned J = I + 1; J < Info.size(); ++J) {
1335 if (Processed[J])
1336 continue;
1338 PartialMulCandidate &InfoB = Info[J];
1339 auto *InfoReal = &InfoA;
1340 auto *InfoImag = &InfoB;
1342 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1343 if (!NodeFromCommon) {
1344 std::swap(InfoReal, InfoImag);
1345 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1347 if (!NodeFromCommon)
1348 continue;
1350 CommonToNode[InfoReal->Common] = NodeFromCommon;
1351 CommonToNode[InfoImag->Common] = NodeFromCommon;
1352 Processed[I] = true;
1353 Processed[J] = true;
1357 std::vector<bool> ProcessedReal(RealMuls.size(), false);
1358 std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1359 NodePtr Result = Accumulator;
1360 for (auto &PMI : Info) {
1361 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1362 continue;
1364 auto It = CommonToNode.find(PMI.Common);
1365 // TODO: Process independent complex multiplications. Cases like this:
1366 // A.real() * B where both A and B are complex numbers.
1367 if (It == CommonToNode.end()) {
1368 LLVM_DEBUG({
1369 dbgs() << "Unprocessed independent partial multiplication:\n";
1370 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1371 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1372 << " multiplied by " << *Mul->Multiplicand << "\n";
1374 return nullptr;
1377 auto &RealMul = RealMuls[PMI.RealIdx];
1378 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1380 auto NodeA = It->second;
1381 auto NodeB = PMI.Node;
1382 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1383 // The following table illustrates the relationship between multiplications
1384 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1385 // can see:
1387 // Rotation | Real | Imag |
1388 // ---------+--------+--------+
1389 // 0 | x * u | x * v |
1390 // 90 | -y * v | y * u |
1391 // 180 | -x * u | -x * v |
1392 // 270 | y * v | -y * u |
1394 // Check if the candidate can indeed be represented by partial
1395 // multiplication
1396 // TODO: Add support for multiplication by complex one
1397 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1398 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1399 continue;
1401 // Determine the rotation based on the multiplications
1402 ComplexDeinterleavingRotation Rotation;
1403 if (IsMultiplicandReal) {
1404 // Detect 0 and 180 degrees rotation
1405 if (RealMul.IsPositive && ImagMul.IsPositive)
1406 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1407 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1408 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1409 else
1410 continue;
1412 } else {
1413 // Detect 90 and 270 degrees rotation
1414 if (!RealMul.IsPositive && ImagMul.IsPositive)
1415 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1416 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1417 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1418 else
1419 continue;
1422 LLVM_DEBUG({
1423 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1424 dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1425 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1426 dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1427 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1428 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1431 NodePtr NodeMul = prepareCompositeNode(
1432 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1433 NodeMul->Rotation = Rotation;
1434 NodeMul->addOperand(NodeA);
1435 NodeMul->addOperand(NodeB);
1436 if (Result)
1437 NodeMul->addOperand(Result);
1438 submitCompositeNode(NodeMul);
1439 Result = NodeMul;
1440 ProcessedReal[PMI.RealIdx] = true;
1441 ProcessedImag[PMI.ImagIdx] = true;
1444 // Ensure all products have been processed, if not return nullptr.
1445 if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1446 !all_of(ProcessedImag, [](bool V) { return V; })) {
1448 // Dump debug information about which partial multiplications are not
1449 // processed.
1450 LLVM_DEBUG({
1451 dbgs() << "Unprocessed products (Real):\n";
1452 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1453 if (!ProcessedReal[i])
1454 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1455 << *RealMuls[i].Multiplier << " multiplied by "
1456 << *RealMuls[i].Multiplicand << "\n";
1458 dbgs() << "Unprocessed products (Imag):\n";
1459 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1460 if (!ProcessedImag[i])
1461 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1462 << *ImagMuls[i].Multiplier << " multiplied by "
1463 << *ImagMuls[i].Multiplicand << "\n";
1466 return nullptr;
1469 return Result;
1472 ComplexDeinterleavingGraph::NodePtr
1473 ComplexDeinterleavingGraph::identifyAdditions(
1474 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1475 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1476 if (RealAddends.size() != ImagAddends.size())
1477 return nullptr;
1479 NodePtr Result;
1480 // If we have accumulator use it as first addend
1481 if (Accumulator)
1482 Result = Accumulator;
1483 // Otherwise find an element with both positive real and imaginary parts.
1484 else
1485 Result = extractPositiveAddend(RealAddends, ImagAddends);
1487 if (!Result)
1488 return nullptr;
1490 while (!RealAddends.empty()) {
1491 auto ItR = RealAddends.begin();
1492 auto [R, IsPositiveR] = *ItR;
1494 bool FoundImag = false;
1495 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1496 auto [I, IsPositiveI] = *ItI;
1497 ComplexDeinterleavingRotation Rotation;
1498 if (IsPositiveR && IsPositiveI)
1499 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1500 else if (!IsPositiveR && IsPositiveI)
1501 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1502 else if (!IsPositiveR && !IsPositiveI)
1503 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1504 else
1505 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1507 NodePtr AddNode;
1508 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1509 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1510 AddNode = identifyNode(R, I);
1511 } else {
1512 AddNode = identifyNode(I, R);
1514 if (AddNode) {
1515 LLVM_DEBUG({
1516 dbgs() << "Identified addition:\n";
1517 dbgs().indent(4) << "X: " << *R << "\n";
1518 dbgs().indent(4) << "Y: " << *I << "\n";
1519 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1522 NodePtr TmpNode;
1523 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1524 TmpNode = prepareCompositeNode(
1525 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1526 if (Flags) {
1527 TmpNode->Opcode = Instruction::FAdd;
1528 TmpNode->Flags = *Flags;
1529 } else {
1530 TmpNode->Opcode = Instruction::Add;
1532 } else if (Rotation ==
1533 llvm::ComplexDeinterleavingRotation::Rotation_180) {
1534 TmpNode = prepareCompositeNode(
1535 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1536 if (Flags) {
1537 TmpNode->Opcode = Instruction::FSub;
1538 TmpNode->Flags = *Flags;
1539 } else {
1540 TmpNode->Opcode = Instruction::Sub;
1542 } else {
1543 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1544 nullptr, nullptr);
1545 TmpNode->Rotation = Rotation;
1548 TmpNode->addOperand(Result);
1549 TmpNode->addOperand(AddNode);
1550 submitCompositeNode(TmpNode);
1551 Result = TmpNode;
1552 RealAddends.erase(ItR);
1553 ImagAddends.erase(ItI);
1554 FoundImag = true;
1555 break;
1558 if (!FoundImag)
1559 return nullptr;
1561 return Result;
1564 ComplexDeinterleavingGraph::NodePtr
1565 ComplexDeinterleavingGraph::extractPositiveAddend(
1566 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1567 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1568 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1569 auto [R, IsPositiveR] = *ItR;
1570 auto [I, IsPositiveI] = *ItI;
1571 if (IsPositiveR && IsPositiveI) {
1572 auto Result = identifyNode(R, I);
1573 if (Result) {
1574 RealAddends.erase(ItR);
1575 ImagAddends.erase(ItI);
1576 return Result;
1581 return nullptr;
1584 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1585 // This potential root instruction might already have been recognized as
1586 // reduction. Because RootToNode maps both Real and Imaginary parts to
1587 // CompositeNode we should choose only one either Real or Imag instruction to
1588 // use as an anchor for generating complex instruction.
1589 auto It = RootToNode.find(RootI);
1590 if (It != RootToNode.end()) {
1591 auto RootNode = It->second;
1592 assert(RootNode->Operation ==
1593 ComplexDeinterleavingOperation::ReductionOperation ||
1594 RootNode->Operation ==
1595 ComplexDeinterleavingOperation::ReductionSingle);
1596 // Find out which part, Real or Imag, comes later, and only if we come to
1597 // the latest part, add it to OrderedRoots.
1598 auto *R = cast<Instruction>(RootNode->Real);
1599 auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
1601 Instruction *ReplacementAnchor;
1602 if (I)
1603 ReplacementAnchor = R->comesBefore(I) ? I : R;
1604 else
1605 ReplacementAnchor = R;
1607 if (ReplacementAnchor != RootI)
1608 return false;
1609 OrderedRoots.push_back(RootI);
1610 return true;
1613 auto RootNode = identifyRoot(RootI);
1614 if (!RootNode)
1615 return false;
1617 LLVM_DEBUG({
1618 Function *F = RootI->getFunction();
1619 BasicBlock *B = RootI->getParent();
1620 dbgs() << "Complex deinterleaving graph for " << F->getName()
1621 << "::" << B->getName() << ".\n";
1622 dump(dbgs());
1623 dbgs() << "\n";
1625 RootToNode[RootI] = RootNode;
1626 OrderedRoots.push_back(RootI);
1627 return true;
1630 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1631 bool FoundPotentialReduction = false;
1633 auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1634 if (!Br || Br->getNumSuccessors() != 2)
1635 return false;
1637 // Identify simple one-block loop
1638 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1639 return false;
1641 SmallVector<PHINode *> PHIs;
1642 for (auto &PHI : B->phis()) {
1643 if (PHI.getNumIncomingValues() != 2)
1644 continue;
1646 if (!PHI.getType()->isVectorTy())
1647 continue;
1649 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1650 if (!ReductionOp)
1651 continue;
1653 // Check if final instruction is reduced outside of current block
1654 Instruction *FinalReduction = nullptr;
1655 auto NumUsers = 0u;
1656 for (auto *U : ReductionOp->users()) {
1657 ++NumUsers;
1658 if (U == &PHI)
1659 continue;
1660 FinalReduction = dyn_cast<Instruction>(U);
1663 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1664 isa<PHINode>(FinalReduction))
1665 continue;
1667 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1668 BackEdge = B;
1669 auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1670 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1671 Incoming = PHI.getIncomingBlock(IncomingIdx);
1672 FoundPotentialReduction = true;
1674 // If the initial value of PHINode is an Instruction, consider it a leaf
1675 // value of a complex deinterleaving graph.
1676 if (auto *InitPHI =
1677 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1678 FinalInstructions.insert(InitPHI);
1680 return FoundPotentialReduction;
1683 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1684 SmallVector<bool> Processed(ReductionInfo.size(), false);
1685 SmallVector<Instruction *> OperationInstruction;
1686 for (auto &P : ReductionInfo)
1687 OperationInstruction.push_back(P.first);
1689 // Identify a complex computation by evaluating two reduction operations that
1690 // potentially could be involved
1691 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1692 if (Processed[i])
1693 continue;
1694 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1695 if (Processed[j])
1696 continue;
1697 auto *Real = OperationInstruction[i];
1698 auto *Imag = OperationInstruction[j];
1699 if (Real->getType() != Imag->getType())
1700 continue;
1702 RealPHI = ReductionInfo[Real].first;
1703 ImagPHI = ReductionInfo[Imag].first;
1704 PHIsFound = false;
1705 auto Node = identifyNode(Real, Imag);
1706 if (!Node) {
1707 std::swap(Real, Imag);
1708 std::swap(RealPHI, ImagPHI);
1709 Node = identifyNode(Real, Imag);
1712 // If a node is identified and reduction PHINode is used in the chain of
1713 // operations, mark its operation instructions as used to prevent
1714 // re-identification and attach the node to the real part
1715 if (Node && PHIsFound) {
1716 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1717 << *Real << " / " << *Imag << "\n");
1718 Processed[i] = true;
1719 Processed[j] = true;
1720 auto RootNode = prepareCompositeNode(
1721 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1722 RootNode->addOperand(Node);
1723 RootToNode[Real] = RootNode;
1724 RootToNode[Imag] = RootNode;
1725 submitCompositeNode(RootNode);
1726 break;
1730 auto *Real = OperationInstruction[i];
1731 // We want to check that we have 2 operands, but the function attributes
1732 // being counted as operands bloats this value.
1733 if (Processed[i] || Real->getNumOperands() < 2)
1734 continue;
1736 RealPHI = ReductionInfo[Real].first;
1737 ImagPHI = nullptr;
1738 PHIsFound = false;
1739 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1740 if (Node && PHIsFound) {
1741 LLVM_DEBUG(
1742 dbgs() << "Identified single reduction starting from instruction: "
1743 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1744 Processed[i] = true;
1745 auto RootNode = prepareCompositeNode(
1746 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1747 RootNode->addOperand(Node);
1748 RootToNode[Real] = RootNode;
1749 submitCompositeNode(RootNode);
1753 RealPHI = nullptr;
1754 ImagPHI = nullptr;
1757 bool ComplexDeinterleavingGraph::checkNodes() {
1759 bool FoundDeinterleaveNode = false;
1760 for (NodePtr N : CompositeNodes) {
1761 if (!N->areOperandsValid())
1762 return false;
1763 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1764 FoundDeinterleaveNode = true;
1767 // We need a deinterleave node in order to guarantee that we're working with
1768 // complex numbers.
1769 if (!FoundDeinterleaveNode) {
1770 LLVM_DEBUG(
1771 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1772 "guarantee safety during graph transformation.\n");
1773 return false;
1776 // Collect all instructions from roots to leaves
1777 SmallPtrSet<Instruction *, 16> AllInstructions;
1778 SmallVector<Instruction *, 8> Worklist;
1779 for (auto &Pair : RootToNode)
1780 Worklist.push_back(Pair.first);
1782 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1783 // chains
1784 while (!Worklist.empty()) {
1785 auto *I = Worklist.back();
1786 Worklist.pop_back();
1788 if (!AllInstructions.insert(I).second)
1789 continue;
1791 for (Value *Op : I->operands()) {
1792 if (auto *OpI = dyn_cast<Instruction>(Op)) {
1793 if (!FinalInstructions.count(I))
1794 Worklist.emplace_back(OpI);
1799 // Find instructions that have users outside of chain
1800 SmallVector<Instruction *, 2> OuterInstructions;
1801 for (auto *I : AllInstructions) {
1802 // Skip root nodes
1803 if (RootToNode.count(I))
1804 continue;
1806 for (User *U : I->users()) {
1807 if (AllInstructions.count(cast<Instruction>(U)))
1808 continue;
1810 // Found an instruction that is not used by XCMLA/XCADD chain
1811 Worklist.emplace_back(I);
1812 break;
1816 // If any instructions are found to be used outside, find and remove roots
1817 // that somehow connect to those instructions.
1818 SmallPtrSet<Instruction *, 16> Visited;
1819 while (!Worklist.empty()) {
1820 auto *I = Worklist.back();
1821 Worklist.pop_back();
1822 if (!Visited.insert(I).second)
1823 continue;
1825 // Found an impacted root node. Removing it from the nodes to be
1826 // deinterleaved
1827 if (RootToNode.count(I)) {
1828 LLVM_DEBUG(dbgs() << "Instruction " << *I
1829 << " could be deinterleaved but its chain of complex "
1830 "operations have an outside user\n");
1831 RootToNode.erase(I);
1834 if (!AllInstructions.count(I) || FinalInstructions.count(I))
1835 continue;
1837 for (User *U : I->users())
1838 Worklist.emplace_back(cast<Instruction>(U));
1840 for (Value *Op : I->operands()) {
1841 if (auto *OpI = dyn_cast<Instruction>(Op))
1842 Worklist.emplace_back(OpI);
1845 return !RootToNode.empty();
1848 ComplexDeinterleavingGraph::NodePtr
1849 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1850 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1851 if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1852 return nullptr;
1854 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1855 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1856 if (!Real || !Imag)
1857 return nullptr;
1859 return identifyNode(Real, Imag);
1862 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1863 if (!SVI)
1864 return nullptr;
1866 // Look for a shufflevector that takes separate vectors of the real and
1867 // imaginary components and recombines them into a single vector.
1868 if (!isInterleavingMask(SVI->getShuffleMask()))
1869 return nullptr;
1871 Instruction *Real;
1872 Instruction *Imag;
1873 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1874 return nullptr;
1876 return identifyNode(Real, Imag);
1879 ComplexDeinterleavingGraph::NodePtr
1880 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1881 Instruction *Imag) {
1882 Instruction *I = nullptr;
1883 Value *FinalValue = nullptr;
1884 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1885 match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1886 match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
1887 m_Value(FinalValue)))) {
1888 NodePtr PlaceholderNode = prepareCompositeNode(
1889 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1890 PlaceholderNode->ReplacementNode = FinalValue;
1891 FinalInstructions.insert(Real);
1892 FinalInstructions.insert(Imag);
1893 return submitCompositeNode(PlaceholderNode);
1896 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1897 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1898 if (!RealShuffle || !ImagShuffle) {
1899 if (RealShuffle || ImagShuffle)
1900 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1901 return nullptr;
1904 Value *RealOp1 = RealShuffle->getOperand(1);
1905 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1906 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1907 return nullptr;
1909 Value *ImagOp1 = ImagShuffle->getOperand(1);
1910 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1911 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1912 return nullptr;
1915 Value *RealOp0 = RealShuffle->getOperand(0);
1916 Value *ImagOp0 = ImagShuffle->getOperand(0);
1918 if (RealOp0 != ImagOp0) {
1919 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1920 return nullptr;
1923 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1924 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1925 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1926 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1927 return nullptr;
1930 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1931 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1932 return nullptr;
1935 // Type checking, the shuffle type should be a vector type of the same
1936 // scalar type, but half the size
1937 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1938 Value *Op = Shuffle->getOperand(0);
1939 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1940 auto *OpTy = cast<FixedVectorType>(Op->getType());
1942 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1943 return false;
1944 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1945 return false;
1947 return true;
1950 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1951 if (!CheckType(Shuffle))
1952 return false;
1954 ArrayRef<int> Mask = Shuffle->getShuffleMask();
1955 int Last = *Mask.rbegin();
1957 Value *Op = Shuffle->getOperand(0);
1958 auto *OpTy = cast<FixedVectorType>(Op->getType());
1959 int NumElements = OpTy->getNumElements();
1961 // Ensure that the deinterleaving shuffle only pulls from the first
1962 // shuffle operand.
1963 return Last < NumElements;
1966 if (RealShuffle->getType() != ImagShuffle->getType()) {
1967 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1968 return nullptr;
1970 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1971 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1972 return nullptr;
1974 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1975 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1976 return nullptr;
1979 NodePtr PlaceholderNode =
1980 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1981 RealShuffle, ImagShuffle);
1982 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1983 FinalInstructions.insert(RealShuffle);
1984 FinalInstructions.insert(ImagShuffle);
1985 return submitCompositeNode(PlaceholderNode);
1988 ComplexDeinterleavingGraph::NodePtr
1989 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1990 auto IsSplat = [](Value *V) -> bool {
1991 // Fixed-width vector with constants
1992 if (isa<ConstantDataVector>(V))
1993 return true;
1995 VectorType *VTy;
1996 ArrayRef<int> Mask;
1997 // Splats are represented differently depending on whether the repeated
1998 // value is a constant or an Instruction
1999 if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2000 if (Const->getOpcode() != Instruction::ShuffleVector)
2001 return false;
2002 VTy = cast<VectorType>(Const->getType());
2003 Mask = Const->getShuffleMask();
2004 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2005 VTy = Shuf->getType();
2006 Mask = Shuf->getShuffleMask();
2007 } else {
2008 return false;
2011 // When the data type is <1 x Type>, it's not possible to differentiate
2012 // between the ComplexDeinterleaving::Deinterleave and
2013 // ComplexDeinterleaving::Splat operations.
2014 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2015 return false;
2017 return all_equal(Mask) && Mask[0] == 0;
2020 if (!IsSplat(R) || !IsSplat(I))
2021 return nullptr;
2023 auto *Real = dyn_cast<Instruction>(R);
2024 auto *Imag = dyn_cast<Instruction>(I);
2025 if ((!Real && Imag) || (Real && !Imag))
2026 return nullptr;
2028 if (Real && Imag) {
2029 // Non-constant splats should be in the same basic block
2030 if (Real->getParent() != Imag->getParent())
2031 return nullptr;
2033 FinalInstructions.insert(Real);
2034 FinalInstructions.insert(Imag);
2036 NodePtr PlaceholderNode =
2037 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
2038 return submitCompositeNode(PlaceholderNode);
2041 ComplexDeinterleavingGraph::NodePtr
2042 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2043 Instruction *Imag) {
2044 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2045 return nullptr;
2047 PHIsFound = true;
2048 NodePtr PlaceholderNode = prepareCompositeNode(
2049 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2050 return submitCompositeNode(PlaceholderNode);
2053 ComplexDeinterleavingGraph::NodePtr
2054 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2055 Instruction *Imag) {
2056 auto *SelectReal = dyn_cast<SelectInst>(Real);
2057 auto *SelectImag = dyn_cast<SelectInst>(Imag);
2058 if (!SelectReal || !SelectImag)
2059 return nullptr;
2061 Instruction *MaskA, *MaskB;
2062 Instruction *AR, *AI, *RA, *BI;
2063 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2064 m_Instruction(RA))) ||
2065 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2066 m_Instruction(BI))))
2067 return nullptr;
2069 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2070 return nullptr;
2072 if (!MaskA->getType()->isVectorTy())
2073 return nullptr;
2075 auto NodeA = identifyNode(AR, AI);
2076 if (!NodeA)
2077 return nullptr;
2079 auto NodeB = identifyNode(RA, BI);
2080 if (!NodeB)
2081 return nullptr;
2083 NodePtr PlaceholderNode = prepareCompositeNode(
2084 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2085 PlaceholderNode->addOperand(NodeA);
2086 PlaceholderNode->addOperand(NodeB);
2087 FinalInstructions.insert(MaskA);
2088 FinalInstructions.insert(MaskB);
2089 return submitCompositeNode(PlaceholderNode);
2092 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2093 std::optional<FastMathFlags> Flags,
2094 Value *InputA, Value *InputB) {
2095 Value *I;
2096 switch (Opcode) {
2097 case Instruction::FNeg:
2098 I = B.CreateFNeg(InputA);
2099 break;
2100 case Instruction::FAdd:
2101 I = B.CreateFAdd(InputA, InputB);
2102 break;
2103 case Instruction::Add:
2104 I = B.CreateAdd(InputA, InputB);
2105 break;
2106 case Instruction::FSub:
2107 I = B.CreateFSub(InputA, InputB);
2108 break;
2109 case Instruction::Sub:
2110 I = B.CreateSub(InputA, InputB);
2111 break;
2112 case Instruction::FMul:
2113 I = B.CreateFMul(InputA, InputB);
2114 break;
2115 case Instruction::Mul:
2116 I = B.CreateMul(InputA, InputB);
2117 break;
2118 default:
2119 llvm_unreachable("Incorrect symmetric opcode");
2121 if (Flags)
2122 cast<Instruction>(I)->setFastMathFlags(*Flags);
2123 return I;
2126 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2127 RawNodePtr Node) {
2128 if (Node->ReplacementNode)
2129 return Node->ReplacementNode;
2131 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
2132 return Node->Operands.size() > Idx
2133 ? replaceNode(Builder, Node->Operands[Idx])
2134 : nullptr;
2137 Value *ReplacementNode;
2138 switch (Node->Operation) {
2139 case ComplexDeinterleavingOperation::CDot: {
2140 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2141 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2142 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2143 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2144 "Node inputs need to be of the same type"));
2145 ReplacementNode = TL->createComplexDeinterleavingIR(
2146 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2147 break;
2149 case ComplexDeinterleavingOperation::CAdd:
2150 case ComplexDeinterleavingOperation::CMulPartial:
2151 case ComplexDeinterleavingOperation::Symmetric: {
2152 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2153 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2154 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2155 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2156 "Node inputs need to be of the same type"));
2157 assert(!Accumulator ||
2158 (Input0->getType() == Accumulator->getType() &&
2159 "Accumulator and input need to be of the same type"));
2160 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2161 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2162 Input0, Input1);
2163 else
2164 ReplacementNode = TL->createComplexDeinterleavingIR(
2165 Builder, Node->Operation, Node->Rotation, Input0, Input1,
2166 Accumulator);
2167 break;
2169 case ComplexDeinterleavingOperation::Deinterleave:
2170 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2171 break;
2172 case ComplexDeinterleavingOperation::Splat: {
2173 auto *NewTy = VectorType::getDoubleElementsVectorType(
2174 cast<VectorType>(Node->Real->getType()));
2175 auto *R = dyn_cast<Instruction>(Node->Real);
2176 auto *I = dyn_cast<Instruction>(Node->Imag);
2177 if (R && I) {
2178 // Splats that are not constant are interleaved where they are located
2179 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
2180 IRBuilder<> IRB(InsertPoint);
2181 ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
2182 NewTy, {Node->Real, Node->Imag});
2183 } else {
2184 ReplacementNode = Builder.CreateIntrinsic(
2185 Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
2187 break;
2189 case ComplexDeinterleavingOperation::ReductionPHI: {
2190 // If Operation is ReductionPHI, a new empty PHINode is created.
2191 // It is filled later when the ReductionOperation is processed.
2192 auto *OldPHI = cast<PHINode>(Node->Real);
2193 auto *VTy = cast<VectorType>(Node->Real->getType());
2194 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2195 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2196 OldToNewPHI[OldPHI] = NewPHI;
2197 ReplacementNode = NewPHI;
2198 break;
2200 case ComplexDeinterleavingOperation::ReductionSingle:
2201 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2202 processReductionSingle(ReplacementNode, Node);
2203 break;
2204 case ComplexDeinterleavingOperation::ReductionOperation:
2205 ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2206 processReductionOperation(ReplacementNode, Node);
2207 break;
2208 case ComplexDeinterleavingOperation::ReductionSelect: {
2209 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
2210 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
2211 auto *A = replaceNode(Builder, Node->Operands[0]);
2212 auto *B = replaceNode(Builder, Node->Operands[1]);
2213 auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
2214 cast<VectorType>(MaskReal->getType()));
2215 auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
2216 NewMaskTy, {MaskReal, MaskImag});
2217 ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2218 break;
2222 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2223 NumComplexTransformations += 1;
2224 Node->ReplacementNode = ReplacementNode;
2225 return ReplacementNode;
2228 void ComplexDeinterleavingGraph::processReductionSingle(
2229 Value *OperationReplacement, RawNodePtr Node) {
2230 auto *Real = cast<Instruction>(Node->Real);
2231 auto *OldPHI = ReductionInfo[Real].first;
2232 auto *NewPHI = OldToNewPHI[OldPHI];
2233 auto *VTy = cast<VectorType>(Real->getType());
2234 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2236 Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2238 IRBuilder<> Builder(Incoming->getTerminator());
2240 Value *NewInit = nullptr;
2241 if (auto *C = dyn_cast<Constant>(Init)) {
2242 if (C->isZeroValue())
2243 NewInit = Constant::getNullValue(NewVTy);
2246 if (!NewInit)
2247 NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2248 {Init, Constant::getNullValue(VTy)});
2250 NewPHI->addIncoming(NewInit, Incoming);
2251 NewPHI->addIncoming(OperationReplacement, BackEdge);
2253 auto *FinalReduction = ReductionInfo[Real].second;
2254 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2256 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2257 FinalReduction->replaceAllUsesWith(AddReduce);
2260 void ComplexDeinterleavingGraph::processReductionOperation(
2261 Value *OperationReplacement, RawNodePtr Node) {
2262 auto *Real = cast<Instruction>(Node->Real);
2263 auto *Imag = cast<Instruction>(Node->Imag);
2264 auto *OldPHIReal = ReductionInfo[Real].first;
2265 auto *OldPHIImag = ReductionInfo[Imag].first;
2266 auto *NewPHI = OldToNewPHI[OldPHIReal];
2268 auto *VTy = cast<VectorType>(Real->getType());
2269 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2271 // We have to interleave initial origin values coming from IncomingBlock
2272 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2273 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2275 IRBuilder<> Builder(Incoming->getTerminator());
2276 auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2277 {InitReal, InitImag});
2279 NewPHI->addIncoming(NewInit, Incoming);
2280 NewPHI->addIncoming(OperationReplacement, BackEdge);
2282 // Deinterleave complex vector outside of loop so that it can be finally
2283 // reduced
2284 auto *FinalReductionReal = ReductionInfo[Real].second;
2285 auto *FinalReductionImag = ReductionInfo[Imag].second;
2287 Builder.SetInsertPoint(
2288 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2289 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2290 OperationReplacement->getType(),
2291 OperationReplacement);
2293 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2294 FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2296 Builder.SetInsertPoint(FinalReductionImag);
2297 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2298 FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2301 void ComplexDeinterleavingGraph::replaceNodes() {
2302 SmallVector<Instruction *, 16> DeadInstrRoots;
2303 for (auto *RootInstruction : OrderedRoots) {
2304 // Check if this potential root went through check process and we can
2305 // deinterleave it
2306 if (!RootToNode.count(RootInstruction))
2307 continue;
2309 IRBuilder<> Builder(RootInstruction);
2310 auto RootNode = RootToNode[RootInstruction];
2311 Value *R = replaceNode(Builder, RootNode.get());
2313 if (RootNode->Operation ==
2314 ComplexDeinterleavingOperation::ReductionOperation) {
2315 auto *RootReal = cast<Instruction>(RootNode->Real);
2316 auto *RootImag = cast<Instruction>(RootNode->Imag);
2317 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2318 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2319 DeadInstrRoots.push_back(RootReal);
2320 DeadInstrRoots.push_back(RootImag);
2321 } else if (RootNode->Operation ==
2322 ComplexDeinterleavingOperation::ReductionSingle) {
2323 auto *RootInst = cast<Instruction>(RootNode->Real);
2324 ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
2325 DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
2326 } else {
2327 assert(R && "Unable to find replacement for RootInstruction");
2328 DeadInstrRoots.push_back(RootInstruction);
2329 RootInstruction->replaceAllUsesWith(R);
2333 for (auto *I : DeadInstrRoots)
2334 RecursivelyDeleteTriviallyDeadInstructions(I, TLI);