[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / clang / lib / Support / RISCVVIntrinsicUtils.cpp
blob597ee194fc8d4b14f9ee051e6ba3ccedb497c8e0
1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringSet.h"
14 #include "llvm/ADT/Twine.h"
15 #include "llvm/Support/ErrorHandling.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <numeric>
18 #include <optional>
20 using namespace llvm;
22 namespace clang {
23 namespace RISCV {
25 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
26 BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
27 const PrototypeDescriptor PrototypeDescriptor::VL =
28 PrototypeDescriptor(BaseTypeModifier::SizeT);
29 const PrototypeDescriptor PrototypeDescriptor::Vector =
30 PrototypeDescriptor(BaseTypeModifier::Vector);
32 //===----------------------------------------------------------------------===//
33 // Type implementation
34 //===----------------------------------------------------------------------===//
36 LMULType::LMULType(int NewLog2LMUL) {
37 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
38 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
39 Log2LMUL = NewLog2LMUL;
42 std::string LMULType::str() const {
43 if (Log2LMUL < 0)
44 return "mf" + utostr(1ULL << (-Log2LMUL));
45 return "m" + utostr(1ULL << Log2LMUL);
48 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
49 int Log2ScaleResult = 0;
50 switch (ElementBitwidth) {
51 default:
52 break;
53 case 8:
54 Log2ScaleResult = Log2LMUL + 3;
55 break;
56 case 16:
57 Log2ScaleResult = Log2LMUL + 2;
58 break;
59 case 32:
60 Log2ScaleResult = Log2LMUL + 1;
61 break;
62 case 64:
63 Log2ScaleResult = Log2LMUL;
64 break;
66 // Illegal vscale result would be less than 1
67 if (Log2ScaleResult < 0)
68 return std::nullopt;
69 return 1 << Log2ScaleResult;
72 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
74 RVVType::RVVType(BasicType BT, int Log2LMUL,
75 const PrototypeDescriptor &prototype)
76 : BT(BT), LMUL(LMULType(Log2LMUL)) {
77 applyBasicType();
78 applyModifier(prototype);
79 Valid = verifyType();
80 if (Valid) {
81 initBuiltinStr();
82 initTypeStr();
83 if (isVector()) {
84 initClangBuiltinStr();
89 // clang-format off
90 // boolean type are encoded the ratio of n (SEW/LMUL)
91 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
92 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
93 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
95 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
96 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
97 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
98 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
99 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
100 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
101 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
102 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
103 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
104 // clang-format on
106 bool RVVType::verifyType() const {
107 if (ScalarType == Invalid)
108 return false;
109 if (isScalar())
110 return true;
111 if (!Scale)
112 return false;
113 if (isFloat() && ElementBitwidth == 8)
114 return false;
115 if (IsTuple && (NF == 1 || NF > 8))
116 return false;
117 if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
118 return false;
119 unsigned V = *Scale;
120 switch (ElementBitwidth) {
121 case 1:
122 case 8:
123 // Check Scale is 1,2,4,8,16,32,64
124 return (V <= 64 && isPowerOf2_32(V));
125 case 16:
126 // Check Scale is 1,2,4,8,16,32
127 return (V <= 32 && isPowerOf2_32(V));
128 case 32:
129 // Check Scale is 1,2,4,8,16
130 return (V <= 16 && isPowerOf2_32(V));
131 case 64:
132 // Check Scale is 1,2,4,8
133 return (V <= 8 && isPowerOf2_32(V));
135 return false;
138 void RVVType::initBuiltinStr() {
139 assert(isValid() && "RVVType is invalid");
140 switch (ScalarType) {
141 case ScalarTypeKind::Void:
142 BuiltinStr = "v";
143 return;
144 case ScalarTypeKind::Size_t:
145 BuiltinStr = "z";
146 if (IsImmediate)
147 BuiltinStr = "I" + BuiltinStr;
148 if (IsPointer)
149 BuiltinStr += "*";
150 return;
151 case ScalarTypeKind::Ptrdiff_t:
152 BuiltinStr = "Y";
153 return;
154 case ScalarTypeKind::UnsignedLong:
155 BuiltinStr = "ULi";
156 return;
157 case ScalarTypeKind::SignedLong:
158 BuiltinStr = "Li";
159 return;
160 case ScalarTypeKind::Boolean:
161 assert(ElementBitwidth == 1);
162 BuiltinStr += "b";
163 break;
164 case ScalarTypeKind::SignedInteger:
165 case ScalarTypeKind::UnsignedInteger:
166 switch (ElementBitwidth) {
167 case 8:
168 BuiltinStr += "c";
169 break;
170 case 16:
171 BuiltinStr += "s";
172 break;
173 case 32:
174 BuiltinStr += "i";
175 break;
176 case 64:
177 BuiltinStr += "Wi";
178 break;
179 default:
180 llvm_unreachable("Unhandled ElementBitwidth!");
182 if (isSignedInteger())
183 BuiltinStr = "S" + BuiltinStr;
184 else
185 BuiltinStr = "U" + BuiltinStr;
186 break;
187 case ScalarTypeKind::Float:
188 switch (ElementBitwidth) {
189 case 16:
190 BuiltinStr += "x";
191 break;
192 case 32:
193 BuiltinStr += "f";
194 break;
195 case 64:
196 BuiltinStr += "d";
197 break;
198 default:
199 llvm_unreachable("Unhandled ElementBitwidth!");
201 break;
202 default:
203 llvm_unreachable("ScalarType is invalid!");
205 if (IsImmediate)
206 BuiltinStr = "I" + BuiltinStr;
207 if (isScalar()) {
208 if (IsConstant)
209 BuiltinStr += "C";
210 if (IsPointer)
211 BuiltinStr += "*";
212 return;
214 BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
215 // Pointer to vector types. Defined for segment load intrinsics.
216 // segment load intrinsics have pointer type arguments to store the loaded
217 // vector values.
218 if (IsPointer)
219 BuiltinStr += "*";
221 if (IsTuple)
222 BuiltinStr = "T" + utostr(NF) + BuiltinStr;
225 void RVVType::initClangBuiltinStr() {
226 assert(isValid() && "RVVType is invalid");
227 assert(isVector() && "Handle Vector type only");
229 ClangBuiltinStr = "__rvv_";
230 switch (ScalarType) {
231 case ScalarTypeKind::Boolean:
232 ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
233 return;
234 case ScalarTypeKind::Float:
235 ClangBuiltinStr += "float";
236 break;
237 case ScalarTypeKind::SignedInteger:
238 ClangBuiltinStr += "int";
239 break;
240 case ScalarTypeKind::UnsignedInteger:
241 ClangBuiltinStr += "uint";
242 break;
243 default:
244 llvm_unreachable("ScalarTypeKind is invalid");
246 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() +
247 (IsTuple ? "x" + utostr(NF) : "") + "_t";
250 void RVVType::initTypeStr() {
251 assert(isValid() && "RVVType is invalid");
253 if (IsConstant)
254 Str += "const ";
256 auto getTypeString = [&](StringRef TypeStr) {
257 if (isScalar())
258 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
259 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() +
260 (IsTuple ? "x" + utostr(NF) : "") + "_t")
261 .str();
264 switch (ScalarType) {
265 case ScalarTypeKind::Void:
266 Str = "void";
267 return;
268 case ScalarTypeKind::Size_t:
269 Str = "size_t";
270 if (IsPointer)
271 Str += " *";
272 return;
273 case ScalarTypeKind::Ptrdiff_t:
274 Str = "ptrdiff_t";
275 return;
276 case ScalarTypeKind::UnsignedLong:
277 Str = "unsigned long";
278 return;
279 case ScalarTypeKind::SignedLong:
280 Str = "long";
281 return;
282 case ScalarTypeKind::Boolean:
283 if (isScalar())
284 Str += "bool";
285 else
286 // Vector bool is special case, the formulate is
287 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
288 Str += "vbool" + utostr(64 / *Scale) + "_t";
289 break;
290 case ScalarTypeKind::Float:
291 if (isScalar()) {
292 if (ElementBitwidth == 64)
293 Str += "double";
294 else if (ElementBitwidth == 32)
295 Str += "float";
296 else if (ElementBitwidth == 16)
297 Str += "_Float16";
298 else
299 llvm_unreachable("Unhandled floating type.");
300 } else
301 Str += getTypeString("float");
302 break;
303 case ScalarTypeKind::SignedInteger:
304 Str += getTypeString("int");
305 break;
306 case ScalarTypeKind::UnsignedInteger:
307 Str += getTypeString("uint");
308 break;
309 default:
310 llvm_unreachable("ScalarType is invalid!");
312 if (IsPointer)
313 Str += " *";
316 void RVVType::initShortStr() {
317 switch (ScalarType) {
318 case ScalarTypeKind::Boolean:
319 assert(isVector());
320 ShortStr = "b" + utostr(64 / *Scale);
321 return;
322 case ScalarTypeKind::Float:
323 ShortStr = "f" + utostr(ElementBitwidth);
324 break;
325 case ScalarTypeKind::SignedInteger:
326 ShortStr = "i" + utostr(ElementBitwidth);
327 break;
328 case ScalarTypeKind::UnsignedInteger:
329 ShortStr = "u" + utostr(ElementBitwidth);
330 break;
331 default:
332 llvm_unreachable("Unhandled case!");
334 if (isVector())
335 ShortStr += LMUL.str();
336 if (isTuple())
337 ShortStr += "x" + utostr(NF);
340 static VectorTypeModifier getTupleVTM(unsigned NF) {
341 assert(2 <= NF && NF <= 8 && "2 <= NF <= 8");
342 return static_cast<VectorTypeModifier>(
343 static_cast<uint8_t>(VectorTypeModifier::Tuple2) + (NF - 2));
346 void RVVType::applyBasicType() {
347 switch (BT) {
348 case BasicType::Int8:
349 ElementBitwidth = 8;
350 ScalarType = ScalarTypeKind::SignedInteger;
351 break;
352 case BasicType::Int16:
353 ElementBitwidth = 16;
354 ScalarType = ScalarTypeKind::SignedInteger;
355 break;
356 case BasicType::Int32:
357 ElementBitwidth = 32;
358 ScalarType = ScalarTypeKind::SignedInteger;
359 break;
360 case BasicType::Int64:
361 ElementBitwidth = 64;
362 ScalarType = ScalarTypeKind::SignedInteger;
363 break;
364 case BasicType::Float16:
365 ElementBitwidth = 16;
366 ScalarType = ScalarTypeKind::Float;
367 break;
368 case BasicType::Float32:
369 ElementBitwidth = 32;
370 ScalarType = ScalarTypeKind::Float;
371 break;
372 case BasicType::Float64:
373 ElementBitwidth = 64;
374 ScalarType = ScalarTypeKind::Float;
375 break;
376 default:
377 llvm_unreachable("Unhandled type code!");
379 assert(ElementBitwidth != 0 && "Bad element bitwidth!");
382 std::optional<PrototypeDescriptor>
383 PrototypeDescriptor::parsePrototypeDescriptor(
384 llvm::StringRef PrototypeDescriptorStr) {
385 PrototypeDescriptor PD;
386 BaseTypeModifier PT = BaseTypeModifier::Invalid;
387 VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
389 if (PrototypeDescriptorStr.empty())
390 return PD;
392 // Handle base type modifier
393 auto PType = PrototypeDescriptorStr.back();
394 switch (PType) {
395 case 'e':
396 PT = BaseTypeModifier::Scalar;
397 break;
398 case 'v':
399 PT = BaseTypeModifier::Vector;
400 break;
401 case 'w':
402 PT = BaseTypeModifier::Vector;
403 VTM = VectorTypeModifier::Widening2XVector;
404 break;
405 case 'q':
406 PT = BaseTypeModifier::Vector;
407 VTM = VectorTypeModifier::Widening4XVector;
408 break;
409 case 'o':
410 PT = BaseTypeModifier::Vector;
411 VTM = VectorTypeModifier::Widening8XVector;
412 break;
413 case 'm':
414 PT = BaseTypeModifier::Vector;
415 VTM = VectorTypeModifier::MaskVector;
416 break;
417 case '0':
418 PT = BaseTypeModifier::Void;
419 break;
420 case 'z':
421 PT = BaseTypeModifier::SizeT;
422 break;
423 case 't':
424 PT = BaseTypeModifier::Ptrdiff;
425 break;
426 case 'u':
427 PT = BaseTypeModifier::UnsignedLong;
428 break;
429 case 'l':
430 PT = BaseTypeModifier::SignedLong;
431 break;
432 default:
433 llvm_unreachable("Illegal primitive type transformers!");
435 PD.PT = static_cast<uint8_t>(PT);
436 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
438 // Compute the vector type transformers, it can only appear one time.
439 if (PrototypeDescriptorStr.startswith("(")) {
440 assert(VTM == VectorTypeModifier::NoModifier &&
441 "VectorTypeModifier should only have one modifier");
442 size_t Idx = PrototypeDescriptorStr.find(')');
443 assert(Idx != StringRef::npos);
444 StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
445 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
446 assert(!PrototypeDescriptorStr.contains('(') &&
447 "Only allow one vector type modifier");
449 auto ComplexTT = ComplexType.split(":");
450 if (ComplexTT.first == "Log2EEW") {
451 uint32_t Log2EEW;
452 if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
453 llvm_unreachable("Invalid Log2EEW value!");
454 return std::nullopt;
456 switch (Log2EEW) {
457 case 3:
458 VTM = VectorTypeModifier::Log2EEW3;
459 break;
460 case 4:
461 VTM = VectorTypeModifier::Log2EEW4;
462 break;
463 case 5:
464 VTM = VectorTypeModifier::Log2EEW5;
465 break;
466 case 6:
467 VTM = VectorTypeModifier::Log2EEW6;
468 break;
469 default:
470 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
471 return std::nullopt;
473 } else if (ComplexTT.first == "FixedSEW") {
474 uint32_t NewSEW;
475 if (ComplexTT.second.getAsInteger(10, NewSEW)) {
476 llvm_unreachable("Invalid FixedSEW value!");
477 return std::nullopt;
479 switch (NewSEW) {
480 case 8:
481 VTM = VectorTypeModifier::FixedSEW8;
482 break;
483 case 16:
484 VTM = VectorTypeModifier::FixedSEW16;
485 break;
486 case 32:
487 VTM = VectorTypeModifier::FixedSEW32;
488 break;
489 case 64:
490 VTM = VectorTypeModifier::FixedSEW64;
491 break;
492 default:
493 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
494 return std::nullopt;
496 } else if (ComplexTT.first == "LFixedLog2LMUL") {
497 int32_t Log2LMUL;
498 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
499 llvm_unreachable("Invalid LFixedLog2LMUL value!");
500 return std::nullopt;
502 switch (Log2LMUL) {
503 case -3:
504 VTM = VectorTypeModifier::LFixedLog2LMULN3;
505 break;
506 case -2:
507 VTM = VectorTypeModifier::LFixedLog2LMULN2;
508 break;
509 case -1:
510 VTM = VectorTypeModifier::LFixedLog2LMULN1;
511 break;
512 case 0:
513 VTM = VectorTypeModifier::LFixedLog2LMUL0;
514 break;
515 case 1:
516 VTM = VectorTypeModifier::LFixedLog2LMUL1;
517 break;
518 case 2:
519 VTM = VectorTypeModifier::LFixedLog2LMUL2;
520 break;
521 case 3:
522 VTM = VectorTypeModifier::LFixedLog2LMUL3;
523 break;
524 default:
525 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
526 return std::nullopt;
528 } else if (ComplexTT.first == "SFixedLog2LMUL") {
529 int32_t Log2LMUL;
530 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
531 llvm_unreachable("Invalid SFixedLog2LMUL value!");
532 return std::nullopt;
534 switch (Log2LMUL) {
535 case -3:
536 VTM = VectorTypeModifier::SFixedLog2LMULN3;
537 break;
538 case -2:
539 VTM = VectorTypeModifier::SFixedLog2LMULN2;
540 break;
541 case -1:
542 VTM = VectorTypeModifier::SFixedLog2LMULN1;
543 break;
544 case 0:
545 VTM = VectorTypeModifier::SFixedLog2LMUL0;
546 break;
547 case 1:
548 VTM = VectorTypeModifier::SFixedLog2LMUL1;
549 break;
550 case 2:
551 VTM = VectorTypeModifier::SFixedLog2LMUL2;
552 break;
553 case 3:
554 VTM = VectorTypeModifier::SFixedLog2LMUL3;
555 break;
556 default:
557 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
558 return std::nullopt;
561 } else if (ComplexTT.first == "SEFixedLog2LMUL") {
562 int32_t Log2LMUL;
563 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
564 llvm_unreachable("Invalid SEFixedLog2LMUL value!");
565 return std::nullopt;
567 switch (Log2LMUL) {
568 case -3:
569 VTM = VectorTypeModifier::SEFixedLog2LMULN3;
570 break;
571 case -2:
572 VTM = VectorTypeModifier::SEFixedLog2LMULN2;
573 break;
574 case -1:
575 VTM = VectorTypeModifier::SEFixedLog2LMULN1;
576 break;
577 case 0:
578 VTM = VectorTypeModifier::SEFixedLog2LMUL0;
579 break;
580 case 1:
581 VTM = VectorTypeModifier::SEFixedLog2LMUL1;
582 break;
583 case 2:
584 VTM = VectorTypeModifier::SEFixedLog2LMUL2;
585 break;
586 case 3:
587 VTM = VectorTypeModifier::SEFixedLog2LMUL3;
588 break;
589 default:
590 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
591 return std::nullopt;
593 } else if (ComplexTT.first == "Tuple") {
594 unsigned NF = 0;
595 if (ComplexTT.second.getAsInteger(10, NF)) {
596 llvm_unreachable("Invalid NF value!");
597 return std::nullopt;
599 VTM = getTupleVTM(NF);
600 } else {
601 llvm_unreachable("Illegal complex type transformers!");
604 PD.VTM = static_cast<uint8_t>(VTM);
606 // Compute the remain type transformers
607 TypeModifier TM = TypeModifier::NoModifier;
608 for (char I : PrototypeDescriptorStr) {
609 switch (I) {
610 case 'P':
611 if ((TM & TypeModifier::Const) == TypeModifier::Const)
612 llvm_unreachable("'P' transformer cannot be used after 'C'");
613 if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
614 llvm_unreachable("'P' transformer cannot be used twice");
615 TM |= TypeModifier::Pointer;
616 break;
617 case 'C':
618 TM |= TypeModifier::Const;
619 break;
620 case 'K':
621 TM |= TypeModifier::Immediate;
622 break;
623 case 'U':
624 TM |= TypeModifier::UnsignedInteger;
625 break;
626 case 'I':
627 TM |= TypeModifier::SignedInteger;
628 break;
629 case 'F':
630 TM |= TypeModifier::Float;
631 break;
632 case 'S':
633 TM |= TypeModifier::LMUL1;
634 break;
635 default:
636 llvm_unreachable("Illegal non-primitive type transformer!");
639 PD.TM = static_cast<uint8_t>(TM);
641 return PD;
644 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
645 // Handle primitive type transformer
646 switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
647 case BaseTypeModifier::Scalar:
648 Scale = 0;
649 break;
650 case BaseTypeModifier::Vector:
651 Scale = LMUL.getScale(ElementBitwidth);
652 break;
653 case BaseTypeModifier::Void:
654 ScalarType = ScalarTypeKind::Void;
655 break;
656 case BaseTypeModifier::SizeT:
657 ScalarType = ScalarTypeKind::Size_t;
658 break;
659 case BaseTypeModifier::Ptrdiff:
660 ScalarType = ScalarTypeKind::Ptrdiff_t;
661 break;
662 case BaseTypeModifier::UnsignedLong:
663 ScalarType = ScalarTypeKind::UnsignedLong;
664 break;
665 case BaseTypeModifier::SignedLong:
666 ScalarType = ScalarTypeKind::SignedLong;
667 break;
668 case BaseTypeModifier::Invalid:
669 ScalarType = ScalarTypeKind::Invalid;
670 return;
673 switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
674 case VectorTypeModifier::Widening2XVector:
675 ElementBitwidth *= 2;
676 LMUL.MulLog2LMUL(1);
677 Scale = LMUL.getScale(ElementBitwidth);
678 break;
679 case VectorTypeModifier::Widening4XVector:
680 ElementBitwidth *= 4;
681 LMUL.MulLog2LMUL(2);
682 Scale = LMUL.getScale(ElementBitwidth);
683 break;
684 case VectorTypeModifier::Widening8XVector:
685 ElementBitwidth *= 8;
686 LMUL.MulLog2LMUL(3);
687 Scale = LMUL.getScale(ElementBitwidth);
688 break;
689 case VectorTypeModifier::MaskVector:
690 ScalarType = ScalarTypeKind::Boolean;
691 Scale = LMUL.getScale(ElementBitwidth);
692 ElementBitwidth = 1;
693 break;
694 case VectorTypeModifier::Log2EEW3:
695 applyLog2EEW(3);
696 break;
697 case VectorTypeModifier::Log2EEW4:
698 applyLog2EEW(4);
699 break;
700 case VectorTypeModifier::Log2EEW5:
701 applyLog2EEW(5);
702 break;
703 case VectorTypeModifier::Log2EEW6:
704 applyLog2EEW(6);
705 break;
706 case VectorTypeModifier::FixedSEW8:
707 applyFixedSEW(8);
708 break;
709 case VectorTypeModifier::FixedSEW16:
710 applyFixedSEW(16);
711 break;
712 case VectorTypeModifier::FixedSEW32:
713 applyFixedSEW(32);
714 break;
715 case VectorTypeModifier::FixedSEW64:
716 applyFixedSEW(64);
717 break;
718 case VectorTypeModifier::LFixedLog2LMULN3:
719 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
720 break;
721 case VectorTypeModifier::LFixedLog2LMULN2:
722 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
723 break;
724 case VectorTypeModifier::LFixedLog2LMULN1:
725 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
726 break;
727 case VectorTypeModifier::LFixedLog2LMUL0:
728 applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
729 break;
730 case VectorTypeModifier::LFixedLog2LMUL1:
731 applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
732 break;
733 case VectorTypeModifier::LFixedLog2LMUL2:
734 applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
735 break;
736 case VectorTypeModifier::LFixedLog2LMUL3:
737 applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
738 break;
739 case VectorTypeModifier::SFixedLog2LMULN3:
740 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
741 break;
742 case VectorTypeModifier::SFixedLog2LMULN2:
743 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
744 break;
745 case VectorTypeModifier::SFixedLog2LMULN1:
746 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
747 break;
748 case VectorTypeModifier::SFixedLog2LMUL0:
749 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
750 break;
751 case VectorTypeModifier::SFixedLog2LMUL1:
752 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
753 break;
754 case VectorTypeModifier::SFixedLog2LMUL2:
755 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
756 break;
757 case VectorTypeModifier::SFixedLog2LMUL3:
758 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
759 break;
760 case VectorTypeModifier::SEFixedLog2LMULN3:
761 applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual);
762 break;
763 case VectorTypeModifier::SEFixedLog2LMULN2:
764 applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual);
765 break;
766 case VectorTypeModifier::SEFixedLog2LMULN1:
767 applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual);
768 break;
769 case VectorTypeModifier::SEFixedLog2LMUL0:
770 applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual);
771 break;
772 case VectorTypeModifier::SEFixedLog2LMUL1:
773 applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual);
774 break;
775 case VectorTypeModifier::SEFixedLog2LMUL2:
776 applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual);
777 break;
778 case VectorTypeModifier::SEFixedLog2LMUL3:
779 applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual);
780 break;
781 case VectorTypeModifier::Tuple2:
782 case VectorTypeModifier::Tuple3:
783 case VectorTypeModifier::Tuple4:
784 case VectorTypeModifier::Tuple5:
785 case VectorTypeModifier::Tuple6:
786 case VectorTypeModifier::Tuple7:
787 case VectorTypeModifier::Tuple8: {
788 IsTuple = true;
789 NF = 2 + static_cast<uint8_t>(Transformer.VTM) -
790 static_cast<uint8_t>(VectorTypeModifier::Tuple2);
791 break;
793 case VectorTypeModifier::NoModifier:
794 break;
797 // Early return if the current type modifier is already invalid.
798 if (ScalarType == Invalid)
799 return;
801 for (unsigned TypeModifierMaskShift = 0;
802 TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
803 ++TypeModifierMaskShift) {
804 unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
805 if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
806 TypeModifierMask)
807 continue;
808 switch (static_cast<TypeModifier>(TypeModifierMask)) {
809 case TypeModifier::Pointer:
810 IsPointer = true;
811 break;
812 case TypeModifier::Const:
813 IsConstant = true;
814 break;
815 case TypeModifier::Immediate:
816 IsImmediate = true;
817 IsConstant = true;
818 break;
819 case TypeModifier::UnsignedInteger:
820 ScalarType = ScalarTypeKind::UnsignedInteger;
821 break;
822 case TypeModifier::SignedInteger:
823 ScalarType = ScalarTypeKind::SignedInteger;
824 break;
825 case TypeModifier::Float:
826 ScalarType = ScalarTypeKind::Float;
827 break;
828 case TypeModifier::LMUL1:
829 LMUL = LMULType(0);
830 // Update ElementBitwidth need to update Scale too.
831 Scale = LMUL.getScale(ElementBitwidth);
832 break;
833 default:
834 llvm_unreachable("Unknown type modifier mask!");
839 void RVVType::applyLog2EEW(unsigned Log2EEW) {
840 // update new elmul = (eew/sew) * lmul
841 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
842 // update new eew
843 ElementBitwidth = 1 << Log2EEW;
844 ScalarType = ScalarTypeKind::SignedInteger;
845 Scale = LMUL.getScale(ElementBitwidth);
848 void RVVType::applyFixedSEW(unsigned NewSEW) {
849 // Set invalid type if src and dst SEW are same.
850 if (ElementBitwidth == NewSEW) {
851 ScalarType = ScalarTypeKind::Invalid;
852 return;
854 // Update new SEW
855 ElementBitwidth = NewSEW;
856 Scale = LMUL.getScale(ElementBitwidth);
859 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
860 switch (Type) {
861 case FixedLMULType::LargerThan:
862 if (Log2LMUL <= LMUL.Log2LMUL) {
863 ScalarType = ScalarTypeKind::Invalid;
864 return;
866 break;
867 case FixedLMULType::SmallerThan:
868 if (Log2LMUL >= LMUL.Log2LMUL) {
869 ScalarType = ScalarTypeKind::Invalid;
870 return;
872 break;
873 case FixedLMULType::SmallerOrEqual:
874 if (Log2LMUL > LMUL.Log2LMUL) {
875 ScalarType = ScalarTypeKind::Invalid;
876 return;
878 break;
881 // Update new LMUL
882 LMUL = LMULType(Log2LMUL);
883 Scale = LMUL.getScale(ElementBitwidth);
886 std::optional<RVVTypes>
887 RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
888 ArrayRef<PrototypeDescriptor> Prototype) {
889 RVVTypes Types;
890 for (const PrototypeDescriptor &Proto : Prototype) {
891 auto T = computeType(BT, Log2LMUL, Proto);
892 if (!T)
893 return std::nullopt;
894 // Record legal type index
895 Types.push_back(*T);
897 return Types;
900 // Compute the hash value of RVVType, used for cache the result of computeType.
901 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
902 PrototypeDescriptor Proto) {
903 // Layout of hash value:
904 // 0 8 16 24 32 40
905 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
906 assert(Log2LMUL >= -3 && Log2LMUL <= 3);
907 return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
908 ((uint64_t)(Proto.PT & 0xff) << 16) |
909 ((uint64_t)(Proto.TM & 0xff) << 24) |
910 ((uint64_t)(Proto.VTM & 0xff) << 32);
913 std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
914 PrototypeDescriptor Proto) {
915 uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
916 // Search first
917 auto It = LegalTypes.find(Idx);
918 if (It != LegalTypes.end())
919 return &(It->second);
921 if (IllegalTypes.count(Idx))
922 return std::nullopt;
924 // Compute type and record the result.
925 RVVType T(BT, Log2LMUL, Proto);
926 if (T.isValid()) {
927 // Record legal type index and value.
928 std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
929 InsertResult = LegalTypes.insert({Idx, T});
930 return &(InsertResult.first->second);
932 // Record illegal type index.
933 IllegalTypes.insert(Idx);
934 return std::nullopt;
937 //===----------------------------------------------------------------------===//
938 // RVVIntrinsic implementation
939 //===----------------------------------------------------------------------===//
940 RVVIntrinsic::RVVIntrinsic(
941 StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
942 StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
943 bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
944 bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen,
945 const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
946 const std::vector<StringRef> &RequiredFeatures, unsigned NF,
947 Policy NewPolicyAttrs, bool HasFRMRoundModeOp)
948 : IRName(IRName), IsMasked(IsMasked),
949 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
950 SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
951 ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
953 // Init BuiltinName, Name and OverloadedName
954 BuiltinName = NewName.str();
955 Name = BuiltinName;
956 if (NewOverloadedName.empty())
957 OverloadedName = NewName.split("_").first.str();
958 else
959 OverloadedName = NewOverloadedName.str();
960 if (!Suffix.empty())
961 Name += "_" + Suffix.str();
962 if (!OverloadedSuffix.empty())
963 OverloadedName += "_" + OverloadedSuffix.str();
965 updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
966 PolicyAttrs, HasFRMRoundModeOp);
968 // Init OutputType and InputTypes
969 OutputType = OutInTypes[0];
970 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
972 // IntrinsicTypes is unmasked TA version index. Need to update it
973 // if there is merge operand (It is always in first operand).
974 IntrinsicTypes = NewIntrinsicTypes;
975 if ((IsMasked && hasMaskedOffOperand()) ||
976 (!IsMasked && hasPassthruOperand())) {
977 for (auto &I : IntrinsicTypes) {
978 if (I >= 0)
979 I += NF;
984 std::string RVVIntrinsic::getBuiltinTypeStr() const {
985 std::string S;
986 S += OutputType->getBuiltinStr();
987 for (const auto &T : InputTypes) {
988 S += T->getBuiltinStr();
990 return S;
993 std::string RVVIntrinsic::getSuffixStr(
994 RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
995 llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
996 SmallVector<std::string> SuffixStrs;
997 for (auto PD : PrototypeDescriptors) {
998 auto T = TypeCache.computeType(Type, Log2LMUL, PD);
999 SuffixStrs.push_back((*T)->getShortStr());
1001 return join(SuffixStrs, "_");
1004 llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
1005 llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
1006 bool HasMaskedOffOperand, bool HasVL, unsigned NF,
1007 PolicyScheme DefaultScheme, Policy PolicyAttrs, bool IsTuple) {
1008 SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
1009 Prototype.end());
1010 bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
1011 if (IsMasked) {
1012 // If HasMaskedOffOperand, insert result type as first input operand if
1013 // need.
1014 if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
1015 if (NF == 1) {
1016 NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
1017 } else if (NF > 1) {
1018 if (IsTuple) {
1019 PrototypeDescriptor BasePtrOperand = Prototype[1];
1020 PrototypeDescriptor MaskoffType = PrototypeDescriptor(
1021 static_cast<uint8_t>(BaseTypeModifier::Vector),
1022 static_cast<uint8_t>(getTupleVTM(NF)),
1023 BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1024 NewPrototype.insert(NewPrototype.begin() + 1, MaskoffType);
1025 } else {
1026 // Convert
1027 // (void, op0 address, op1 address, ...)
1028 // to
1029 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1030 PrototypeDescriptor MaskoffType = NewPrototype[1];
1031 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1032 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1036 if (HasMaskedOffOperand && NF > 1) {
1037 // Convert
1038 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1039 // to
1040 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1041 // ...)
1042 if (IsTuple)
1043 NewPrototype.insert(NewPrototype.begin() + 1,
1044 PrototypeDescriptor::Mask);
1045 else
1046 NewPrototype.insert(NewPrototype.begin() + NF + 1,
1047 PrototypeDescriptor::Mask);
1048 } else {
1049 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
1050 NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
1052 } else {
1053 if (NF == 1) {
1054 if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
1055 NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
1056 } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
1057 if (IsTuple) {
1058 PrototypeDescriptor BasePtrOperand = Prototype[0];
1059 PrototypeDescriptor MaskoffType = PrototypeDescriptor(
1060 static_cast<uint8_t>(BaseTypeModifier::Vector),
1061 static_cast<uint8_t>(getTupleVTM(NF)),
1062 BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1063 NewPrototype.insert(NewPrototype.begin(), MaskoffType);
1064 } else {
1065 // NF > 1 cases for segment load operations.
1066 // Convert
1067 // (void, op0 address, op1 address, ...)
1068 // to
1069 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1070 PrototypeDescriptor MaskoffType = Prototype[1];
1071 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1072 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1077 // If HasVL, append PrototypeDescriptor:VL to last operand
1078 if (HasVL)
1079 NewPrototype.push_back(PrototypeDescriptor::VL);
1081 return NewPrototype;
1084 llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1085 return {Policy(Policy::PolicyType::Undisturbed)}; // TU
1088 llvm::SmallVector<Policy>
1089 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
1090 bool HasMaskPolicy) {
1091 if (HasTailPolicy && HasMaskPolicy)
1092 return {Policy(Policy::PolicyType::Undisturbed,
1093 Policy::PolicyType::Agnostic), // TUM
1094 Policy(Policy::PolicyType::Undisturbed,
1095 Policy::PolicyType::Undisturbed), // TUMU
1096 Policy(Policy::PolicyType::Agnostic,
1097 Policy::PolicyType::Undisturbed)}; // MU
1098 if (HasTailPolicy && !HasMaskPolicy)
1099 return {Policy(Policy::PolicyType::Undisturbed,
1100 Policy::PolicyType::Agnostic)}; // TU
1101 if (!HasTailPolicy && HasMaskPolicy)
1102 return {Policy(Policy::PolicyType::Agnostic,
1103 Policy::PolicyType::Undisturbed)}; // MU
1104 llvm_unreachable("An RVV instruction should not be without both tail policy "
1105 "and mask policy");
1108 void RVVIntrinsic::updateNamesAndPolicy(
1109 bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName,
1110 std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) {
1112 auto appendPolicySuffix = [&](const std::string &suffix) {
1113 Name += suffix;
1114 BuiltinName += suffix;
1115 OverloadedName += suffix;
1118 // This follows the naming guideline under riscv-c-api-doc to add the
1119 // `__riscv_` suffix for all RVV intrinsics.
1120 Name = "__riscv_" + Name;
1121 OverloadedName = "__riscv_" + OverloadedName;
1123 if (HasFRMRoundModeOp) {
1124 Name += "_rm";
1125 BuiltinName += "_rm";
1128 if (IsMasked) {
1129 if (PolicyAttrs.isTUMUPolicy())
1130 appendPolicySuffix("_tumu");
1131 else if (PolicyAttrs.isTUMAPolicy())
1132 appendPolicySuffix("_tum");
1133 else if (PolicyAttrs.isTAMUPolicy())
1134 appendPolicySuffix("_mu");
1135 else if (PolicyAttrs.isTAMAPolicy()) {
1136 Name += "_m";
1137 BuiltinName += "_m";
1138 } else
1139 llvm_unreachable("Unhandled policy condition");
1140 } else {
1141 if (PolicyAttrs.isTUPolicy())
1142 appendPolicySuffix("_tu");
1143 else if (PolicyAttrs.isTAPolicy()) // no suffix needed
1144 return;
1145 else
1146 llvm_unreachable("Unhandled policy condition");
1150 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
1151 SmallVector<PrototypeDescriptor> PrototypeDescriptors;
1152 const StringRef Primaries("evwqom0ztul");
1153 while (!Prototypes.empty()) {
1154 size_t Idx = 0;
1155 // Skip over complex prototype because it could contain primitive type
1156 // character.
1157 if (Prototypes[0] == '(')
1158 Idx = Prototypes.find_first_of(')');
1159 Idx = Prototypes.find_first_of(Primaries, Idx);
1160 assert(Idx != StringRef::npos);
1161 auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
1162 Prototypes.slice(0, Idx + 1));
1163 if (!PD)
1164 llvm_unreachable("Error during parsing prototype.");
1165 PrototypeDescriptors.push_back(*PD);
1166 Prototypes = Prototypes.drop_front(Idx + 1);
1168 return PrototypeDescriptors;
1171 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
1172 OS << "{";
1173 OS << "\"" << Record.Name << "\",";
1174 if (Record.OverloadedName == nullptr ||
1175 StringRef(Record.OverloadedName).empty())
1176 OS << "nullptr,";
1177 else
1178 OS << "\"" << Record.OverloadedName << "\",";
1179 OS << Record.PrototypeIndex << ",";
1180 OS << Record.SuffixIndex << ",";
1181 OS << Record.OverloadedSuffixIndex << ",";
1182 OS << (int)Record.PrototypeLength << ",";
1183 OS << (int)Record.SuffixLength << ",";
1184 OS << (int)Record.OverloadedSuffixSize << ",";
1185 OS << (int)Record.RequiredExtensions << ",";
1186 OS << (int)Record.TypeRangeMask << ",";
1187 OS << (int)Record.Log2LMULMask << ",";
1188 OS << (int)Record.NF << ",";
1189 OS << (int)Record.HasMasked << ",";
1190 OS << (int)Record.HasVL << ",";
1191 OS << (int)Record.HasMaskedOffOperand << ",";
1192 OS << (int)Record.HasTailPolicy << ",";
1193 OS << (int)Record.HasMaskPolicy << ",";
1194 OS << (int)Record.HasFRMRoundModeOp << ",";
1195 OS << (int)Record.IsTuple << ",";
1196 OS << (int)Record.UnMaskedPolicyScheme << ",";
1197 OS << (int)Record.MaskedPolicyScheme << ",";
1198 OS << "},\n";
1199 return OS;
1202 } // end namespace RISCV
1203 } // end namespace clang