Add gfx950 mfma instructions to ROCDL dialect (#123361)
[llvm-project.git] / llvm / lib / Target / VE / VECustomDAG.cpp
blob2855a65f654c96e1649174b5d4c06e35b2bb21c3
1 //===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that VE uses to lower LLVM code into a
10 // selection DAG.
12 //===----------------------------------------------------------------------===//
14 #include "VECustomDAG.h"
16 #ifndef DEBUG_TYPE
17 #define DEBUG_TYPE "vecustomdag"
18 #endif
20 namespace llvm {
22 bool isPackedVectorType(EVT SomeVT) {
23 if (!SomeVT.isVector())
24 return false;
25 return SomeVT.getVectorNumElements() > StandardVectorWidth;
28 MVT splitVectorType(MVT VT) {
29 if (!VT.isVector())
30 return VT;
31 return MVT::getVectorVT(VT.getVectorElementType(), StandardVectorWidth);
34 MVT getLegalVectorType(Packing P, MVT ElemVT) {
35 return MVT::getVectorVT(ElemVT, P == Packing::Normal ? StandardVectorWidth
36 : PackedVectorWidth);
39 Packing getTypePacking(EVT VT) {
40 assert(VT.isVector());
41 return isPackedVectorType(VT) ? Packing::Dense : Packing::Normal;
44 bool isMaskType(EVT SomeVT) {
45 if (!SomeVT.isVector())
46 return false;
47 return SomeVT.getVectorElementType() == MVT::i1;
50 bool isMaskArithmetic(SDValue Op) {
51 switch (Op.getOpcode()) {
52 default:
53 return false;
54 case ISD::AND:
55 case ISD::XOR:
56 case ISD::OR:
57 return isMaskType(Op.getValueType());
61 /// \returns the VVP_* SDNode opcode corresponsing to \p OC.
62 std::optional<unsigned> getVVPOpcode(unsigned Opcode) {
63 switch (Opcode) {
64 case ISD::MLOAD:
65 return VEISD::VVP_LOAD;
66 case ISD::MSTORE:
67 return VEISD::VVP_STORE;
68 #define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \
69 case ISD::VPOPC: \
70 return VEISD::VVPNAME;
71 #define ADD_VVP_OP(VVPNAME, SDNAME) \
72 case VEISD::VVPNAME: \
73 case ISD::SDNAME: \
74 return VEISD::VVPNAME;
75 #include "VVPNodes.def"
76 // TODO: Map those in VVPNodes.def too
77 case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
78 return VEISD::VVP_LOAD;
79 case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
80 return VEISD::VVP_STORE;
82 return std::nullopt;
85 bool maySafelyIgnoreMask(SDValue Op) {
86 auto VVPOpc = getVVPOpcode(Op->getOpcode());
87 auto Opc = VVPOpc.value_or(Op->getOpcode());
89 switch (Opc) {
90 case VEISD::VVP_SDIV:
91 case VEISD::VVP_UDIV:
92 case VEISD::VVP_FDIV:
93 case VEISD::VVP_SELECT:
94 return false;
96 default:
97 return true;
101 bool supportsPackedMode(unsigned Opcode, EVT IdiomVT) {
102 bool IsPackedOp = isPackedVectorType(IdiomVT);
103 bool IsMaskOp = isMaskType(IdiomVT);
104 switch (Opcode) {
105 default:
106 return false;
108 case VEISD::VEC_BROADCAST:
109 return true;
110 #define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME:
111 #include "VVPNodes.def"
112 return IsPackedOp && !IsMaskOp;
116 bool isPackingSupportOpcode(unsigned Opc) {
117 switch (Opc) {
118 case VEISD::VEC_PACK:
119 case VEISD::VEC_UNPACK_LO:
120 case VEISD::VEC_UNPACK_HI:
121 return true;
123 return false;
126 bool isVVPOrVEC(unsigned Opcode) {
127 switch (Opcode) {
128 case VEISD::VEC_BROADCAST:
129 #define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME:
130 #include "VVPNodes.def"
131 return true;
133 return false;
136 bool isVVPUnaryOp(unsigned VVPOpcode) {
137 switch (VVPOpcode) {
138 #define ADD_UNARY_VVP_OP(VVPNAME, ...) \
139 case VEISD::VVPNAME: \
140 return true;
141 #include "VVPNodes.def"
143 return false;
146 bool isVVPBinaryOp(unsigned VVPOpcode) {
147 switch (VVPOpcode) {
148 #define ADD_BINARY_VVP_OP(VVPNAME, ...) \
149 case VEISD::VVPNAME: \
150 return true;
151 #include "VVPNodes.def"
153 return false;
156 bool isVVPReductionOp(unsigned Opcode) {
157 switch (Opcode) {
158 #define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME:
159 #include "VVPNodes.def"
160 return true;
162 return false;
165 // Return the AVL operand position for this VVP or VEC Op.
166 std::optional<int> getAVLPos(unsigned Opc) {
167 // This is only available for VP SDNodes
168 auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opc);
169 if (PosOpt)
170 return *PosOpt;
172 // VVP Opcodes.
173 if (isVVPBinaryOp(Opc))
174 return 3;
176 // VM Opcodes.
177 switch (Opc) {
178 case VEISD::VEC_BROADCAST:
179 return 1;
180 case VEISD::VVP_SELECT:
181 return 3;
182 case VEISD::VVP_LOAD:
183 return 4;
184 case VEISD::VVP_STORE:
185 return 5;
188 return std::nullopt;
191 std::optional<int> getMaskPos(unsigned Opc) {
192 // This is only available for VP SDNodes
193 auto PosOpt = ISD::getVPMaskIdx(Opc);
194 if (PosOpt)
195 return *PosOpt;
197 // VVP Opcodes.
198 if (isVVPBinaryOp(Opc))
199 return 2;
201 // Other opcodes.
202 switch (Opc) {
203 case ISD::MSTORE:
204 return 4;
205 case ISD::MLOAD:
206 return 3;
207 case VEISD::VVP_SELECT:
208 return 2;
211 return std::nullopt;
214 bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }
216 /// Node Properties {
218 SDValue getNodeChain(SDValue Op) {
219 if (MemSDNode *MemN = dyn_cast<MemSDNode>(Op.getNode()))
220 return MemN->getChain();
222 switch (Op->getOpcode()) {
223 case VEISD::VVP_LOAD:
224 case VEISD::VVP_STORE:
225 return Op->getOperand(0);
227 return SDValue();
230 SDValue getMemoryPtr(SDValue Op) {
231 if (auto *MemN = dyn_cast<MemSDNode>(Op.getNode()))
232 return MemN->getBasePtr();
234 switch (Op->getOpcode()) {
235 case VEISD::VVP_LOAD:
236 return Op->getOperand(1);
237 case VEISD::VVP_STORE:
238 return Op->getOperand(2);
240 return SDValue();
243 std::optional<EVT> getIdiomaticVectorType(SDNode *Op) {
244 unsigned OC = Op->getOpcode();
246 // For memory ops -> the transfered data type
247 if (auto MemN = dyn_cast<MemSDNode>(Op))
248 return MemN->getMemoryVT();
250 switch (OC) {
251 // Standard ISD.
252 case ISD::SELECT: // not aliased with VVP_SELECT
253 case ISD::CONCAT_VECTORS:
254 case ISD::EXTRACT_SUBVECTOR:
255 case ISD::VECTOR_SHUFFLE:
256 case ISD::BUILD_VECTOR:
257 case ISD::SCALAR_TO_VECTOR:
258 return Op->getValueType(0);
261 // Translate to VVP where possible.
262 unsigned OriginalOC = OC;
263 if (auto VVPOpc = getVVPOpcode(OC))
264 OC = *VVPOpc;
266 if (isVVPReductionOp(OC))
267 return Op->getOperand(hasReductionStartParam(OriginalOC) ? 1 : 0)
268 .getValueType();
270 switch (OC) {
271 default:
272 case VEISD::VVP_SETCC:
273 return Op->getOperand(0).getValueType();
275 case VEISD::VVP_SELECT:
276 #define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:
277 #include "VVPNodes.def"
278 return Op->getValueType(0);
280 case VEISD::VVP_LOAD:
281 return Op->getValueType(0);
283 case VEISD::VVP_STORE:
284 return Op->getOperand(1)->getValueType(0);
286 // VEC
287 case VEISD::VEC_BROADCAST:
288 return Op->getValueType(0);
292 SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) {
293 switch (Op->getOpcode()) {
294 case VEISD::VVP_STORE:
295 return Op->getOperand(3);
296 case VEISD::VVP_LOAD:
297 return Op->getOperand(2);
300 if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode()))
301 return StoreN->getStride();
302 if (auto *StoreN = dyn_cast<VPStridedLoadSDNode>(Op.getNode()))
303 return StoreN->getStride();
305 if (isa<MemSDNode>(Op.getNode())) {
306 // Regular MLOAD/MSTORE/LOAD/STORE
307 // No stride argument -> use the contiguous element size as stride.
308 uint64_t ElemStride = getIdiomaticVectorType(Op.getNode())
309 ->getVectorElementType()
310 .getStoreSize();
311 return CDAG.getConstant(ElemStride, MVT::i64);
313 return SDValue();
316 SDValue getGatherScatterIndex(SDValue Op) {
317 if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode()))
318 return N->getIndex();
319 if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode()))
320 return N->getIndex();
321 return SDValue();
324 SDValue getGatherScatterScale(SDValue Op) {
325 if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode()))
326 return N->getScale();
327 if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode()))
328 return N->getScale();
329 return SDValue();
332 SDValue getStoredValue(SDValue Op) {
333 switch (Op->getOpcode()) {
334 case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
335 case VEISD::VVP_STORE:
336 return Op->getOperand(1);
338 if (auto *StoreN = dyn_cast<StoreSDNode>(Op.getNode()))
339 return StoreN->getValue();
340 if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Op.getNode()))
341 return StoreN->getValue();
342 if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode()))
343 return StoreN->getValue();
344 if (auto *StoreN = dyn_cast<VPStoreSDNode>(Op.getNode()))
345 return StoreN->getValue();
346 if (auto *StoreN = dyn_cast<MaskedScatterSDNode>(Op.getNode()))
347 return StoreN->getValue();
348 if (auto *StoreN = dyn_cast<VPScatterSDNode>(Op.getNode()))
349 return StoreN->getValue();
350 return SDValue();
353 SDValue getNodePassthru(SDValue Op) {
354 if (auto *N = dyn_cast<MaskedLoadSDNode>(Op.getNode()))
355 return N->getPassThru();
356 if (auto *N = dyn_cast<MaskedGatherSDNode>(Op.getNode()))
357 return N->getPassThru();
359 return SDValue();
362 bool hasReductionStartParam(unsigned OPC) {
363 // TODO: Ordered reduction opcodes.
364 if (ISD::isVPReduction(OPC))
365 return true;
366 return false;
369 unsigned getScalarReductionOpcode(unsigned VVPOC, bool IsMask) {
370 assert(!IsMask && "Mask reduction isel");
372 switch (VVPOC) {
373 #define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD) \
374 case VEISD::VVP_RED_ISD: \
375 return ISD::REDUCE_ISD;
376 #include "VVPNodes.def"
377 default:
378 break;
380 llvm_unreachable("Cannot not scalarize this reduction Opcode!");
383 /// } Node Properties
385 SDValue getNodeAVL(SDValue Op) {
386 auto PosOpt = getAVLPos(Op->getOpcode());
387 return PosOpt ? Op->getOperand(*PosOpt) : SDValue();
390 SDValue getNodeMask(SDValue Op) {
391 auto PosOpt = getMaskPos(Op->getOpcode());
392 return PosOpt ? Op->getOperand(*PosOpt) : SDValue();
395 std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) {
396 SDValue AVL = getNodeAVL(Op);
397 if (!AVL)
398 return {SDValue(), true};
399 if (isLegalAVL(AVL))
400 return {AVL->getOperand(0), true};
401 return {AVL, false};
404 SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget,
405 bool IsOpaque) const {
406 return DAG.getConstant(Val, DL, VT, IsTarget, IsOpaque);
409 SDValue VECustomDAG::getConstantMask(Packing Packing, bool AllTrue) const {
410 auto MaskVT = getLegalVectorType(Packing, MVT::i1);
412 // VEISelDAGtoDAG will replace this pattern with the constant-true VM.
413 auto TrueVal = DAG.getAllOnesConstant(DL, MVT::i32);
414 auto AVL = getConstant(MaskVT.getVectorNumElements(), MVT::i32);
415 auto Res = getNode(VEISD::VEC_BROADCAST, MaskVT, {TrueVal, AVL});
416 if (AllTrue)
417 return Res;
419 return DAG.getNOT(DL, Res, Res.getValueType());
422 SDValue VECustomDAG::getMaskBroadcast(EVT ResultVT, SDValue Scalar,
423 SDValue AVL) const {
424 // Constant mask splat.
425 if (auto BcConst = dyn_cast<ConstantSDNode>(Scalar))
426 return getConstantMask(getTypePacking(ResultVT),
427 BcConst->getSExtValue() != 0);
429 // Expand the broadcast to a vector comparison.
430 auto ScalarBoolVT = Scalar.getSimpleValueType();
431 assert(ScalarBoolVT == MVT::i32);
433 // Cast to i32 ty.
434 SDValue CmpElem = DAG.getSExtOrTrunc(Scalar, DL, MVT::i32);
435 unsigned ElemCount = ResultVT.getVectorNumElements();
436 MVT CmpVecTy = MVT::getVectorVT(ScalarBoolVT, ElemCount);
438 // Broadcast to vector.
439 SDValue BCVec =
440 DAG.getNode(VEISD::VEC_BROADCAST, DL, CmpVecTy, {CmpElem, AVL});
441 SDValue ZeroVec =
442 getBroadcast(CmpVecTy, {DAG.getConstant(0, DL, ScalarBoolVT)}, AVL);
444 MVT BoolVecTy = MVT::getVectorVT(MVT::i1, ElemCount);
446 // Broadcast(Data) != Broadcast(0)
447 // TODO: Use a VVP operation for this.
448 return DAG.getSetCC(DL, BoolVecTy, BCVec, ZeroVec, ISD::CondCode::SETNE);
451 SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar,
452 SDValue AVL) const {
453 assert(ResultVT.isVector());
454 auto ScaVT = Scalar.getValueType();
456 if (isMaskType(ResultVT))
457 return getMaskBroadcast(ResultVT, Scalar, AVL);
459 if (isPackedVectorType(ResultVT)) {
460 // v512x packed mode broadcast
461 // Replicate the scalar reg (f32 or i32) onto the opposing half of the full
462 // scalar register. If it's an I64 type, assume that this has already
463 // happened.
464 if (ScaVT == MVT::f32) {
465 Scalar = getNode(VEISD::REPL_F32, MVT::i64, Scalar);
466 } else if (ScaVT == MVT::i32) {
467 Scalar = getNode(VEISD::REPL_I32, MVT::i64, Scalar);
471 return getNode(VEISD::VEC_BROADCAST, ResultVT, {Scalar, AVL});
474 SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const {
475 if (isLegalAVL(AVL))
476 return AVL;
477 return getNode(VEISD::LEGALAVL, AVL.getValueType(), AVL);
480 SDValue VECustomDAG::getUnpack(EVT DestVT, SDValue Vec, PackElem Part,
481 SDValue AVL) const {
482 assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
484 // TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands.
485 unsigned OC =
486 (Part == PackElem::Lo) ? VEISD::VEC_UNPACK_LO : VEISD::VEC_UNPACK_HI;
487 return DAG.getNode(OC, DL, DestVT, Vec, AVL);
490 SDValue VECustomDAG::getPack(EVT DestVT, SDValue LoVec, SDValue HiVec,
491 SDValue AVL) const {
492 assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
494 // TODO: Peek through VEC_UNPACK_LO|HI operands.
495 return DAG.getNode(VEISD::VEC_PACK, DL, DestVT, LoVec, HiVec, AVL);
498 VETargetMasks VECustomDAG::getTargetSplitMask(SDValue RawMask, SDValue RawAVL,
499 PackElem Part) const {
500 // Adjust AVL for this part
501 SDValue NewAVL;
502 SDValue OneV = getConstant(1, MVT::i32);
503 if (Part == PackElem::Hi)
504 NewAVL = getNode(ISD::ADD, MVT::i32, {RawAVL, OneV});
505 else
506 NewAVL = RawAVL;
507 NewAVL = getNode(ISD::SRL, MVT::i32, {NewAVL, OneV});
509 NewAVL = annotateLegalAVL(NewAVL);
511 // Legalize Mask (unpack or all-true)
512 SDValue NewMask;
513 if (!RawMask)
514 NewMask = getConstantMask(Packing::Normal, true);
515 else
516 NewMask = getUnpack(MVT::v256i1, RawMask, Part, NewAVL);
518 return VETargetMasks(NewMask, NewAVL);
521 SDValue VECustomDAG::getSplitPtrOffset(SDValue Ptr, SDValue ByteStride,
522 PackElem Part) const {
523 // High starts at base ptr but has more significant bits in the 64bit vector
524 // element.
525 if (Part == PackElem::Hi)
526 return Ptr;
527 return getNode(ISD::ADD, MVT::i64, {Ptr, ByteStride});
530 SDValue VECustomDAG::getSplitPtrStride(SDValue PackStride) const {
531 if (auto ConstBytes = dyn_cast<ConstantSDNode>(PackStride))
532 return getConstant(2 * ConstBytes->getSExtValue(), MVT::i64);
533 return getNode(ISD::SHL, MVT::i64, {PackStride, getConstant(1, MVT::i32)});
536 SDValue VECustomDAG::getGatherScatterAddress(SDValue BasePtr, SDValue Scale,
537 SDValue Index, SDValue Mask,
538 SDValue AVL) const {
539 EVT IndexVT = Index.getValueType();
541 // Apply scale.
542 SDValue ScaledIndex;
543 if (!Scale || isOneConstant(Scale))
544 ScaledIndex = Index;
545 else {
546 SDValue ScaleBroadcast = getBroadcast(IndexVT, Scale, AVL);
547 ScaledIndex =
548 getNode(VEISD::VVP_MUL, IndexVT, {Index, ScaleBroadcast, Mask, AVL});
551 // Add basePtr.
552 if (isNullConstant(BasePtr))
553 return ScaledIndex;
555 // re-constitute pointer vector (basePtr + index * scale)
556 SDValue BaseBroadcast = getBroadcast(IndexVT, BasePtr, AVL);
557 auto ResPtr =
558 getNode(VEISD::VVP_ADD, IndexVT, {BaseBroadcast, ScaledIndex, Mask, AVL});
559 return ResPtr;
562 SDValue VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode, EVT ResVT,
563 SDValue StartV, SDValue VectorV,
564 SDValue Mask, SDValue AVL,
565 SDNodeFlags Flags) const {
567 // Optionally attach the start param with a scalar op (where it is
568 // unsupported).
569 bool scalarizeStartParam = StartV && !hasReductionStartParam(VVPOpcode);
570 bool IsMaskReduction = isMaskType(VectorV.getValueType());
571 assert(!IsMaskReduction && "TODO Implement");
572 auto AttachStartValue = [&](SDValue ReductionResV) {
573 if (!scalarizeStartParam)
574 return ReductionResV;
575 auto ScalarOC = getScalarReductionOpcode(VVPOpcode, IsMaskReduction);
576 return getNode(ScalarOC, ResVT, {StartV, ReductionResV});
579 // Fixup: Always Use sequential 'fmul' reduction.
580 if (!scalarizeStartParam && StartV) {
581 assert(hasReductionStartParam(VVPOpcode));
582 return AttachStartValue(
583 getNode(VVPOpcode, ResVT, {StartV, VectorV, Mask, AVL}, Flags));
584 } else
585 return AttachStartValue(
586 getNode(VVPOpcode, ResVT, {VectorV, Mask, AVL}, Flags));
589 } // namespace llvm