1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringMap.h"
14 #include "llvm/ADT/StringSet.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/ErrorHandling.h"
17 #include "llvm/Support/raw_ostream.h"
26 const PrototypeDescriptor
PrototypeDescriptor::Mask
= PrototypeDescriptor(
27 BaseTypeModifier::Vector
, VectorTypeModifier::MaskVector
);
28 const PrototypeDescriptor
PrototypeDescriptor::VL
=
29 PrototypeDescriptor(BaseTypeModifier::SizeT
);
30 const PrototypeDescriptor
PrototypeDescriptor::Vector
=
31 PrototypeDescriptor(BaseTypeModifier::Vector
);
33 //===----------------------------------------------------------------------===//
34 // Type implementation
35 //===----------------------------------------------------------------------===//
37 LMULType::LMULType(int NewLog2LMUL
) {
38 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
39 assert(NewLog2LMUL
<= 3 && NewLog2LMUL
>= -3 && "Bad LMUL number!");
40 Log2LMUL
= NewLog2LMUL
;
43 std::string
LMULType::str() const {
45 return "mf" + utostr(1ULL << (-Log2LMUL
));
46 return "m" + utostr(1ULL << Log2LMUL
);
49 VScaleVal
LMULType::getScale(unsigned ElementBitwidth
) const {
50 int Log2ScaleResult
= 0;
51 switch (ElementBitwidth
) {
55 Log2ScaleResult
= Log2LMUL
+ 3;
58 Log2ScaleResult
= Log2LMUL
+ 2;
61 Log2ScaleResult
= Log2LMUL
+ 1;
64 Log2ScaleResult
= Log2LMUL
;
67 // Illegal vscale result would be less than 1
68 if (Log2ScaleResult
< 0)
70 return 1 << Log2ScaleResult
;
73 void LMULType::MulLog2LMUL(int log2LMUL
) { Log2LMUL
+= log2LMUL
; }
75 RVVType::RVVType(BasicType BT
, int Log2LMUL
,
76 const PrototypeDescriptor
&prototype
)
77 : BT(BT
), LMUL(LMULType(Log2LMUL
)) {
79 applyModifier(prototype
);
85 initClangBuiltinStr();
91 // boolean type are encoded the ratio of n (SEW/LMUL)
92 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
93 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
94 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
96 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
97 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
98 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
99 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
100 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
101 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
102 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
103 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
104 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
107 bool RVVType::verifyType() const {
108 if (ScalarType
== Invalid
)
114 if (isFloat() && ElementBitwidth
== 8)
117 switch (ElementBitwidth
) {
120 // Check Scale is 1,2,4,8,16,32,64
121 return (V
<= 64 && isPowerOf2_32(V
));
123 // Check Scale is 1,2,4,8,16,32
124 return (V
<= 32 && isPowerOf2_32(V
));
126 // Check Scale is 1,2,4,8,16
127 return (V
<= 16 && isPowerOf2_32(V
));
129 // Check Scale is 1,2,4,8
130 return (V
<= 8 && isPowerOf2_32(V
));
135 void RVVType::initBuiltinStr() {
136 assert(isValid() && "RVVType is invalid");
137 switch (ScalarType
) {
138 case ScalarTypeKind::Void
:
141 case ScalarTypeKind::Size_t
:
144 BuiltinStr
= "I" + BuiltinStr
;
148 case ScalarTypeKind::Ptrdiff_t
:
151 case ScalarTypeKind::UnsignedLong
:
154 case ScalarTypeKind::SignedLong
:
157 case ScalarTypeKind::Boolean
:
158 assert(ElementBitwidth
== 1);
161 case ScalarTypeKind::SignedInteger
:
162 case ScalarTypeKind::UnsignedInteger
:
163 switch (ElementBitwidth
) {
177 llvm_unreachable("Unhandled ElementBitwidth!");
179 if (isSignedInteger())
180 BuiltinStr
= "S" + BuiltinStr
;
182 BuiltinStr
= "U" + BuiltinStr
;
184 case ScalarTypeKind::Float
:
185 switch (ElementBitwidth
) {
196 llvm_unreachable("Unhandled ElementBitwidth!");
200 llvm_unreachable("ScalarType is invalid!");
203 BuiltinStr
= "I" + BuiltinStr
;
211 BuiltinStr
= "q" + utostr(*Scale
) + BuiltinStr
;
212 // Pointer to vector types. Defined for segment load intrinsics.
213 // segment load intrinsics have pointer type arguments to store the loaded
219 void RVVType::initClangBuiltinStr() {
220 assert(isValid() && "RVVType is invalid");
221 assert(isVector() && "Handle Vector type only");
223 ClangBuiltinStr
= "__rvv_";
224 switch (ScalarType
) {
225 case ScalarTypeKind::Boolean
:
226 ClangBuiltinStr
+= "bool" + utostr(64 / *Scale
) + "_t";
228 case ScalarTypeKind::Float
:
229 ClangBuiltinStr
+= "float";
231 case ScalarTypeKind::SignedInteger
:
232 ClangBuiltinStr
+= "int";
234 case ScalarTypeKind::UnsignedInteger
:
235 ClangBuiltinStr
+= "uint";
238 llvm_unreachable("ScalarTypeKind is invalid");
240 ClangBuiltinStr
+= utostr(ElementBitwidth
) + LMUL
.str() + "_t";
243 void RVVType::initTypeStr() {
244 assert(isValid() && "RVVType is invalid");
249 auto getTypeString
= [&](StringRef TypeStr
) {
251 return Twine(TypeStr
+ Twine(ElementBitwidth
) + "_t").str();
252 return Twine("v" + TypeStr
+ Twine(ElementBitwidth
) + LMUL
.str() + "_t")
256 switch (ScalarType
) {
257 case ScalarTypeKind::Void
:
260 case ScalarTypeKind::Size_t
:
265 case ScalarTypeKind::Ptrdiff_t
:
268 case ScalarTypeKind::UnsignedLong
:
269 Str
= "unsigned long";
271 case ScalarTypeKind::SignedLong
:
274 case ScalarTypeKind::Boolean
:
278 // Vector bool is special case, the formulate is
279 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
280 Str
+= "vbool" + utostr(64 / *Scale
) + "_t";
282 case ScalarTypeKind::Float
:
284 if (ElementBitwidth
== 64)
286 else if (ElementBitwidth
== 32)
288 else if (ElementBitwidth
== 16)
291 llvm_unreachable("Unhandled floating type.");
293 Str
+= getTypeString("float");
295 case ScalarTypeKind::SignedInteger
:
296 Str
+= getTypeString("int");
298 case ScalarTypeKind::UnsignedInteger
:
299 Str
+= getTypeString("uint");
302 llvm_unreachable("ScalarType is invalid!");
308 void RVVType::initShortStr() {
309 switch (ScalarType
) {
310 case ScalarTypeKind::Boolean
:
312 ShortStr
= "b" + utostr(64 / *Scale
);
314 case ScalarTypeKind::Float
:
315 ShortStr
= "f" + utostr(ElementBitwidth
);
317 case ScalarTypeKind::SignedInteger
:
318 ShortStr
= "i" + utostr(ElementBitwidth
);
320 case ScalarTypeKind::UnsignedInteger
:
321 ShortStr
= "u" + utostr(ElementBitwidth
);
324 llvm_unreachable("Unhandled case!");
327 ShortStr
+= LMUL
.str();
330 void RVVType::applyBasicType() {
332 case BasicType::Int8
:
334 ScalarType
= ScalarTypeKind::SignedInteger
;
336 case BasicType::Int16
:
337 ElementBitwidth
= 16;
338 ScalarType
= ScalarTypeKind::SignedInteger
;
340 case BasicType::Int32
:
341 ElementBitwidth
= 32;
342 ScalarType
= ScalarTypeKind::SignedInteger
;
344 case BasicType::Int64
:
345 ElementBitwidth
= 64;
346 ScalarType
= ScalarTypeKind::SignedInteger
;
348 case BasicType::Float16
:
349 ElementBitwidth
= 16;
350 ScalarType
= ScalarTypeKind::Float
;
352 case BasicType::Float32
:
353 ElementBitwidth
= 32;
354 ScalarType
= ScalarTypeKind::Float
;
356 case BasicType::Float64
:
357 ElementBitwidth
= 64;
358 ScalarType
= ScalarTypeKind::Float
;
361 llvm_unreachable("Unhandled type code!");
363 assert(ElementBitwidth
!= 0 && "Bad element bitwidth!");
366 std::optional
<PrototypeDescriptor
>
367 PrototypeDescriptor::parsePrototypeDescriptor(
368 llvm::StringRef PrototypeDescriptorStr
) {
369 PrototypeDescriptor PD
;
370 BaseTypeModifier PT
= BaseTypeModifier::Invalid
;
371 VectorTypeModifier VTM
= VectorTypeModifier::NoModifier
;
373 if (PrototypeDescriptorStr
.empty())
376 // Handle base type modifier
377 auto PType
= PrototypeDescriptorStr
.back();
380 PT
= BaseTypeModifier::Scalar
;
383 PT
= BaseTypeModifier::Vector
;
386 PT
= BaseTypeModifier::Vector
;
387 VTM
= VectorTypeModifier::Widening2XVector
;
390 PT
= BaseTypeModifier::Vector
;
391 VTM
= VectorTypeModifier::Widening4XVector
;
394 PT
= BaseTypeModifier::Vector
;
395 VTM
= VectorTypeModifier::Widening8XVector
;
398 PT
= BaseTypeModifier::Vector
;
399 VTM
= VectorTypeModifier::MaskVector
;
402 PT
= BaseTypeModifier::Void
;
405 PT
= BaseTypeModifier::SizeT
;
408 PT
= BaseTypeModifier::Ptrdiff
;
411 PT
= BaseTypeModifier::UnsignedLong
;
414 PT
= BaseTypeModifier::SignedLong
;
417 llvm_unreachable("Illegal primitive type transformers!");
419 PD
.PT
= static_cast<uint8_t>(PT
);
420 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_back();
422 // Compute the vector type transformers, it can only appear one time.
423 if (PrototypeDescriptorStr
.startswith("(")) {
424 assert(VTM
== VectorTypeModifier::NoModifier
&&
425 "VectorTypeModifier should only have one modifier");
426 size_t Idx
= PrototypeDescriptorStr
.find(')');
427 assert(Idx
!= StringRef::npos
);
428 StringRef ComplexType
= PrototypeDescriptorStr
.slice(1, Idx
);
429 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_front(Idx
+ 1);
430 assert(!PrototypeDescriptorStr
.contains('(') &&
431 "Only allow one vector type modifier");
433 auto ComplexTT
= ComplexType
.split(":");
434 if (ComplexTT
.first
== "Log2EEW") {
436 if (ComplexTT
.second
.getAsInteger(10, Log2EEW
)) {
437 llvm_unreachable("Invalid Log2EEW value!");
442 VTM
= VectorTypeModifier::Log2EEW3
;
445 VTM
= VectorTypeModifier::Log2EEW4
;
448 VTM
= VectorTypeModifier::Log2EEW5
;
451 VTM
= VectorTypeModifier::Log2EEW6
;
454 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
457 } else if (ComplexTT
.first
== "FixedSEW") {
459 if (ComplexTT
.second
.getAsInteger(10, NewSEW
)) {
460 llvm_unreachable("Invalid FixedSEW value!");
465 VTM
= VectorTypeModifier::FixedSEW8
;
468 VTM
= VectorTypeModifier::FixedSEW16
;
471 VTM
= VectorTypeModifier::FixedSEW32
;
474 VTM
= VectorTypeModifier::FixedSEW64
;
477 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
480 } else if (ComplexTT
.first
== "LFixedLog2LMUL") {
482 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
483 llvm_unreachable("Invalid LFixedLog2LMUL value!");
488 VTM
= VectorTypeModifier::LFixedLog2LMULN3
;
491 VTM
= VectorTypeModifier::LFixedLog2LMULN2
;
494 VTM
= VectorTypeModifier::LFixedLog2LMULN1
;
497 VTM
= VectorTypeModifier::LFixedLog2LMUL0
;
500 VTM
= VectorTypeModifier::LFixedLog2LMUL1
;
503 VTM
= VectorTypeModifier::LFixedLog2LMUL2
;
506 VTM
= VectorTypeModifier::LFixedLog2LMUL3
;
509 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
512 } else if (ComplexTT
.first
== "SFixedLog2LMUL") {
514 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
515 llvm_unreachable("Invalid SFixedLog2LMUL value!");
520 VTM
= VectorTypeModifier::SFixedLog2LMULN3
;
523 VTM
= VectorTypeModifier::SFixedLog2LMULN2
;
526 VTM
= VectorTypeModifier::SFixedLog2LMULN1
;
529 VTM
= VectorTypeModifier::SFixedLog2LMUL0
;
532 VTM
= VectorTypeModifier::SFixedLog2LMUL1
;
535 VTM
= VectorTypeModifier::SFixedLog2LMUL2
;
538 VTM
= VectorTypeModifier::SFixedLog2LMUL3
;
541 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
546 llvm_unreachable("Illegal complex type transformers!");
549 PD
.VTM
= static_cast<uint8_t>(VTM
);
551 // Compute the remain type transformers
552 TypeModifier TM
= TypeModifier::NoModifier
;
553 for (char I
: PrototypeDescriptorStr
) {
556 if ((TM
& TypeModifier::Const
) == TypeModifier::Const
)
557 llvm_unreachable("'P' transformer cannot be used after 'C'");
558 if ((TM
& TypeModifier::Pointer
) == TypeModifier::Pointer
)
559 llvm_unreachable("'P' transformer cannot be used twice");
560 TM
|= TypeModifier::Pointer
;
563 TM
|= TypeModifier::Const
;
566 TM
|= TypeModifier::Immediate
;
569 TM
|= TypeModifier::UnsignedInteger
;
572 TM
|= TypeModifier::SignedInteger
;
575 TM
|= TypeModifier::Float
;
578 TM
|= TypeModifier::LMUL1
;
581 llvm_unreachable("Illegal non-primitive type transformer!");
584 PD
.TM
= static_cast<uint8_t>(TM
);
589 void RVVType::applyModifier(const PrototypeDescriptor
&Transformer
) {
590 // Handle primitive type transformer
591 switch (static_cast<BaseTypeModifier
>(Transformer
.PT
)) {
592 case BaseTypeModifier::Scalar
:
595 case BaseTypeModifier::Vector
:
596 Scale
= LMUL
.getScale(ElementBitwidth
);
598 case BaseTypeModifier::Void
:
599 ScalarType
= ScalarTypeKind::Void
;
601 case BaseTypeModifier::SizeT
:
602 ScalarType
= ScalarTypeKind::Size_t
;
604 case BaseTypeModifier::Ptrdiff
:
605 ScalarType
= ScalarTypeKind::Ptrdiff_t
;
607 case BaseTypeModifier::UnsignedLong
:
608 ScalarType
= ScalarTypeKind::UnsignedLong
;
610 case BaseTypeModifier::SignedLong
:
611 ScalarType
= ScalarTypeKind::SignedLong
;
613 case BaseTypeModifier::Invalid
:
614 ScalarType
= ScalarTypeKind::Invalid
;
618 switch (static_cast<VectorTypeModifier
>(Transformer
.VTM
)) {
619 case VectorTypeModifier::Widening2XVector
:
620 ElementBitwidth
*= 2;
622 Scale
= LMUL
.getScale(ElementBitwidth
);
624 case VectorTypeModifier::Widening4XVector
:
625 ElementBitwidth
*= 4;
627 Scale
= LMUL
.getScale(ElementBitwidth
);
629 case VectorTypeModifier::Widening8XVector
:
630 ElementBitwidth
*= 8;
632 Scale
= LMUL
.getScale(ElementBitwidth
);
634 case VectorTypeModifier::MaskVector
:
635 ScalarType
= ScalarTypeKind::Boolean
;
636 Scale
= LMUL
.getScale(ElementBitwidth
);
639 case VectorTypeModifier::Log2EEW3
:
642 case VectorTypeModifier::Log2EEW4
:
645 case VectorTypeModifier::Log2EEW5
:
648 case VectorTypeModifier::Log2EEW6
:
651 case VectorTypeModifier::FixedSEW8
:
654 case VectorTypeModifier::FixedSEW16
:
657 case VectorTypeModifier::FixedSEW32
:
660 case VectorTypeModifier::FixedSEW64
:
663 case VectorTypeModifier::LFixedLog2LMULN3
:
664 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan
);
666 case VectorTypeModifier::LFixedLog2LMULN2
:
667 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan
);
669 case VectorTypeModifier::LFixedLog2LMULN1
:
670 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan
);
672 case VectorTypeModifier::LFixedLog2LMUL0
:
673 applyFixedLog2LMUL(0, FixedLMULType::LargerThan
);
675 case VectorTypeModifier::LFixedLog2LMUL1
:
676 applyFixedLog2LMUL(1, FixedLMULType::LargerThan
);
678 case VectorTypeModifier::LFixedLog2LMUL2
:
679 applyFixedLog2LMUL(2, FixedLMULType::LargerThan
);
681 case VectorTypeModifier::LFixedLog2LMUL3
:
682 applyFixedLog2LMUL(3, FixedLMULType::LargerThan
);
684 case VectorTypeModifier::SFixedLog2LMULN3
:
685 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan
);
687 case VectorTypeModifier::SFixedLog2LMULN2
:
688 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan
);
690 case VectorTypeModifier::SFixedLog2LMULN1
:
691 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan
);
693 case VectorTypeModifier::SFixedLog2LMUL0
:
694 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan
);
696 case VectorTypeModifier::SFixedLog2LMUL1
:
697 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan
);
699 case VectorTypeModifier::SFixedLog2LMUL2
:
700 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan
);
702 case VectorTypeModifier::SFixedLog2LMUL3
:
703 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan
);
705 case VectorTypeModifier::NoModifier
:
709 for (unsigned TypeModifierMaskShift
= 0;
710 TypeModifierMaskShift
<= static_cast<unsigned>(TypeModifier::MaxOffset
);
711 ++TypeModifierMaskShift
) {
712 unsigned TypeModifierMask
= 1 << TypeModifierMaskShift
;
713 if ((static_cast<unsigned>(Transformer
.TM
) & TypeModifierMask
) !=
716 switch (static_cast<TypeModifier
>(TypeModifierMask
)) {
717 case TypeModifier::Pointer
:
720 case TypeModifier::Const
:
723 case TypeModifier::Immediate
:
727 case TypeModifier::UnsignedInteger
:
728 ScalarType
= ScalarTypeKind::UnsignedInteger
;
730 case TypeModifier::SignedInteger
:
731 ScalarType
= ScalarTypeKind::SignedInteger
;
733 case TypeModifier::Float
:
734 ScalarType
= ScalarTypeKind::Float
;
736 case TypeModifier::LMUL1
:
738 // Update ElementBitwidth need to update Scale too.
739 Scale
= LMUL
.getScale(ElementBitwidth
);
742 llvm_unreachable("Unknown type modifier mask!");
747 void RVVType::applyLog2EEW(unsigned Log2EEW
) {
748 // update new elmul = (eew/sew) * lmul
749 LMUL
.MulLog2LMUL(Log2EEW
- Log2_32(ElementBitwidth
));
751 ElementBitwidth
= 1 << Log2EEW
;
752 ScalarType
= ScalarTypeKind::SignedInteger
;
753 Scale
= LMUL
.getScale(ElementBitwidth
);
756 void RVVType::applyFixedSEW(unsigned NewSEW
) {
757 // Set invalid type if src and dst SEW are same.
758 if (ElementBitwidth
== NewSEW
) {
759 ScalarType
= ScalarTypeKind::Invalid
;
763 ElementBitwidth
= NewSEW
;
764 Scale
= LMUL
.getScale(ElementBitwidth
);
767 void RVVType::applyFixedLog2LMUL(int Log2LMUL
, enum FixedLMULType Type
) {
769 case FixedLMULType::LargerThan
:
770 if (Log2LMUL
< LMUL
.Log2LMUL
) {
771 ScalarType
= ScalarTypeKind::Invalid
;
775 case FixedLMULType::SmallerThan
:
776 if (Log2LMUL
> LMUL
.Log2LMUL
) {
777 ScalarType
= ScalarTypeKind::Invalid
;
784 LMUL
= LMULType(Log2LMUL
);
785 Scale
= LMUL
.getScale(ElementBitwidth
);
788 std::optional
<RVVTypes
>
789 RVVTypeCache::computeTypes(BasicType BT
, int Log2LMUL
, unsigned NF
,
790 ArrayRef
<PrototypeDescriptor
> Prototype
) {
791 // LMUL x NF must be less than or equal to 8.
792 if ((Log2LMUL
>= 1) && (1 << Log2LMUL
) * NF
> 8)
796 for (const PrototypeDescriptor
&Proto
: Prototype
) {
797 auto T
= computeType(BT
, Log2LMUL
, Proto
);
800 // Record legal type index
806 // Compute the hash value of RVVType, used for cache the result of computeType.
807 static uint64_t computeRVVTypeHashValue(BasicType BT
, int Log2LMUL
,
808 PrototypeDescriptor Proto
) {
809 // Layout of hash value:
811 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
812 assert(Log2LMUL
>= -3 && Log2LMUL
<= 3);
813 return (Log2LMUL
+ 3) | (static_cast<uint64_t>(BT
) & 0xff) << 8 |
814 ((uint64_t)(Proto
.PT
& 0xff) << 16) |
815 ((uint64_t)(Proto
.TM
& 0xff) << 24) |
816 ((uint64_t)(Proto
.VTM
& 0xff) << 32);
819 std::optional
<RVVTypePtr
> RVVTypeCache::computeType(BasicType BT
, int Log2LMUL
,
820 PrototypeDescriptor Proto
) {
821 uint64_t Idx
= computeRVVTypeHashValue(BT
, Log2LMUL
, Proto
);
823 auto It
= LegalTypes
.find(Idx
);
824 if (It
!= LegalTypes
.end())
825 return &(It
->second
);
827 if (IllegalTypes
.count(Idx
))
830 // Compute type and record the result.
831 RVVType
T(BT
, Log2LMUL
, Proto
);
833 // Record legal type index and value.
834 std::pair
<std::unordered_map
<uint64_t, RVVType
>::iterator
, bool>
835 InsertResult
= LegalTypes
.insert({Idx
, T
});
836 return &(InsertResult
.first
->second
);
838 // Record illegal type index.
839 IllegalTypes
.insert(Idx
);
843 //===----------------------------------------------------------------------===//
844 // RVVIntrinsic implementation
845 //===----------------------------------------------------------------------===//
846 RVVIntrinsic::RVVIntrinsic(StringRef NewName
, StringRef Suffix
,
847 StringRef NewOverloadedName
,
848 StringRef OverloadedSuffix
, StringRef IRName
,
849 bool IsMasked
, bool HasMaskedOffOperand
, bool HasVL
,
850 PolicyScheme Scheme
, bool SupportOverloading
,
851 bool HasBuiltinAlias
, StringRef ManualCodegen
,
852 const RVVTypes
&OutInTypes
,
853 const std::vector
<int64_t> &NewIntrinsicTypes
,
854 const std::vector
<StringRef
> &RequiredFeatures
,
855 unsigned NF
, Policy NewPolicyAttrs
)
856 : IRName(IRName
), IsMasked(IsMasked
),
857 HasMaskedOffOperand(HasMaskedOffOperand
), HasVL(HasVL
), Scheme(Scheme
),
858 SupportOverloading(SupportOverloading
), HasBuiltinAlias(HasBuiltinAlias
),
859 ManualCodegen(ManualCodegen
.str()), NF(NF
), PolicyAttrs(NewPolicyAttrs
) {
861 // Init BuiltinName, Name and OverloadedName
862 BuiltinName
= NewName
.str();
864 if (NewOverloadedName
.empty())
865 OverloadedName
= NewName
.split("_").first
.str();
867 OverloadedName
= NewOverloadedName
.str();
869 Name
+= "_" + Suffix
.str();
870 if (!OverloadedSuffix
.empty())
871 OverloadedName
+= "_" + OverloadedSuffix
.str();
873 updateNamesAndPolicy(IsMasked
, hasPolicy(), Name
, BuiltinName
, OverloadedName
,
876 // Init OutputType and InputTypes
877 OutputType
= OutInTypes
[0];
878 InputTypes
.assign(OutInTypes
.begin() + 1, OutInTypes
.end());
880 // IntrinsicTypes is unmasked TA version index. Need to update it
881 // if there is merge operand (It is always in first operand).
882 IntrinsicTypes
= NewIntrinsicTypes
;
883 if ((IsMasked
&& hasMaskedOffOperand()) ||
884 (!IsMasked
&& hasPassthruOperand())) {
885 for (auto &I
: IntrinsicTypes
) {
892 std::string
RVVIntrinsic::getBuiltinTypeStr() const {
894 S
+= OutputType
->getBuiltinStr();
895 for (const auto &T
: InputTypes
) {
896 S
+= T
->getBuiltinStr();
901 std::string
RVVIntrinsic::getSuffixStr(
902 RVVTypeCache
&TypeCache
, BasicType Type
, int Log2LMUL
,
903 llvm::ArrayRef
<PrototypeDescriptor
> PrototypeDescriptors
) {
904 SmallVector
<std::string
> SuffixStrs
;
905 for (auto PD
: PrototypeDescriptors
) {
906 auto T
= TypeCache
.computeType(Type
, Log2LMUL
, PD
);
907 SuffixStrs
.push_back((*T
)->getShortStr());
909 return join(SuffixStrs
, "_");
912 llvm::SmallVector
<PrototypeDescriptor
> RVVIntrinsic::computeBuiltinTypes(
913 llvm::ArrayRef
<PrototypeDescriptor
> Prototype
, bool IsMasked
,
914 bool HasMaskedOffOperand
, bool HasVL
, unsigned NF
,
915 PolicyScheme DefaultScheme
, Policy PolicyAttrs
) {
916 SmallVector
<PrototypeDescriptor
> NewPrototype(Prototype
.begin(),
918 bool HasPassthruOp
= DefaultScheme
== PolicyScheme::HasPassthruOperand
;
920 // If HasMaskedOffOperand, insert result type as first input operand if
922 if (HasMaskedOffOperand
&& !PolicyAttrs
.isTAMAPolicy()) {
924 NewPrototype
.insert(NewPrototype
.begin() + 1, NewPrototype
[0]);
927 // (void, op0 address, op1 address, ...)
929 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
930 PrototypeDescriptor MaskoffType
= NewPrototype
[1];
931 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
932 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
935 if (HasMaskedOffOperand
&& NF
> 1) {
937 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
939 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
941 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1,
942 PrototypeDescriptor::Mask
);
944 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
945 NewPrototype
.insert(NewPrototype
.begin() + 1, PrototypeDescriptor::Mask
);
949 if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
)
950 NewPrototype
.insert(NewPrototype
.begin(), NewPrototype
[0]);
951 } else if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
) {
952 // NF > 1 cases for segment load operations.
954 // (void, op0 address, op1 address, ...)
956 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
957 PrototypeDescriptor MaskoffType
= Prototype
[1];
958 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
959 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
963 // If HasVL, append PrototypeDescriptor:VL to last operand
965 NewPrototype
.push_back(PrototypeDescriptor::VL
);
969 llvm::SmallVector
<Policy
> RVVIntrinsic::getSupportedUnMaskedPolicies() {
970 return {Policy(Policy::PolicyType::Undisturbed
)}; // TU
973 llvm::SmallVector
<Policy
>
974 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy
,
975 bool HasMaskPolicy
) {
976 if (HasTailPolicy
&& HasMaskPolicy
)
977 return {Policy(Policy::PolicyType::Undisturbed
,
978 Policy::PolicyType::Agnostic
), // TUM
979 Policy(Policy::PolicyType::Undisturbed
,
980 Policy::PolicyType::Undisturbed
), // TUMU
981 Policy(Policy::PolicyType::Agnostic
,
982 Policy::PolicyType::Undisturbed
)}; // MU
983 if (HasTailPolicy
&& !HasMaskPolicy
)
984 return {Policy(Policy::PolicyType::Undisturbed
,
985 Policy::PolicyType::Agnostic
)}; // TU
986 if (!HasTailPolicy
&& HasMaskPolicy
)
987 return {Policy(Policy::PolicyType::Agnostic
,
988 Policy::PolicyType::Undisturbed
)}; // MU
989 llvm_unreachable("An RVV instruction should not be without both tail policy "
993 void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked
, bool HasPolicy
,
995 std::string
&BuiltinName
,
996 std::string
&OverloadedName
,
997 Policy
&PolicyAttrs
) {
999 auto appendPolicySuffix
= [&](const std::string
&suffix
) {
1001 BuiltinName
+= suffix
;
1002 OverloadedName
+= suffix
;
1005 // This follows the naming guideline under riscv-c-api-doc to add the
1006 // `__riscv_` suffix for all RVV intrinsics.
1007 Name
= "__riscv_" + Name
;
1008 OverloadedName
= "__riscv_" + OverloadedName
;
1011 if (PolicyAttrs
.isTUMUPolicy())
1012 appendPolicySuffix("_tumu");
1013 else if (PolicyAttrs
.isTUMAPolicy())
1014 appendPolicySuffix("_tum");
1015 else if (PolicyAttrs
.isTAMUPolicy())
1016 appendPolicySuffix("_mu");
1017 else if (PolicyAttrs
.isTAMAPolicy()) {
1020 BuiltinName
+= "_tama";
1022 BuiltinName
+= "_m";
1024 llvm_unreachable("Unhandled policy condition");
1026 if (PolicyAttrs
.isTUPolicy())
1027 appendPolicySuffix("_tu");
1028 else if (PolicyAttrs
.isTAPolicy()) {
1030 BuiltinName
+= "_ta";
1032 llvm_unreachable("Unhandled policy condition");
1036 SmallVector
<PrototypeDescriptor
> parsePrototypes(StringRef Prototypes
) {
1037 SmallVector
<PrototypeDescriptor
> PrototypeDescriptors
;
1038 const StringRef
Primaries("evwqom0ztul");
1039 while (!Prototypes
.empty()) {
1041 // Skip over complex prototype because it could contain primitive type
1043 if (Prototypes
[0] == '(')
1044 Idx
= Prototypes
.find_first_of(')');
1045 Idx
= Prototypes
.find_first_of(Primaries
, Idx
);
1046 assert(Idx
!= StringRef::npos
);
1047 auto PD
= PrototypeDescriptor::parsePrototypeDescriptor(
1048 Prototypes
.slice(0, Idx
+ 1));
1050 llvm_unreachable("Error during parsing prototype.");
1051 PrototypeDescriptors
.push_back(*PD
);
1052 Prototypes
= Prototypes
.drop_front(Idx
+ 1);
1054 return PrototypeDescriptors
;
1057 raw_ostream
&operator<<(raw_ostream
&OS
, const RVVIntrinsicRecord
&Record
) {
1059 OS
<< "\"" << Record
.Name
<< "\",";
1060 if (Record
.OverloadedName
== nullptr ||
1061 StringRef(Record
.OverloadedName
).empty())
1064 OS
<< "\"" << Record
.OverloadedName
<< "\",";
1065 OS
<< Record
.PrototypeIndex
<< ",";
1066 OS
<< Record
.SuffixIndex
<< ",";
1067 OS
<< Record
.OverloadedSuffixIndex
<< ",";
1068 OS
<< (int)Record
.PrototypeLength
<< ",";
1069 OS
<< (int)Record
.SuffixLength
<< ",";
1070 OS
<< (int)Record
.OverloadedSuffixSize
<< ",";
1071 OS
<< (int)Record
.RequiredExtensions
<< ",";
1072 OS
<< (int)Record
.TypeRangeMask
<< ",";
1073 OS
<< (int)Record
.Log2LMULMask
<< ",";
1074 OS
<< (int)Record
.NF
<< ",";
1075 OS
<< (int)Record
.HasMasked
<< ",";
1076 OS
<< (int)Record
.HasVL
<< ",";
1077 OS
<< (int)Record
.HasMaskedOffOperand
<< ",";
1078 OS
<< (int)Record
.HasTailPolicy
<< ",";
1079 OS
<< (int)Record
.HasMaskPolicy
<< ",";
1080 OS
<< (int)Record
.UnMaskedPolicyScheme
<< ",";
1081 OS
<< (int)Record
.MaskedPolicyScheme
<< ",";
1086 } // end namespace RISCV
1087 } // end namespace clang