1 //===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the lowering and legalization of vector instructions to
10 // VVP_*layer SDNodes.
12 //===----------------------------------------------------------------------===//
14 #include "VECustomDAG.h"
15 #include "VEISelLowering.h"
19 #define DEBUG_TYPE "ve-lower"
21 SDValue
VETargetLowering::splitMaskArithmetic(SDValue Op
,
22 SelectionDAG
&DAG
) const {
23 VECustomDAG
CDAG(DAG
, Op
);
25 CDAG
.getConstant(Op
.getValueType().getVectorNumElements(), MVT::i32
);
26 SDValue A
= Op
->getOperand(0);
27 SDValue B
= Op
->getOperand(1);
28 SDValue LoA
= CDAG
.getUnpack(MVT::v256i1
, A
, PackElem::Lo
, AVL
);
29 SDValue HiA
= CDAG
.getUnpack(MVT::v256i1
, A
, PackElem::Hi
, AVL
);
30 SDValue LoB
= CDAG
.getUnpack(MVT::v256i1
, B
, PackElem::Lo
, AVL
);
31 SDValue HiB
= CDAG
.getUnpack(MVT::v256i1
, B
, PackElem::Hi
, AVL
);
32 unsigned Opc
= Op
.getOpcode();
33 auto LoRes
= CDAG
.getNode(Opc
, MVT::v256i1
, {LoA
, LoB
});
34 auto HiRes
= CDAG
.getNode(Opc
, MVT::v256i1
, {HiA
, HiB
});
35 return CDAG
.getPack(MVT::v512i1
, LoRes
, HiRes
, AVL
);
38 SDValue
VETargetLowering::lowerToVVP(SDValue Op
, SelectionDAG
&DAG
) const {
39 // Can we represent this as a VVP node.
40 const unsigned Opcode
= Op
->getOpcode();
41 auto VVPOpcodeOpt
= getVVPOpcode(Opcode
);
44 unsigned VVPOpcode
= *VVPOpcodeOpt
;
45 const bool FromVP
= ISD::isVPOpcode(Opcode
);
47 // The representative and legalized vector type of this operation.
48 VECustomDAG
CDAG(DAG
, Op
);
49 // Dispatch to complex lowering functions.
52 case VEISD::VVP_STORE
:
53 return lowerVVP_LOAD_STORE(Op
, CDAG
);
54 case VEISD::VVP_GATHER
:
55 case VEISD::VVP_SCATTER
:
56 return lowerVVP_GATHER_SCATTER(Op
, CDAG
);
59 EVT OpVecVT
= *getIdiomaticVectorType(Op
.getNode());
60 EVT LegalVecVT
= getTypeToTransformTo(*DAG
.getContext(), OpVecVT
);
61 auto Packing
= getTypePacking(LegalVecVT
.getSimpleVT());
67 // All upstream VP SDNodes always have a mask and avl.
68 auto MaskIdx
= ISD::getVPMaskIdx(Opcode
);
69 auto AVLIdx
= ISD::getVPExplicitVectorLengthIdx(Opcode
);
71 Mask
= Op
->getOperand(*MaskIdx
);
73 AVL
= Op
->getOperand(*AVLIdx
);
76 // Materialize default mask and avl.
78 AVL
= CDAG
.getConstant(OpVecVT
.getVectorNumElements(), MVT::i32
);
80 Mask
= CDAG
.getConstantMask(Packing
, true);
82 assert(LegalVecVT
.isSimple());
83 if (isVVPUnaryOp(VVPOpcode
))
84 return CDAG
.getNode(VVPOpcode
, LegalVecVT
, {Op
->getOperand(0), Mask
, AVL
});
85 if (isVVPBinaryOp(VVPOpcode
))
86 return CDAG
.getNode(VVPOpcode
, LegalVecVT
,
87 {Op
->getOperand(0), Op
->getOperand(1), Mask
, AVL
});
88 if (isVVPReductionOp(VVPOpcode
)) {
89 auto SrcHasStart
= hasReductionStartParam(Op
->getOpcode());
90 SDValue StartV
= SrcHasStart
? Op
->getOperand(0) : SDValue();
91 SDValue VectorV
= Op
->getOperand(SrcHasStart
? 1 : 0);
92 return CDAG
.getLegalReductionOpVVP(VVPOpcode
, Op
.getValueType(), StartV
,
93 VectorV
, Mask
, AVL
, Op
->getFlags());
98 llvm_unreachable("lowerToVVP called for unexpected SDNode.");
99 case VEISD::VVP_FFMA
: {
100 // VE has a swizzled operand order in FMA (compared to LLVM IR and
102 auto X
= Op
->getOperand(2);
103 auto Y
= Op
->getOperand(0);
104 auto Z
= Op
->getOperand(1);
105 return CDAG
.getNode(VVPOpcode
, LegalVecVT
, {X
, Y
, Z
, Mask
, AVL
});
107 case VEISD::VVP_SELECT
: {
108 auto Mask
= Op
->getOperand(0);
109 auto OnTrue
= Op
->getOperand(1);
110 auto OnFalse
= Op
->getOperand(2);
111 return CDAG
.getNode(VVPOpcode
, LegalVecVT
, {OnTrue
, OnFalse
, Mask
, AVL
});
113 case VEISD::VVP_SETCC
: {
114 EVT LegalResVT
= getTypeToTransformTo(*DAG
.getContext(), Op
.getValueType());
115 auto LHS
= Op
->getOperand(0);
116 auto RHS
= Op
->getOperand(1);
117 auto Pred
= Op
->getOperand(2);
118 return CDAG
.getNode(VVPOpcode
, LegalResVT
, {LHS
, RHS
, Pred
, Mask
, AVL
});
123 SDValue
VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op
,
124 VECustomDAG
&CDAG
) const {
125 auto VVPOpc
= *getVVPOpcode(Op
->getOpcode());
126 const bool IsLoad
= (VVPOpc
== VEISD::VVP_LOAD
);
129 SDValue BasePtr
= getMemoryPtr(Op
);
130 SDValue Mask
= getNodeMask(Op
);
131 SDValue Chain
= getNodeChain(Op
);
132 SDValue AVL
= getNodeAVL(Op
);
134 SDValue Data
= getStoredValue(Op
);
136 SDValue PassThru
= getNodePassthru(Op
);
138 SDValue StrideV
= getLoadStoreStride(Op
, CDAG
);
140 auto DataVT
= *getIdiomaticVectorType(Op
.getNode());
141 auto Packing
= getTypePacking(DataVT
);
143 // TODO: Infer lower AVL from mask.
145 AVL
= CDAG
.getConstant(DataVT
.getVectorNumElements(), MVT::i32
);
147 // Default to the all-true mask.
149 Mask
= CDAG
.getConstantMask(Packing
, true);
152 MVT LegalDataVT
= getLegalVectorType(
153 Packing
, DataVT
.getVectorElementType().getSimpleVT());
155 auto NewLoadV
= CDAG
.getNode(VEISD::VVP_LOAD
, {LegalDataVT
, MVT::Other
},
156 {Chain
, BasePtr
, StrideV
, Mask
, AVL
});
158 if (!PassThru
|| PassThru
->isUndef())
161 // Convert passthru to an explicit select node.
162 SDValue DataV
= CDAG
.getNode(VEISD::VVP_SELECT
, DataVT
,
163 {NewLoadV
, PassThru
, Mask
, AVL
});
164 SDValue NewLoadChainV
= SDValue(NewLoadV
.getNode(), 1);
166 // Merge them back into one node.
167 return CDAG
.getMergeValues({DataV
, NewLoadChainV
});
171 assert(VVPOpc
== VEISD::VVP_STORE
);
172 if (getTypeAction(*CDAG
.getDAG()->getContext(), Data
.getValueType()) !=
173 TargetLowering::TypeLegal
)
174 // Doesn't lower store instruction if an operand is not lowered yet.
175 // If it isn't, return SDValue(). In this way, LLVM will try to lower
176 // store instruction again after lowering all operands.
178 return CDAG
.getNode(VEISD::VVP_STORE
, Op
.getNode()->getVTList(),
179 {Chain
, Data
, BasePtr
, StrideV
, Mask
, AVL
});
182 SDValue
VETargetLowering::splitPackedLoadStore(SDValue Op
,
183 VECustomDAG
&CDAG
) const {
184 auto VVPOC
= *getVVPOpcode(Op
.getOpcode());
185 assert((VVPOC
== VEISD::VVP_LOAD
) || (VVPOC
== VEISD::VVP_STORE
));
187 MVT DataVT
= getIdiomaticVectorType(Op
.getNode())->getSimpleVT();
188 assert(getTypePacking(DataVT
) == Packing::Dense
&&
189 "Can only split packed load/store");
190 MVT SplitDataVT
= splitVectorType(DataVT
);
192 assert(!getNodePassthru(Op
) &&
193 "Should have been folded in lowering to VVP layer");
195 // Analyze the operation
196 SDValue PackedMask
= getNodeMask(Op
);
197 SDValue PackedAVL
= getAnnotatedNodeAVL(Op
).first
;
198 SDValue PackPtr
= getMemoryPtr(Op
);
199 SDValue PackData
= getStoredValue(Op
);
200 SDValue PackStride
= getLoadStoreStride(Op
, CDAG
);
202 unsigned ChainResIdx
= PackData
? 0 : 1;
206 SDValue UpperPartAVL
; // we will use this for packing things back together
207 for (PackElem Part
: {PackElem::Hi
, PackElem::Lo
}) {
208 // VP ops already have an explicit mask and AVL. When expanding from non-VP
209 // attach those additional inputs here.
210 auto SplitTM
= CDAG
.getTargetSplitMask(PackedMask
, PackedAVL
, Part
);
212 // Keep track of the (higher) lvl.
213 if (Part
== PackElem::Hi
)
214 UpperPartAVL
= SplitTM
.AVL
;
216 // Attach non-predicating value operands
217 SmallVector
<SDValue
, 4> OpVec
;
220 OpVec
.push_back(getNodeChain(Op
));
225 CDAG
.getUnpack(SplitDataVT
, PackData
, Part
, SplitTM
.AVL
);
226 OpVec
.push_back(PartData
);
230 // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
232 // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
233 OpVec
.push_back(CDAG
.getSplitPtrOffset(PackPtr
, PackStride
, Part
));
234 OpVec
.push_back(CDAG
.getSplitPtrStride(PackStride
));
236 // Add predicating args and generate part node
237 OpVec
.push_back(SplitTM
.Mask
);
238 OpVec
.push_back(SplitTM
.AVL
);
242 PartOps
[(int)Part
] = CDAG
.getNode(VVPOC
, MVT::Other
, OpVec
);
246 CDAG
.getNode(VVPOC
, {SplitDataVT
, MVT::Other
}, OpVec
);
251 SDValue LowChain
= SDValue(PartOps
[(int)PackElem::Lo
].getNode(), ChainResIdx
);
252 SDValue HiChain
= SDValue(PartOps
[(int)PackElem::Hi
].getNode(), ChainResIdx
);
253 SDValue FusedChains
=
254 CDAG
.getNode(ISD::TokenFactor
, MVT::Other
, {LowChain
, HiChain
});
256 // Chain only [store]
260 // Re-pack into full packed vector result
262 getLegalVectorType(Packing::Dense
, DataVT
.getVectorElementType());
263 SDValue PackedVals
= CDAG
.getPack(PackedVT
, PartOps
[(int)PackElem::Lo
],
264 PartOps
[(int)PackElem::Hi
], UpperPartAVL
);
266 return CDAG
.getMergeValues({PackedVals
, FusedChains
});
269 SDValue
VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op
,
270 VECustomDAG
&CDAG
) const {
271 EVT DataVT
= *getIdiomaticVectorType(Op
.getNode());
272 auto Packing
= getTypePacking(DataVT
);
274 getLegalVectorType(Packing
, DataVT
.getVectorElementType().getSimpleVT());
276 SDValue AVL
= getAnnotatedNodeAVL(Op
).first
;
277 SDValue Index
= getGatherScatterIndex(Op
);
278 SDValue BasePtr
= getMemoryPtr(Op
);
279 SDValue Mask
= getNodeMask(Op
);
280 SDValue Chain
= getNodeChain(Op
);
281 SDValue Scale
= getGatherScatterScale(Op
);
282 SDValue PassThru
= getNodePassthru(Op
);
283 SDValue StoredValue
= getStoredValue(Op
);
284 if (PassThru
&& PassThru
->isUndef())
285 PassThru
= SDValue();
287 bool IsScatter
= (bool)StoredValue
;
289 // TODO: Infer lower AVL from mask.
291 AVL
= CDAG
.getConstant(DataVT
.getVectorNumElements(), MVT::i32
);
293 // Default to the all-true mask.
295 Mask
= CDAG
.getConstantMask(Packing
, true);
298 CDAG
.getGatherScatterAddress(BasePtr
, Scale
, Index
, Mask
, AVL
);
300 return CDAG
.getNode(VEISD::VVP_SCATTER
, MVT::Other
,
301 {Chain
, StoredValue
, AddressVec
, Mask
, AVL
});
304 SDValue NewLoadV
= CDAG
.getNode(VEISD::VVP_GATHER
, {LegalDataVT
, MVT::Other
},
305 {Chain
, AddressVec
, Mask
, AVL
});
310 // TODO: Use vvp_select
311 SDValue DataV
= CDAG
.getNode(VEISD::VVP_SELECT
, LegalDataVT
,
312 {NewLoadV
, PassThru
, Mask
, AVL
});
313 SDValue NewLoadChainV
= SDValue(NewLoadV
.getNode(), 1);
314 return CDAG
.getMergeValues({DataV
, NewLoadChainV
});
317 SDValue
VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op
,
318 VECustomDAG
&CDAG
) const {
319 LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
320 MVT DataVT
= getIdiomaticVectorType(Op
.getNode())->getSimpleVT();
322 // TODO: Recognize packable load,store.
323 if (isPackedVectorType(DataVT
))
324 return splitPackedLoadStore(Op
, CDAG
);
326 return legalizePackedAVL(Op
, CDAG
);
329 SDValue
VETargetLowering::legalizeInternalVectorOp(SDValue Op
,
330 SelectionDAG
&DAG
) const {
331 LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
332 VECustomDAG
CDAG(DAG
, Op
);
334 // Dispatch to specialized legalization functions.
335 switch (Op
->getOpcode()) {
336 case VEISD::VVP_LOAD
:
337 case VEISD::VVP_STORE
:
338 return legalizeInternalLoadStoreOp(Op
, CDAG
);
341 EVT IdiomVT
= Op
.getValueType();
342 if (isPackedVectorType(IdiomVT
) &&
343 !supportsPackedMode(Op
.getOpcode(), IdiomVT
))
344 return splitVectorOp(Op
, CDAG
);
346 // TODO: Implement odd/even splitting.
347 return legalizePackedAVL(Op
, CDAG
);
350 SDValue
VETargetLowering::splitVectorOp(SDValue Op
, VECustomDAG
&CDAG
) const {
351 MVT ResVT
= splitVectorType(Op
.getValue(0).getSimpleValueType());
353 auto AVLPos
= getAVLPos(Op
->getOpcode());
354 auto MaskPos
= getMaskPos(Op
->getOpcode());
356 SDValue PackedMask
= getNodeMask(Op
);
357 auto AVLPair
= getAnnotatedNodeAVL(Op
);
358 SDValue PackedAVL
= AVLPair
.first
;
359 assert(!AVLPair
.second
&& "Expecting non pack-legalized oepration");
364 SDValue UpperPartAVL
; // we will use this for packing things back together
365 for (PackElem Part
: {PackElem::Hi
, PackElem::Lo
}) {
366 // VP ops already have an explicit mask and AVL. When expanding from non-VP
367 // attach those additional inputs here.
368 auto SplitTM
= CDAG
.getTargetSplitMask(PackedMask
, PackedAVL
, Part
);
370 if (Part
== PackElem::Hi
)
371 UpperPartAVL
= SplitTM
.AVL
;
373 // Attach non-predicating value operands
374 SmallVector
<SDValue
, 4> OpVec
;
375 for (unsigned i
= 0; i
< Op
.getNumOperands(); ++i
) {
376 if (AVLPos
&& ((int)i
) == *AVLPos
)
378 if (MaskPos
&& ((int)i
) == *MaskPos
)
382 auto PackedOperand
= Op
.getOperand(i
);
383 auto UnpackedOpVT
= splitVectorType(PackedOperand
.getSimpleValueType());
385 CDAG
.getUnpack(UnpackedOpVT
, PackedOperand
, Part
, SplitTM
.AVL
);
386 OpVec
.push_back(PartV
);
389 // Add predicating args and generate part node.
390 OpVec
.push_back(SplitTM
.Mask
);
391 OpVec
.push_back(SplitTM
.AVL
);
392 // Emit legal VVP nodes.
394 CDAG
.getNode(Op
.getOpcode(), ResVT
, OpVec
, Op
->getFlags());
397 // Re-package vectors.
398 return CDAG
.getPack(Op
.getValueType(), PartOps
[(int)PackElem::Lo
],
399 PartOps
[(int)PackElem::Hi
], UpperPartAVL
);
402 SDValue
VETargetLowering::legalizePackedAVL(SDValue Op
,
403 VECustomDAG
&CDAG
) const {
404 LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
405 // Only required for VEC and VVP ops.
406 if (!isVVPOrVEC(Op
->getOpcode()))
409 // Operation already has a legal AVL.
410 auto AVL
= getNodeAVL(Op
);
414 // Half and round up EVL for 32bit element types.
415 SDValue LegalAVL
= AVL
;
416 MVT IdiomVT
= getIdiomaticVectorType(Op
.getNode())->getSimpleVT();
417 if (isPackedVectorType(IdiomVT
)) {
418 assert(maySafelyIgnoreMask(Op
) &&
419 "TODO Shift predication from EVL into Mask");
421 if (auto *ConstAVL
= dyn_cast
<ConstantSDNode
>(AVL
)) {
422 LegalAVL
= CDAG
.getConstant((ConstAVL
->getZExtValue() + 1) / 2, MVT::i32
);
424 auto ConstOne
= CDAG
.getConstant(1, MVT::i32
);
425 auto PlusOne
= CDAG
.getNode(ISD::ADD
, MVT::i32
, {AVL
, ConstOne
});
426 LegalAVL
= CDAG
.getNode(ISD::SRL
, MVT::i32
, {PlusOne
, ConstOne
});
430 SDValue AnnotatedLegalAVL
= CDAG
.annotateLegalAVL(LegalAVL
);
432 // Copy the operand list.
433 int NumOp
= Op
->getNumOperands();
434 auto AVLPos
= getAVLPos(Op
->getOpcode());
435 std::vector
<SDValue
> FixedOperands
;
436 for (int i
= 0; i
< NumOp
; ++i
) {
437 if (AVLPos
&& (i
== *AVLPos
)) {
438 FixedOperands
.push_back(AnnotatedLegalAVL
);
441 FixedOperands
.push_back(Op
->getOperand(i
));
444 // Clone the operation with fixed operands.
445 auto Flags
= Op
->getFlags();
447 CDAG
.getNode(Op
->getOpcode(), Op
->getVTList(), FixedOperands
, Flags
);