1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringSet.h"
14 #include "llvm/ADT/Twine.h"
15 #include "llvm/Support/ErrorHandling.h"
16 #include "llvm/Support/raw_ostream.h"
25 const PrototypeDescriptor
PrototypeDescriptor::Mask
= PrototypeDescriptor(
26 BaseTypeModifier::Vector
, VectorTypeModifier::MaskVector
);
27 const PrototypeDescriptor
PrototypeDescriptor::VL
=
28 PrototypeDescriptor(BaseTypeModifier::SizeT
);
29 const PrototypeDescriptor
PrototypeDescriptor::Vector
=
30 PrototypeDescriptor(BaseTypeModifier::Vector
);
32 //===----------------------------------------------------------------------===//
33 // Type implementation
34 //===----------------------------------------------------------------------===//
36 LMULType::LMULType(int NewLog2LMUL
) {
37 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
38 assert(NewLog2LMUL
<= 3 && NewLog2LMUL
>= -3 && "Bad LMUL number!");
39 Log2LMUL
= NewLog2LMUL
;
42 std::string
LMULType::str() const {
44 return "mf" + utostr(1ULL << (-Log2LMUL
));
45 return "m" + utostr(1ULL << Log2LMUL
);
48 VScaleVal
LMULType::getScale(unsigned ElementBitwidth
) const {
49 int Log2ScaleResult
= 0;
50 switch (ElementBitwidth
) {
54 Log2ScaleResult
= Log2LMUL
+ 3;
57 Log2ScaleResult
= Log2LMUL
+ 2;
60 Log2ScaleResult
= Log2LMUL
+ 1;
63 Log2ScaleResult
= Log2LMUL
;
66 // Illegal vscale result would be less than 1
67 if (Log2ScaleResult
< 0)
69 return 1 << Log2ScaleResult
;
72 void LMULType::MulLog2LMUL(int log2LMUL
) { Log2LMUL
+= log2LMUL
; }
74 RVVType::RVVType(BasicType BT
, int Log2LMUL
,
75 const PrototypeDescriptor
&prototype
)
76 : BT(BT
), LMUL(LMULType(Log2LMUL
)) {
78 applyModifier(prototype
);
84 initClangBuiltinStr();
90 // boolean type are encoded the ratio of n (SEW/LMUL)
91 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
92 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
93 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
95 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
96 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
97 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
98 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
99 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
100 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
101 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
102 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
103 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
106 bool RVVType::verifyType() const {
107 if (ScalarType
== Invalid
)
113 if (isFloat() && ElementBitwidth
== 8)
115 if (IsTuple
&& (NF
== 1 || NF
> 8))
117 if (IsTuple
&& (1 << std::max(0, LMUL
.Log2LMUL
)) * NF
> 8)
120 switch (ElementBitwidth
) {
123 // Check Scale is 1,2,4,8,16,32,64
124 return (V
<= 64 && isPowerOf2_32(V
));
126 // Check Scale is 1,2,4,8,16,32
127 return (V
<= 32 && isPowerOf2_32(V
));
129 // Check Scale is 1,2,4,8,16
130 return (V
<= 16 && isPowerOf2_32(V
));
132 // Check Scale is 1,2,4,8
133 return (V
<= 8 && isPowerOf2_32(V
));
138 void RVVType::initBuiltinStr() {
139 assert(isValid() && "RVVType is invalid");
140 switch (ScalarType
) {
141 case ScalarTypeKind::Void
:
144 case ScalarTypeKind::Size_t
:
147 BuiltinStr
= "I" + BuiltinStr
;
151 case ScalarTypeKind::Ptrdiff_t
:
154 case ScalarTypeKind::UnsignedLong
:
157 case ScalarTypeKind::SignedLong
:
160 case ScalarTypeKind::Boolean
:
161 assert(ElementBitwidth
== 1);
164 case ScalarTypeKind::SignedInteger
:
165 case ScalarTypeKind::UnsignedInteger
:
166 switch (ElementBitwidth
) {
180 llvm_unreachable("Unhandled ElementBitwidth!");
182 if (isSignedInteger())
183 BuiltinStr
= "S" + BuiltinStr
;
185 BuiltinStr
= "U" + BuiltinStr
;
187 case ScalarTypeKind::Float
:
188 switch (ElementBitwidth
) {
199 llvm_unreachable("Unhandled ElementBitwidth!");
203 llvm_unreachable("ScalarType is invalid!");
206 BuiltinStr
= "I" + BuiltinStr
;
214 BuiltinStr
= "q" + utostr(*Scale
) + BuiltinStr
;
215 // Pointer to vector types. Defined for segment load intrinsics.
216 // segment load intrinsics have pointer type arguments to store the loaded
222 BuiltinStr
= "T" + utostr(NF
) + BuiltinStr
;
225 void RVVType::initClangBuiltinStr() {
226 assert(isValid() && "RVVType is invalid");
227 assert(isVector() && "Handle Vector type only");
229 ClangBuiltinStr
= "__rvv_";
230 switch (ScalarType
) {
231 case ScalarTypeKind::Boolean
:
232 ClangBuiltinStr
+= "bool" + utostr(64 / *Scale
) + "_t";
234 case ScalarTypeKind::Float
:
235 ClangBuiltinStr
+= "float";
237 case ScalarTypeKind::SignedInteger
:
238 ClangBuiltinStr
+= "int";
240 case ScalarTypeKind::UnsignedInteger
:
241 ClangBuiltinStr
+= "uint";
244 llvm_unreachable("ScalarTypeKind is invalid");
246 ClangBuiltinStr
+= utostr(ElementBitwidth
) + LMUL
.str() +
247 (IsTuple
? "x" + utostr(NF
) : "") + "_t";
250 void RVVType::initTypeStr() {
251 assert(isValid() && "RVVType is invalid");
256 auto getTypeString
= [&](StringRef TypeStr
) {
258 return Twine(TypeStr
+ Twine(ElementBitwidth
) + "_t").str();
259 return Twine("v" + TypeStr
+ Twine(ElementBitwidth
) + LMUL
.str() +
260 (IsTuple
? "x" + utostr(NF
) : "") + "_t")
264 switch (ScalarType
) {
265 case ScalarTypeKind::Void
:
268 case ScalarTypeKind::Size_t
:
273 case ScalarTypeKind::Ptrdiff_t
:
276 case ScalarTypeKind::UnsignedLong
:
277 Str
= "unsigned long";
279 case ScalarTypeKind::SignedLong
:
282 case ScalarTypeKind::Boolean
:
286 // Vector bool is special case, the formulate is
287 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
288 Str
+= "vbool" + utostr(64 / *Scale
) + "_t";
290 case ScalarTypeKind::Float
:
292 if (ElementBitwidth
== 64)
294 else if (ElementBitwidth
== 32)
296 else if (ElementBitwidth
== 16)
299 llvm_unreachable("Unhandled floating type.");
301 Str
+= getTypeString("float");
303 case ScalarTypeKind::SignedInteger
:
304 Str
+= getTypeString("int");
306 case ScalarTypeKind::UnsignedInteger
:
307 Str
+= getTypeString("uint");
310 llvm_unreachable("ScalarType is invalid!");
316 void RVVType::initShortStr() {
317 switch (ScalarType
) {
318 case ScalarTypeKind::Boolean
:
320 ShortStr
= "b" + utostr(64 / *Scale
);
322 case ScalarTypeKind::Float
:
323 ShortStr
= "f" + utostr(ElementBitwidth
);
325 case ScalarTypeKind::SignedInteger
:
326 ShortStr
= "i" + utostr(ElementBitwidth
);
328 case ScalarTypeKind::UnsignedInteger
:
329 ShortStr
= "u" + utostr(ElementBitwidth
);
332 llvm_unreachable("Unhandled case!");
335 ShortStr
+= LMUL
.str();
337 ShortStr
+= "x" + utostr(NF
);
340 static VectorTypeModifier
getTupleVTM(unsigned NF
) {
341 assert(2 <= NF
&& NF
<= 8 && "2 <= NF <= 8");
342 return static_cast<VectorTypeModifier
>(
343 static_cast<uint8_t>(VectorTypeModifier::Tuple2
) + (NF
- 2));
346 void RVVType::applyBasicType() {
348 case BasicType::Int8
:
350 ScalarType
= ScalarTypeKind::SignedInteger
;
352 case BasicType::Int16
:
353 ElementBitwidth
= 16;
354 ScalarType
= ScalarTypeKind::SignedInteger
;
356 case BasicType::Int32
:
357 ElementBitwidth
= 32;
358 ScalarType
= ScalarTypeKind::SignedInteger
;
360 case BasicType::Int64
:
361 ElementBitwidth
= 64;
362 ScalarType
= ScalarTypeKind::SignedInteger
;
364 case BasicType::Float16
:
365 ElementBitwidth
= 16;
366 ScalarType
= ScalarTypeKind::Float
;
368 case BasicType::Float32
:
369 ElementBitwidth
= 32;
370 ScalarType
= ScalarTypeKind::Float
;
372 case BasicType::Float64
:
373 ElementBitwidth
= 64;
374 ScalarType
= ScalarTypeKind::Float
;
377 llvm_unreachable("Unhandled type code!");
379 assert(ElementBitwidth
!= 0 && "Bad element bitwidth!");
382 std::optional
<PrototypeDescriptor
>
383 PrototypeDescriptor::parsePrototypeDescriptor(
384 llvm::StringRef PrototypeDescriptorStr
) {
385 PrototypeDescriptor PD
;
386 BaseTypeModifier PT
= BaseTypeModifier::Invalid
;
387 VectorTypeModifier VTM
= VectorTypeModifier::NoModifier
;
389 if (PrototypeDescriptorStr
.empty())
392 // Handle base type modifier
393 auto PType
= PrototypeDescriptorStr
.back();
396 PT
= BaseTypeModifier::Scalar
;
399 PT
= BaseTypeModifier::Vector
;
402 PT
= BaseTypeModifier::Vector
;
403 VTM
= VectorTypeModifier::Widening2XVector
;
406 PT
= BaseTypeModifier::Vector
;
407 VTM
= VectorTypeModifier::Widening4XVector
;
410 PT
= BaseTypeModifier::Vector
;
411 VTM
= VectorTypeModifier::Widening8XVector
;
414 PT
= BaseTypeModifier::Vector
;
415 VTM
= VectorTypeModifier::MaskVector
;
418 PT
= BaseTypeModifier::Void
;
421 PT
= BaseTypeModifier::SizeT
;
424 PT
= BaseTypeModifier::Ptrdiff
;
427 PT
= BaseTypeModifier::UnsignedLong
;
430 PT
= BaseTypeModifier::SignedLong
;
433 PT
= BaseTypeModifier::Float32
;
436 llvm_unreachable("Illegal primitive type transformers!");
438 PD
.PT
= static_cast<uint8_t>(PT
);
439 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_back();
441 // Compute the vector type transformers, it can only appear one time.
442 if (PrototypeDescriptorStr
.startswith("(")) {
443 assert(VTM
== VectorTypeModifier::NoModifier
&&
444 "VectorTypeModifier should only have one modifier");
445 size_t Idx
= PrototypeDescriptorStr
.find(')');
446 assert(Idx
!= StringRef::npos
);
447 StringRef ComplexType
= PrototypeDescriptorStr
.slice(1, Idx
);
448 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_front(Idx
+ 1);
449 assert(!PrototypeDescriptorStr
.contains('(') &&
450 "Only allow one vector type modifier");
452 auto ComplexTT
= ComplexType
.split(":");
453 if (ComplexTT
.first
== "Log2EEW") {
455 if (ComplexTT
.second
.getAsInteger(10, Log2EEW
)) {
456 llvm_unreachable("Invalid Log2EEW value!");
461 VTM
= VectorTypeModifier::Log2EEW3
;
464 VTM
= VectorTypeModifier::Log2EEW4
;
467 VTM
= VectorTypeModifier::Log2EEW5
;
470 VTM
= VectorTypeModifier::Log2EEW6
;
473 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
476 } else if (ComplexTT
.first
== "FixedSEW") {
478 if (ComplexTT
.second
.getAsInteger(10, NewSEW
)) {
479 llvm_unreachable("Invalid FixedSEW value!");
484 VTM
= VectorTypeModifier::FixedSEW8
;
487 VTM
= VectorTypeModifier::FixedSEW16
;
490 VTM
= VectorTypeModifier::FixedSEW32
;
493 VTM
= VectorTypeModifier::FixedSEW64
;
496 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
499 } else if (ComplexTT
.first
== "LFixedLog2LMUL") {
501 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
502 llvm_unreachable("Invalid LFixedLog2LMUL value!");
507 VTM
= VectorTypeModifier::LFixedLog2LMULN3
;
510 VTM
= VectorTypeModifier::LFixedLog2LMULN2
;
513 VTM
= VectorTypeModifier::LFixedLog2LMULN1
;
516 VTM
= VectorTypeModifier::LFixedLog2LMUL0
;
519 VTM
= VectorTypeModifier::LFixedLog2LMUL1
;
522 VTM
= VectorTypeModifier::LFixedLog2LMUL2
;
525 VTM
= VectorTypeModifier::LFixedLog2LMUL3
;
528 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
531 } else if (ComplexTT
.first
== "SFixedLog2LMUL") {
533 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
534 llvm_unreachable("Invalid SFixedLog2LMUL value!");
539 VTM
= VectorTypeModifier::SFixedLog2LMULN3
;
542 VTM
= VectorTypeModifier::SFixedLog2LMULN2
;
545 VTM
= VectorTypeModifier::SFixedLog2LMULN1
;
548 VTM
= VectorTypeModifier::SFixedLog2LMUL0
;
551 VTM
= VectorTypeModifier::SFixedLog2LMUL1
;
554 VTM
= VectorTypeModifier::SFixedLog2LMUL2
;
557 VTM
= VectorTypeModifier::SFixedLog2LMUL3
;
560 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
564 } else if (ComplexTT
.first
== "SEFixedLog2LMUL") {
566 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
567 llvm_unreachable("Invalid SEFixedLog2LMUL value!");
572 VTM
= VectorTypeModifier::SEFixedLog2LMULN3
;
575 VTM
= VectorTypeModifier::SEFixedLog2LMULN2
;
578 VTM
= VectorTypeModifier::SEFixedLog2LMULN1
;
581 VTM
= VectorTypeModifier::SEFixedLog2LMUL0
;
584 VTM
= VectorTypeModifier::SEFixedLog2LMUL1
;
587 VTM
= VectorTypeModifier::SEFixedLog2LMUL2
;
590 VTM
= VectorTypeModifier::SEFixedLog2LMUL3
;
593 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
596 } else if (ComplexTT
.first
== "Tuple") {
598 if (ComplexTT
.second
.getAsInteger(10, NF
)) {
599 llvm_unreachable("Invalid NF value!");
602 VTM
= getTupleVTM(NF
);
604 llvm_unreachable("Illegal complex type transformers!");
607 PD
.VTM
= static_cast<uint8_t>(VTM
);
609 // Compute the remain type transformers
610 TypeModifier TM
= TypeModifier::NoModifier
;
611 for (char I
: PrototypeDescriptorStr
) {
614 if ((TM
& TypeModifier::Const
) == TypeModifier::Const
)
615 llvm_unreachable("'P' transformer cannot be used after 'C'");
616 if ((TM
& TypeModifier::Pointer
) == TypeModifier::Pointer
)
617 llvm_unreachable("'P' transformer cannot be used twice");
618 TM
|= TypeModifier::Pointer
;
621 TM
|= TypeModifier::Const
;
624 TM
|= TypeModifier::Immediate
;
627 TM
|= TypeModifier::UnsignedInteger
;
630 TM
|= TypeModifier::SignedInteger
;
633 TM
|= TypeModifier::Float
;
636 TM
|= TypeModifier::LMUL1
;
639 llvm_unreachable("Illegal non-primitive type transformer!");
642 PD
.TM
= static_cast<uint8_t>(TM
);
647 void RVVType::applyModifier(const PrototypeDescriptor
&Transformer
) {
648 // Handle primitive type transformer
649 switch (static_cast<BaseTypeModifier
>(Transformer
.PT
)) {
650 case BaseTypeModifier::Scalar
:
653 case BaseTypeModifier::Vector
:
654 Scale
= LMUL
.getScale(ElementBitwidth
);
656 case BaseTypeModifier::Void
:
657 ScalarType
= ScalarTypeKind::Void
;
659 case BaseTypeModifier::SizeT
:
660 ScalarType
= ScalarTypeKind::Size_t
;
662 case BaseTypeModifier::Ptrdiff
:
663 ScalarType
= ScalarTypeKind::Ptrdiff_t
;
665 case BaseTypeModifier::UnsignedLong
:
666 ScalarType
= ScalarTypeKind::UnsignedLong
;
668 case BaseTypeModifier::SignedLong
:
669 ScalarType
= ScalarTypeKind::SignedLong
;
671 case BaseTypeModifier::Float32
:
672 ElementBitwidth
= 32;
673 ScalarType
= ScalarTypeKind::Float
;
675 case BaseTypeModifier::Invalid
:
676 ScalarType
= ScalarTypeKind::Invalid
;
680 switch (static_cast<VectorTypeModifier
>(Transformer
.VTM
)) {
681 case VectorTypeModifier::Widening2XVector
:
682 ElementBitwidth
*= 2;
684 Scale
= LMUL
.getScale(ElementBitwidth
);
686 case VectorTypeModifier::Widening4XVector
:
687 ElementBitwidth
*= 4;
689 Scale
= LMUL
.getScale(ElementBitwidth
);
691 case VectorTypeModifier::Widening8XVector
:
692 ElementBitwidth
*= 8;
694 Scale
= LMUL
.getScale(ElementBitwidth
);
696 case VectorTypeModifier::MaskVector
:
697 ScalarType
= ScalarTypeKind::Boolean
;
698 Scale
= LMUL
.getScale(ElementBitwidth
);
701 case VectorTypeModifier::Log2EEW3
:
704 case VectorTypeModifier::Log2EEW4
:
707 case VectorTypeModifier::Log2EEW5
:
710 case VectorTypeModifier::Log2EEW6
:
713 case VectorTypeModifier::FixedSEW8
:
716 case VectorTypeModifier::FixedSEW16
:
719 case VectorTypeModifier::FixedSEW32
:
722 case VectorTypeModifier::FixedSEW64
:
725 case VectorTypeModifier::LFixedLog2LMULN3
:
726 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan
);
728 case VectorTypeModifier::LFixedLog2LMULN2
:
729 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan
);
731 case VectorTypeModifier::LFixedLog2LMULN1
:
732 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan
);
734 case VectorTypeModifier::LFixedLog2LMUL0
:
735 applyFixedLog2LMUL(0, FixedLMULType::LargerThan
);
737 case VectorTypeModifier::LFixedLog2LMUL1
:
738 applyFixedLog2LMUL(1, FixedLMULType::LargerThan
);
740 case VectorTypeModifier::LFixedLog2LMUL2
:
741 applyFixedLog2LMUL(2, FixedLMULType::LargerThan
);
743 case VectorTypeModifier::LFixedLog2LMUL3
:
744 applyFixedLog2LMUL(3, FixedLMULType::LargerThan
);
746 case VectorTypeModifier::SFixedLog2LMULN3
:
747 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan
);
749 case VectorTypeModifier::SFixedLog2LMULN2
:
750 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan
);
752 case VectorTypeModifier::SFixedLog2LMULN1
:
753 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan
);
755 case VectorTypeModifier::SFixedLog2LMUL0
:
756 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan
);
758 case VectorTypeModifier::SFixedLog2LMUL1
:
759 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan
);
761 case VectorTypeModifier::SFixedLog2LMUL2
:
762 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan
);
764 case VectorTypeModifier::SFixedLog2LMUL3
:
765 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan
);
767 case VectorTypeModifier::SEFixedLog2LMULN3
:
768 applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual
);
770 case VectorTypeModifier::SEFixedLog2LMULN2
:
771 applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual
);
773 case VectorTypeModifier::SEFixedLog2LMULN1
:
774 applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual
);
776 case VectorTypeModifier::SEFixedLog2LMUL0
:
777 applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual
);
779 case VectorTypeModifier::SEFixedLog2LMUL1
:
780 applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual
);
782 case VectorTypeModifier::SEFixedLog2LMUL2
:
783 applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual
);
785 case VectorTypeModifier::SEFixedLog2LMUL3
:
786 applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual
);
788 case VectorTypeModifier::Tuple2
:
789 case VectorTypeModifier::Tuple3
:
790 case VectorTypeModifier::Tuple4
:
791 case VectorTypeModifier::Tuple5
:
792 case VectorTypeModifier::Tuple6
:
793 case VectorTypeModifier::Tuple7
:
794 case VectorTypeModifier::Tuple8
: {
796 NF
= 2 + static_cast<uint8_t>(Transformer
.VTM
) -
797 static_cast<uint8_t>(VectorTypeModifier::Tuple2
);
800 case VectorTypeModifier::NoModifier
:
804 // Early return if the current type modifier is already invalid.
805 if (ScalarType
== Invalid
)
808 for (unsigned TypeModifierMaskShift
= 0;
809 TypeModifierMaskShift
<= static_cast<unsigned>(TypeModifier::MaxOffset
);
810 ++TypeModifierMaskShift
) {
811 unsigned TypeModifierMask
= 1 << TypeModifierMaskShift
;
812 if ((static_cast<unsigned>(Transformer
.TM
) & TypeModifierMask
) !=
815 switch (static_cast<TypeModifier
>(TypeModifierMask
)) {
816 case TypeModifier::Pointer
:
819 case TypeModifier::Const
:
822 case TypeModifier::Immediate
:
826 case TypeModifier::UnsignedInteger
:
827 ScalarType
= ScalarTypeKind::UnsignedInteger
;
829 case TypeModifier::SignedInteger
:
830 ScalarType
= ScalarTypeKind::SignedInteger
;
832 case TypeModifier::Float
:
833 ScalarType
= ScalarTypeKind::Float
;
835 case TypeModifier::LMUL1
:
837 // Update ElementBitwidth need to update Scale too.
838 Scale
= LMUL
.getScale(ElementBitwidth
);
841 llvm_unreachable("Unknown type modifier mask!");
846 void RVVType::applyLog2EEW(unsigned Log2EEW
) {
847 // update new elmul = (eew/sew) * lmul
848 LMUL
.MulLog2LMUL(Log2EEW
- Log2_32(ElementBitwidth
));
850 ElementBitwidth
= 1 << Log2EEW
;
851 ScalarType
= ScalarTypeKind::SignedInteger
;
852 Scale
= LMUL
.getScale(ElementBitwidth
);
855 void RVVType::applyFixedSEW(unsigned NewSEW
) {
856 // Set invalid type if src and dst SEW are same.
857 if (ElementBitwidth
== NewSEW
) {
858 ScalarType
= ScalarTypeKind::Invalid
;
862 ElementBitwidth
= NewSEW
;
863 Scale
= LMUL
.getScale(ElementBitwidth
);
866 void RVVType::applyFixedLog2LMUL(int Log2LMUL
, enum FixedLMULType Type
) {
868 case FixedLMULType::LargerThan
:
869 if (Log2LMUL
<= LMUL
.Log2LMUL
) {
870 ScalarType
= ScalarTypeKind::Invalid
;
874 case FixedLMULType::SmallerThan
:
875 if (Log2LMUL
>= LMUL
.Log2LMUL
) {
876 ScalarType
= ScalarTypeKind::Invalid
;
880 case FixedLMULType::SmallerOrEqual
:
881 if (Log2LMUL
> LMUL
.Log2LMUL
) {
882 ScalarType
= ScalarTypeKind::Invalid
;
889 LMUL
= LMULType(Log2LMUL
);
890 Scale
= LMUL
.getScale(ElementBitwidth
);
893 std::optional
<RVVTypes
>
894 RVVTypeCache::computeTypes(BasicType BT
, int Log2LMUL
, unsigned NF
,
895 ArrayRef
<PrototypeDescriptor
> Prototype
) {
897 for (const PrototypeDescriptor
&Proto
: Prototype
) {
898 auto T
= computeType(BT
, Log2LMUL
, Proto
);
901 // Record legal type index
907 // Compute the hash value of RVVType, used for cache the result of computeType.
908 static uint64_t computeRVVTypeHashValue(BasicType BT
, int Log2LMUL
,
909 PrototypeDescriptor Proto
) {
910 // Layout of hash value:
912 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
913 assert(Log2LMUL
>= -3 && Log2LMUL
<= 3);
914 return (Log2LMUL
+ 3) | (static_cast<uint64_t>(BT
) & 0xff) << 8 |
915 ((uint64_t)(Proto
.PT
& 0xff) << 16) |
916 ((uint64_t)(Proto
.TM
& 0xff) << 24) |
917 ((uint64_t)(Proto
.VTM
& 0xff) << 32);
920 std::optional
<RVVTypePtr
> RVVTypeCache::computeType(BasicType BT
, int Log2LMUL
,
921 PrototypeDescriptor Proto
) {
922 uint64_t Idx
= computeRVVTypeHashValue(BT
, Log2LMUL
, Proto
);
924 auto It
= LegalTypes
.find(Idx
);
925 if (It
!= LegalTypes
.end())
926 return &(It
->second
);
928 if (IllegalTypes
.count(Idx
))
931 // Compute type and record the result.
932 RVVType
T(BT
, Log2LMUL
, Proto
);
934 // Record legal type index and value.
935 std::pair
<std::unordered_map
<uint64_t, RVVType
>::iterator
, bool>
936 InsertResult
= LegalTypes
.insert({Idx
, T
});
937 return &(InsertResult
.first
->second
);
939 // Record illegal type index.
940 IllegalTypes
.insert(Idx
);
944 //===----------------------------------------------------------------------===//
945 // RVVIntrinsic implementation
946 //===----------------------------------------------------------------------===//
947 RVVIntrinsic::RVVIntrinsic(
948 StringRef NewName
, StringRef Suffix
, StringRef NewOverloadedName
,
949 StringRef OverloadedSuffix
, StringRef IRName
, bool IsMasked
,
950 bool HasMaskedOffOperand
, bool HasVL
, PolicyScheme Scheme
,
951 bool SupportOverloading
, bool HasBuiltinAlias
, StringRef ManualCodegen
,
952 const RVVTypes
&OutInTypes
, const std::vector
<int64_t> &NewIntrinsicTypes
,
953 const std::vector
<StringRef
> &RequiredFeatures
, unsigned NF
,
954 Policy NewPolicyAttrs
, bool HasFRMRoundModeOp
)
955 : IRName(IRName
), IsMasked(IsMasked
),
956 HasMaskedOffOperand(HasMaskedOffOperand
), HasVL(HasVL
), Scheme(Scheme
),
957 SupportOverloading(SupportOverloading
), HasBuiltinAlias(HasBuiltinAlias
),
958 ManualCodegen(ManualCodegen
.str()), NF(NF
), PolicyAttrs(NewPolicyAttrs
) {
960 // Init BuiltinName, Name and OverloadedName
961 BuiltinName
= NewName
.str();
963 if (NewOverloadedName
.empty())
964 OverloadedName
= NewName
.split("_").first
.str();
966 OverloadedName
= NewOverloadedName
.str();
968 Name
+= "_" + Suffix
.str();
969 if (!OverloadedSuffix
.empty())
970 OverloadedName
+= "_" + OverloadedSuffix
.str();
972 updateNamesAndPolicy(IsMasked
, hasPolicy(), Name
, BuiltinName
, OverloadedName
,
973 PolicyAttrs
, HasFRMRoundModeOp
);
975 // Init OutputType and InputTypes
976 OutputType
= OutInTypes
[0];
977 InputTypes
.assign(OutInTypes
.begin() + 1, OutInTypes
.end());
979 // IntrinsicTypes is unmasked TA version index. Need to update it
980 // if there is merge operand (It is always in first operand).
981 IntrinsicTypes
= NewIntrinsicTypes
;
982 if ((IsMasked
&& hasMaskedOffOperand()) ||
983 (!IsMasked
&& hasPassthruOperand())) {
984 for (auto &I
: IntrinsicTypes
) {
991 std::string
RVVIntrinsic::getBuiltinTypeStr() const {
993 S
+= OutputType
->getBuiltinStr();
994 for (const auto &T
: InputTypes
) {
995 S
+= T
->getBuiltinStr();
1000 std::string
RVVIntrinsic::getSuffixStr(
1001 RVVTypeCache
&TypeCache
, BasicType Type
, int Log2LMUL
,
1002 llvm::ArrayRef
<PrototypeDescriptor
> PrototypeDescriptors
) {
1003 SmallVector
<std::string
> SuffixStrs
;
1004 for (auto PD
: PrototypeDescriptors
) {
1005 auto T
= TypeCache
.computeType(Type
, Log2LMUL
, PD
);
1006 SuffixStrs
.push_back((*T
)->getShortStr());
1008 return join(SuffixStrs
, "_");
1011 llvm::SmallVector
<PrototypeDescriptor
> RVVIntrinsic::computeBuiltinTypes(
1012 llvm::ArrayRef
<PrototypeDescriptor
> Prototype
, bool IsMasked
,
1013 bool HasMaskedOffOperand
, bool HasVL
, unsigned NF
,
1014 PolicyScheme DefaultScheme
, Policy PolicyAttrs
, bool IsTuple
) {
1015 SmallVector
<PrototypeDescriptor
> NewPrototype(Prototype
.begin(),
1017 bool HasPassthruOp
= DefaultScheme
== PolicyScheme::HasPassthruOperand
;
1019 // If HasMaskedOffOperand, insert result type as first input operand if
1021 if (HasMaskedOffOperand
&& !PolicyAttrs
.isTAMAPolicy()) {
1023 NewPrototype
.insert(NewPrototype
.begin() + 1, NewPrototype
[0]);
1024 } else if (NF
> 1) {
1026 PrototypeDescriptor BasePtrOperand
= Prototype
[1];
1027 PrototypeDescriptor MaskoffType
= PrototypeDescriptor(
1028 static_cast<uint8_t>(BaseTypeModifier::Vector
),
1029 static_cast<uint8_t>(getTupleVTM(NF
)),
1030 BasePtrOperand
.TM
& ~static_cast<uint8_t>(TypeModifier::Pointer
));
1031 NewPrototype
.insert(NewPrototype
.begin() + 1, MaskoffType
);
1034 // (void, op0 address, op1 address, ...)
1036 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1037 PrototypeDescriptor MaskoffType
= NewPrototype
[1];
1038 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
1039 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
1043 if (HasMaskedOffOperand
&& NF
> 1) {
1045 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1047 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1050 NewPrototype
.insert(NewPrototype
.begin() + 1,
1051 PrototypeDescriptor::Mask
);
1053 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1,
1054 PrototypeDescriptor::Mask
);
1056 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
1057 NewPrototype
.insert(NewPrototype
.begin() + 1, PrototypeDescriptor::Mask
);
1061 if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
)
1062 NewPrototype
.insert(NewPrototype
.begin(), NewPrototype
[0]);
1063 } else if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
) {
1065 PrototypeDescriptor BasePtrOperand
= Prototype
[0];
1066 PrototypeDescriptor MaskoffType
= PrototypeDescriptor(
1067 static_cast<uint8_t>(BaseTypeModifier::Vector
),
1068 static_cast<uint8_t>(getTupleVTM(NF
)),
1069 BasePtrOperand
.TM
& ~static_cast<uint8_t>(TypeModifier::Pointer
));
1070 NewPrototype
.insert(NewPrototype
.begin(), MaskoffType
);
1072 // NF > 1 cases for segment load operations.
1074 // (void, op0 address, op1 address, ...)
1076 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1077 PrototypeDescriptor MaskoffType
= Prototype
[1];
1078 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
1079 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
1084 // If HasVL, append PrototypeDescriptor:VL to last operand
1086 NewPrototype
.push_back(PrototypeDescriptor::VL
);
1088 return NewPrototype
;
1091 llvm::SmallVector
<Policy
> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1092 return {Policy(Policy::PolicyType::Undisturbed
)}; // TU
1095 llvm::SmallVector
<Policy
>
1096 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy
,
1097 bool HasMaskPolicy
) {
1098 if (HasTailPolicy
&& HasMaskPolicy
)
1099 return {Policy(Policy::PolicyType::Undisturbed
,
1100 Policy::PolicyType::Agnostic
), // TUM
1101 Policy(Policy::PolicyType::Undisturbed
,
1102 Policy::PolicyType::Undisturbed
), // TUMU
1103 Policy(Policy::PolicyType::Agnostic
,
1104 Policy::PolicyType::Undisturbed
)}; // MU
1105 if (HasTailPolicy
&& !HasMaskPolicy
)
1106 return {Policy(Policy::PolicyType::Undisturbed
,
1107 Policy::PolicyType::Agnostic
)}; // TU
1108 if (!HasTailPolicy
&& HasMaskPolicy
)
1109 return {Policy(Policy::PolicyType::Agnostic
,
1110 Policy::PolicyType::Undisturbed
)}; // MU
1111 llvm_unreachable("An RVV instruction should not be without both tail policy "
1115 void RVVIntrinsic::updateNamesAndPolicy(
1116 bool IsMasked
, bool HasPolicy
, std::string
&Name
, std::string
&BuiltinName
,
1117 std::string
&OverloadedName
, Policy
&PolicyAttrs
, bool HasFRMRoundModeOp
) {
1119 auto appendPolicySuffix
= [&](const std::string
&suffix
) {
1121 BuiltinName
+= suffix
;
1122 OverloadedName
+= suffix
;
1125 // This follows the naming guideline under riscv-c-api-doc to add the
1126 // `__riscv_` suffix for all RVV intrinsics.
1127 Name
= "__riscv_" + Name
;
1128 OverloadedName
= "__riscv_" + OverloadedName
;
1130 if (HasFRMRoundModeOp
) {
1132 BuiltinName
+= "_rm";
1136 if (PolicyAttrs
.isTUMUPolicy())
1137 appendPolicySuffix("_tumu");
1138 else if (PolicyAttrs
.isTUMAPolicy())
1139 appendPolicySuffix("_tum");
1140 else if (PolicyAttrs
.isTAMUPolicy())
1141 appendPolicySuffix("_mu");
1142 else if (PolicyAttrs
.isTAMAPolicy()) {
1144 BuiltinName
+= "_m";
1146 llvm_unreachable("Unhandled policy condition");
1148 if (PolicyAttrs
.isTUPolicy())
1149 appendPolicySuffix("_tu");
1150 else if (PolicyAttrs
.isTAPolicy()) // no suffix needed
1153 llvm_unreachable("Unhandled policy condition");
1157 SmallVector
<PrototypeDescriptor
> parsePrototypes(StringRef Prototypes
) {
1158 SmallVector
<PrototypeDescriptor
> PrototypeDescriptors
;
1159 const StringRef
Primaries("evwqom0ztulf");
1160 while (!Prototypes
.empty()) {
1162 // Skip over complex prototype because it could contain primitive type
1164 if (Prototypes
[0] == '(')
1165 Idx
= Prototypes
.find_first_of(')');
1166 Idx
= Prototypes
.find_first_of(Primaries
, Idx
);
1167 assert(Idx
!= StringRef::npos
);
1168 auto PD
= PrototypeDescriptor::parsePrototypeDescriptor(
1169 Prototypes
.slice(0, Idx
+ 1));
1171 llvm_unreachable("Error during parsing prototype.");
1172 PrototypeDescriptors
.push_back(*PD
);
1173 Prototypes
= Prototypes
.drop_front(Idx
+ 1);
1175 return PrototypeDescriptors
;
1178 raw_ostream
&operator<<(raw_ostream
&OS
, const RVVIntrinsicRecord
&Record
) {
1180 OS
<< "\"" << Record
.Name
<< "\",";
1181 if (Record
.OverloadedName
== nullptr ||
1182 StringRef(Record
.OverloadedName
).empty())
1185 OS
<< "\"" << Record
.OverloadedName
<< "\",";
1186 OS
<< Record
.PrototypeIndex
<< ",";
1187 OS
<< Record
.SuffixIndex
<< ",";
1188 OS
<< Record
.OverloadedSuffixIndex
<< ",";
1189 OS
<< (int)Record
.PrototypeLength
<< ",";
1190 OS
<< (int)Record
.SuffixLength
<< ",";
1191 OS
<< (int)Record
.OverloadedSuffixSize
<< ",";
1192 OS
<< (int)Record
.RequiredExtensions
<< ",";
1193 OS
<< (int)Record
.TypeRangeMask
<< ",";
1194 OS
<< (int)Record
.Log2LMULMask
<< ",";
1195 OS
<< (int)Record
.NF
<< ",";
1196 OS
<< (int)Record
.HasMasked
<< ",";
1197 OS
<< (int)Record
.HasVL
<< ",";
1198 OS
<< (int)Record
.HasMaskedOffOperand
<< ",";
1199 OS
<< (int)Record
.HasTailPolicy
<< ",";
1200 OS
<< (int)Record
.HasMaskPolicy
<< ",";
1201 OS
<< (int)Record
.HasFRMRoundModeOp
<< ",";
1202 OS
<< (int)Record
.IsTuple
<< ",";
1203 OS
<< (int)Record
.UnMaskedPolicyScheme
<< ",";
1204 OS
<< (int)Record
.MaskedPolicyScheme
<< ",";
1209 } // end namespace RISCV
1210 } // end namespace clang