Add gfx950 mfma instructions to ROCDL dialect (#123361)
[llvm-project.git] / llvm / lib / Target / VE / VVPISelLowering.cpp
blobf1e2d7f717016bd5a524a4664125b9bd05f59fee
1 //===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering and legalization of vector instructions to
10 // VVP_*layer SDNodes.
12 //===----------------------------------------------------------------------===//
14 #include "VECustomDAG.h"
15 #include "VEISelLowering.h"
17 using namespace llvm;
19 #define DEBUG_TYPE "ve-lower"
21 SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
22 SelectionDAG &DAG) const {
23 VECustomDAG CDAG(DAG, Op);
24 SDValue AVL =
25 CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
26 SDValue A = Op->getOperand(0);
27 SDValue B = Op->getOperand(1);
28 SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
29 SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
30 SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
31 SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
32 unsigned Opc = Op.getOpcode();
33 auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
34 auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
35 return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
38 SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
39 // Can we represent this as a VVP node.
40 const unsigned Opcode = Op->getOpcode();
41 auto VVPOpcodeOpt = getVVPOpcode(Opcode);
42 if (!VVPOpcodeOpt)
43 return SDValue();
44 unsigned VVPOpcode = *VVPOpcodeOpt;
45 const bool FromVP = ISD::isVPOpcode(Opcode);
47 // The representative and legalized vector type of this operation.
48 VECustomDAG CDAG(DAG, Op);
49 // Dispatch to complex lowering functions.
50 switch (VVPOpcode) {
51 case VEISD::VVP_LOAD:
52 case VEISD::VVP_STORE:
53 return lowerVVP_LOAD_STORE(Op, CDAG);
54 case VEISD::VVP_GATHER:
55 case VEISD::VVP_SCATTER:
56 return lowerVVP_GATHER_SCATTER(Op, CDAG);
59 EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
60 EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
61 auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
63 SDValue AVL;
64 SDValue Mask;
66 if (FromVP) {
67 // All upstream VP SDNodes always have a mask and avl.
68 auto MaskIdx = ISD::getVPMaskIdx(Opcode);
69 auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
70 if (MaskIdx)
71 Mask = Op->getOperand(*MaskIdx);
72 if (AVLIdx)
73 AVL = Op->getOperand(*AVLIdx);
76 // Materialize default mask and avl.
77 if (!AVL)
78 AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
79 if (!Mask)
80 Mask = CDAG.getConstantMask(Packing, true);
82 assert(LegalVecVT.isSimple());
83 if (isVVPUnaryOp(VVPOpcode))
84 return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
85 if (isVVPBinaryOp(VVPOpcode))
86 return CDAG.getNode(VVPOpcode, LegalVecVT,
87 {Op->getOperand(0), Op->getOperand(1), Mask, AVL});
88 if (isVVPReductionOp(VVPOpcode)) {
89 auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
90 SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
91 SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
92 return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
93 VectorV, Mask, AVL, Op->getFlags());
96 switch (VVPOpcode) {
97 default:
98 llvm_unreachable("lowerToVVP called for unexpected SDNode.");
99 case VEISD::VVP_FFMA: {
100 // VE has a swizzled operand order in FMA (compared to LLVM IR and
101 // SDNodes).
102 auto X = Op->getOperand(2);
103 auto Y = Op->getOperand(0);
104 auto Z = Op->getOperand(1);
105 return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
107 case VEISD::VVP_SELECT: {
108 auto Mask = Op->getOperand(0);
109 auto OnTrue = Op->getOperand(1);
110 auto OnFalse = Op->getOperand(2);
111 return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
113 case VEISD::VVP_SETCC: {
114 EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
115 auto LHS = Op->getOperand(0);
116 auto RHS = Op->getOperand(1);
117 auto Pred = Op->getOperand(2);
118 return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
123 SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
124 VECustomDAG &CDAG) const {
125 auto VVPOpc = *getVVPOpcode(Op->getOpcode());
126 const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
128 // Shares.
129 SDValue BasePtr = getMemoryPtr(Op);
130 SDValue Mask = getNodeMask(Op);
131 SDValue Chain = getNodeChain(Op);
132 SDValue AVL = getNodeAVL(Op);
133 // Store specific.
134 SDValue Data = getStoredValue(Op);
135 // Load specific.
136 SDValue PassThru = getNodePassthru(Op);
138 SDValue StrideV = getLoadStoreStride(Op, CDAG);
140 auto DataVT = *getIdiomaticVectorType(Op.getNode());
141 auto Packing = getTypePacking(DataVT);
143 // TODO: Infer lower AVL from mask.
144 if (!AVL)
145 AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
147 // Default to the all-true mask.
148 if (!Mask)
149 Mask = CDAG.getConstantMask(Packing, true);
151 if (IsLoad) {
152 MVT LegalDataVT = getLegalVectorType(
153 Packing, DataVT.getVectorElementType().getSimpleVT());
155 auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
156 {Chain, BasePtr, StrideV, Mask, AVL});
158 if (!PassThru || PassThru->isUndef())
159 return NewLoadV;
161 // Convert passthru to an explicit select node.
162 SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
163 {NewLoadV, PassThru, Mask, AVL});
164 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
166 // Merge them back into one node.
167 return CDAG.getMergeValues({DataV, NewLoadChainV});
170 // VVP_STORE
171 assert(VVPOpc == VEISD::VVP_STORE);
172 if (getTypeAction(*CDAG.getDAG()->getContext(), Data.getValueType()) !=
173 TargetLowering::TypeLegal)
174 // Doesn't lower store instruction if an operand is not lowered yet.
175 // If it isn't, return SDValue(). In this way, LLVM will try to lower
176 // store instruction again after lowering all operands.
177 return SDValue();
178 return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
179 {Chain, Data, BasePtr, StrideV, Mask, AVL});
182 SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
183 VECustomDAG &CDAG) const {
184 auto VVPOC = *getVVPOpcode(Op.getOpcode());
185 assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
187 MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
188 assert(getTypePacking(DataVT) == Packing::Dense &&
189 "Can only split packed load/store");
190 MVT SplitDataVT = splitVectorType(DataVT);
192 assert(!getNodePassthru(Op) &&
193 "Should have been folded in lowering to VVP layer");
195 // Analyze the operation
196 SDValue PackedMask = getNodeMask(Op);
197 SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
198 SDValue PackPtr = getMemoryPtr(Op);
199 SDValue PackData = getStoredValue(Op);
200 SDValue PackStride = getLoadStoreStride(Op, CDAG);
202 unsigned ChainResIdx = PackData ? 0 : 1;
204 SDValue PartOps[2];
206 SDValue UpperPartAVL; // we will use this for packing things back together
207 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
208 // VP ops already have an explicit mask and AVL. When expanding from non-VP
209 // attach those additional inputs here.
210 auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
212 // Keep track of the (higher) lvl.
213 if (Part == PackElem::Hi)
214 UpperPartAVL = SplitTM.AVL;
216 // Attach non-predicating value operands
217 SmallVector<SDValue, 4> OpVec;
219 // Chain
220 OpVec.push_back(getNodeChain(Op));
222 // Data
223 if (PackData) {
224 SDValue PartData =
225 CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
226 OpVec.push_back(PartData);
229 // Ptr & Stride
230 // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
231 // Stride info
232 // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
233 OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
234 OpVec.push_back(CDAG.getSplitPtrStride(PackStride));
236 // Add predicating args and generate part node
237 OpVec.push_back(SplitTM.Mask);
238 OpVec.push_back(SplitTM.AVL);
240 if (PackData) {
241 // Store
242 PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
243 } else {
244 // Load
245 PartOps[(int)Part] =
246 CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
250 // Merge the chains
251 SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
252 SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
253 SDValue FusedChains =
254 CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});
256 // Chain only [store]
257 if (PackData)
258 return FusedChains;
260 // Re-pack into full packed vector result
261 MVT PackedVT =
262 getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
263 SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
264 PartOps[(int)PackElem::Hi], UpperPartAVL);
266 return CDAG.getMergeValues({PackedVals, FusedChains});
269 SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
270 VECustomDAG &CDAG) const {
271 EVT DataVT = *getIdiomaticVectorType(Op.getNode());
272 auto Packing = getTypePacking(DataVT);
273 MVT LegalDataVT =
274 getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());
276 SDValue AVL = getAnnotatedNodeAVL(Op).first;
277 SDValue Index = getGatherScatterIndex(Op);
278 SDValue BasePtr = getMemoryPtr(Op);
279 SDValue Mask = getNodeMask(Op);
280 SDValue Chain = getNodeChain(Op);
281 SDValue Scale = getGatherScatterScale(Op);
282 SDValue PassThru = getNodePassthru(Op);
283 SDValue StoredValue = getStoredValue(Op);
284 if (PassThru && PassThru->isUndef())
285 PassThru = SDValue();
287 bool IsScatter = (bool)StoredValue;
289 // TODO: Infer lower AVL from mask.
290 if (!AVL)
291 AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
293 // Default to the all-true mask.
294 if (!Mask)
295 Mask = CDAG.getConstantMask(Packing, true);
297 SDValue AddressVec =
298 CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
299 if (IsScatter)
300 return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
301 {Chain, StoredValue, AddressVec, Mask, AVL});
303 // Gather.
304 SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
305 {Chain, AddressVec, Mask, AVL});
307 if (!PassThru)
308 return NewLoadV;
310 // TODO: Use vvp_select
311 SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
312 {NewLoadV, PassThru, Mask, AVL});
313 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
314 return CDAG.getMergeValues({DataV, NewLoadChainV});
317 SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
318 VECustomDAG &CDAG) const {
319 LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
320 MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
322 // TODO: Recognize packable load,store.
323 if (isPackedVectorType(DataVT))
324 return splitPackedLoadStore(Op, CDAG);
326 return legalizePackedAVL(Op, CDAG);
329 SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
330 SelectionDAG &DAG) const {
331 LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
332 VECustomDAG CDAG(DAG, Op);
334 // Dispatch to specialized legalization functions.
335 switch (Op->getOpcode()) {
336 case VEISD::VVP_LOAD:
337 case VEISD::VVP_STORE:
338 return legalizeInternalLoadStoreOp(Op, CDAG);
341 EVT IdiomVT = Op.getValueType();
342 if (isPackedVectorType(IdiomVT) &&
343 !supportsPackedMode(Op.getOpcode(), IdiomVT))
344 return splitVectorOp(Op, CDAG);
346 // TODO: Implement odd/even splitting.
347 return legalizePackedAVL(Op, CDAG);
350 SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
351 MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());
353 auto AVLPos = getAVLPos(Op->getOpcode());
354 auto MaskPos = getMaskPos(Op->getOpcode());
356 SDValue PackedMask = getNodeMask(Op);
357 auto AVLPair = getAnnotatedNodeAVL(Op);
358 SDValue PackedAVL = AVLPair.first;
359 assert(!AVLPair.second && "Expecting non pack-legalized oepration");
361 // request the parts
362 SDValue PartOps[2];
364 SDValue UpperPartAVL; // we will use this for packing things back together
365 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
366 // VP ops already have an explicit mask and AVL. When expanding from non-VP
367 // attach those additional inputs here.
368 auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
370 if (Part == PackElem::Hi)
371 UpperPartAVL = SplitTM.AVL;
373 // Attach non-predicating value operands
374 SmallVector<SDValue, 4> OpVec;
375 for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
376 if (AVLPos && ((int)i) == *AVLPos)
377 continue;
378 if (MaskPos && ((int)i) == *MaskPos)
379 continue;
381 // Value operand
382 auto PackedOperand = Op.getOperand(i);
383 auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
384 SDValue PartV =
385 CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
386 OpVec.push_back(PartV);
389 // Add predicating args and generate part node.
390 OpVec.push_back(SplitTM.Mask);
391 OpVec.push_back(SplitTM.AVL);
392 // Emit legal VVP nodes.
393 PartOps[(int)Part] =
394 CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
397 // Re-package vectors.
398 return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
399 PartOps[(int)PackElem::Hi], UpperPartAVL);
402 SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
403 VECustomDAG &CDAG) const {
404 LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
405 // Only required for VEC and VVP ops.
406 if (!isVVPOrVEC(Op->getOpcode()))
407 return Op;
409 // Operation already has a legal AVL.
410 auto AVL = getNodeAVL(Op);
411 if (isLegalAVL(AVL))
412 return Op;
414 // Half and round up EVL for 32bit element types.
415 SDValue LegalAVL = AVL;
416 MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
417 if (isPackedVectorType(IdiomVT)) {
418 assert(maySafelyIgnoreMask(Op) &&
419 "TODO Shift predication from EVL into Mask");
421 if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
422 LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
423 } else {
424 auto ConstOne = CDAG.getConstant(1, MVT::i32);
425 auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
426 LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
430 SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);
432 // Copy the operand list.
433 int NumOp = Op->getNumOperands();
434 auto AVLPos = getAVLPos(Op->getOpcode());
435 std::vector<SDValue> FixedOperands;
436 for (int i = 0; i < NumOp; ++i) {
437 if (AVLPos && (i == *AVLPos)) {
438 FixedOperands.push_back(AnnotatedLegalAVL);
439 continue;
441 FixedOperands.push_back(Op->getOperand(i));
444 // Clone the operation with fixed operands.
445 auto Flags = Op->getFlags();
446 SDValue NewN =
447 CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
448 return NewN;