1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringMap.h"
14 #include "llvm/ADT/StringSet.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/ErrorHandling.h"
17 #include "llvm/Support/raw_ostream.h"
26 const PrototypeDescriptor
PrototypeDescriptor::Mask
= PrototypeDescriptor(
27 BaseTypeModifier::Vector
, VectorTypeModifier::MaskVector
);
28 const PrototypeDescriptor
PrototypeDescriptor::VL
=
29 PrototypeDescriptor(BaseTypeModifier::SizeT
);
30 const PrototypeDescriptor
PrototypeDescriptor::Vector
=
31 PrototypeDescriptor(BaseTypeModifier::Vector
);
33 //===----------------------------------------------------------------------===//
34 // Type implementation
35 //===----------------------------------------------------------------------===//
37 LMULType::LMULType(int NewLog2LMUL
) {
38 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
39 assert(NewLog2LMUL
<= 3 && NewLog2LMUL
>= -3 && "Bad LMUL number!");
40 Log2LMUL
= NewLog2LMUL
;
43 std::string
LMULType::str() const {
45 return "mf" + utostr(1ULL << (-Log2LMUL
));
46 return "m" + utostr(1ULL << Log2LMUL
);
49 VScaleVal
LMULType::getScale(unsigned ElementBitwidth
) const {
50 int Log2ScaleResult
= 0;
51 switch (ElementBitwidth
) {
55 Log2ScaleResult
= Log2LMUL
+ 3;
58 Log2ScaleResult
= Log2LMUL
+ 2;
61 Log2ScaleResult
= Log2LMUL
+ 1;
64 Log2ScaleResult
= Log2LMUL
;
67 // Illegal vscale result would be less than 1
68 if (Log2ScaleResult
< 0)
70 return 1 << Log2ScaleResult
;
73 void LMULType::MulLog2LMUL(int log2LMUL
) { Log2LMUL
+= log2LMUL
; }
75 RVVType::RVVType(BasicType BT
, int Log2LMUL
,
76 const PrototypeDescriptor
&prototype
)
77 : BT(BT
), LMUL(LMULType(Log2LMUL
)) {
79 applyModifier(prototype
);
85 initClangBuiltinStr();
91 // boolean type are encoded the ratio of n (SEW/LMUL)
92 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
93 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
94 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
96 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
97 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
98 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
99 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
100 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
101 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
102 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
103 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
104 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
107 bool RVVType::verifyType() const {
108 if (ScalarType
== Invalid
)
114 if (isFloat() && ElementBitwidth
== 8)
116 if (IsTuple
&& (NF
== 1 || NF
> 8))
118 if (IsTuple
&& (1 << std::max(0, LMUL
.Log2LMUL
)) * NF
> 8)
121 switch (ElementBitwidth
) {
124 // Check Scale is 1,2,4,8,16,32,64
125 return (V
<= 64 && isPowerOf2_32(V
));
127 // Check Scale is 1,2,4,8,16,32
128 return (V
<= 32 && isPowerOf2_32(V
));
130 // Check Scale is 1,2,4,8,16
131 return (V
<= 16 && isPowerOf2_32(V
));
133 // Check Scale is 1,2,4,8
134 return (V
<= 8 && isPowerOf2_32(V
));
139 void RVVType::initBuiltinStr() {
140 assert(isValid() && "RVVType is invalid");
141 switch (ScalarType
) {
142 case ScalarTypeKind::Void
:
145 case ScalarTypeKind::Size_t
:
148 BuiltinStr
= "I" + BuiltinStr
;
152 case ScalarTypeKind::Ptrdiff_t
:
155 case ScalarTypeKind::UnsignedLong
:
158 case ScalarTypeKind::SignedLong
:
161 case ScalarTypeKind::Boolean
:
162 assert(ElementBitwidth
== 1);
165 case ScalarTypeKind::SignedInteger
:
166 case ScalarTypeKind::UnsignedInteger
:
167 switch (ElementBitwidth
) {
181 llvm_unreachable("Unhandled ElementBitwidth!");
183 if (isSignedInteger())
184 BuiltinStr
= "S" + BuiltinStr
;
186 BuiltinStr
= "U" + BuiltinStr
;
188 case ScalarTypeKind::Float
:
189 switch (ElementBitwidth
) {
200 llvm_unreachable("Unhandled ElementBitwidth!");
204 llvm_unreachable("ScalarType is invalid!");
207 BuiltinStr
= "I" + BuiltinStr
;
215 BuiltinStr
= "q" + utostr(*Scale
) + BuiltinStr
;
216 // Pointer to vector types. Defined for segment load intrinsics.
217 // segment load intrinsics have pointer type arguments to store the loaded
223 BuiltinStr
= "T" + utostr(NF
) + BuiltinStr
;
226 void RVVType::initClangBuiltinStr() {
227 assert(isValid() && "RVVType is invalid");
228 assert(isVector() && "Handle Vector type only");
230 ClangBuiltinStr
= "__rvv_";
231 switch (ScalarType
) {
232 case ScalarTypeKind::Boolean
:
233 ClangBuiltinStr
+= "bool" + utostr(64 / *Scale
) + "_t";
235 case ScalarTypeKind::Float
:
236 ClangBuiltinStr
+= "float";
238 case ScalarTypeKind::SignedInteger
:
239 ClangBuiltinStr
+= "int";
241 case ScalarTypeKind::UnsignedInteger
:
242 ClangBuiltinStr
+= "uint";
245 llvm_unreachable("ScalarTypeKind is invalid");
247 ClangBuiltinStr
+= utostr(ElementBitwidth
) + LMUL
.str() +
248 (IsTuple
? "x" + utostr(NF
) : "") + "_t";
251 void RVVType::initTypeStr() {
252 assert(isValid() && "RVVType is invalid");
257 auto getTypeString
= [&](StringRef TypeStr
) {
259 return Twine(TypeStr
+ Twine(ElementBitwidth
) + "_t").str();
260 return Twine("v" + TypeStr
+ Twine(ElementBitwidth
) + LMUL
.str() +
261 (IsTuple
? "x" + utostr(NF
) : "") + "_t")
265 switch (ScalarType
) {
266 case ScalarTypeKind::Void
:
269 case ScalarTypeKind::Size_t
:
274 case ScalarTypeKind::Ptrdiff_t
:
277 case ScalarTypeKind::UnsignedLong
:
278 Str
= "unsigned long";
280 case ScalarTypeKind::SignedLong
:
283 case ScalarTypeKind::Boolean
:
287 // Vector bool is special case, the formulate is
288 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
289 Str
+= "vbool" + utostr(64 / *Scale
) + "_t";
291 case ScalarTypeKind::Float
:
293 if (ElementBitwidth
== 64)
295 else if (ElementBitwidth
== 32)
297 else if (ElementBitwidth
== 16)
300 llvm_unreachable("Unhandled floating type.");
302 Str
+= getTypeString("float");
304 case ScalarTypeKind::SignedInteger
:
305 Str
+= getTypeString("int");
307 case ScalarTypeKind::UnsignedInteger
:
308 Str
+= getTypeString("uint");
311 llvm_unreachable("ScalarType is invalid!");
317 void RVVType::initShortStr() {
318 switch (ScalarType
) {
319 case ScalarTypeKind::Boolean
:
321 ShortStr
= "b" + utostr(64 / *Scale
);
323 case ScalarTypeKind::Float
:
324 ShortStr
= "f" + utostr(ElementBitwidth
);
326 case ScalarTypeKind::SignedInteger
:
327 ShortStr
= "i" + utostr(ElementBitwidth
);
329 case ScalarTypeKind::UnsignedInteger
:
330 ShortStr
= "u" + utostr(ElementBitwidth
);
333 llvm_unreachable("Unhandled case!");
336 ShortStr
+= LMUL
.str();
338 ShortStr
+= "x" + utostr(NF
);
341 static VectorTypeModifier
getTupleVTM(unsigned NF
) {
342 assert(2 <= NF
&& NF
<= 8 && "2 <= NF <= 8");
343 return static_cast<VectorTypeModifier
>(
344 static_cast<uint8_t>(VectorTypeModifier::Tuple2
) + (NF
- 2));
347 void RVVType::applyBasicType() {
349 case BasicType::Int8
:
351 ScalarType
= ScalarTypeKind::SignedInteger
;
353 case BasicType::Int16
:
354 ElementBitwidth
= 16;
355 ScalarType
= ScalarTypeKind::SignedInteger
;
357 case BasicType::Int32
:
358 ElementBitwidth
= 32;
359 ScalarType
= ScalarTypeKind::SignedInteger
;
361 case BasicType::Int64
:
362 ElementBitwidth
= 64;
363 ScalarType
= ScalarTypeKind::SignedInteger
;
365 case BasicType::Float16
:
366 ElementBitwidth
= 16;
367 ScalarType
= ScalarTypeKind::Float
;
369 case BasicType::Float32
:
370 ElementBitwidth
= 32;
371 ScalarType
= ScalarTypeKind::Float
;
373 case BasicType::Float64
:
374 ElementBitwidth
= 64;
375 ScalarType
= ScalarTypeKind::Float
;
378 llvm_unreachable("Unhandled type code!");
380 assert(ElementBitwidth
!= 0 && "Bad element bitwidth!");
383 std::optional
<PrototypeDescriptor
>
384 PrototypeDescriptor::parsePrototypeDescriptor(
385 llvm::StringRef PrototypeDescriptorStr
) {
386 PrototypeDescriptor PD
;
387 BaseTypeModifier PT
= BaseTypeModifier::Invalid
;
388 VectorTypeModifier VTM
= VectorTypeModifier::NoModifier
;
390 if (PrototypeDescriptorStr
.empty())
393 // Handle base type modifier
394 auto PType
= PrototypeDescriptorStr
.back();
397 PT
= BaseTypeModifier::Scalar
;
400 PT
= BaseTypeModifier::Vector
;
403 PT
= BaseTypeModifier::Vector
;
404 VTM
= VectorTypeModifier::Widening2XVector
;
407 PT
= BaseTypeModifier::Vector
;
408 VTM
= VectorTypeModifier::Widening4XVector
;
411 PT
= BaseTypeModifier::Vector
;
412 VTM
= VectorTypeModifier::Widening8XVector
;
415 PT
= BaseTypeModifier::Vector
;
416 VTM
= VectorTypeModifier::MaskVector
;
419 PT
= BaseTypeModifier::Void
;
422 PT
= BaseTypeModifier::SizeT
;
425 PT
= BaseTypeModifier::Ptrdiff
;
428 PT
= BaseTypeModifier::UnsignedLong
;
431 PT
= BaseTypeModifier::SignedLong
;
434 llvm_unreachable("Illegal primitive type transformers!");
436 PD
.PT
= static_cast<uint8_t>(PT
);
437 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_back();
439 // Compute the vector type transformers, it can only appear one time.
440 if (PrototypeDescriptorStr
.startswith("(")) {
441 assert(VTM
== VectorTypeModifier::NoModifier
&&
442 "VectorTypeModifier should only have one modifier");
443 size_t Idx
= PrototypeDescriptorStr
.find(')');
444 assert(Idx
!= StringRef::npos
);
445 StringRef ComplexType
= PrototypeDescriptorStr
.slice(1, Idx
);
446 PrototypeDescriptorStr
= PrototypeDescriptorStr
.drop_front(Idx
+ 1);
447 assert(!PrototypeDescriptorStr
.contains('(') &&
448 "Only allow one vector type modifier");
450 auto ComplexTT
= ComplexType
.split(":");
451 if (ComplexTT
.first
== "Log2EEW") {
453 if (ComplexTT
.second
.getAsInteger(10, Log2EEW
)) {
454 llvm_unreachable("Invalid Log2EEW value!");
459 VTM
= VectorTypeModifier::Log2EEW3
;
462 VTM
= VectorTypeModifier::Log2EEW4
;
465 VTM
= VectorTypeModifier::Log2EEW5
;
468 VTM
= VectorTypeModifier::Log2EEW6
;
471 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
474 } else if (ComplexTT
.first
== "FixedSEW") {
476 if (ComplexTT
.second
.getAsInteger(10, NewSEW
)) {
477 llvm_unreachable("Invalid FixedSEW value!");
482 VTM
= VectorTypeModifier::FixedSEW8
;
485 VTM
= VectorTypeModifier::FixedSEW16
;
488 VTM
= VectorTypeModifier::FixedSEW32
;
491 VTM
= VectorTypeModifier::FixedSEW64
;
494 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
497 } else if (ComplexTT
.first
== "LFixedLog2LMUL") {
499 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
500 llvm_unreachable("Invalid LFixedLog2LMUL value!");
505 VTM
= VectorTypeModifier::LFixedLog2LMULN3
;
508 VTM
= VectorTypeModifier::LFixedLog2LMULN2
;
511 VTM
= VectorTypeModifier::LFixedLog2LMULN1
;
514 VTM
= VectorTypeModifier::LFixedLog2LMUL0
;
517 VTM
= VectorTypeModifier::LFixedLog2LMUL1
;
520 VTM
= VectorTypeModifier::LFixedLog2LMUL2
;
523 VTM
= VectorTypeModifier::LFixedLog2LMUL3
;
526 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
529 } else if (ComplexTT
.first
== "SFixedLog2LMUL") {
531 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
532 llvm_unreachable("Invalid SFixedLog2LMUL value!");
537 VTM
= VectorTypeModifier::SFixedLog2LMULN3
;
540 VTM
= VectorTypeModifier::SFixedLog2LMULN2
;
543 VTM
= VectorTypeModifier::SFixedLog2LMULN1
;
546 VTM
= VectorTypeModifier::SFixedLog2LMUL0
;
549 VTM
= VectorTypeModifier::SFixedLog2LMUL1
;
552 VTM
= VectorTypeModifier::SFixedLog2LMUL2
;
555 VTM
= VectorTypeModifier::SFixedLog2LMUL3
;
558 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
562 } else if (ComplexTT
.first
== "SEFixedLog2LMUL") {
564 if (ComplexTT
.second
.getAsInteger(10, Log2LMUL
)) {
565 llvm_unreachable("Invalid SEFixedLog2LMUL value!");
570 VTM
= VectorTypeModifier::SEFixedLog2LMULN3
;
573 VTM
= VectorTypeModifier::SEFixedLog2LMULN2
;
576 VTM
= VectorTypeModifier::SEFixedLog2LMULN1
;
579 VTM
= VectorTypeModifier::SEFixedLog2LMUL0
;
582 VTM
= VectorTypeModifier::SEFixedLog2LMUL1
;
585 VTM
= VectorTypeModifier::SEFixedLog2LMUL2
;
588 VTM
= VectorTypeModifier::SEFixedLog2LMUL3
;
591 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
594 } else if (ComplexTT
.first
== "Tuple") {
596 if (ComplexTT
.second
.getAsInteger(10, NF
)) {
597 llvm_unreachable("Invalid NF value!");
600 VTM
= getTupleVTM(NF
);
602 llvm_unreachable("Illegal complex type transformers!");
605 PD
.VTM
= static_cast<uint8_t>(VTM
);
607 // Compute the remain type transformers
608 TypeModifier TM
= TypeModifier::NoModifier
;
609 for (char I
: PrototypeDescriptorStr
) {
612 if ((TM
& TypeModifier::Const
) == TypeModifier::Const
)
613 llvm_unreachable("'P' transformer cannot be used after 'C'");
614 if ((TM
& TypeModifier::Pointer
) == TypeModifier::Pointer
)
615 llvm_unreachable("'P' transformer cannot be used twice");
616 TM
|= TypeModifier::Pointer
;
619 TM
|= TypeModifier::Const
;
622 TM
|= TypeModifier::Immediate
;
625 TM
|= TypeModifier::UnsignedInteger
;
628 TM
|= TypeModifier::SignedInteger
;
631 TM
|= TypeModifier::Float
;
634 TM
|= TypeModifier::LMUL1
;
637 llvm_unreachable("Illegal non-primitive type transformer!");
640 PD
.TM
= static_cast<uint8_t>(TM
);
645 void RVVType::applyModifier(const PrototypeDescriptor
&Transformer
) {
646 // Handle primitive type transformer
647 switch (static_cast<BaseTypeModifier
>(Transformer
.PT
)) {
648 case BaseTypeModifier::Scalar
:
651 case BaseTypeModifier::Vector
:
652 Scale
= LMUL
.getScale(ElementBitwidth
);
654 case BaseTypeModifier::Void
:
655 ScalarType
= ScalarTypeKind::Void
;
657 case BaseTypeModifier::SizeT
:
658 ScalarType
= ScalarTypeKind::Size_t
;
660 case BaseTypeModifier::Ptrdiff
:
661 ScalarType
= ScalarTypeKind::Ptrdiff_t
;
663 case BaseTypeModifier::UnsignedLong
:
664 ScalarType
= ScalarTypeKind::UnsignedLong
;
666 case BaseTypeModifier::SignedLong
:
667 ScalarType
= ScalarTypeKind::SignedLong
;
669 case BaseTypeModifier::Invalid
:
670 ScalarType
= ScalarTypeKind::Invalid
;
674 switch (static_cast<VectorTypeModifier
>(Transformer
.VTM
)) {
675 case VectorTypeModifier::Widening2XVector
:
676 ElementBitwidth
*= 2;
678 Scale
= LMUL
.getScale(ElementBitwidth
);
680 case VectorTypeModifier::Widening4XVector
:
681 ElementBitwidth
*= 4;
683 Scale
= LMUL
.getScale(ElementBitwidth
);
685 case VectorTypeModifier::Widening8XVector
:
686 ElementBitwidth
*= 8;
688 Scale
= LMUL
.getScale(ElementBitwidth
);
690 case VectorTypeModifier::MaskVector
:
691 ScalarType
= ScalarTypeKind::Boolean
;
692 Scale
= LMUL
.getScale(ElementBitwidth
);
695 case VectorTypeModifier::Log2EEW3
:
698 case VectorTypeModifier::Log2EEW4
:
701 case VectorTypeModifier::Log2EEW5
:
704 case VectorTypeModifier::Log2EEW6
:
707 case VectorTypeModifier::FixedSEW8
:
710 case VectorTypeModifier::FixedSEW16
:
713 case VectorTypeModifier::FixedSEW32
:
716 case VectorTypeModifier::FixedSEW64
:
719 case VectorTypeModifier::LFixedLog2LMULN3
:
720 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan
);
722 case VectorTypeModifier::LFixedLog2LMULN2
:
723 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan
);
725 case VectorTypeModifier::LFixedLog2LMULN1
:
726 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan
);
728 case VectorTypeModifier::LFixedLog2LMUL0
:
729 applyFixedLog2LMUL(0, FixedLMULType::LargerThan
);
731 case VectorTypeModifier::LFixedLog2LMUL1
:
732 applyFixedLog2LMUL(1, FixedLMULType::LargerThan
);
734 case VectorTypeModifier::LFixedLog2LMUL2
:
735 applyFixedLog2LMUL(2, FixedLMULType::LargerThan
);
737 case VectorTypeModifier::LFixedLog2LMUL3
:
738 applyFixedLog2LMUL(3, FixedLMULType::LargerThan
);
740 case VectorTypeModifier::SFixedLog2LMULN3
:
741 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan
);
743 case VectorTypeModifier::SFixedLog2LMULN2
:
744 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan
);
746 case VectorTypeModifier::SFixedLog2LMULN1
:
747 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan
);
749 case VectorTypeModifier::SFixedLog2LMUL0
:
750 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan
);
752 case VectorTypeModifier::SFixedLog2LMUL1
:
753 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan
);
755 case VectorTypeModifier::SFixedLog2LMUL2
:
756 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan
);
758 case VectorTypeModifier::SFixedLog2LMUL3
:
759 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan
);
761 case VectorTypeModifier::SEFixedLog2LMULN3
:
762 applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual
);
764 case VectorTypeModifier::SEFixedLog2LMULN2
:
765 applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual
);
767 case VectorTypeModifier::SEFixedLog2LMULN1
:
768 applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual
);
770 case VectorTypeModifier::SEFixedLog2LMUL0
:
771 applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual
);
773 case VectorTypeModifier::SEFixedLog2LMUL1
:
774 applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual
);
776 case VectorTypeModifier::SEFixedLog2LMUL2
:
777 applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual
);
779 case VectorTypeModifier::SEFixedLog2LMUL3
:
780 applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual
);
782 case VectorTypeModifier::Tuple2
:
783 case VectorTypeModifier::Tuple3
:
784 case VectorTypeModifier::Tuple4
:
785 case VectorTypeModifier::Tuple5
:
786 case VectorTypeModifier::Tuple6
:
787 case VectorTypeModifier::Tuple7
:
788 case VectorTypeModifier::Tuple8
: {
790 NF
= 2 + static_cast<uint8_t>(Transformer
.VTM
) -
791 static_cast<uint8_t>(VectorTypeModifier::Tuple2
);
794 case VectorTypeModifier::NoModifier
:
798 // Early return if the current type modifier is already invalid.
799 if (ScalarType
== Invalid
)
802 for (unsigned TypeModifierMaskShift
= 0;
803 TypeModifierMaskShift
<= static_cast<unsigned>(TypeModifier::MaxOffset
);
804 ++TypeModifierMaskShift
) {
805 unsigned TypeModifierMask
= 1 << TypeModifierMaskShift
;
806 if ((static_cast<unsigned>(Transformer
.TM
) & TypeModifierMask
) !=
809 switch (static_cast<TypeModifier
>(TypeModifierMask
)) {
810 case TypeModifier::Pointer
:
813 case TypeModifier::Const
:
816 case TypeModifier::Immediate
:
820 case TypeModifier::UnsignedInteger
:
821 ScalarType
= ScalarTypeKind::UnsignedInteger
;
823 case TypeModifier::SignedInteger
:
824 ScalarType
= ScalarTypeKind::SignedInteger
;
826 case TypeModifier::Float
:
827 ScalarType
= ScalarTypeKind::Float
;
829 case TypeModifier::LMUL1
:
831 // Update ElementBitwidth need to update Scale too.
832 Scale
= LMUL
.getScale(ElementBitwidth
);
835 llvm_unreachable("Unknown type modifier mask!");
840 void RVVType::applyLog2EEW(unsigned Log2EEW
) {
841 // update new elmul = (eew/sew) * lmul
842 LMUL
.MulLog2LMUL(Log2EEW
- Log2_32(ElementBitwidth
));
844 ElementBitwidth
= 1 << Log2EEW
;
845 ScalarType
= ScalarTypeKind::SignedInteger
;
846 Scale
= LMUL
.getScale(ElementBitwidth
);
849 void RVVType::applyFixedSEW(unsigned NewSEW
) {
850 // Set invalid type if src and dst SEW are same.
851 if (ElementBitwidth
== NewSEW
) {
852 ScalarType
= ScalarTypeKind::Invalid
;
856 ElementBitwidth
= NewSEW
;
857 Scale
= LMUL
.getScale(ElementBitwidth
);
860 void RVVType::applyFixedLog2LMUL(int Log2LMUL
, enum FixedLMULType Type
) {
862 case FixedLMULType::LargerThan
:
863 if (Log2LMUL
<= LMUL
.Log2LMUL
) {
864 ScalarType
= ScalarTypeKind::Invalid
;
868 case FixedLMULType::SmallerThan
:
869 if (Log2LMUL
>= LMUL
.Log2LMUL
) {
870 ScalarType
= ScalarTypeKind::Invalid
;
874 case FixedLMULType::SmallerOrEqual
:
875 if (Log2LMUL
> LMUL
.Log2LMUL
) {
876 ScalarType
= ScalarTypeKind::Invalid
;
883 LMUL
= LMULType(Log2LMUL
);
884 Scale
= LMUL
.getScale(ElementBitwidth
);
887 std::optional
<RVVTypes
>
888 RVVTypeCache::computeTypes(BasicType BT
, int Log2LMUL
, unsigned NF
,
889 ArrayRef
<PrototypeDescriptor
> Prototype
) {
891 for (const PrototypeDescriptor
&Proto
: Prototype
) {
892 auto T
= computeType(BT
, Log2LMUL
, Proto
);
895 // Record legal type index
901 // Compute the hash value of RVVType, used for cache the result of computeType.
902 static uint64_t computeRVVTypeHashValue(BasicType BT
, int Log2LMUL
,
903 PrototypeDescriptor Proto
) {
904 // Layout of hash value:
906 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
907 assert(Log2LMUL
>= -3 && Log2LMUL
<= 3);
908 return (Log2LMUL
+ 3) | (static_cast<uint64_t>(BT
) & 0xff) << 8 |
909 ((uint64_t)(Proto
.PT
& 0xff) << 16) |
910 ((uint64_t)(Proto
.TM
& 0xff) << 24) |
911 ((uint64_t)(Proto
.VTM
& 0xff) << 32);
914 std::optional
<RVVTypePtr
> RVVTypeCache::computeType(BasicType BT
, int Log2LMUL
,
915 PrototypeDescriptor Proto
) {
916 uint64_t Idx
= computeRVVTypeHashValue(BT
, Log2LMUL
, Proto
);
918 auto It
= LegalTypes
.find(Idx
);
919 if (It
!= LegalTypes
.end())
920 return &(It
->second
);
922 if (IllegalTypes
.count(Idx
))
925 // Compute type and record the result.
926 RVVType
T(BT
, Log2LMUL
, Proto
);
928 // Record legal type index and value.
929 std::pair
<std::unordered_map
<uint64_t, RVVType
>::iterator
, bool>
930 InsertResult
= LegalTypes
.insert({Idx
, T
});
931 return &(InsertResult
.first
->second
);
933 // Record illegal type index.
934 IllegalTypes
.insert(Idx
);
938 //===----------------------------------------------------------------------===//
939 // RVVIntrinsic implementation
940 //===----------------------------------------------------------------------===//
941 RVVIntrinsic::RVVIntrinsic(
942 StringRef NewName
, StringRef Suffix
, StringRef NewOverloadedName
,
943 StringRef OverloadedSuffix
, StringRef IRName
, bool IsMasked
,
944 bool HasMaskedOffOperand
, bool HasVL
, PolicyScheme Scheme
,
945 bool SupportOverloading
, bool HasBuiltinAlias
, StringRef ManualCodegen
,
946 const RVVTypes
&OutInTypes
, const std::vector
<int64_t> &NewIntrinsicTypes
,
947 const std::vector
<StringRef
> &RequiredFeatures
, unsigned NF
,
948 Policy NewPolicyAttrs
, bool HasFRMRoundModeOp
)
949 : IRName(IRName
), IsMasked(IsMasked
),
950 HasMaskedOffOperand(HasMaskedOffOperand
), HasVL(HasVL
), Scheme(Scheme
),
951 SupportOverloading(SupportOverloading
), HasBuiltinAlias(HasBuiltinAlias
),
952 ManualCodegen(ManualCodegen
.str()), NF(NF
), PolicyAttrs(NewPolicyAttrs
) {
954 // Init BuiltinName, Name and OverloadedName
955 BuiltinName
= NewName
.str();
957 if (NewOverloadedName
.empty())
958 OverloadedName
= NewName
.split("_").first
.str();
960 OverloadedName
= NewOverloadedName
.str();
962 Name
+= "_" + Suffix
.str();
963 if (!OverloadedSuffix
.empty())
964 OverloadedName
+= "_" + OverloadedSuffix
.str();
966 updateNamesAndPolicy(IsMasked
, hasPolicy(), Name
, BuiltinName
, OverloadedName
,
967 PolicyAttrs
, HasFRMRoundModeOp
);
969 // Init OutputType and InputTypes
970 OutputType
= OutInTypes
[0];
971 InputTypes
.assign(OutInTypes
.begin() + 1, OutInTypes
.end());
973 // IntrinsicTypes is unmasked TA version index. Need to update it
974 // if there is merge operand (It is always in first operand).
975 IntrinsicTypes
= NewIntrinsicTypes
;
976 if ((IsMasked
&& hasMaskedOffOperand()) ||
977 (!IsMasked
&& hasPassthruOperand())) {
978 for (auto &I
: IntrinsicTypes
) {
985 std::string
RVVIntrinsic::getBuiltinTypeStr() const {
987 S
+= OutputType
->getBuiltinStr();
988 for (const auto &T
: InputTypes
) {
989 S
+= T
->getBuiltinStr();
994 std::string
RVVIntrinsic::getSuffixStr(
995 RVVTypeCache
&TypeCache
, BasicType Type
, int Log2LMUL
,
996 llvm::ArrayRef
<PrototypeDescriptor
> PrototypeDescriptors
) {
997 SmallVector
<std::string
> SuffixStrs
;
998 for (auto PD
: PrototypeDescriptors
) {
999 auto T
= TypeCache
.computeType(Type
, Log2LMUL
, PD
);
1000 SuffixStrs
.push_back((*T
)->getShortStr());
1002 return join(SuffixStrs
, "_");
1005 llvm::SmallVector
<PrototypeDescriptor
> RVVIntrinsic::computeBuiltinTypes(
1006 llvm::ArrayRef
<PrototypeDescriptor
> Prototype
, bool IsMasked
,
1007 bool HasMaskedOffOperand
, bool HasVL
, unsigned NF
,
1008 PolicyScheme DefaultScheme
, Policy PolicyAttrs
, bool IsTuple
) {
1009 SmallVector
<PrototypeDescriptor
> NewPrototype(Prototype
.begin(),
1011 bool HasPassthruOp
= DefaultScheme
== PolicyScheme::HasPassthruOperand
;
1013 // If HasMaskedOffOperand, insert result type as first input operand if
1015 if (HasMaskedOffOperand
&& !PolicyAttrs
.isTAMAPolicy()) {
1017 NewPrototype
.insert(NewPrototype
.begin() + 1, NewPrototype
[0]);
1018 } else if (NF
> 1) {
1020 PrototypeDescriptor BasePtrOperand
= Prototype
[1];
1021 PrototypeDescriptor MaskoffType
= PrototypeDescriptor(
1022 static_cast<uint8_t>(BaseTypeModifier::Vector
),
1023 static_cast<uint8_t>(getTupleVTM(NF
)),
1024 BasePtrOperand
.TM
& ~static_cast<uint8_t>(TypeModifier::Pointer
));
1025 NewPrototype
.insert(NewPrototype
.begin() + 1, MaskoffType
);
1028 // (void, op0 address, op1 address, ...)
1030 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1031 PrototypeDescriptor MaskoffType
= NewPrototype
[1];
1032 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
1033 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
1037 if (HasMaskedOffOperand
&& NF
> 1) {
1039 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1041 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1044 NewPrototype
.insert(NewPrototype
.begin() + 1,
1045 PrototypeDescriptor::Mask
);
1047 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1,
1048 PrototypeDescriptor::Mask
);
1050 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
1051 NewPrototype
.insert(NewPrototype
.begin() + 1, PrototypeDescriptor::Mask
);
1055 if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
)
1056 NewPrototype
.insert(NewPrototype
.begin(), NewPrototype
[0]);
1057 } else if (PolicyAttrs
.isTUPolicy() && HasPassthruOp
) {
1059 PrototypeDescriptor BasePtrOperand
= Prototype
[0];
1060 PrototypeDescriptor MaskoffType
= PrototypeDescriptor(
1061 static_cast<uint8_t>(BaseTypeModifier::Vector
),
1062 static_cast<uint8_t>(getTupleVTM(NF
)),
1063 BasePtrOperand
.TM
& ~static_cast<uint8_t>(TypeModifier::Pointer
));
1064 NewPrototype
.insert(NewPrototype
.begin(), MaskoffType
);
1066 // NF > 1 cases for segment load operations.
1068 // (void, op0 address, op1 address, ...)
1070 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1071 PrototypeDescriptor MaskoffType
= Prototype
[1];
1072 MaskoffType
.TM
&= ~static_cast<uint8_t>(TypeModifier::Pointer
);
1073 NewPrototype
.insert(NewPrototype
.begin() + NF
+ 1, NF
, MaskoffType
);
1078 // If HasVL, append PrototypeDescriptor:VL to last operand
1080 NewPrototype
.push_back(PrototypeDescriptor::VL
);
1082 return NewPrototype
;
1085 llvm::SmallVector
<Policy
> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1086 return {Policy(Policy::PolicyType::Undisturbed
)}; // TU
1089 llvm::SmallVector
<Policy
>
1090 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy
,
1091 bool HasMaskPolicy
) {
1092 if (HasTailPolicy
&& HasMaskPolicy
)
1093 return {Policy(Policy::PolicyType::Undisturbed
,
1094 Policy::PolicyType::Agnostic
), // TUM
1095 Policy(Policy::PolicyType::Undisturbed
,
1096 Policy::PolicyType::Undisturbed
), // TUMU
1097 Policy(Policy::PolicyType::Agnostic
,
1098 Policy::PolicyType::Undisturbed
)}; // MU
1099 if (HasTailPolicy
&& !HasMaskPolicy
)
1100 return {Policy(Policy::PolicyType::Undisturbed
,
1101 Policy::PolicyType::Agnostic
)}; // TU
1102 if (!HasTailPolicy
&& HasMaskPolicy
)
1103 return {Policy(Policy::PolicyType::Agnostic
,
1104 Policy::PolicyType::Undisturbed
)}; // MU
1105 llvm_unreachable("An RVV instruction should not be without both tail policy "
1109 void RVVIntrinsic::updateNamesAndPolicy(
1110 bool IsMasked
, bool HasPolicy
, std::string
&Name
, std::string
&BuiltinName
,
1111 std::string
&OverloadedName
, Policy
&PolicyAttrs
, bool HasFRMRoundModeOp
) {
1113 auto appendPolicySuffix
= [&](const std::string
&suffix
) {
1115 BuiltinName
+= suffix
;
1116 OverloadedName
+= suffix
;
1119 // This follows the naming guideline under riscv-c-api-doc to add the
1120 // `__riscv_` suffix for all RVV intrinsics.
1121 Name
= "__riscv_" + Name
;
1122 OverloadedName
= "__riscv_" + OverloadedName
;
1124 if (HasFRMRoundModeOp
) {
1126 BuiltinName
+= "_rm";
1130 if (PolicyAttrs
.isTUMUPolicy())
1131 appendPolicySuffix("_tumu");
1132 else if (PolicyAttrs
.isTUMAPolicy())
1133 appendPolicySuffix("_tum");
1134 else if (PolicyAttrs
.isTAMUPolicy())
1135 appendPolicySuffix("_mu");
1136 else if (PolicyAttrs
.isTAMAPolicy()) {
1138 BuiltinName
+= "_m";
1140 llvm_unreachable("Unhandled policy condition");
1142 if (PolicyAttrs
.isTUPolicy())
1143 appendPolicySuffix("_tu");
1144 else if (PolicyAttrs
.isTAPolicy()) // no suffix needed
1147 llvm_unreachable("Unhandled policy condition");
1151 SmallVector
<PrototypeDescriptor
> parsePrototypes(StringRef Prototypes
) {
1152 SmallVector
<PrototypeDescriptor
> PrototypeDescriptors
;
1153 const StringRef
Primaries("evwqom0ztul");
1154 while (!Prototypes
.empty()) {
1156 // Skip over complex prototype because it could contain primitive type
1158 if (Prototypes
[0] == '(')
1159 Idx
= Prototypes
.find_first_of(')');
1160 Idx
= Prototypes
.find_first_of(Primaries
, Idx
);
1161 assert(Idx
!= StringRef::npos
);
1162 auto PD
= PrototypeDescriptor::parsePrototypeDescriptor(
1163 Prototypes
.slice(0, Idx
+ 1));
1165 llvm_unreachable("Error during parsing prototype.");
1166 PrototypeDescriptors
.push_back(*PD
);
1167 Prototypes
= Prototypes
.drop_front(Idx
+ 1);
1169 return PrototypeDescriptors
;
1172 raw_ostream
&operator<<(raw_ostream
&OS
, const RVVIntrinsicRecord
&Record
) {
1174 OS
<< "\"" << Record
.Name
<< "\",";
1175 if (Record
.OverloadedName
== nullptr ||
1176 StringRef(Record
.OverloadedName
).empty())
1179 OS
<< "\"" << Record
.OverloadedName
<< "\",";
1180 OS
<< Record
.PrototypeIndex
<< ",";
1181 OS
<< Record
.SuffixIndex
<< ",";
1182 OS
<< Record
.OverloadedSuffixIndex
<< ",";
1183 OS
<< (int)Record
.PrototypeLength
<< ",";
1184 OS
<< (int)Record
.SuffixLength
<< ",";
1185 OS
<< (int)Record
.OverloadedSuffixSize
<< ",";
1186 OS
<< (int)Record
.RequiredExtensions
<< ",";
1187 OS
<< (int)Record
.TypeRangeMask
<< ",";
1188 OS
<< (int)Record
.Log2LMULMask
<< ",";
1189 OS
<< (int)Record
.NF
<< ",";
1190 OS
<< (int)Record
.HasMasked
<< ",";
1191 OS
<< (int)Record
.HasVL
<< ",";
1192 OS
<< (int)Record
.HasMaskedOffOperand
<< ",";
1193 OS
<< (int)Record
.HasTailPolicy
<< ",";
1194 OS
<< (int)Record
.HasMaskPolicy
<< ",";
1195 OS
<< (int)Record
.HasFRMRoundModeOp
<< ",";
1196 OS
<< (int)Record
.IsTuple
<< ",";
1197 OS
<< (int)Record
.UnMaskedPolicyScheme
<< ",";
1198 OS
<< (int)Record
.MaskedPolicyScheme
<< ",";
1203 } // end namespace RISCV
1204 } // end namespace clang