1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringSet.h"
14 #include "llvm/ADT/Twine.h"
15 #include "llvm/Support/ErrorHandling.h"
16 #include "llvm/Support/raw_ostream.h"
25 const PrototypeDescriptor
PrototypeDescriptor::Mask
= PrototypeDescriptor(
26 BaseTypeModifier::Vector
, VectorTypeModifier::MaskVector
);
27 const PrototypeDescriptor
PrototypeDescriptor::VL
=
28 PrototypeDescriptor(BaseTypeModifier::SizeT
);
29 const PrototypeDescriptor
PrototypeDescriptor::Vector
=
30 PrototypeDescriptor(BaseTypeModifier::Vector
);
32 //===----------------------------------------------------------------------===//
33 // Type implementation
34 //===----------------------------------------------------------------------===//
36 LMULType::LMULType(int NewLog2LMUL
) {
37 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
38 assert(NewLog2LMUL
<= 3 && NewLog2LMUL
>= -3 && "Bad LMUL number!");
39 Log2LMUL
= NewLog2LMUL
;
42 std::string
LMULType::str() const {
44 return "mf" + utostr(1ULL << (-Log2LMUL
));
45 return "m" + utostr(1ULL << Log2LMUL
);
48 VScaleVal
LMULType::getScale(unsigned ElementBitwidth
) const {
49 int Log2ScaleResult
= 0;
50 switch (ElementBitwidth
) {
54 Log2ScaleResult
= Log2LMUL
+ 3;
57 Log2ScaleResult
= Log2LMUL
+ 2;
60 Log2ScaleResult
= Log2LMUL
+ 1;
63 Log2ScaleResult
= Log2LMUL
;
66 // Illegal vscale result would be less than 1
67 if (Log2ScaleResult
< 0)
69 return 1 << Log2ScaleResult
;
72 void LMULType::MulLog2LMUL(int log2LMUL
) { Log2LMUL
+= log2LMUL
; }
74 RVVType::RVVType(BasicType BT
, int Log2LMUL
,
75 const PrototypeDescriptor
&prototype
)
76 : BT(BT
), LMUL(LMULType(Log2LMUL
)) {
78 applyModifier(prototype
);
84 initClangBuiltinStr();
90 // boolean type are encoded the ratio of n (SEW/LMUL)
91 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
92 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
93 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
95 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
96 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
97 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
98 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
99 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
100 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
101 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
102 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
103 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
106 bool RVVType::verifyType() const {
107 if (ScalarType
== Invalid
)
113 if (isFloat() && ElementBitwidth
== 8)
115 if (IsTuple
&& (NF
== 1 || NF
> 8))
117 if (IsTuple
&& (1 << std::max(0, LMUL
.Log2LMUL
)) * NF
> 8)
120 switch (ElementBitwidth
) {
123 // Check Scale is 1,2,4,8,16,32,64
124 return (V
<= 64 && isPowerOf2_32(V
));
126 // Check Scale is 1,2,4,8,16,32
127 return (V
<= 32 && isPowerOf2_32(V
));
129 // Check Scale is 1,2,4,8,16
130 return (V
<= 16 && isPowerOf2_32(V
));
132 // Check Scale is 1,2,4,8
133 return (V
<= 8 && isPowerOf2_32(V
));
138 void RVVType::initBuiltinStr() {
139 assert(isValid() && "RVVType is invalid");
140 switch (ScalarType
) {
141 case ScalarTypeKind::Void
:
144 case ScalarTypeKind::Size_t
:
147 BuiltinStr
= "I" + BuiltinStr
;
151 case ScalarTypeKind::Ptrdiff_t
:
154 case ScalarTypeKind::UnsignedLong
:
157 case ScalarTypeKind::SignedLong
:
160 case ScalarTypeKind::Boolean
:
161 assert(ElementBitwidth
== 1);
164 case ScalarTypeKind::SignedInteger
:
165 case ScalarTypeKind::UnsignedInteger
:
166 switch (ElementBitwidth
) {
180 llvm_unreachable("Unhandled ElementBitwidth!");
182 if (isSignedInteger())
183 BuiltinStr
= "S" + BuiltinStr
;
185 BuiltinStr
= "U" + BuiltinStr
;
187 case ScalarTypeKind::Float
:
188 switch (ElementBitwidth
) {
199 llvm_unreachable("Unhandled ElementBitwidth!");
203 llvm_unreachable("ScalarType is invalid!");
206 BuiltinStr
= "I" + BuiltinStr
;
214 BuiltinStr
= "q" + utostr(*Scale
) + BuiltinStr
;
215 // Pointer to vector types. Defined for segment load intrinsics.
216 // segment load intrinsics have pointer type arguments to store the loaded
222 BuiltinStr
= "T" + utostr(NF
) + BuiltinStr
;
225 void RVVType::initClangBuiltinStr() {
226 assert(isValid() && "RVVType is invalid");
227 assert(isVector() && "Handle Vector type only");
229 ClangBuiltinStr
= "__rvv_";
230 switch (ScalarType
) {
231 case ScalarTypeKind::Boolean
:
232 ClangBuiltinStr
+= "bool" + utostr(64 / *Scale
) + "_t";
234 case ScalarTypeKind::Float
:
235 ClangBuiltinStr
+= "float";
237 case ScalarTypeKind::SignedInteger
:
238 ClangBuiltinStr
+= "int";
240 case ScalarTypeKind::UnsignedInteger
:
241 ClangBuiltinStr
+= "uint";
244 llvm_unreachable("ScalarTypeKind is invalid");
246 ClangBuiltinStr
+= utostr(ElementBitwidth
) + LMUL
.str() +
247 (IsTuple
? "x" + utostr(NF
) : "") + "_t";
250 void RVVType::initTypeStr() {
251 assert(isValid() && "RVVType is invalid");
256 auto getTypeString
= [&](StringRef TypeStr
) {
258 return Twine(TypeStr
+ Twine(ElementBitwidth
) + "_t").str();
259 return Twine("v" + TypeStr
+ Twine(ElementBitwidth
) + LMUL
.str() +
260 (IsTuple
? "x" + utostr(NF
) : "") + "_t")
264 switch (ScalarType
) {
265 case ScalarTypeKind::Void
:
268 case ScalarTypeKind::Size_t
:
273 case ScalarTypeKind::Ptrdiff_t
:
276 case ScalarTypeKind::UnsignedLong
:
277 Str
= "unsigned long";
279 case ScalarTypeKind::SignedLong
:
282 case ScalarTypeKind::Boolean
:
286 // Vector bool is special case, the formulate is
287 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
288 Str
+= "vbool" + utostr(64 / *Scale
) + "_t";
290 case ScalarTypeKind::Float
:
292 if (ElementBitwidth
== 64)
294 else if (ElementBitwidth
== 32)
296 else if (ElementBitwidth
== 16)
299 llvm_unreachable("Unhandled floating type.");
301 Str
+= getTypeString("float");
303 case ScalarTypeKind::SignedInteger
:
304 Str
+= getTypeString("int");
306 case ScalarTypeKind::UnsignedInteger
:
307 Str
+= getTypeString("uint");
310 llvm_unreachable("ScalarType is invalid!");
316 void RVVType::initShortStr() {
317 switch (ScalarType
) {
318 case ScalarTypeKind::Boolean
:
320 ShortStr
= "b" + utostr(64 / *Scale
);
322 case ScalarTypeKind::Float
:
323 ShortStr
= "f" + utostr(ElementBitwidth
);
325 case ScalarTypeKind::SignedInteger
:
326 ShortStr
= "i" + utostr(ElementBitwidth
);
328 case ScalarTypeKind::UnsignedInteger
:
329 ShortStr
= "u" + utostr(ElementBitwidth
);
332 llvm_unreachable("Unhandled case!");
335 ShortStr
+= LMUL
.str();
337 ShortStr
+= "x" + utostr(NF
);
340 static VectorTypeModifier
getTupleVTM(unsigned NF
) {
341 assert(2 <= NF
&& NF
<= 8 && "2 <= NF <= 8");
342 return static_cast<VectorTypeModifier
>(
343 static_cast<uint8_t>(VectorTypeModifier::Tuple2
) + (NF
- 2));
346 void RVVType::applyBasicType() {
348 case BasicType::Int8
:
350 ScalarType
= ScalarTypeKind::SignedInteger
;
352 case BasicType::Int16
:
353 ElementBitwidth
= 16;
354 ScalarType
= ScalarTypeKind::SignedInteger
;
356 case BasicType::Int32
:
357 ElementBitwidth
= 32;
358 ScalarType
= ScalarTypeKind::SignedInteger
;
360 case BasicType::Int64
:
361 ElementBitwidth
= 64;
362 ScalarType
= ScalarTypeKind::SignedInteger
;
364 case BasicType::Float16
:
365 ElementBitwidth
= 16;
366 ScalarType
= ScalarTypeKind::Float
;
368 case BasicType::Float32
:
369 ElementBitwidth
= 32;
370 ScalarType
= ScalarTypeKind::Float
;
372 case BasicType::Float64
:
373 ElementBitwidth
= 64;
374 ScalarType
= ScalarTypeKind::Float
;
377 llvm_unreachable("Unhandled type code!");
379 assert(ElementBitwidth
!= 0 && "Bad element bitwidth!");
382 std::optional
<PrototypeDescriptor
>
383 PrototypeDescriptor::parsePrototypeDescriptor(
384 llvm::StringRef PrototypeDescriptorStr
) {
385 PrototypeDescriptor PD
;
386 BaseTypeModifier PT
= BaseTypeModifier::Invalid
;
387 VectorTypeModifier VTM
= VectorTypeModifier::NoModifier
;
389 if (PrototypeDescriptorStr
.empty())
392 // Handle base type modifier
393 auto PType
= PrototypeDescriptorStr
.back();
396 PT
= BaseTypeModifier::Scalar
;
399 PT
= BaseTypeModifier::Vector
;
402 PT
= BaseTypeModifier::Vector
;
403 VTM
= VectorTypeModifier::Widening2XVector
;
406 PT
= BaseTypeModifier::Vector
;
407 VTM
= VectorTypeModifier::Widening4XVector
;
410 PT
= BaseTypeModifier::Vector
;
411 VTM
= VectorTypeModifier::Widening8XVector
;
414 PT
= BaseTypeModifier::Vector
;
415 VTM
= VectorTypeModifier::MaskVector
;
418 PT
= BaseTypeModifier::Void
;
421 PT
= BaseTypeModifier::SizeT
;
424 PT
= BaseTypeModifier::Ptrdiff
;
427 PT
= BaseTypeModifier::UnsignedLong
;
430 PT
= BaseTypeModifier::SignedLong
;
433 llvm_unreachable("Illegal primitive type transformers!");
435 PD
.PT
= static_cast<uint8_t>(PT
);
436 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_back();
438 // Compute the vector type transformers, it can only appear one time.
439 if (PrototypeDescriptorStr
.startswith("(")) {
440 assert(VTM
== VectorTypeModifier::NoModifier
&&
441 "VectorTypeModifier should only have one modifier");
442 size_t Idx
= PrototypeDescriptorStr
.find(')');
443 assert(Idx
!= StringRef::npos
);
444 StringRef ComplexType
= PrototypeDescriptorStr
.slice(1, Idx
);
445 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_front(Idx
+ 1);
446 assert(!PrototypeDescriptorStr
.contains('(') &&
447 "Only allow one vector type modifier");
449 auto ComplexTT
= ComplexType
.split(":");
450 if (ComplexTT
.first
== "Log2EEW") {
452 if (ComplexTT
.second
.getAsInteger(10, Log2EEW
)) {
453 llvm_unreachable("Invalid Log2EEW value!");
458 VTM
= VectorTypeModifier::Log2EEW3
;
461 VTM
= VectorTypeModifier::Log2EEW4
;
464 VTM
= VectorTypeModifier::Log2EEW5
;
467 VTM
= VectorTypeModifier::Log2EEW6
;
470 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
473 } else if (ComplexTT
.first
== "FixedSEW") {
475 if (ComplexTT
.second
.getAsInteger(10, NewSEW
)) {
476 llvm_unreachable("Invalid FixedSEW value!");
481 VTM
= VectorTypeModifier::FixedSEW8
;
484 VTM
= VectorTypeModifier::FixedSEW16
;
487 VTM
= VectorTypeModifier::FixedSEW32
;
490 VTM
= VectorTypeModifier::FixedSEW64
;
493 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
496 } else if (ComplexTT
.first
== "LFixedLog2LMUL") {
498 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
499 llvm_unreachable("Invalid LFixedLog2LMUL value!");
504 VTM
= VectorTypeModifier::LFixedLog2LMULN3
;
507 VTM
= VectorTypeModifier::LFixedLog2LMULN2
;
510 VTM
= VectorTypeModifier::LFixedLog2LMULN1
;
513 VTM
= VectorTypeModifier::LFixedLog2LMUL0
;
516 VTM
= VectorTypeModifier::LFixedLog2LMUL1
;
519 VTM
= VectorTypeModifier::LFixedLog2LMUL2
;
522 VTM
= VectorTypeModifier::LFixedLog2LMUL3
;
525 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
528 } else if (ComplexTT
.first
== "SFixedLog2LMUL") {
530 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
531 llvm_unreachable("Invalid SFixedLog2LMUL value!");
536 VTM
= VectorTypeModifier::SFixedLog2LMULN3
;
539 VTM
= VectorTypeModifier::SFixedLog2LMULN2
;
542 VTM
= VectorTypeModifier::SFixedLog2LMULN1
;
545 VTM
= VectorTypeModifier::SFixedLog2LMUL0
;
548 VTM
= VectorTypeModifier::SFixedLog2LMUL1
;
551 VTM
= VectorTypeModifier::SFixedLog2LMUL2
;
554 VTM
= VectorTypeModifier::SFixedLog2LMUL3
;
557 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
561 } else if (ComplexTT
.first
== "SEFixedLog2LMUL") {
563 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
564 llvm_unreachable("Invalid SEFixedLog2LMUL value!");
569 VTM
= VectorTypeModifier::SEFixedLog2LMULN3
;
572 VTM
= VectorTypeModifier::SEFixedLog2LMULN2
;
575 VTM
= VectorTypeModifier::SEFixedLog2LMULN1
;
578 VTM
= VectorTypeModifier::SEFixedLog2LMUL0
;
581 VTM
= VectorTypeModifier::SEFixedLog2LMUL1
;
584 VTM
= VectorTypeModifier::SEFixedLog2LMUL2
;
587 VTM
= VectorTypeModifier::SEFixedLog2LMUL3
;
590 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
593 } else if (ComplexTT
.first
== "Tuple") {
595 if (ComplexTT
.second
.getAsInteger(10, NF
)) {
596 llvm_unreachable("Invalid NF value!");
599 VTM
= getTupleVTM(NF
);
601 llvm_unreachable("Illegal complex type transformers!");
604 PD
.VTM
= static_cast<uint8_t>(VTM
);
606 // Compute the remain type transformers
607 TypeModifier TM
= TypeModifier::NoModifier
;
608 for (char I
: PrototypeDescriptorStr
) {
611 if ((TM
& TypeModifier::Const
) == TypeModifier::Const
)
612 llvm_unreachable("'P' transformer cannot be used after 'C'");
613 if ((TM
& TypeModifier::Pointer
) == TypeModifier::Pointer
)
614 llvm_unreachable("'P' transformer cannot be used twice");
615 TM
|= TypeModifier::Pointer
;
618 TM
|= TypeModifier::Const
;
621 TM
|= TypeModifier::Immediate
;
624 TM
|= TypeModifier::UnsignedInteger
;
627 TM
|= TypeModifier::SignedInteger
;
630 TM
|= TypeModifier::Float
;
633 TM
|= TypeModifier::LMUL1
;
636 llvm_unreachable("Illegal non-primitive type transformer!");
639 PD
.TM
= static_cast<uint8_t>(TM
);
644 void RVVType::applyModifier(const PrototypeDescriptor
&Transformer
) {
645 // Handle primitive type transformer
646 switch (static_cast<BaseTypeModifier
>(Transformer
.PT
)) {
647 case BaseTypeModifier::Scalar
:
650 case BaseTypeModifier::Vector
:
651 Scale
= LMUL
.getScale(ElementBitwidth
);
653 case BaseTypeModifier::Void
:
654 ScalarType
= ScalarTypeKind::Void
;
656 case BaseTypeModifier::SizeT
:
657 ScalarType
= ScalarTypeKind::Size_t
;
659 case BaseTypeModifier::Ptrdiff
:
660 ScalarType
= ScalarTypeKind::Ptrdiff_t
;
662 case BaseTypeModifier::UnsignedLong
:
663 ScalarType
= ScalarTypeKind::UnsignedLong
;
665 case BaseTypeModifier::SignedLong
:
666 ScalarType
= ScalarTypeKind::SignedLong
;
668 case BaseTypeModifier::Invalid
:
669 ScalarType
= ScalarTypeKind::Invalid
;
673 switch (static_cast<VectorTypeModifier
>(Transformer
.VTM
)) {
674 case VectorTypeModifier::Widening2XVector
:
675 ElementBitwidth
*= 2;
677 Scale
= LMUL
.getScale(ElementBitwidth
);
679 case VectorTypeModifier::Widening4XVector
:
680 ElementBitwidth
*= 4;
682 Scale
= LMUL
.getScale(ElementBitwidth
);
684 case VectorTypeModifier::Widening8XVector
:
685 ElementBitwidth
*= 8;
687 Scale
= LMUL
.getScale(ElementBitwidth
);
689 case VectorTypeModifier::MaskVector
:
690 ScalarType
= ScalarTypeKind::Boolean
;
691 Scale
= LMUL
.getScale(ElementBitwidth
);
694 case VectorTypeModifier::Log2EEW3
:
697 case VectorTypeModifier::Log2EEW4
:
700 case VectorTypeModifier::Log2EEW5
:
703 case VectorTypeModifier::Log2EEW6
:
706 case VectorTypeModifier::FixedSEW8
:
709 case VectorTypeModifier::FixedSEW16
:
712 case VectorTypeModifier::FixedSEW32
:
715 case VectorTypeModifier::FixedSEW64
:
718 case VectorTypeModifier::LFixedLog2LMULN3
:
719 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan
);
721 case VectorTypeModifier::LFixedLog2LMULN2
:
722 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan
);
724 case VectorTypeModifier::LFixedLog2LMULN1
:
725 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan
);
727 case VectorTypeModifier::LFixedLog2LMUL0
:
728 applyFixedLog2LMUL(0, FixedLMULType::LargerThan
);
730 case VectorTypeModifier::LFixedLog2LMUL1
:
731 applyFixedLog2LMUL(1, FixedLMULType::LargerThan
);
733 case VectorTypeModifier::LFixedLog2LMUL2
:
734 applyFixedLog2LMUL(2, FixedLMULType::LargerThan
);
736 case VectorTypeModifier::LFixedLog2LMUL3
:
737 applyFixedLog2LMUL(3, FixedLMULType::LargerThan
);
739 case VectorTypeModifier::SFixedLog2LMULN3
:
740 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan
);
742 case VectorTypeModifier::SFixedLog2LMULN2
:
743 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan
);
745 case VectorTypeModifier::SFixedLog2LMULN1
:
746 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan
);
748 case VectorTypeModifier::SFixedLog2LMUL0
:
749 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan
);
751 case VectorTypeModifier::SFixedLog2LMUL1
:
752 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan
);
754 case VectorTypeModifier::SFixedLog2LMUL2
:
755 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan
);
757 case VectorTypeModifier::SFixedLog2LMUL3
:
758 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan
);
760 case VectorTypeModifier::SEFixedLog2LMULN3
:
761 applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual
);
763 case VectorTypeModifier::SEFixedLog2LMULN2
:
764 applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual
);
766 case VectorTypeModifier::SEFixedLog2LMULN1
:
767 applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual
);
769 case VectorTypeModifier::SEFixedLog2LMUL0
:
770 applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual
);
772 case VectorTypeModifier::SEFixedLog2LMUL1
:
773 applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual
);
775 case VectorTypeModifier::SEFixedLog2LMUL2
:
776 applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual
);
778 case VectorTypeModifier::SEFixedLog2LMUL3
:
779 applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual
);
781 case VectorTypeModifier::Tuple2
:
782 case VectorTypeModifier::Tuple3
:
783 case VectorTypeModifier::Tuple4
:
784 case VectorTypeModifier::Tuple5
:
785 case VectorTypeModifier::Tuple6
:
786 case VectorTypeModifier::Tuple7
:
787 case VectorTypeModifier::Tuple8
: {
789 NF
= 2 + static_cast<uint8_t>(Transformer
.VTM
) -
790 static_cast<uint8_t>(VectorTypeModifier::Tuple2
);
793 case VectorTypeModifier::NoModifier
:
797 // Early return if the current type modifier is already invalid.
798 if (ScalarType
== Invalid
)
801 for (unsigned TypeModifierMaskShift
= 0;
802 TypeModifierMaskShift
<= static_cast<unsigned>(TypeModifier::MaxOffset
);
803 ++TypeModifierMaskShift
) {
804 unsigned TypeModifierMask
= 1 << TypeModifierMaskShift
;
805 if ((static_cast<unsigned>(Transformer
.TM
) & TypeModifierMask
) !=
808 switch (static_cast<TypeModifier
>(TypeModifierMask
)) {
809 case TypeModifier::Pointer
:
812 case TypeModifier::Const
:
815 case TypeModifier::Immediate
:
819 case TypeModifier::UnsignedInteger
:
820 ScalarType
= ScalarTypeKind::UnsignedInteger
;
822 case TypeModifier::SignedInteger
:
823 ScalarType
= ScalarTypeKind::SignedInteger
;
825 case TypeModifier::Float
:
826 ScalarType
= ScalarTypeKind::Float
;
828 case TypeModifier::LMUL1
:
830 // Update ElementBitwidth need to update Scale too.
831 Scale
= LMUL
.getScale(ElementBitwidth
);
834 llvm_unreachable("Unknown type modifier mask!");
839 void RVVType::applyLog2EEW(unsigned Log2EEW
) {
840 // update new elmul = (eew/sew) * lmul
841 LMUL
.MulLog2LMUL(Log2EEW
- Log2_32(ElementBitwidth
));
843 ElementBitwidth
= 1 << Log2EEW
;
844 ScalarType
= ScalarTypeKind::SignedInteger
;
845 Scale
= LMUL
.getScale(ElementBitwidth
);
848 void RVVType::applyFixedSEW(unsigned NewSEW
) {
849 // Set invalid type if src and dst SEW are same.
850 if (ElementBitwidth
== NewSEW
) {
851 ScalarType
= ScalarTypeKind::Invalid
;
855 ElementBitwidth
= NewSEW
;
856 Scale
= LMUL
.getScale(ElementBitwidth
);
859 void RVVType::applyFixedLog2LMUL(int Log2LMUL
, enum FixedLMULType Type
) {
861 case FixedLMULType::LargerThan
:
862 if (Log2LMUL
<= LMUL
.Log2LMUL
) {
863 ScalarType
= ScalarTypeKind::Invalid
;
867 case FixedLMULType::SmallerThan
:
868 if (Log2LMUL
>= LMUL
.Log2LMUL
) {
869 ScalarType
= ScalarTypeKind::Invalid
;
873 case FixedLMULType::SmallerOrEqual
:
874 if (Log2LMUL
> LMUL
.Log2LMUL
) {
875 ScalarType
= ScalarTypeKind::Invalid
;
882 LMUL
= LMULType(Log2LMUL
);
883 Scale
= LMUL
.getScale(ElementBitwidth
);
886 std::optional
<RVVTypes
>
887 RVVTypeCache::computeTypes(BasicType BT
, int Log2LMUL
, unsigned NF
,
888 ArrayRef
<PrototypeDescriptor
> Prototype
) {
890 for (const PrototypeDescriptor
&Proto
: Prototype
) {
891 auto T
= computeType(BT
, Log2LMUL
, Proto
);
894 // Record legal type index
900 // Compute the hash value of RVVType, used for cache the result of computeType.
901 static uint64_t computeRVVTypeHashValue(BasicType BT
, int Log2LMUL
,
902 PrototypeDescriptor Proto
) {
903 // Layout of hash value:
905 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
906 assert(Log2LMUL
>= -3 && Log2LMUL
<= 3);
907 return (Log2LMUL
+ 3) | (static_cast<uint64_t>(BT
) & 0xff) << 8 |
908 ((uint64_t)(Proto
.PT
& 0xff) << 16) |
909 ((uint64_t)(Proto
.TM
& 0xff) << 24) |
910 ((uint64_t)(Proto
.VTM
& 0xff) << 32);
913 std::optional
<RVVTypePtr
> RVVTypeCache::computeType(BasicType BT
, int Log2LMUL
,
914 PrototypeDescriptor Proto
) {
915 uint64_t Idx
= computeRVVTypeHashValue(BT
, Log2LMUL
, Proto
);
917 auto It
= LegalTypes
.find(Idx
);
918 if (It
!= LegalTypes
.end())
919 return &(It
->second
);
921 if (IllegalTypes
.count(Idx
))
924 // Compute type and record the result.
925 RVVType
T(BT
, Log2LMUL
, Proto
);
927 // Record legal type index and value.
928 std::pair
<std::unordered_map
<uint64_t, RVVType
>::iterator
, bool>
929 InsertResult
= LegalTypes
.insert({Idx
, T
});
930 return &(InsertResult
.first
->second
);
932 // Record illegal type index.
933 IllegalTypes
.insert(Idx
);
937 //===----------------------------------------------------------------------===//
938 // RVVIntrinsic implementation
939 //===----------------------------------------------------------------------===//
940 RVVIntrinsic::RVVIntrinsic(
941 StringRef NewName
, StringRef Suffix
, StringRef NewOverloadedName
,
942 StringRef OverloadedSuffix
, StringRef IRName
, bool IsMasked
,
943 bool HasMaskedOffOperand
, bool HasVL
, PolicyScheme Scheme
,
944 bool SupportOverloading
, bool HasBuiltinAlias
, StringRef ManualCodegen
,
945 const RVVTypes
&OutInTypes
, const std::vector
<int64_t> &NewIntrinsicTypes
,
946 const std::vector
<StringRef
> &RequiredFeatures
, unsigned NF
,
947 Policy NewPolicyAttrs
, bool HasFRMRoundModeOp
)
948 : IRName(IRName
), IsMasked(IsMasked
),
949 HasMaskedOffOperand(HasMaskedOffOperand
), HasVL(HasVL
), Scheme(Scheme
),
950 SupportOverloading(SupportOverloading
), HasBuiltinAlias(HasBuiltinAlias
),
951 ManualCodegen(ManualCodegen
.str()), NF(NF
), PolicyAttrs(NewPolicyAttrs
) {
953 // Init BuiltinName, Name and OverloadedName
954 BuiltinName
= NewName
.str();
956 if (NewOverloadedName
.empty())
957 OverloadedName
= NewName
.split("_").first
.str();
959 OverloadedName
= NewOverloadedName
.str();
961 Name
+= "_" + Suffix
.str();
962 if (!OverloadedSuffix
.empty())
963 OverloadedName
+= "_" + OverloadedSuffix
.str();
965 updateNamesAndPolicy(IsMasked
, hasPolicy(), Name
, BuiltinName
, OverloadedName
,
966 PolicyAttrs
, HasFRMRoundModeOp
);
968 // Init OutputType and InputTypes
969 OutputType
= OutInTypes
[0];
970 InputTypes
.assign(OutInTypes
.begin() + 1, OutInTypes
.end());
972 // IntrinsicTypes is unmasked TA version index. Need to update it
973 // if there is merge operand (It is always in first operand).
974 IntrinsicTypes
= NewIntrinsicTypes
;
975 if ((IsMasked
&& hasMaskedOffOperand()) ||
976 (!IsMasked
&& hasPassthruOperand())) {
977 for (auto &I
: IntrinsicTypes
) {
984 std::string
RVVIntrinsic::getBuiltinTypeStr() const {
986 S
+= OutputType
->getBuiltinStr();
987 for (const auto &T
: InputTypes
) {
988 S
+= T
->getBuiltinStr();
993 std::string
RVVIntrinsic::getSuffixStr(
994 RVVTypeCache
&TypeCache
, BasicType Type
, int Log2LMUL
,
995 llvm::ArrayRef
<PrototypeDescriptor
> PrototypeDescriptors
) {
996 SmallVector
<std::string
> SuffixStrs
;
997 for (auto PD
: PrototypeDescriptors
) {
998 auto T
= TypeCache
.computeType(Type
, Log2LMUL
, PD
);
999 SuffixStrs
.push_back((*T
)->getShortStr());
1001 return join(SuffixStrs
, "_");
1004 llvm::SmallVector
<PrototypeDescriptor
> RVVIntrinsic::computeBuiltinTypes(
1005 llvm::ArrayRef
<PrototypeDescriptor
> Prototype
, bool IsMasked
,
1006 bool HasMaskedOffOperand
, bool HasVL
, unsigned NF
,
1007 PolicyScheme DefaultScheme
, Policy PolicyAttrs
, bool IsTuple
) {
1008 SmallVector
<PrototypeDescriptor
> NewPrototype(Prototype
.begin(),
1010 bool HasPassthruOp
= DefaultScheme
== PolicyScheme::HasPassthruOperand
;
1012 // If HasMaskedOffOperand, insert result type as first input operand if
1014 if (HasMaskedOffOperand
&& !PolicyAttrs
.isTAMAPolicy()) {
1016 NewPrototype
.insert(NewPrototype
.begin() + 1, NewPrototype
[0]);
1017 } else if (NF
> 1) {
1019 PrototypeDescriptor BasePtrOperand
= Prototype
[1];
1020 PrototypeDescriptor MaskoffType
= PrototypeDescriptor(
1021 static_cast<uint8_t>(BaseTypeModifier::Vector
),
1022 static_cast<uint8_t>(getTupleVTM(NF
)),
1023 BasePtrOperand
.TM
& ~static_cast<uint8_t>(TypeModifier::Pointer
));
1024 NewPrototype
.insert(NewPrototype
.begin() + 1, MaskoffType
);
1027 // (void, op0 address, op1 address, ...)
1029 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1030 PrototypeDescriptor MaskoffType
= NewPrototype
[1];
1031 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
1032 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
1036 if (HasMaskedOffOperand
&& NF
> 1) {
1038 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1040 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1043 NewPrototype
.insert(NewPrototype
.begin() + 1,
1044 PrototypeDescriptor::Mask
);
1046 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1,
1047 PrototypeDescriptor::Mask
);
1049 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
1050 NewPrototype
.insert(NewPrototype
.begin() + 1, PrototypeDescriptor::Mask
);
1054 if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
)
1055 NewPrototype
.insert(NewPrototype
.begin(), NewPrototype
[0]);
1056 } else if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
) {
1058 PrototypeDescriptor BasePtrOperand
= Prototype
[0];
1059 PrototypeDescriptor MaskoffType
= PrototypeDescriptor(
1060 static_cast<uint8_t>(BaseTypeModifier::Vector
),
1061 static_cast<uint8_t>(getTupleVTM(NF
)),
1062 BasePtrOperand
.TM
& ~static_cast<uint8_t>(TypeModifier::Pointer
));
1063 NewPrototype
.insert(NewPrototype
.begin(), MaskoffType
);
1065 // NF > 1 cases for segment load operations.
1067 // (void, op0 address, op1 address, ...)
1069 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1070 PrototypeDescriptor MaskoffType
= Prototype
[1];
1071 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
1072 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
1077 // If HasVL, append PrototypeDescriptor:VL to last operand
1079 NewPrototype
.push_back(PrototypeDescriptor::VL
);
1081 return NewPrototype
;
1084 llvm::SmallVector
<Policy
> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1085 return {Policy(Policy::PolicyType::Undisturbed
)}; // TU
1088 llvm::SmallVector
<Policy
>
1089 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy
,
1090 bool HasMaskPolicy
) {
1091 if (HasTailPolicy
&& HasMaskPolicy
)
1092 return {Policy(Policy::PolicyType::Undisturbed
,
1093 Policy::PolicyType::Agnostic
), // TUM
1094 Policy(Policy::PolicyType::Undisturbed
,
1095 Policy::PolicyType::Undisturbed
), // TUMU
1096 Policy(Policy::PolicyType::Agnostic
,
1097 Policy::PolicyType::Undisturbed
)}; // MU
1098 if (HasTailPolicy
&& !HasMaskPolicy
)
1099 return {Policy(Policy::PolicyType::Undisturbed
,
1100 Policy::PolicyType::Agnostic
)}; // TU
1101 if (!HasTailPolicy
&& HasMaskPolicy
)
1102 return {Policy(Policy::PolicyType::Agnostic
,
1103 Policy::PolicyType::Undisturbed
)}; // MU
1104 llvm_unreachable("An RVV instruction should not be without both tail policy "
1108 void RVVIntrinsic::updateNamesAndPolicy(
1109 bool IsMasked
, bool HasPolicy
, std::string
&Name
, std::string
&BuiltinName
,
1110 std::string
&OverloadedName
, Policy
&PolicyAttrs
, bool HasFRMRoundModeOp
) {
1112 auto appendPolicySuffix
= [&](const std::string
&suffix
) {
1114 BuiltinName
+= suffix
;
1115 OverloadedName
+= suffix
;
1118 // This follows the naming guideline under riscv-c-api-doc to add the
1119 // `__riscv_` suffix for all RVV intrinsics.
1120 Name
= "__riscv_" + Name
;
1121 OverloadedName
= "__riscv_" + OverloadedName
;
1123 if (HasFRMRoundModeOp
) {
1125 BuiltinName
+= "_rm";
1129 if (PolicyAttrs
.isTUMUPolicy())
1130 appendPolicySuffix("_tumu");
1131 else if (PolicyAttrs
.isTUMAPolicy())
1132 appendPolicySuffix("_tum");
1133 else if (PolicyAttrs
.isTAMUPolicy())
1134 appendPolicySuffix("_mu");
1135 else if (PolicyAttrs
.isTAMAPolicy()) {
1137 BuiltinName
+= "_m";
1139 llvm_unreachable("Unhandled policy condition");
1141 if (PolicyAttrs
.isTUPolicy())
1142 appendPolicySuffix("_tu");
1143 else if (PolicyAttrs
.isTAPolicy()) // no suffix needed
1146 llvm_unreachable("Unhandled policy condition");
1150 SmallVector
<PrototypeDescriptor
> parsePrototypes(StringRef Prototypes
) {
1151 SmallVector
<PrototypeDescriptor
> PrototypeDescriptors
;
1152 const StringRef
Primaries("evwqom0ztul");
1153 while (!Prototypes
.empty()) {
1155 // Skip over complex prototype because it could contain primitive type
1157 if (Prototypes
[0] == '(')
1158 Idx
= Prototypes
.find_first_of(')');
1159 Idx
= Prototypes
.find_first_of(Primaries
, Idx
);
1160 assert(Idx
!= StringRef::npos
);
1161 auto PD
= PrototypeDescriptor::parsePrototypeDescriptor(
1162 Prototypes
.slice(0, Idx
+ 1));
1164 llvm_unreachable("Error during parsing prototype.");
1165 PrototypeDescriptors
.push_back(*PD
);
1166 Prototypes
= Prototypes
.drop_front(Idx
+ 1);
1168 return PrototypeDescriptors
;
1171 raw_ostream
&operator<<(raw_ostream
&OS
, const RVVIntrinsicRecord
&Record
) {
1173 OS
<< "\"" << Record
.Name
<< "\",";
1174 if (Record
.OverloadedName
== nullptr ||
1175 StringRef(Record
.OverloadedName
).empty())
1178 OS
<< "\"" << Record
.OverloadedName
<< "\",";
1179 OS
<< Record
.PrototypeIndex
<< ",";
1180 OS
<< Record
.SuffixIndex
<< ",";
1181 OS
<< Record
.OverloadedSuffixIndex
<< ",";
1182 OS
<< (int)Record
.PrototypeLength
<< ",";
1183 OS
<< (int)Record
.SuffixLength
<< ",";
1184 OS
<< (int)Record
.OverloadedSuffixSize
<< ",";
1185 OS
<< (int)Record
.RequiredExtensions
<< ",";
1186 OS
<< (int)Record
.TypeRangeMask
<< ",";
1187 OS
<< (int)Record
.Log2LMULMask
<< ",";
1188 OS
<< (int)Record
.NF
<< ",";
1189 OS
<< (int)Record
.HasMasked
<< ",";
1190 OS
<< (int)Record
.HasVL
<< ",";
1191 OS
<< (int)Record
.HasMaskedOffOperand
<< ",";
1192 OS
<< (int)Record
.HasTailPolicy
<< ",";
1193 OS
<< (int)Record
.HasMaskPolicy
<< ",";
1194 OS
<< (int)Record
.HasFRMRoundModeOp
<< ",";
1195 OS
<< (int)Record
.IsTuple
<< ",";
1196 OS
<< (int)Record
.UnMaskedPolicyScheme
<< ",";
1197 OS
<< (int)Record
.MaskedPolicyScheme
<< ",";
1202 } // end namespace RISCV
1203 } // end namespace clang