1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringSet.h"
14 #include "llvm/ADT/Twine.h"
15 #include "llvm/Support/ErrorHandling.h"
16 #include "llvm/Support/raw_ostream.h"
25 const PrototypeDescriptor
PrototypeDescriptor::Mask
= PrototypeDescriptor(
26 BaseTypeModifier::Vector
, VectorTypeModifier::MaskVector
);
27 const PrototypeDescriptor
PrototypeDescriptor::VL
=
28 PrototypeDescriptor(BaseTypeModifier::SizeT
);
29 const PrototypeDescriptor
PrototypeDescriptor::Vector
=
30 PrototypeDescriptor(BaseTypeModifier::Vector
);
32 //===----------------------------------------------------------------------===//
33 // Type implementation
34 //===----------------------------------------------------------------------===//
36 LMULType::LMULType(int NewLog2LMUL
) {
37 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
38 assert(NewLog2LMUL
<= 3 && NewLog2LMUL
>= -3 && "Bad LMUL number!");
39 Log2LMUL
= NewLog2LMUL
;
42 std::string
LMULType::str() const {
44 return "mf" + utostr(1ULL << (-Log2LMUL
));
45 return "m" + utostr(1ULL << Log2LMUL
);
48 VScaleVal
LMULType::getScale(unsigned ElementBitwidth
) const {
49 int Log2ScaleResult
= 0;
50 switch (ElementBitwidth
) {
54 Log2ScaleResult
= Log2LMUL
+ 3;
57 Log2ScaleResult
= Log2LMUL
+ 2;
60 Log2ScaleResult
= Log2LMUL
+ 1;
63 Log2ScaleResult
= Log2LMUL
;
66 // Illegal vscale result would be less than 1
67 if (Log2ScaleResult
< 0)
69 return 1 << Log2ScaleResult
;
72 void LMULType::MulLog2LMUL(int log2LMUL
) { Log2LMUL
+= log2LMUL
; }
74 RVVType::RVVType(BasicType BT
, int Log2LMUL
,
75 const PrototypeDescriptor
&prototype
)
76 : BT(BT
), LMUL(LMULType(Log2LMUL
)) {
78 applyModifier(prototype
);
84 initClangBuiltinStr();
90 // boolean type are encoded the ratio of n (SEW/LMUL)
91 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
92 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
93 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
95 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
96 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
97 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
98 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
99 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
100 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
101 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
102 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
103 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
104 // bfloat16 | N/A | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16
107 bool RVVType::verifyType() const {
108 if (ScalarType
== Invalid
)
114 if (isFloat() && ElementBitwidth
== 8)
116 if (isBFloat() && ElementBitwidth
!= 16)
118 if (IsTuple
&& (NF
== 1 || NF
> 8))
120 if (IsTuple
&& (1 << std::max(0, LMUL
.Log2LMUL
)) * NF
> 8)
123 switch (ElementBitwidth
) {
126 // Check Scale is 1,2,4,8,16,32,64
127 return (V
<= 64 && isPowerOf2_32(V
));
129 // Check Scale is 1,2,4,8,16,32
130 return (V
<= 32 && isPowerOf2_32(V
));
132 // Check Scale is 1,2,4,8,16
133 return (V
<= 16 && isPowerOf2_32(V
));
135 // Check Scale is 1,2,4,8
136 return (V
<= 8 && isPowerOf2_32(V
));
141 void RVVType::initBuiltinStr() {
142 assert(isValid() && "RVVType is invalid");
143 switch (ScalarType
) {
144 case ScalarTypeKind::Void
:
147 case ScalarTypeKind::Size_t
:
150 BuiltinStr
= "I" + BuiltinStr
;
154 case ScalarTypeKind::Ptrdiff_t
:
157 case ScalarTypeKind::UnsignedLong
:
160 case ScalarTypeKind::SignedLong
:
163 case ScalarTypeKind::Boolean
:
164 assert(ElementBitwidth
== 1);
167 case ScalarTypeKind::SignedInteger
:
168 case ScalarTypeKind::UnsignedInteger
:
169 switch (ElementBitwidth
) {
183 llvm_unreachable("Unhandled ElementBitwidth!");
185 if (isSignedInteger())
186 BuiltinStr
= "S" + BuiltinStr
;
188 BuiltinStr
= "U" + BuiltinStr
;
190 case ScalarTypeKind::Float
:
191 switch (ElementBitwidth
) {
202 llvm_unreachable("Unhandled ElementBitwidth!");
205 case ScalarTypeKind::BFloat
:
209 llvm_unreachable("ScalarType is invalid!");
212 BuiltinStr
= "I" + BuiltinStr
;
220 BuiltinStr
= "q" + utostr(*Scale
) + BuiltinStr
;
221 // Pointer to vector types. Defined for segment load intrinsics.
222 // segment load intrinsics have pointer type arguments to store the loaded
228 BuiltinStr
= "T" + utostr(NF
) + BuiltinStr
;
231 void RVVType::initClangBuiltinStr() {
232 assert(isValid() && "RVVType is invalid");
233 assert(isVector() && "Handle Vector type only");
235 ClangBuiltinStr
= "__rvv_";
236 switch (ScalarType
) {
237 case ScalarTypeKind::Boolean
:
238 ClangBuiltinStr
+= "bool" + utostr(64 / *Scale
) + "_t";
240 case ScalarTypeKind::Float
:
241 ClangBuiltinStr
+= "float";
243 case ScalarTypeKind::BFloat
:
244 ClangBuiltinStr
+= "bfloat";
246 case ScalarTypeKind::SignedInteger
:
247 ClangBuiltinStr
+= "int";
249 case ScalarTypeKind::UnsignedInteger
:
250 ClangBuiltinStr
+= "uint";
253 llvm_unreachable("ScalarTypeKind is invalid");
255 ClangBuiltinStr
+= utostr(ElementBitwidth
) + LMUL
.str() +
256 (IsTuple
? "x" + utostr(NF
) : "") + "_t";
259 void RVVType::initTypeStr() {
260 assert(isValid() && "RVVType is invalid");
265 auto getTypeString
= [&](StringRef TypeStr
) {
267 return Twine(TypeStr
+ Twine(ElementBitwidth
) + "_t").str();
268 return Twine("v" + TypeStr
+ Twine(ElementBitwidth
) + LMUL
.str() +
269 (IsTuple
? "x" + utostr(NF
) : "") + "_t")
273 switch (ScalarType
) {
274 case ScalarTypeKind::Void
:
277 case ScalarTypeKind::Size_t
:
282 case ScalarTypeKind::Ptrdiff_t
:
285 case ScalarTypeKind::UnsignedLong
:
286 Str
= "unsigned long";
288 case ScalarTypeKind::SignedLong
:
291 case ScalarTypeKind::Boolean
:
295 // Vector bool is special case, the formulate is
296 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
297 Str
+= "vbool" + utostr(64 / *Scale
) + "_t";
299 case ScalarTypeKind::Float
:
301 if (ElementBitwidth
== 64)
303 else if (ElementBitwidth
== 32)
305 else if (ElementBitwidth
== 16)
308 llvm_unreachable("Unhandled floating type.");
310 Str
+= getTypeString("float");
312 case ScalarTypeKind::BFloat
:
314 if (ElementBitwidth
== 16)
317 llvm_unreachable("Unhandled floating type.");
319 Str
+= getTypeString("bfloat");
321 case ScalarTypeKind::SignedInteger
:
322 Str
+= getTypeString("int");
324 case ScalarTypeKind::UnsignedInteger
:
325 Str
+= getTypeString("uint");
328 llvm_unreachable("ScalarType is invalid!");
334 void RVVType::initShortStr() {
335 switch (ScalarType
) {
336 case ScalarTypeKind::Boolean
:
338 ShortStr
= "b" + utostr(64 / *Scale
);
340 case ScalarTypeKind::Float
:
341 ShortStr
= "f" + utostr(ElementBitwidth
);
343 case ScalarTypeKind::BFloat
:
344 ShortStr
= "bf" + utostr(ElementBitwidth
);
346 case ScalarTypeKind::SignedInteger
:
347 ShortStr
= "i" + utostr(ElementBitwidth
);
349 case ScalarTypeKind::UnsignedInteger
:
350 ShortStr
= "u" + utostr(ElementBitwidth
);
353 llvm_unreachable("Unhandled case!");
356 ShortStr
+= LMUL
.str();
358 ShortStr
+= "x" + utostr(NF
);
361 static VectorTypeModifier
getTupleVTM(unsigned NF
) {
362 assert(2 <= NF
&& NF
<= 8 && "2 <= NF <= 8");
363 return static_cast<VectorTypeModifier
>(
364 static_cast<uint8_t>(VectorTypeModifier::Tuple2
) + (NF
- 2));
367 void RVVType::applyBasicType() {
369 case BasicType::Int8
:
371 ScalarType
= ScalarTypeKind::SignedInteger
;
373 case BasicType::Int16
:
374 ElementBitwidth
= 16;
375 ScalarType
= ScalarTypeKind::SignedInteger
;
377 case BasicType::Int32
:
378 ElementBitwidth
= 32;
379 ScalarType
= ScalarTypeKind::SignedInteger
;
381 case BasicType::Int64
:
382 ElementBitwidth
= 64;
383 ScalarType
= ScalarTypeKind::SignedInteger
;
385 case BasicType::Float16
:
386 ElementBitwidth
= 16;
387 ScalarType
= ScalarTypeKind::Float
;
389 case BasicType::Float32
:
390 ElementBitwidth
= 32;
391 ScalarType
= ScalarTypeKind::Float
;
393 case BasicType::Float64
:
394 ElementBitwidth
= 64;
395 ScalarType
= ScalarTypeKind::Float
;
397 case BasicType::BFloat16
:
398 ElementBitwidth
= 16;
399 ScalarType
= ScalarTypeKind::BFloat
;
402 llvm_unreachable("Unhandled type code!");
404 assert(ElementBitwidth
!= 0 && "Bad element bitwidth!");
407 std::optional
<PrototypeDescriptor
>
408 PrototypeDescriptor::parsePrototypeDescriptor(
409 llvm::StringRef PrototypeDescriptorStr
) {
410 PrototypeDescriptor PD
;
411 BaseTypeModifier PT
= BaseTypeModifier::Invalid
;
412 VectorTypeModifier VTM
= VectorTypeModifier::NoModifier
;
414 if (PrototypeDescriptorStr
.empty())
417 // Handle base type modifier
418 auto PType
= PrototypeDescriptorStr
.back();
421 PT
= BaseTypeModifier::Scalar
;
424 PT
= BaseTypeModifier::Vector
;
427 PT
= BaseTypeModifier::Vector
;
428 VTM
= VectorTypeModifier::Widening2XVector
;
431 PT
= BaseTypeModifier::Vector
;
432 VTM
= VectorTypeModifier::Widening4XVector
;
435 PT
= BaseTypeModifier::Vector
;
436 VTM
= VectorTypeModifier::Widening8XVector
;
439 PT
= BaseTypeModifier::Vector
;
440 VTM
= VectorTypeModifier::MaskVector
;
443 PT
= BaseTypeModifier::Void
;
446 PT
= BaseTypeModifier::SizeT
;
449 PT
= BaseTypeModifier::Ptrdiff
;
452 PT
= BaseTypeModifier::UnsignedLong
;
455 PT
= BaseTypeModifier::SignedLong
;
458 PT
= BaseTypeModifier::Float32
;
461 llvm_unreachable("Illegal primitive type transformers!");
463 PD
.PT
= static_cast<uint8_t>(PT
);
464 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_back();
466 // Compute the vector type transformers, it can only appear one time.
467 if (PrototypeDescriptorStr
.starts_with("(")) {
468 assert(VTM
== VectorTypeModifier::NoModifier
&&
469 "VectorTypeModifier should only have one modifier");
470 size_t Idx
= PrototypeDescriptorStr
.find(')');
471 assert(Idx
!= StringRef::npos
);
472 StringRef ComplexType
= PrototypeDescriptorStr
.slice(1, Idx
);
473 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_front(Idx
+ 1);
474 assert(!PrototypeDescriptorStr
.contains('(') &&
475 "Only allow one vector type modifier");
477 auto ComplexTT
= ComplexType
.split(":");
478 if (ComplexTT
.first
== "Log2EEW") {
480 if (ComplexTT
.second
.getAsInteger(10, Log2EEW
)) {
481 llvm_unreachable("Invalid Log2EEW value!");
486 VTM
= VectorTypeModifier::Log2EEW3
;
489 VTM
= VectorTypeModifier::Log2EEW4
;
492 VTM
= VectorTypeModifier::Log2EEW5
;
495 VTM
= VectorTypeModifier::Log2EEW6
;
498 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
501 } else if (ComplexTT
.first
== "FixedSEW") {
503 if (ComplexTT
.second
.getAsInteger(10, NewSEW
)) {
504 llvm_unreachable("Invalid FixedSEW value!");
509 VTM
= VectorTypeModifier::FixedSEW8
;
512 VTM
= VectorTypeModifier::FixedSEW16
;
515 VTM
= VectorTypeModifier::FixedSEW32
;
518 VTM
= VectorTypeModifier::FixedSEW64
;
521 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
524 } else if (ComplexTT
.first
== "LFixedLog2LMUL") {
526 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
527 llvm_unreachable("Invalid LFixedLog2LMUL value!");
532 VTM
= VectorTypeModifier::LFixedLog2LMULN3
;
535 VTM
= VectorTypeModifier::LFixedLog2LMULN2
;
538 VTM
= VectorTypeModifier::LFixedLog2LMULN1
;
541 VTM
= VectorTypeModifier::LFixedLog2LMUL0
;
544 VTM
= VectorTypeModifier::LFixedLog2LMUL1
;
547 VTM
= VectorTypeModifier::LFixedLog2LMUL2
;
550 VTM
= VectorTypeModifier::LFixedLog2LMUL3
;
553 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
556 } else if (ComplexTT
.first
== "SFixedLog2LMUL") {
558 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
559 llvm_unreachable("Invalid SFixedLog2LMUL value!");
564 VTM
= VectorTypeModifier::SFixedLog2LMULN3
;
567 VTM
= VectorTypeModifier::SFixedLog2LMULN2
;
570 VTM
= VectorTypeModifier::SFixedLog2LMULN1
;
573 VTM
= VectorTypeModifier::SFixedLog2LMUL0
;
576 VTM
= VectorTypeModifier::SFixedLog2LMUL1
;
579 VTM
= VectorTypeModifier::SFixedLog2LMUL2
;
582 VTM
= VectorTypeModifier::SFixedLog2LMUL3
;
585 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
589 } else if (ComplexTT
.first
== "SEFixedLog2LMUL") {
591 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
592 llvm_unreachable("Invalid SEFixedLog2LMUL value!");
597 VTM
= VectorTypeModifier::SEFixedLog2LMULN3
;
600 VTM
= VectorTypeModifier::SEFixedLog2LMULN2
;
603 VTM
= VectorTypeModifier::SEFixedLog2LMULN1
;
606 VTM
= VectorTypeModifier::SEFixedLog2LMUL0
;
609 VTM
= VectorTypeModifier::SEFixedLog2LMUL1
;
612 VTM
= VectorTypeModifier::SEFixedLog2LMUL2
;
615 VTM
= VectorTypeModifier::SEFixedLog2LMUL3
;
618 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
621 } else if (ComplexTT
.first
== "Tuple") {
623 if (ComplexTT
.second
.getAsInteger(10, NF
)) {
624 llvm_unreachable("Invalid NF value!");
627 VTM
= getTupleVTM(NF
);
629 llvm_unreachable("Illegal complex type transformers!");
632 PD
.VTM
= static_cast<uint8_t>(VTM
);
634 // Compute the remain type transformers
635 TypeModifier TM
= TypeModifier::NoModifier
;
636 for (char I
: PrototypeDescriptorStr
) {
639 if ((TM
& TypeModifier::Const
) == TypeModifier::Const
)
640 llvm_unreachable("'P' transformer cannot be used after 'C'");
641 if ((TM
& TypeModifier::Pointer
) == TypeModifier::Pointer
)
642 llvm_unreachable("'P' transformer cannot be used twice");
643 TM
|= TypeModifier::Pointer
;
646 TM
|= TypeModifier::Const
;
649 TM
|= TypeModifier::Immediate
;
652 TM
|= TypeModifier::UnsignedInteger
;
655 TM
|= TypeModifier::SignedInteger
;
658 TM
|= TypeModifier::Float
;
661 TM
|= TypeModifier::LMUL1
;
664 llvm_unreachable("Illegal non-primitive type transformer!");
667 PD
.TM
= static_cast<uint8_t>(TM
);
672 void RVVType::applyModifier(const PrototypeDescriptor
&Transformer
) {
673 // Handle primitive type transformer
674 switch (static_cast<BaseTypeModifier
>(Transformer
.PT
)) {
675 case BaseTypeModifier::Scalar
:
678 case BaseTypeModifier::Vector
:
679 Scale
= LMUL
.getScale(ElementBitwidth
);
681 case BaseTypeModifier::Void
:
682 ScalarType
= ScalarTypeKind::Void
;
684 case BaseTypeModifier::SizeT
:
685 ScalarType
= ScalarTypeKind::Size_t
;
687 case BaseTypeModifier::Ptrdiff
:
688 ScalarType
= ScalarTypeKind::Ptrdiff_t
;
690 case BaseTypeModifier::UnsignedLong
:
691 ScalarType
= ScalarTypeKind::UnsignedLong
;
693 case BaseTypeModifier::SignedLong
:
694 ScalarType
= ScalarTypeKind::SignedLong
;
696 case BaseTypeModifier::Float32
:
697 ElementBitwidth
= 32;
698 ScalarType
= ScalarTypeKind::Float
;
700 case BaseTypeModifier::Invalid
:
701 ScalarType
= ScalarTypeKind::Invalid
;
705 switch (static_cast<VectorTypeModifier
>(Transformer
.VTM
)) {
706 case VectorTypeModifier::Widening2XVector
:
707 ElementBitwidth
*= 2;
709 Scale
= LMUL
.getScale(ElementBitwidth
);
711 case VectorTypeModifier::Widening4XVector
:
712 ElementBitwidth
*= 4;
714 Scale
= LMUL
.getScale(ElementBitwidth
);
716 case VectorTypeModifier::Widening8XVector
:
717 ElementBitwidth
*= 8;
719 Scale
= LMUL
.getScale(ElementBitwidth
);
721 case VectorTypeModifier::MaskVector
:
722 ScalarType
= ScalarTypeKind::Boolean
;
723 Scale
= LMUL
.getScale(ElementBitwidth
);
726 case VectorTypeModifier::Log2EEW3
:
729 case VectorTypeModifier::Log2EEW4
:
732 case VectorTypeModifier::Log2EEW5
:
735 case VectorTypeModifier::Log2EEW6
:
738 case VectorTypeModifier::FixedSEW8
:
741 case VectorTypeModifier::FixedSEW16
:
744 case VectorTypeModifier::FixedSEW32
:
747 case VectorTypeModifier::FixedSEW64
:
750 case VectorTypeModifier::LFixedLog2LMULN3
:
751 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan
);
753 case VectorTypeModifier::LFixedLog2LMULN2
:
754 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan
);
756 case VectorTypeModifier::LFixedLog2LMULN1
:
757 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan
);
759 case VectorTypeModifier::LFixedLog2LMUL0
:
760 applyFixedLog2LMUL(0, FixedLMULType::LargerThan
);
762 case VectorTypeModifier::LFixedLog2LMUL1
:
763 applyFixedLog2LMUL(1, FixedLMULType::LargerThan
);
765 case VectorTypeModifier::LFixedLog2LMUL2
:
766 applyFixedLog2LMUL(2, FixedLMULType::LargerThan
);
768 case VectorTypeModifier::LFixedLog2LMUL3
:
769 applyFixedLog2LMUL(3, FixedLMULType::LargerThan
);
771 case VectorTypeModifier::SFixedLog2LMULN3
:
772 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan
);
774 case VectorTypeModifier::SFixedLog2LMULN2
:
775 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan
);
777 case VectorTypeModifier::SFixedLog2LMULN1
:
778 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan
);
780 case VectorTypeModifier::SFixedLog2LMUL0
:
781 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan
);
783 case VectorTypeModifier::SFixedLog2LMUL1
:
784 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan
);
786 case VectorTypeModifier::SFixedLog2LMUL2
:
787 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan
);
789 case VectorTypeModifier::SFixedLog2LMUL3
:
790 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan
);
792 case VectorTypeModifier::SEFixedLog2LMULN3
:
793 applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual
);
795 case VectorTypeModifier::SEFixedLog2LMULN2
:
796 applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual
);
798 case VectorTypeModifier::SEFixedLog2LMULN1
:
799 applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual
);
801 case VectorTypeModifier::SEFixedLog2LMUL0
:
802 applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual
);
804 case VectorTypeModifier::SEFixedLog2LMUL1
:
805 applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual
);
807 case VectorTypeModifier::SEFixedLog2LMUL2
:
808 applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual
);
810 case VectorTypeModifier::SEFixedLog2LMUL3
:
811 applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual
);
813 case VectorTypeModifier::Tuple2
:
814 case VectorTypeModifier::Tuple3
:
815 case VectorTypeModifier::Tuple4
:
816 case VectorTypeModifier::Tuple5
:
817 case VectorTypeModifier::Tuple6
:
818 case VectorTypeModifier::Tuple7
:
819 case VectorTypeModifier::Tuple8
: {
821 NF
= 2 + static_cast<uint8_t>(Transformer
.VTM
) -
822 static_cast<uint8_t>(VectorTypeModifier::Tuple2
);
825 case VectorTypeModifier::NoModifier
:
829 // Early return if the current type modifier is already invalid.
830 if (ScalarType
== Invalid
)
833 for (unsigned TypeModifierMaskShift
= 0;
834 TypeModifierMaskShift
<= static_cast<unsigned>(TypeModifier::MaxOffset
);
835 ++TypeModifierMaskShift
) {
836 unsigned TypeModifierMask
= 1 << TypeModifierMaskShift
;
837 if ((static_cast<unsigned>(Transformer
.TM
) & TypeModifierMask
) !=
840 switch (static_cast<TypeModifier
>(TypeModifierMask
)) {
841 case TypeModifier::Pointer
:
844 case TypeModifier::Const
:
847 case TypeModifier::Immediate
:
851 case TypeModifier::UnsignedInteger
:
852 ScalarType
= ScalarTypeKind::UnsignedInteger
;
854 case TypeModifier::SignedInteger
:
855 ScalarType
= ScalarTypeKind::SignedInteger
;
857 case TypeModifier::Float
:
858 ScalarType
= ScalarTypeKind::Float
;
860 case TypeModifier::BFloat
:
861 ScalarType
= ScalarTypeKind::BFloat
;
863 case TypeModifier::LMUL1
:
865 // Update ElementBitwidth need to update Scale too.
866 Scale
= LMUL
.getScale(ElementBitwidth
);
869 llvm_unreachable("Unknown type modifier mask!");
874 void RVVType::applyLog2EEW(unsigned Log2EEW
) {
875 // update new elmul = (eew/sew) * lmul
876 LMUL
.MulLog2LMUL(Log2EEW
- Log2_32(ElementBitwidth
));
878 ElementBitwidth
= 1 << Log2EEW
;
879 ScalarType
= ScalarTypeKind::SignedInteger
;
880 Scale
= LMUL
.getScale(ElementBitwidth
);
883 void RVVType::applyFixedSEW(unsigned NewSEW
) {
884 // Set invalid type if src and dst SEW are same.
885 if (ElementBitwidth
== NewSEW
) {
886 ScalarType
= ScalarTypeKind::Invalid
;
890 ElementBitwidth
= NewSEW
;
891 Scale
= LMUL
.getScale(ElementBitwidth
);
894 void RVVType::applyFixedLog2LMUL(int Log2LMUL
, enum FixedLMULType Type
) {
896 case FixedLMULType::LargerThan
:
897 if (Log2LMUL
<= LMUL
.Log2LMUL
) {
898 ScalarType
= ScalarTypeKind::Invalid
;
902 case FixedLMULType::SmallerThan
:
903 if (Log2LMUL
>= LMUL
.Log2LMUL
) {
904 ScalarType
= ScalarTypeKind::Invalid
;
908 case FixedLMULType::SmallerOrEqual
:
909 if (Log2LMUL
> LMUL
.Log2LMUL
) {
910 ScalarType
= ScalarTypeKind::Invalid
;
917 LMUL
= LMULType(Log2LMUL
);
918 Scale
= LMUL
.getScale(ElementBitwidth
);
921 std::optional
<RVVTypes
>
922 RVVTypeCache::computeTypes(BasicType BT
, int Log2LMUL
, unsigned NF
,
923 ArrayRef
<PrototypeDescriptor
> Prototype
) {
925 for (const PrototypeDescriptor
&Proto
: Prototype
) {
926 auto T
= computeType(BT
, Log2LMUL
, Proto
);
929 // Record legal type index
935 // Compute the hash value of RVVType, used for cache the result of computeType.
936 static uint64_t computeRVVTypeHashValue(BasicType BT
, int Log2LMUL
,
937 PrototypeDescriptor Proto
) {
938 // Layout of hash value:
940 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
941 assert(Log2LMUL
>= -3 && Log2LMUL
<= 3);
942 return (Log2LMUL
+ 3) | (static_cast<uint64_t>(BT
) & 0xff) << 8 |
943 ((uint64_t)(Proto
.PT
& 0xff) << 16) |
944 ((uint64_t)(Proto
.TM
& 0xff) << 24) |
945 ((uint64_t)(Proto
.VTM
& 0xff) << 32);
948 std::optional
<RVVTypePtr
> RVVTypeCache::computeType(BasicType BT
, int Log2LMUL
,
949 PrototypeDescriptor Proto
) {
950 uint64_t Idx
= computeRVVTypeHashValue(BT
, Log2LMUL
, Proto
);
952 auto It
= LegalTypes
.find(Idx
);
953 if (It
!= LegalTypes
.end())
954 return &(It
->second
);
956 if (IllegalTypes
.count(Idx
))
959 // Compute type and record the result.
960 RVVType
T(BT
, Log2LMUL
, Proto
);
962 // Record legal type index and value.
963 std::pair
<std::unordered_map
<uint64_t, RVVType
>::iterator
, bool>
964 InsertResult
= LegalTypes
.insert({Idx
, T
});
965 return &(InsertResult
.first
->second
);
967 // Record illegal type index.
968 IllegalTypes
.insert(Idx
);
972 //===----------------------------------------------------------------------===//
973 // RVVIntrinsic implementation
974 //===----------------------------------------------------------------------===//
975 RVVIntrinsic::RVVIntrinsic(
976 StringRef NewName
, StringRef Suffix
, StringRef NewOverloadedName
,
977 StringRef OverloadedSuffix
, StringRef IRName
, bool IsMasked
,
978 bool HasMaskedOffOperand
, bool HasVL
, PolicyScheme Scheme
,
979 bool SupportOverloading
, bool HasBuiltinAlias
, StringRef ManualCodegen
,
980 const RVVTypes
&OutInTypes
, const std::vector
<int64_t> &NewIntrinsicTypes
,
981 const std::vector
<StringRef
> &RequiredFeatures
, unsigned NF
,
982 Policy NewPolicyAttrs
, bool HasFRMRoundModeOp
)
983 : IRName(IRName
), IsMasked(IsMasked
),
984 HasMaskedOffOperand(HasMaskedOffOperand
), HasVL(HasVL
), Scheme(Scheme
),
985 SupportOverloading(SupportOverloading
), HasBuiltinAlias(HasBuiltinAlias
),
986 ManualCodegen(ManualCodegen
.str()), NF(NF
), PolicyAttrs(NewPolicyAttrs
) {
988 // Init BuiltinName, Name and OverloadedName
989 BuiltinName
= NewName
.str();
991 if (NewOverloadedName
.empty())
992 OverloadedName
= NewName
.split("_").first
.str();
994 OverloadedName
= NewOverloadedName
.str();
996 Name
+= "_" + Suffix
.str();
997 if (!OverloadedSuffix
.empty())
998 OverloadedName
+= "_" + OverloadedSuffix
.str();
1000 updateNamesAndPolicy(IsMasked
, hasPolicy(), Name
, BuiltinName
, OverloadedName
,
1001 PolicyAttrs
, HasFRMRoundModeOp
);
1003 // Init OutputType and InputTypes
1004 OutputType
= OutInTypes
[0];
1005 InputTypes
.assign(OutInTypes
.begin() + 1, OutInTypes
.end());
1007 // IntrinsicTypes is unmasked TA version index. Need to update it
1008 // if there is merge operand (It is always in first operand).
1009 IntrinsicTypes
= NewIntrinsicTypes
;
1010 if ((IsMasked
&& hasMaskedOffOperand()) ||
1011 (!IsMasked
&& hasPassthruOperand())) {
1012 for (auto &I
: IntrinsicTypes
) {
1019 std::string
RVVIntrinsic::getBuiltinTypeStr() const {
1021 S
+= OutputType
->getBuiltinStr();
1022 for (const auto &T
: InputTypes
) {
1023 S
+= T
->getBuiltinStr();
1028 std::string
RVVIntrinsic::getSuffixStr(
1029 RVVTypeCache
&TypeCache
, BasicType Type
, int Log2LMUL
,
1030 llvm::ArrayRef
<PrototypeDescriptor
> PrototypeDescriptors
) {
1031 SmallVector
<std::string
> SuffixStrs
;
1032 for (auto PD
: PrototypeDescriptors
) {
1033 auto T
= TypeCache
.computeType(Type
, Log2LMUL
, PD
);
1034 SuffixStrs
.push_back((*T
)->getShortStr());
1036 return join(SuffixStrs
, "_");
1039 llvm::SmallVector
<PrototypeDescriptor
> RVVIntrinsic::computeBuiltinTypes(
1040 llvm::ArrayRef
<PrototypeDescriptor
> Prototype
, bool IsMasked
,
1041 bool HasMaskedOffOperand
, bool HasVL
, unsigned NF
,
1042 PolicyScheme DefaultScheme
, Policy PolicyAttrs
, bool IsTuple
) {
1043 SmallVector
<PrototypeDescriptor
> NewPrototype(Prototype
.begin(),
1045 bool HasPassthruOp
= DefaultScheme
== PolicyScheme::HasPassthruOperand
;
1047 // If HasMaskedOffOperand, insert result type as first input operand if
1049 if (HasMaskedOffOperand
&& !PolicyAttrs
.isTAMAPolicy()) {
1051 NewPrototype
.insert(NewPrototype
.begin() + 1, NewPrototype
[0]);
1052 } else if (NF
> 1) {
1054 PrototypeDescriptor BasePtrOperand
= Prototype
[1];
1055 PrototypeDescriptor MaskoffType
= PrototypeDescriptor(
1056 static_cast<uint8_t>(BaseTypeModifier::Vector
),
1057 static_cast<uint8_t>(getTupleVTM(NF
)),
1058 BasePtrOperand
.TM
& ~static_cast<uint8_t>(TypeModifier::Pointer
));
1059 NewPrototype
.insert(NewPrototype
.begin() + 1, MaskoffType
);
1062 // (void, op0 address, op1 address, ...)
1064 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1065 PrototypeDescriptor MaskoffType
= NewPrototype
[1];
1066 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
1067 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
1071 if (HasMaskedOffOperand
&& NF
> 1) {
1073 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1075 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1078 NewPrototype
.insert(NewPrototype
.begin() + 1,
1079 PrototypeDescriptor::Mask
);
1081 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1,
1082 PrototypeDescriptor::Mask
);
1084 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
1085 NewPrototype
.insert(NewPrototype
.begin() + 1, PrototypeDescriptor::Mask
);
1089 if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
)
1090 NewPrototype
.insert(NewPrototype
.begin(), NewPrototype
[0]);
1091 } else if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
) {
1093 PrototypeDescriptor BasePtrOperand
= Prototype
[0];
1094 PrototypeDescriptor MaskoffType
= PrototypeDescriptor(
1095 static_cast<uint8_t>(BaseTypeModifier::Vector
),
1096 static_cast<uint8_t>(getTupleVTM(NF
)),
1097 BasePtrOperand
.TM
& ~static_cast<uint8_t>(TypeModifier::Pointer
));
1098 NewPrototype
.insert(NewPrototype
.begin(), MaskoffType
);
1100 // NF > 1 cases for segment load operations.
1102 // (void, op0 address, op1 address, ...)
1104 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1105 PrototypeDescriptor MaskoffType
= Prototype
[1];
1106 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
1107 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
1112 // If HasVL, append PrototypeDescriptor:VL to last operand
1114 NewPrototype
.push_back(PrototypeDescriptor::VL
);
1116 return NewPrototype
;
1119 llvm::SmallVector
<Policy
> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1120 return {Policy(Policy::PolicyType::Undisturbed
)}; // TU
1123 llvm::SmallVector
<Policy
>
1124 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy
,
1125 bool HasMaskPolicy
) {
1126 if (HasTailPolicy
&& HasMaskPolicy
)
1127 return {Policy(Policy::PolicyType::Undisturbed
,
1128 Policy::PolicyType::Agnostic
), // TUM
1129 Policy(Policy::PolicyType::Undisturbed
,
1130 Policy::PolicyType::Undisturbed
), // TUMU
1131 Policy(Policy::PolicyType::Agnostic
,
1132 Policy::PolicyType::Undisturbed
)}; // MU
1133 if (HasTailPolicy
&& !HasMaskPolicy
)
1134 return {Policy(Policy::PolicyType::Undisturbed
,
1135 Policy::PolicyType::Agnostic
)}; // TU
1136 if (!HasTailPolicy
&& HasMaskPolicy
)
1137 return {Policy(Policy::PolicyType::Agnostic
,
1138 Policy::PolicyType::Undisturbed
)}; // MU
1139 llvm_unreachable("An RVV instruction should not be without both tail policy "
1143 void RVVIntrinsic::updateNamesAndPolicy(
1144 bool IsMasked
, bool HasPolicy
, std::string
&Name
, std::string
&BuiltinName
,
1145 std::string
&OverloadedName
, Policy
&PolicyAttrs
, bool HasFRMRoundModeOp
) {
1147 auto appendPolicySuffix
= [&](const std::string
&suffix
) {
1149 BuiltinName
+= suffix
;
1150 OverloadedName
+= suffix
;
1153 // This follows the naming guideline under riscv-c-api-doc to add the
1154 // `__riscv_` suffix for all RVV intrinsics.
1155 Name
= "__riscv_" + Name
;
1156 OverloadedName
= "__riscv_" + OverloadedName
;
1158 if (HasFRMRoundModeOp
) {
1160 BuiltinName
+= "_rm";
1164 if (PolicyAttrs
.isTUMUPolicy())
1165 appendPolicySuffix("_tumu");
1166 else if (PolicyAttrs
.isTUMAPolicy())
1167 appendPolicySuffix("_tum");
1168 else if (PolicyAttrs
.isTAMUPolicy())
1169 appendPolicySuffix("_mu");
1170 else if (PolicyAttrs
.isTAMAPolicy()) {
1172 BuiltinName
+= "_m";
1174 llvm_unreachable("Unhandled policy condition");
1176 if (PolicyAttrs
.isTUPolicy())
1177 appendPolicySuffix("_tu");
1178 else if (PolicyAttrs
.isTAPolicy()) // no suffix needed
1181 llvm_unreachable("Unhandled policy condition");
1185 SmallVector
<PrototypeDescriptor
> parsePrototypes(StringRef Prototypes
) {
1186 SmallVector
<PrototypeDescriptor
> PrototypeDescriptors
;
1187 const StringRef
Primaries("evwqom0ztulf");
1188 while (!Prototypes
.empty()) {
1190 // Skip over complex prototype because it could contain primitive type
1192 if (Prototypes
[0] == '(')
1193 Idx
= Prototypes
.find_first_of(')');
1194 Idx
= Prototypes
.find_first_of(Primaries
, Idx
);
1195 assert(Idx
!= StringRef::npos
);
1196 auto PD
= PrototypeDescriptor::parsePrototypeDescriptor(
1197 Prototypes
.slice(0, Idx
+ 1));
1199 llvm_unreachable("Error during parsing prototype.");
1200 PrototypeDescriptors
.push_back(*PD
);
1201 Prototypes
= Prototypes
.drop_front(Idx
+ 1);
1203 return PrototypeDescriptors
;
1206 raw_ostream
&operator<<(raw_ostream
&OS
, const RVVIntrinsicRecord
&Record
) {
1208 OS
<< "\"" << Record
.Name
<< "\",";
1209 if (Record
.OverloadedName
== nullptr ||
1210 StringRef(Record
.OverloadedName
).empty())
1213 OS
<< "\"" << Record
.OverloadedName
<< "\",";
1214 OS
<< Record
.PrototypeIndex
<< ",";
1215 OS
<< Record
.SuffixIndex
<< ",";
1216 OS
<< Record
.OverloadedSuffixIndex
<< ",";
1217 OS
<< (int)Record
.PrototypeLength
<< ",";
1218 OS
<< (int)Record
.SuffixLength
<< ",";
1219 OS
<< (int)Record
.OverloadedSuffixSize
<< ",";
1220 OS
<< Record
.RequiredExtensions
<< ",";
1221 OS
<< (int)Record
.TypeRangeMask
<< ",";
1222 OS
<< (int)Record
.Log2LMULMask
<< ",";
1223 OS
<< (int)Record
.NF
<< ",";
1224 OS
<< (int)Record
.HasMasked
<< ",";
1225 OS
<< (int)Record
.HasVL
<< ",";
1226 OS
<< (int)Record
.HasMaskedOffOperand
<< ",";
1227 OS
<< (int)Record
.HasTailPolicy
<< ",";
1228 OS
<< (int)Record
.HasMaskPolicy
<< ",";
1229 OS
<< (int)Record
.HasFRMRoundModeOp
<< ",";
1230 OS
<< (int)Record
.IsTuple
<< ",";
1231 OS
<< (int)Record
.UnMaskedPolicyScheme
<< ",";
1232 OS
<< (int)Record
.MaskedPolicyScheme
<< ",";
1237 } // end namespace RISCV
1238 } // end namespace clang