1 //===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the interfaces that VE uses to lower LLVM code into a
12 //===----------------------------------------------------------------------===//
14 #include "VECustomDAG.h"
17 #define DEBUG_TYPE "vecustomdag"
22 bool isPackedVectorType(EVT SomeVT
) {
23 if (!SomeVT
.isVector())
25 return SomeVT
.getVectorNumElements() > StandardVectorWidth
;
28 MVT
splitVectorType(MVT VT
) {
31 return MVT::getVectorVT(VT
.getVectorElementType(), StandardVectorWidth
);
34 MVT
getLegalVectorType(Packing P
, MVT ElemVT
) {
35 return MVT::getVectorVT(ElemVT
, P
== Packing::Normal
? StandardVectorWidth
39 Packing
getTypePacking(EVT VT
) {
40 assert(VT
.isVector());
41 return isPackedVectorType(VT
) ? Packing::Dense
: Packing::Normal
;
44 bool isMaskType(EVT SomeVT
) {
45 if (!SomeVT
.isVector())
47 return SomeVT
.getVectorElementType() == MVT::i1
;
50 bool isMaskArithmetic(SDValue Op
) {
51 switch (Op
.getOpcode()) {
57 return isMaskType(Op
.getValueType());
61 /// \returns the VVP_* SDNode opcode corresponsing to \p OC.
62 std::optional
<unsigned> getVVPOpcode(unsigned Opcode
) {
65 return VEISD::VVP_LOAD
;
67 return VEISD::VVP_STORE
;
68 #define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \
70 return VEISD::VVPNAME;
71 #define ADD_VVP_OP(VVPNAME, SDNAME) \
72 case VEISD::VVPNAME: \
74 return VEISD::VVPNAME;
75 #include "VVPNodes.def"
76 // TODO: Map those in VVPNodes.def too
77 case ISD::EXPERIMENTAL_VP_STRIDED_LOAD
:
78 return VEISD::VVP_LOAD
;
79 case ISD::EXPERIMENTAL_VP_STRIDED_STORE
:
80 return VEISD::VVP_STORE
;
85 bool maySafelyIgnoreMask(SDValue Op
) {
86 auto VVPOpc
= getVVPOpcode(Op
->getOpcode());
87 auto Opc
= VVPOpc
.value_or(Op
->getOpcode());
93 case VEISD::VVP_SELECT
:
101 bool supportsPackedMode(unsigned Opcode
, EVT IdiomVT
) {
102 bool IsPackedOp
= isPackedVectorType(IdiomVT
);
103 bool IsMaskOp
= isMaskType(IdiomVT
);
108 case VEISD::VEC_BROADCAST
:
110 #define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME:
111 #include "VVPNodes.def"
112 return IsPackedOp
&& !IsMaskOp
;
116 bool isPackingSupportOpcode(unsigned Opc
) {
118 case VEISD::VEC_PACK
:
119 case VEISD::VEC_UNPACK_LO
:
120 case VEISD::VEC_UNPACK_HI
:
126 bool isVVPOrVEC(unsigned Opcode
) {
128 case VEISD::VEC_BROADCAST
:
129 #define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME:
130 #include "VVPNodes.def"
136 bool isVVPUnaryOp(unsigned VVPOpcode
) {
138 #define ADD_UNARY_VVP_OP(VVPNAME, ...) \
139 case VEISD::VVPNAME: \
141 #include "VVPNodes.def"
146 bool isVVPBinaryOp(unsigned VVPOpcode
) {
148 #define ADD_BINARY_VVP_OP(VVPNAME, ...) \
149 case VEISD::VVPNAME: \
151 #include "VVPNodes.def"
156 bool isVVPReductionOp(unsigned Opcode
) {
158 #define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME:
159 #include "VVPNodes.def"
165 // Return the AVL operand position for this VVP or VEC Op.
166 std::optional
<int> getAVLPos(unsigned Opc
) {
167 // This is only available for VP SDNodes
168 auto PosOpt
= ISD::getVPExplicitVectorLengthIdx(Opc
);
173 if (isVVPBinaryOp(Opc
))
178 case VEISD::VEC_BROADCAST
:
180 case VEISD::VVP_SELECT
:
182 case VEISD::VVP_LOAD
:
184 case VEISD::VVP_STORE
:
191 std::optional
<int> getMaskPos(unsigned Opc
) {
192 // This is only available for VP SDNodes
193 auto PosOpt
= ISD::getVPMaskIdx(Opc
);
198 if (isVVPBinaryOp(Opc
))
207 case VEISD::VVP_SELECT
:
214 bool isLegalAVL(SDValue AVL
) { return AVL
->getOpcode() == VEISD::LEGALAVL
; }
216 /// Node Properties {
218 SDValue
getNodeChain(SDValue Op
) {
219 if (MemSDNode
*MemN
= dyn_cast
<MemSDNode
>(Op
.getNode()))
220 return MemN
->getChain();
222 switch (Op
->getOpcode()) {
223 case VEISD::VVP_LOAD
:
224 case VEISD::VVP_STORE
:
225 return Op
->getOperand(0);
230 SDValue
getMemoryPtr(SDValue Op
) {
231 if (auto *MemN
= dyn_cast
<MemSDNode
>(Op
.getNode()))
232 return MemN
->getBasePtr();
234 switch (Op
->getOpcode()) {
235 case VEISD::VVP_LOAD
:
236 return Op
->getOperand(1);
237 case VEISD::VVP_STORE
:
238 return Op
->getOperand(2);
243 std::optional
<EVT
> getIdiomaticVectorType(SDNode
*Op
) {
244 unsigned OC
= Op
->getOpcode();
246 // For memory ops -> the transfered data type
247 if (auto MemN
= dyn_cast
<MemSDNode
>(Op
))
248 return MemN
->getMemoryVT();
252 case ISD::SELECT
: // not aliased with VVP_SELECT
253 case ISD::CONCAT_VECTORS
:
254 case ISD::EXTRACT_SUBVECTOR
:
255 case ISD::VECTOR_SHUFFLE
:
256 case ISD::BUILD_VECTOR
:
257 case ISD::SCALAR_TO_VECTOR
:
258 return Op
->getValueType(0);
261 // Translate to VVP where possible.
262 unsigned OriginalOC
= OC
;
263 if (auto VVPOpc
= getVVPOpcode(OC
))
266 if (isVVPReductionOp(OC
))
267 return Op
->getOperand(hasReductionStartParam(OriginalOC
) ? 1 : 0)
272 case VEISD::VVP_SETCC
:
273 return Op
->getOperand(0).getValueType();
275 case VEISD::VVP_SELECT
:
276 #define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:
277 #include "VVPNodes.def"
278 return Op
->getValueType(0);
280 case VEISD::VVP_LOAD
:
281 return Op
->getValueType(0);
283 case VEISD::VVP_STORE
:
284 return Op
->getOperand(1)->getValueType(0);
287 case VEISD::VEC_BROADCAST
:
288 return Op
->getValueType(0);
292 SDValue
getLoadStoreStride(SDValue Op
, VECustomDAG
&CDAG
) {
293 switch (Op
->getOpcode()) {
294 case VEISD::VVP_STORE
:
295 return Op
->getOperand(3);
296 case VEISD::VVP_LOAD
:
297 return Op
->getOperand(2);
300 if (auto *StoreN
= dyn_cast
<VPStridedStoreSDNode
>(Op
.getNode()))
301 return StoreN
->getStride();
302 if (auto *StoreN
= dyn_cast
<VPStridedLoadSDNode
>(Op
.getNode()))
303 return StoreN
->getStride();
305 if (isa
<MemSDNode
>(Op
.getNode())) {
306 // Regular MLOAD/MSTORE/LOAD/STORE
307 // No stride argument -> use the contiguous element size as stride.
308 uint64_t ElemStride
= getIdiomaticVectorType(Op
.getNode())
309 ->getVectorElementType()
311 return CDAG
.getConstant(ElemStride
, MVT::i64
);
316 SDValue
getGatherScatterIndex(SDValue Op
) {
317 if (auto *N
= dyn_cast
<MaskedGatherScatterSDNode
>(Op
.getNode()))
318 return N
->getIndex();
319 if (auto *N
= dyn_cast
<VPGatherScatterSDNode
>(Op
.getNode()))
320 return N
->getIndex();
324 SDValue
getGatherScatterScale(SDValue Op
) {
325 if (auto *N
= dyn_cast
<MaskedGatherScatterSDNode
>(Op
.getNode()))
326 return N
->getScale();
327 if (auto *N
= dyn_cast
<VPGatherScatterSDNode
>(Op
.getNode()))
328 return N
->getScale();
332 SDValue
getStoredValue(SDValue Op
) {
333 switch (Op
->getOpcode()) {
334 case ISD::EXPERIMENTAL_VP_STRIDED_STORE
:
335 case VEISD::VVP_STORE
:
336 return Op
->getOperand(1);
338 if (auto *StoreN
= dyn_cast
<StoreSDNode
>(Op
.getNode()))
339 return StoreN
->getValue();
340 if (auto *StoreN
= dyn_cast
<MaskedStoreSDNode
>(Op
.getNode()))
341 return StoreN
->getValue();
342 if (auto *StoreN
= dyn_cast
<VPStridedStoreSDNode
>(Op
.getNode()))
343 return StoreN
->getValue();
344 if (auto *StoreN
= dyn_cast
<VPStoreSDNode
>(Op
.getNode()))
345 return StoreN
->getValue();
346 if (auto *StoreN
= dyn_cast
<MaskedScatterSDNode
>(Op
.getNode()))
347 return StoreN
->getValue();
348 if (auto *StoreN
= dyn_cast
<VPScatterSDNode
>(Op
.getNode()))
349 return StoreN
->getValue();
353 SDValue
getNodePassthru(SDValue Op
) {
354 if (auto *N
= dyn_cast
<MaskedLoadSDNode
>(Op
.getNode()))
355 return N
->getPassThru();
356 if (auto *N
= dyn_cast
<MaskedGatherSDNode
>(Op
.getNode()))
357 return N
->getPassThru();
362 bool hasReductionStartParam(unsigned OPC
) {
363 // TODO: Ordered reduction opcodes.
364 if (ISD::isVPReduction(OPC
))
369 unsigned getScalarReductionOpcode(unsigned VVPOC
, bool IsMask
) {
370 assert(!IsMask
&& "Mask reduction isel");
373 #define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD) \
374 case VEISD::VVP_RED_ISD: \
375 return ISD::REDUCE_ISD;
376 #include "VVPNodes.def"
380 llvm_unreachable("Cannot not scalarize this reduction Opcode!");
383 /// } Node Properties
385 SDValue
getNodeAVL(SDValue Op
) {
386 auto PosOpt
= getAVLPos(Op
->getOpcode());
387 return PosOpt
? Op
->getOperand(*PosOpt
) : SDValue();
390 SDValue
getNodeMask(SDValue Op
) {
391 auto PosOpt
= getMaskPos(Op
->getOpcode());
392 return PosOpt
? Op
->getOperand(*PosOpt
) : SDValue();
395 std::pair
<SDValue
, bool> getAnnotatedNodeAVL(SDValue Op
) {
396 SDValue AVL
= getNodeAVL(Op
);
398 return {SDValue(), true};
400 return {AVL
->getOperand(0), true};
404 SDValue
VECustomDAG::getConstant(uint64_t Val
, EVT VT
, bool IsTarget
,
405 bool IsOpaque
) const {
406 return DAG
.getConstant(Val
, DL
, VT
, IsTarget
, IsOpaque
);
409 SDValue
VECustomDAG::getConstantMask(Packing Packing
, bool AllTrue
) const {
410 auto MaskVT
= getLegalVectorType(Packing
, MVT::i1
);
412 // VEISelDAGtoDAG will replace this pattern with the constant-true VM.
413 auto TrueVal
= DAG
.getAllOnesConstant(DL
, MVT::i32
);
414 auto AVL
= getConstant(MaskVT
.getVectorNumElements(), MVT::i32
);
415 auto Res
= getNode(VEISD::VEC_BROADCAST
, MaskVT
, {TrueVal
, AVL
});
419 return DAG
.getNOT(DL
, Res
, Res
.getValueType());
422 SDValue
VECustomDAG::getMaskBroadcast(EVT ResultVT
, SDValue Scalar
,
424 // Constant mask splat.
425 if (auto BcConst
= dyn_cast
<ConstantSDNode
>(Scalar
))
426 return getConstantMask(getTypePacking(ResultVT
),
427 BcConst
->getSExtValue() != 0);
429 // Expand the broadcast to a vector comparison.
430 auto ScalarBoolVT
= Scalar
.getSimpleValueType();
431 assert(ScalarBoolVT
== MVT::i32
);
434 SDValue CmpElem
= DAG
.getSExtOrTrunc(Scalar
, DL
, MVT::i32
);
435 unsigned ElemCount
= ResultVT
.getVectorNumElements();
436 MVT CmpVecTy
= MVT::getVectorVT(ScalarBoolVT
, ElemCount
);
438 // Broadcast to vector.
440 DAG
.getNode(VEISD::VEC_BROADCAST
, DL
, CmpVecTy
, {CmpElem
, AVL
});
442 getBroadcast(CmpVecTy
, {DAG
.getConstant(0, DL
, ScalarBoolVT
)}, AVL
);
444 MVT BoolVecTy
= MVT::getVectorVT(MVT::i1
, ElemCount
);
446 // Broadcast(Data) != Broadcast(0)
447 // TODO: Use a VVP operation for this.
448 return DAG
.getSetCC(DL
, BoolVecTy
, BCVec
, ZeroVec
, ISD::CondCode::SETNE
);
451 SDValue
VECustomDAG::getBroadcast(EVT ResultVT
, SDValue Scalar
,
453 assert(ResultVT
.isVector());
454 auto ScaVT
= Scalar
.getValueType();
456 if (isMaskType(ResultVT
))
457 return getMaskBroadcast(ResultVT
, Scalar
, AVL
);
459 if (isPackedVectorType(ResultVT
)) {
460 // v512x packed mode broadcast
461 // Replicate the scalar reg (f32 or i32) onto the opposing half of the full
462 // scalar register. If it's an I64 type, assume that this has already
464 if (ScaVT
== MVT::f32
) {
465 Scalar
= getNode(VEISD::REPL_F32
, MVT::i64
, Scalar
);
466 } else if (ScaVT
== MVT::i32
) {
467 Scalar
= getNode(VEISD::REPL_I32
, MVT::i64
, Scalar
);
471 return getNode(VEISD::VEC_BROADCAST
, ResultVT
, {Scalar
, AVL
});
474 SDValue
VECustomDAG::annotateLegalAVL(SDValue AVL
) const {
477 return getNode(VEISD::LEGALAVL
, AVL
.getValueType(), AVL
);
480 SDValue
VECustomDAG::getUnpack(EVT DestVT
, SDValue Vec
, PackElem Part
,
482 assert(getAnnotatedNodeAVL(AVL
).second
&& "Expected a pack-legalized AVL");
484 // TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands.
486 (Part
== PackElem::Lo
) ? VEISD::VEC_UNPACK_LO
: VEISD::VEC_UNPACK_HI
;
487 return DAG
.getNode(OC
, DL
, DestVT
, Vec
, AVL
);
490 SDValue
VECustomDAG::getPack(EVT DestVT
, SDValue LoVec
, SDValue HiVec
,
492 assert(getAnnotatedNodeAVL(AVL
).second
&& "Expected a pack-legalized AVL");
494 // TODO: Peek through VEC_UNPACK_LO|HI operands.
495 return DAG
.getNode(VEISD::VEC_PACK
, DL
, DestVT
, LoVec
, HiVec
, AVL
);
498 VETargetMasks
VECustomDAG::getTargetSplitMask(SDValue RawMask
, SDValue RawAVL
,
499 PackElem Part
) const {
500 // Adjust AVL for this part
502 SDValue OneV
= getConstant(1, MVT::i32
);
503 if (Part
== PackElem::Hi
)
504 NewAVL
= getNode(ISD::ADD
, MVT::i32
, {RawAVL
, OneV
});
507 NewAVL
= getNode(ISD::SRL
, MVT::i32
, {NewAVL
, OneV
});
509 NewAVL
= annotateLegalAVL(NewAVL
);
511 // Legalize Mask (unpack or all-true)
514 NewMask
= getConstantMask(Packing::Normal
, true);
516 NewMask
= getUnpack(MVT::v256i1
, RawMask
, Part
, NewAVL
);
518 return VETargetMasks(NewMask
, NewAVL
);
521 SDValue
VECustomDAG::getSplitPtrOffset(SDValue Ptr
, SDValue ByteStride
,
522 PackElem Part
) const {
523 // High starts at base ptr but has more significant bits in the 64bit vector
525 if (Part
== PackElem::Hi
)
527 return getNode(ISD::ADD
, MVT::i64
, {Ptr
, ByteStride
});
530 SDValue
VECustomDAG::getSplitPtrStride(SDValue PackStride
) const {
531 if (auto ConstBytes
= dyn_cast
<ConstantSDNode
>(PackStride
))
532 return getConstant(2 * ConstBytes
->getSExtValue(), MVT::i64
);
533 return getNode(ISD::SHL
, MVT::i64
, {PackStride
, getConstant(1, MVT::i32
)});
536 SDValue
VECustomDAG::getGatherScatterAddress(SDValue BasePtr
, SDValue Scale
,
537 SDValue Index
, SDValue Mask
,
539 EVT IndexVT
= Index
.getValueType();
543 if (!Scale
|| isOneConstant(Scale
))
546 SDValue ScaleBroadcast
= getBroadcast(IndexVT
, Scale
, AVL
);
548 getNode(VEISD::VVP_MUL
, IndexVT
, {Index
, ScaleBroadcast
, Mask
, AVL
});
552 if (isNullConstant(BasePtr
))
555 // re-constitute pointer vector (basePtr + index * scale)
556 SDValue BaseBroadcast
= getBroadcast(IndexVT
, BasePtr
, AVL
);
558 getNode(VEISD::VVP_ADD
, IndexVT
, {BaseBroadcast
, ScaledIndex
, Mask
, AVL
});
562 SDValue
VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode
, EVT ResVT
,
563 SDValue StartV
, SDValue VectorV
,
564 SDValue Mask
, SDValue AVL
,
565 SDNodeFlags Flags
) const {
567 // Optionally attach the start param with a scalar op (where it is
569 bool scalarizeStartParam
= StartV
&& !hasReductionStartParam(VVPOpcode
);
570 bool IsMaskReduction
= isMaskType(VectorV
.getValueType());
571 assert(!IsMaskReduction
&& "TODO Implement");
572 auto AttachStartValue
= [&](SDValue ReductionResV
) {
573 if (!scalarizeStartParam
)
574 return ReductionResV
;
575 auto ScalarOC
= getScalarReductionOpcode(VVPOpcode
, IsMaskReduction
);
576 return getNode(ScalarOC
, ResVT
, {StartV
, ReductionResV
});
579 // Fixup: Always Use sequential 'fmul' reduction.
580 if (!scalarizeStartParam
&& StartV
) {
581 assert(hasReductionStartParam(VVPOpcode
));
582 return AttachStartValue(
583 getNode(VVPOpcode
, ResVT
, {StartV
, VectorV
, Mask
, AVL
}, Flags
));
585 return AttachStartValue(
586 getNode(VVPOpcode
, ResVT
, {VectorV
, Mask
, AVL
}, Flags
));