[lld][WebAssembly] Add `--table-base` setting
[llvm-project.git] / clang / lib / Support / RISCVVIntrinsicUtils.cpp
blobc105db434dc43c9b0f9353cd2df52561b3154cb5
1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringMap.h"
14 #include "llvm/ADT/StringSet.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/ErrorHandling.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include <numeric>
19 #include <optional>
21 using namespace llvm;
23 namespace clang {
24 namespace RISCV {
26 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
27 BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
28 const PrototypeDescriptor PrototypeDescriptor::VL =
29 PrototypeDescriptor(BaseTypeModifier::SizeT);
30 const PrototypeDescriptor PrototypeDescriptor::Vector =
31 PrototypeDescriptor(BaseTypeModifier::Vector);
33 //===----------------------------------------------------------------------===//
34 // Type implementation
35 //===----------------------------------------------------------------------===//
37 LMULType::LMULType(int NewLog2LMUL) {
38 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
39 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
40 Log2LMUL = NewLog2LMUL;
43 std::string LMULType::str() const {
44 if (Log2LMUL < 0)
45 return "mf" + utostr(1ULL << (-Log2LMUL));
46 return "m" + utostr(1ULL << Log2LMUL);
49 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
50 int Log2ScaleResult = 0;
51 switch (ElementBitwidth) {
52 default:
53 break;
54 case 8:
55 Log2ScaleResult = Log2LMUL + 3;
56 break;
57 case 16:
58 Log2ScaleResult = Log2LMUL + 2;
59 break;
60 case 32:
61 Log2ScaleResult = Log2LMUL + 1;
62 break;
63 case 64:
64 Log2ScaleResult = Log2LMUL;
65 break;
67 // Illegal vscale result would be less than 1
68 if (Log2ScaleResult < 0)
69 return std::nullopt;
70 return 1 << Log2ScaleResult;
73 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
75 RVVType::RVVType(BasicType BT, int Log2LMUL,
76 const PrototypeDescriptor &prototype)
77 : BT(BT), LMUL(LMULType(Log2LMUL)) {
78 applyBasicType();
79 applyModifier(prototype);
80 Valid = verifyType();
81 if (Valid) {
82 initBuiltinStr();
83 initTypeStr();
84 if (isVector()) {
85 initClangBuiltinStr();
90 // clang-format off
91 // boolean type are encoded the ratio of n (SEW/LMUL)
92 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
93 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
94 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
96 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
97 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
98 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
99 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
100 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
101 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
102 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
103 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
104 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
105 // clang-format on
107 bool RVVType::verifyType() const {
108 if (ScalarType == Invalid)
109 return false;
110 if (isScalar())
111 return true;
112 if (!Scale)
113 return false;
114 if (isFloat() && ElementBitwidth == 8)
115 return false;
116 if (IsTuple && (NF == 1 || NF > 8))
117 return false;
118 if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
119 return false;
120 unsigned V = *Scale;
121 switch (ElementBitwidth) {
122 case 1:
123 case 8:
124 // Check Scale is 1,2,4,8,16,32,64
125 return (V <= 64 && isPowerOf2_32(V));
126 case 16:
127 // Check Scale is 1,2,4,8,16,32
128 return (V <= 32 && isPowerOf2_32(V));
129 case 32:
130 // Check Scale is 1,2,4,8,16
131 return (V <= 16 && isPowerOf2_32(V));
132 case 64:
133 // Check Scale is 1,2,4,8
134 return (V <= 8 && isPowerOf2_32(V));
136 return false;
139 void RVVType::initBuiltinStr() {
140 assert(isValid() && "RVVType is invalid");
141 switch (ScalarType) {
142 case ScalarTypeKind::Void:
143 BuiltinStr = "v";
144 return;
145 case ScalarTypeKind::Size_t:
146 BuiltinStr = "z";
147 if (IsImmediate)
148 BuiltinStr = "I" + BuiltinStr;
149 if (IsPointer)
150 BuiltinStr += "*";
151 return;
152 case ScalarTypeKind::Ptrdiff_t:
153 BuiltinStr = "Y";
154 return;
155 case ScalarTypeKind::UnsignedLong:
156 BuiltinStr = "ULi";
157 return;
158 case ScalarTypeKind::SignedLong:
159 BuiltinStr = "Li";
160 return;
161 case ScalarTypeKind::Boolean:
162 assert(ElementBitwidth == 1);
163 BuiltinStr += "b";
164 break;
165 case ScalarTypeKind::SignedInteger:
166 case ScalarTypeKind::UnsignedInteger:
167 switch (ElementBitwidth) {
168 case 8:
169 BuiltinStr += "c";
170 break;
171 case 16:
172 BuiltinStr += "s";
173 break;
174 case 32:
175 BuiltinStr += "i";
176 break;
177 case 64:
178 BuiltinStr += "Wi";
179 break;
180 default:
181 llvm_unreachable("Unhandled ElementBitwidth!");
183 if (isSignedInteger())
184 BuiltinStr = "S" + BuiltinStr;
185 else
186 BuiltinStr = "U" + BuiltinStr;
187 break;
188 case ScalarTypeKind::Float:
189 switch (ElementBitwidth) {
190 case 16:
191 BuiltinStr += "x";
192 break;
193 case 32:
194 BuiltinStr += "f";
195 break;
196 case 64:
197 BuiltinStr += "d";
198 break;
199 default:
200 llvm_unreachable("Unhandled ElementBitwidth!");
202 break;
203 default:
204 llvm_unreachable("ScalarType is invalid!");
206 if (IsImmediate)
207 BuiltinStr = "I" + BuiltinStr;
208 if (isScalar()) {
209 if (IsConstant)
210 BuiltinStr += "C";
211 if (IsPointer)
212 BuiltinStr += "*";
213 return;
215 BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
216 // Pointer to vector types. Defined for segment load intrinsics.
217 // segment load intrinsics have pointer type arguments to store the loaded
218 // vector values.
219 if (IsPointer)
220 BuiltinStr += "*";
222 if (IsTuple)
223 BuiltinStr = "T" + utostr(NF) + BuiltinStr;
226 void RVVType::initClangBuiltinStr() {
227 assert(isValid() && "RVVType is invalid");
228 assert(isVector() && "Handle Vector type only");
230 ClangBuiltinStr = "__rvv_";
231 switch (ScalarType) {
232 case ScalarTypeKind::Boolean:
233 ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
234 return;
235 case ScalarTypeKind::Float:
236 ClangBuiltinStr += "float";
237 break;
238 case ScalarTypeKind::SignedInteger:
239 ClangBuiltinStr += "int";
240 break;
241 case ScalarTypeKind::UnsignedInteger:
242 ClangBuiltinStr += "uint";
243 break;
244 default:
245 llvm_unreachable("ScalarTypeKind is invalid");
247 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() +
248 (IsTuple ? "x" + utostr(NF) : "") + "_t";
251 void RVVType::initTypeStr() {
252 assert(isValid() && "RVVType is invalid");
254 if (IsConstant)
255 Str += "const ";
257 auto getTypeString = [&](StringRef TypeStr) {
258 if (isScalar())
259 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
260 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() +
261 (IsTuple ? "x" + utostr(NF) : "") + "_t")
262 .str();
265 switch (ScalarType) {
266 case ScalarTypeKind::Void:
267 Str = "void";
268 return;
269 case ScalarTypeKind::Size_t:
270 Str = "size_t";
271 if (IsPointer)
272 Str += " *";
273 return;
274 case ScalarTypeKind::Ptrdiff_t:
275 Str = "ptrdiff_t";
276 return;
277 case ScalarTypeKind::UnsignedLong:
278 Str = "unsigned long";
279 return;
280 case ScalarTypeKind::SignedLong:
281 Str = "long";
282 return;
283 case ScalarTypeKind::Boolean:
284 if (isScalar())
285 Str += "bool";
286 else
287 // Vector bool is special case, the formulate is
288 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
289 Str += "vbool" + utostr(64 / *Scale) + "_t";
290 break;
291 case ScalarTypeKind::Float:
292 if (isScalar()) {
293 if (ElementBitwidth == 64)
294 Str += "double";
295 else if (ElementBitwidth == 32)
296 Str += "float";
297 else if (ElementBitwidth == 16)
298 Str += "_Float16";
299 else
300 llvm_unreachable("Unhandled floating type.");
301 } else
302 Str += getTypeString("float");
303 break;
304 case ScalarTypeKind::SignedInteger:
305 Str += getTypeString("int");
306 break;
307 case ScalarTypeKind::UnsignedInteger:
308 Str += getTypeString("uint");
309 break;
310 default:
311 llvm_unreachable("ScalarType is invalid!");
313 if (IsPointer)
314 Str += " *";
317 void RVVType::initShortStr() {
318 switch (ScalarType) {
319 case ScalarTypeKind::Boolean:
320 assert(isVector());
321 ShortStr = "b" + utostr(64 / *Scale);
322 return;
323 case ScalarTypeKind::Float:
324 ShortStr = "f" + utostr(ElementBitwidth);
325 break;
326 case ScalarTypeKind::SignedInteger:
327 ShortStr = "i" + utostr(ElementBitwidth);
328 break;
329 case ScalarTypeKind::UnsignedInteger:
330 ShortStr = "u" + utostr(ElementBitwidth);
331 break;
332 default:
333 llvm_unreachable("Unhandled case!");
335 if (isVector())
336 ShortStr += LMUL.str();
337 if (isTuple())
338 ShortStr += "x" + utostr(NF);
341 static VectorTypeModifier getTupleVTM(unsigned NF) {
342 assert(2 <= NF && NF <= 8 && "2 <= NF <= 8");
343 return static_cast<VectorTypeModifier>(
344 static_cast<uint8_t>(VectorTypeModifier::Tuple2) + (NF - 2));
347 void RVVType::applyBasicType() {
348 switch (BT) {
349 case BasicType::Int8:
350 ElementBitwidth = 8;
351 ScalarType = ScalarTypeKind::SignedInteger;
352 break;
353 case BasicType::Int16:
354 ElementBitwidth = 16;
355 ScalarType = ScalarTypeKind::SignedInteger;
356 break;
357 case BasicType::Int32:
358 ElementBitwidth = 32;
359 ScalarType = ScalarTypeKind::SignedInteger;
360 break;
361 case BasicType::Int64:
362 ElementBitwidth = 64;
363 ScalarType = ScalarTypeKind::SignedInteger;
364 break;
365 case BasicType::Float16:
366 ElementBitwidth = 16;
367 ScalarType = ScalarTypeKind::Float;
368 break;
369 case BasicType::Float32:
370 ElementBitwidth = 32;
371 ScalarType = ScalarTypeKind::Float;
372 break;
373 case BasicType::Float64:
374 ElementBitwidth = 64;
375 ScalarType = ScalarTypeKind::Float;
376 break;
377 default:
378 llvm_unreachable("Unhandled type code!");
380 assert(ElementBitwidth != 0 && "Bad element bitwidth!");
383 std::optional<PrototypeDescriptor>
384 PrototypeDescriptor::parsePrototypeDescriptor(
385 llvm::StringRef PrototypeDescriptorStr) {
386 PrototypeDescriptor PD;
387 BaseTypeModifier PT = BaseTypeModifier::Invalid;
388 VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
390 if (PrototypeDescriptorStr.empty())
391 return PD;
393 // Handle base type modifier
394 auto PType = PrototypeDescriptorStr.back();
395 switch (PType) {
396 case 'e':
397 PT = BaseTypeModifier::Scalar;
398 break;
399 case 'v':
400 PT = BaseTypeModifier::Vector;
401 break;
402 case 'w':
403 PT = BaseTypeModifier::Vector;
404 VTM = VectorTypeModifier::Widening2XVector;
405 break;
406 case 'q':
407 PT = BaseTypeModifier::Vector;
408 VTM = VectorTypeModifier::Widening4XVector;
409 break;
410 case 'o':
411 PT = BaseTypeModifier::Vector;
412 VTM = VectorTypeModifier::Widening8XVector;
413 break;
414 case 'm':
415 PT = BaseTypeModifier::Vector;
416 VTM = VectorTypeModifier::MaskVector;
417 break;
418 case '0':
419 PT = BaseTypeModifier::Void;
420 break;
421 case 'z':
422 PT = BaseTypeModifier::SizeT;
423 break;
424 case 't':
425 PT = BaseTypeModifier::Ptrdiff;
426 break;
427 case 'u':
428 PT = BaseTypeModifier::UnsignedLong;
429 break;
430 case 'l':
431 PT = BaseTypeModifier::SignedLong;
432 break;
433 default:
434 llvm_unreachable("Illegal primitive type transformers!");
436 PD.PT = static_cast<uint8_t>(PT);
437 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
439 // Compute the vector type transformers, it can only appear one time.
440 if (PrototypeDescriptorStr.startswith("(")) {
441 assert(VTM == VectorTypeModifier::NoModifier &&
442 "VectorTypeModifier should only have one modifier");
443 size_t Idx = PrototypeDescriptorStr.find(')');
444 assert(Idx != StringRef::npos);
445 StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
446 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
447 assert(!PrototypeDescriptorStr.contains('(') &&
448 "Only allow one vector type modifier");
450 auto ComplexTT = ComplexType.split(":");
451 if (ComplexTT.first == "Log2EEW") {
452 uint32_t Log2EEW;
453 if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
454 llvm_unreachable("Invalid Log2EEW value!");
455 return std::nullopt;
457 switch (Log2EEW) {
458 case 3:
459 VTM = VectorTypeModifier::Log2EEW3;
460 break;
461 case 4:
462 VTM = VectorTypeModifier::Log2EEW4;
463 break;
464 case 5:
465 VTM = VectorTypeModifier::Log2EEW5;
466 break;
467 case 6:
468 VTM = VectorTypeModifier::Log2EEW6;
469 break;
470 default:
471 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
472 return std::nullopt;
474 } else if (ComplexTT.first == "FixedSEW") {
475 uint32_t NewSEW;
476 if (ComplexTT.second.getAsInteger(10, NewSEW)) {
477 llvm_unreachable("Invalid FixedSEW value!");
478 return std::nullopt;
480 switch (NewSEW) {
481 case 8:
482 VTM = VectorTypeModifier::FixedSEW8;
483 break;
484 case 16:
485 VTM = VectorTypeModifier::FixedSEW16;
486 break;
487 case 32:
488 VTM = VectorTypeModifier::FixedSEW32;
489 break;
490 case 64:
491 VTM = VectorTypeModifier::FixedSEW64;
492 break;
493 default:
494 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
495 return std::nullopt;
497 } else if (ComplexTT.first == "LFixedLog2LMUL") {
498 int32_t Log2LMUL;
499 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
500 llvm_unreachable("Invalid LFixedLog2LMUL value!");
501 return std::nullopt;
503 switch (Log2LMUL) {
504 case -3:
505 VTM = VectorTypeModifier::LFixedLog2LMULN3;
506 break;
507 case -2:
508 VTM = VectorTypeModifier::LFixedLog2LMULN2;
509 break;
510 case -1:
511 VTM = VectorTypeModifier::LFixedLog2LMULN1;
512 break;
513 case 0:
514 VTM = VectorTypeModifier::LFixedLog2LMUL0;
515 break;
516 case 1:
517 VTM = VectorTypeModifier::LFixedLog2LMUL1;
518 break;
519 case 2:
520 VTM = VectorTypeModifier::LFixedLog2LMUL2;
521 break;
522 case 3:
523 VTM = VectorTypeModifier::LFixedLog2LMUL3;
524 break;
525 default:
526 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
527 return std::nullopt;
529 } else if (ComplexTT.first == "SFixedLog2LMUL") {
530 int32_t Log2LMUL;
531 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
532 llvm_unreachable("Invalid SFixedLog2LMUL value!");
533 return std::nullopt;
535 switch (Log2LMUL) {
536 case -3:
537 VTM = VectorTypeModifier::SFixedLog2LMULN3;
538 break;
539 case -2:
540 VTM = VectorTypeModifier::SFixedLog2LMULN2;
541 break;
542 case -1:
543 VTM = VectorTypeModifier::SFixedLog2LMULN1;
544 break;
545 case 0:
546 VTM = VectorTypeModifier::SFixedLog2LMUL0;
547 break;
548 case 1:
549 VTM = VectorTypeModifier::SFixedLog2LMUL1;
550 break;
551 case 2:
552 VTM = VectorTypeModifier::SFixedLog2LMUL2;
553 break;
554 case 3:
555 VTM = VectorTypeModifier::SFixedLog2LMUL3;
556 break;
557 default:
558 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
559 return std::nullopt;
562 } else if (ComplexTT.first == "SEFixedLog2LMUL") {
563 int32_t Log2LMUL;
564 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
565 llvm_unreachable("Invalid SEFixedLog2LMUL value!");
566 return std::nullopt;
568 switch (Log2LMUL) {
569 case -3:
570 VTM = VectorTypeModifier::SEFixedLog2LMULN3;
571 break;
572 case -2:
573 VTM = VectorTypeModifier::SEFixedLog2LMULN2;
574 break;
575 case -1:
576 VTM = VectorTypeModifier::SEFixedLog2LMULN1;
577 break;
578 case 0:
579 VTM = VectorTypeModifier::SEFixedLog2LMUL0;
580 break;
581 case 1:
582 VTM = VectorTypeModifier::SEFixedLog2LMUL1;
583 break;
584 case 2:
585 VTM = VectorTypeModifier::SEFixedLog2LMUL2;
586 break;
587 case 3:
588 VTM = VectorTypeModifier::SEFixedLog2LMUL3;
589 break;
590 default:
591 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
592 return std::nullopt;
594 } else if (ComplexTT.first == "Tuple") {
595 unsigned NF = 0;
596 if (ComplexTT.second.getAsInteger(10, NF)) {
597 llvm_unreachable("Invalid NF value!");
598 return std::nullopt;
600 VTM = getTupleVTM(NF);
601 } else {
602 llvm_unreachable("Illegal complex type transformers!");
605 PD.VTM = static_cast<uint8_t>(VTM);
607 // Compute the remain type transformers
608 TypeModifier TM = TypeModifier::NoModifier;
609 for (char I : PrototypeDescriptorStr) {
610 switch (I) {
611 case 'P':
612 if ((TM & TypeModifier::Const) == TypeModifier::Const)
613 llvm_unreachable("'P' transformer cannot be used after 'C'");
614 if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
615 llvm_unreachable("'P' transformer cannot be used twice");
616 TM |= TypeModifier::Pointer;
617 break;
618 case 'C':
619 TM |= TypeModifier::Const;
620 break;
621 case 'K':
622 TM |= TypeModifier::Immediate;
623 break;
624 case 'U':
625 TM |= TypeModifier::UnsignedInteger;
626 break;
627 case 'I':
628 TM |= TypeModifier::SignedInteger;
629 break;
630 case 'F':
631 TM |= TypeModifier::Float;
632 break;
633 case 'S':
634 TM |= TypeModifier::LMUL1;
635 break;
636 default:
637 llvm_unreachable("Illegal non-primitive type transformer!");
640 PD.TM = static_cast<uint8_t>(TM);
642 return PD;
645 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
646 // Handle primitive type transformer
647 switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
648 case BaseTypeModifier::Scalar:
649 Scale = 0;
650 break;
651 case BaseTypeModifier::Vector:
652 Scale = LMUL.getScale(ElementBitwidth);
653 break;
654 case BaseTypeModifier::Void:
655 ScalarType = ScalarTypeKind::Void;
656 break;
657 case BaseTypeModifier::SizeT:
658 ScalarType = ScalarTypeKind::Size_t;
659 break;
660 case BaseTypeModifier::Ptrdiff:
661 ScalarType = ScalarTypeKind::Ptrdiff_t;
662 break;
663 case BaseTypeModifier::UnsignedLong:
664 ScalarType = ScalarTypeKind::UnsignedLong;
665 break;
666 case BaseTypeModifier::SignedLong:
667 ScalarType = ScalarTypeKind::SignedLong;
668 break;
669 case BaseTypeModifier::Invalid:
670 ScalarType = ScalarTypeKind::Invalid;
671 return;
674 switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
675 case VectorTypeModifier::Widening2XVector:
676 ElementBitwidth *= 2;
677 LMUL.MulLog2LMUL(1);
678 Scale = LMUL.getScale(ElementBitwidth);
679 break;
680 case VectorTypeModifier::Widening4XVector:
681 ElementBitwidth *= 4;
682 LMUL.MulLog2LMUL(2);
683 Scale = LMUL.getScale(ElementBitwidth);
684 break;
685 case VectorTypeModifier::Widening8XVector:
686 ElementBitwidth *= 8;
687 LMUL.MulLog2LMUL(3);
688 Scale = LMUL.getScale(ElementBitwidth);
689 break;
690 case VectorTypeModifier::MaskVector:
691 ScalarType = ScalarTypeKind::Boolean;
692 Scale = LMUL.getScale(ElementBitwidth);
693 ElementBitwidth = 1;
694 break;
695 case VectorTypeModifier::Log2EEW3:
696 applyLog2EEW(3);
697 break;
698 case VectorTypeModifier::Log2EEW4:
699 applyLog2EEW(4);
700 break;
701 case VectorTypeModifier::Log2EEW5:
702 applyLog2EEW(5);
703 break;
704 case VectorTypeModifier::Log2EEW6:
705 applyLog2EEW(6);
706 break;
707 case VectorTypeModifier::FixedSEW8:
708 applyFixedSEW(8);
709 break;
710 case VectorTypeModifier::FixedSEW16:
711 applyFixedSEW(16);
712 break;
713 case VectorTypeModifier::FixedSEW32:
714 applyFixedSEW(32);
715 break;
716 case VectorTypeModifier::FixedSEW64:
717 applyFixedSEW(64);
718 break;
719 case VectorTypeModifier::LFixedLog2LMULN3:
720 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
721 break;
722 case VectorTypeModifier::LFixedLog2LMULN2:
723 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
724 break;
725 case VectorTypeModifier::LFixedLog2LMULN1:
726 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
727 break;
728 case VectorTypeModifier::LFixedLog2LMUL0:
729 applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
730 break;
731 case VectorTypeModifier::LFixedLog2LMUL1:
732 applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
733 break;
734 case VectorTypeModifier::LFixedLog2LMUL2:
735 applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
736 break;
737 case VectorTypeModifier::LFixedLog2LMUL3:
738 applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
739 break;
740 case VectorTypeModifier::SFixedLog2LMULN3:
741 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
742 break;
743 case VectorTypeModifier::SFixedLog2LMULN2:
744 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
745 break;
746 case VectorTypeModifier::SFixedLog2LMULN1:
747 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
748 break;
749 case VectorTypeModifier::SFixedLog2LMUL0:
750 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
751 break;
752 case VectorTypeModifier::SFixedLog2LMUL1:
753 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
754 break;
755 case VectorTypeModifier::SFixedLog2LMUL2:
756 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
757 break;
758 case VectorTypeModifier::SFixedLog2LMUL3:
759 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
760 break;
761 case VectorTypeModifier::SEFixedLog2LMULN3:
762 applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual);
763 break;
764 case VectorTypeModifier::SEFixedLog2LMULN2:
765 applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual);
766 break;
767 case VectorTypeModifier::SEFixedLog2LMULN1:
768 applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual);
769 break;
770 case VectorTypeModifier::SEFixedLog2LMUL0:
771 applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual);
772 break;
773 case VectorTypeModifier::SEFixedLog2LMUL1:
774 applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual);
775 break;
776 case VectorTypeModifier::SEFixedLog2LMUL2:
777 applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual);
778 break;
779 case VectorTypeModifier::SEFixedLog2LMUL3:
780 applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual);
781 break;
782 case VectorTypeModifier::Tuple2:
783 case VectorTypeModifier::Tuple3:
784 case VectorTypeModifier::Tuple4:
785 case VectorTypeModifier::Tuple5:
786 case VectorTypeModifier::Tuple6:
787 case VectorTypeModifier::Tuple7:
788 case VectorTypeModifier::Tuple8: {
789 IsTuple = true;
790 NF = 2 + static_cast<uint8_t>(Transformer.VTM) -
791 static_cast<uint8_t>(VectorTypeModifier::Tuple2);
792 break;
794 case VectorTypeModifier::NoModifier:
795 break;
798 // Early return if the current type modifier is already invalid.
799 if (ScalarType == Invalid)
800 return;
802 for (unsigned TypeModifierMaskShift = 0;
803 TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
804 ++TypeModifierMaskShift) {
805 unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
806 if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
807 TypeModifierMask)
808 continue;
809 switch (static_cast<TypeModifier>(TypeModifierMask)) {
810 case TypeModifier::Pointer:
811 IsPointer = true;
812 break;
813 case TypeModifier::Const:
814 IsConstant = true;
815 break;
816 case TypeModifier::Immediate:
817 IsImmediate = true;
818 IsConstant = true;
819 break;
820 case TypeModifier::UnsignedInteger:
821 ScalarType = ScalarTypeKind::UnsignedInteger;
822 break;
823 case TypeModifier::SignedInteger:
824 ScalarType = ScalarTypeKind::SignedInteger;
825 break;
826 case TypeModifier::Float:
827 ScalarType = ScalarTypeKind::Float;
828 break;
829 case TypeModifier::LMUL1:
830 LMUL = LMULType(0);
831 // Update ElementBitwidth need to update Scale too.
832 Scale = LMUL.getScale(ElementBitwidth);
833 break;
834 default:
835 llvm_unreachable("Unknown type modifier mask!");
840 void RVVType::applyLog2EEW(unsigned Log2EEW) {
841 // update new elmul = (eew/sew) * lmul
842 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
843 // update new eew
844 ElementBitwidth = 1 << Log2EEW;
845 ScalarType = ScalarTypeKind::SignedInteger;
846 Scale = LMUL.getScale(ElementBitwidth);
849 void RVVType::applyFixedSEW(unsigned NewSEW) {
850 // Set invalid type if src and dst SEW are same.
851 if (ElementBitwidth == NewSEW) {
852 ScalarType = ScalarTypeKind::Invalid;
853 return;
855 // Update new SEW
856 ElementBitwidth = NewSEW;
857 Scale = LMUL.getScale(ElementBitwidth);
860 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
861 switch (Type) {
862 case FixedLMULType::LargerThan:
863 if (Log2LMUL <= LMUL.Log2LMUL) {
864 ScalarType = ScalarTypeKind::Invalid;
865 return;
867 break;
868 case FixedLMULType::SmallerThan:
869 if (Log2LMUL >= LMUL.Log2LMUL) {
870 ScalarType = ScalarTypeKind::Invalid;
871 return;
873 break;
874 case FixedLMULType::SmallerOrEqual:
875 if (Log2LMUL > LMUL.Log2LMUL) {
876 ScalarType = ScalarTypeKind::Invalid;
877 return;
879 break;
882 // Update new LMUL
883 LMUL = LMULType(Log2LMUL);
884 Scale = LMUL.getScale(ElementBitwidth);
887 std::optional<RVVTypes>
888 RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
889 ArrayRef<PrototypeDescriptor> Prototype) {
890 RVVTypes Types;
891 for (const PrototypeDescriptor &Proto : Prototype) {
892 auto T = computeType(BT, Log2LMUL, Proto);
893 if (!T)
894 return std::nullopt;
895 // Record legal type index
896 Types.push_back(*T);
898 return Types;
901 // Compute the hash value of RVVType, used for cache the result of computeType.
902 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
903 PrototypeDescriptor Proto) {
904 // Layout of hash value:
905 // 0 8 16 24 32 40
906 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
907 assert(Log2LMUL >= -3 && Log2LMUL <= 3);
908 return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
909 ((uint64_t)(Proto.PT & 0xff) << 16) |
910 ((uint64_t)(Proto.TM & 0xff) << 24) |
911 ((uint64_t)(Proto.VTM & 0xff) << 32);
914 std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
915 PrototypeDescriptor Proto) {
916 uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
917 // Search first
918 auto It = LegalTypes.find(Idx);
919 if (It != LegalTypes.end())
920 return &(It->second);
922 if (IllegalTypes.count(Idx))
923 return std::nullopt;
925 // Compute type and record the result.
926 RVVType T(BT, Log2LMUL, Proto);
927 if (T.isValid()) {
928 // Record legal type index and value.
929 std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
930 InsertResult = LegalTypes.insert({Idx, T});
931 return &(InsertResult.first->second);
933 // Record illegal type index.
934 IllegalTypes.insert(Idx);
935 return std::nullopt;
938 //===----------------------------------------------------------------------===//
939 // RVVIntrinsic implementation
940 //===----------------------------------------------------------------------===//
941 RVVIntrinsic::RVVIntrinsic(
942 StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
943 StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
944 bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
945 bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen,
946 const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
947 const std::vector<StringRef> &RequiredFeatures, unsigned NF,
948 Policy NewPolicyAttrs, bool HasFRMRoundModeOp)
949 : IRName(IRName), IsMasked(IsMasked),
950 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
951 SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
952 ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
954 // Init BuiltinName, Name and OverloadedName
955 BuiltinName = NewName.str();
956 Name = BuiltinName;
957 if (NewOverloadedName.empty())
958 OverloadedName = NewName.split("_").first.str();
959 else
960 OverloadedName = NewOverloadedName.str();
961 if (!Suffix.empty())
962 Name += "_" + Suffix.str();
963 if (!OverloadedSuffix.empty())
964 OverloadedName += "_" + OverloadedSuffix.str();
966 updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
967 PolicyAttrs, HasFRMRoundModeOp);
969 // Init OutputType and InputTypes
970 OutputType = OutInTypes[0];
971 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
973 // IntrinsicTypes is unmasked TA version index. Need to update it
974 // if there is merge operand (It is always in first operand).
975 IntrinsicTypes = NewIntrinsicTypes;
976 if ((IsMasked && hasMaskedOffOperand()) ||
977 (!IsMasked && hasPassthruOperand())) {
978 for (auto &I : IntrinsicTypes) {
979 if (I >= 0)
980 I += NF;
985 std::string RVVIntrinsic::getBuiltinTypeStr() const {
986 std::string S;
987 S += OutputType->getBuiltinStr();
988 for (const auto &T : InputTypes) {
989 S += T->getBuiltinStr();
991 return S;
994 std::string RVVIntrinsic::getSuffixStr(
995 RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
996 llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
997 SmallVector<std::string> SuffixStrs;
998 for (auto PD : PrototypeDescriptors) {
999 auto T = TypeCache.computeType(Type, Log2LMUL, PD);
1000 SuffixStrs.push_back((*T)->getShortStr());
1002 return join(SuffixStrs, "_");
1005 llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
1006 llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
1007 bool HasMaskedOffOperand, bool HasVL, unsigned NF,
1008 PolicyScheme DefaultScheme, Policy PolicyAttrs, bool IsTuple) {
1009 SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
1010 Prototype.end());
1011 bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
1012 if (IsMasked) {
1013 // If HasMaskedOffOperand, insert result type as first input operand if
1014 // need.
1015 if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
1016 if (NF == 1) {
1017 NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
1018 } else if (NF > 1) {
1019 if (IsTuple) {
1020 PrototypeDescriptor BasePtrOperand = Prototype[1];
1021 PrototypeDescriptor MaskoffType = PrototypeDescriptor(
1022 static_cast<uint8_t>(BaseTypeModifier::Vector),
1023 static_cast<uint8_t>(getTupleVTM(NF)),
1024 BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1025 NewPrototype.insert(NewPrototype.begin() + 1, MaskoffType);
1026 } else {
1027 // Convert
1028 // (void, op0 address, op1 address, ...)
1029 // to
1030 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1031 PrototypeDescriptor MaskoffType = NewPrototype[1];
1032 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1033 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1037 if (HasMaskedOffOperand && NF > 1) {
1038 // Convert
1039 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1040 // to
1041 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1042 // ...)
1043 if (IsTuple)
1044 NewPrototype.insert(NewPrototype.begin() + 1,
1045 PrototypeDescriptor::Mask);
1046 else
1047 NewPrototype.insert(NewPrototype.begin() + NF + 1,
1048 PrototypeDescriptor::Mask);
1049 } else {
1050 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
1051 NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
1053 } else {
1054 if (NF == 1) {
1055 if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
1056 NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
1057 } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
1058 if (IsTuple) {
1059 PrototypeDescriptor BasePtrOperand = Prototype[0];
1060 PrototypeDescriptor MaskoffType = PrototypeDescriptor(
1061 static_cast<uint8_t>(BaseTypeModifier::Vector),
1062 static_cast<uint8_t>(getTupleVTM(NF)),
1063 BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1064 NewPrototype.insert(NewPrototype.begin(), MaskoffType);
1065 } else {
1066 // NF > 1 cases for segment load operations.
1067 // Convert
1068 // (void, op0 address, op1 address, ...)
1069 // to
1070 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1071 PrototypeDescriptor MaskoffType = Prototype[1];
1072 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1073 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1078 // If HasVL, append PrototypeDescriptor:VL to last operand
1079 if (HasVL)
1080 NewPrototype.push_back(PrototypeDescriptor::VL);
1082 return NewPrototype;
1085 llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1086 return {Policy(Policy::PolicyType::Undisturbed)}; // TU
1089 llvm::SmallVector<Policy>
1090 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
1091 bool HasMaskPolicy) {
1092 if (HasTailPolicy && HasMaskPolicy)
1093 return {Policy(Policy::PolicyType::Undisturbed,
1094 Policy::PolicyType::Agnostic), // TUM
1095 Policy(Policy::PolicyType::Undisturbed,
1096 Policy::PolicyType::Undisturbed), // TUMU
1097 Policy(Policy::PolicyType::Agnostic,
1098 Policy::PolicyType::Undisturbed)}; // MU
1099 if (HasTailPolicy && !HasMaskPolicy)
1100 return {Policy(Policy::PolicyType::Undisturbed,
1101 Policy::PolicyType::Agnostic)}; // TU
1102 if (!HasTailPolicy && HasMaskPolicy)
1103 return {Policy(Policy::PolicyType::Agnostic,
1104 Policy::PolicyType::Undisturbed)}; // MU
1105 llvm_unreachable("An RVV instruction should not be without both tail policy "
1106 "and mask policy");
1109 void RVVIntrinsic::updateNamesAndPolicy(
1110 bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName,
1111 std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) {
1113 auto appendPolicySuffix = [&](const std::string &suffix) {
1114 Name += suffix;
1115 BuiltinName += suffix;
1116 OverloadedName += suffix;
1119 // This follows the naming guideline under riscv-c-api-doc to add the
1120 // `__riscv_` suffix for all RVV intrinsics.
1121 Name = "__riscv_" + Name;
1122 OverloadedName = "__riscv_" + OverloadedName;
1124 if (HasFRMRoundModeOp) {
1125 Name += "_rm";
1126 BuiltinName += "_rm";
1129 if (IsMasked) {
1130 if (PolicyAttrs.isTUMUPolicy())
1131 appendPolicySuffix("_tumu");
1132 else if (PolicyAttrs.isTUMAPolicy())
1133 appendPolicySuffix("_tum");
1134 else if (PolicyAttrs.isTAMUPolicy())
1135 appendPolicySuffix("_mu");
1136 else if (PolicyAttrs.isTAMAPolicy()) {
1137 Name += "_m";
1138 BuiltinName += "_m";
1139 } else
1140 llvm_unreachable("Unhandled policy condition");
1141 } else {
1142 if (PolicyAttrs.isTUPolicy())
1143 appendPolicySuffix("_tu");
1144 else if (PolicyAttrs.isTAPolicy()) // no suffix needed
1145 return;
1146 else
1147 llvm_unreachable("Unhandled policy condition");
1151 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
1152 SmallVector<PrototypeDescriptor> PrototypeDescriptors;
1153 const StringRef Primaries("evwqom0ztul");
1154 while (!Prototypes.empty()) {
1155 size_t Idx = 0;
1156 // Skip over complex prototype because it could contain primitive type
1157 // character.
1158 if (Prototypes[0] == '(')
1159 Idx = Prototypes.find_first_of(')');
1160 Idx = Prototypes.find_first_of(Primaries, Idx);
1161 assert(Idx != StringRef::npos);
1162 auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
1163 Prototypes.slice(0, Idx + 1));
1164 if (!PD)
1165 llvm_unreachable("Error during parsing prototype.");
1166 PrototypeDescriptors.push_back(*PD);
1167 Prototypes = Prototypes.drop_front(Idx + 1);
1169 return PrototypeDescriptors;
1172 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
1173 OS << "{";
1174 OS << "\"" << Record.Name << "\",";
1175 if (Record.OverloadedName == nullptr ||
1176 StringRef(Record.OverloadedName).empty())
1177 OS << "nullptr,";
1178 else
1179 OS << "\"" << Record.OverloadedName << "\",";
1180 OS << Record.PrototypeIndex << ",";
1181 OS << Record.SuffixIndex << ",";
1182 OS << Record.OverloadedSuffixIndex << ",";
1183 OS << (int)Record.PrototypeLength << ",";
1184 OS << (int)Record.SuffixLength << ",";
1185 OS << (int)Record.OverloadedSuffixSize << ",";
1186 OS << (int)Record.RequiredExtensions << ",";
1187 OS << (int)Record.TypeRangeMask << ",";
1188 OS << (int)Record.Log2LMULMask << ",";
1189 OS << (int)Record.NF << ",";
1190 OS << (int)Record.HasMasked << ",";
1191 OS << (int)Record.HasVL << ",";
1192 OS << (int)Record.HasMaskedOffOperand << ",";
1193 OS << (int)Record.HasTailPolicy << ",";
1194 OS << (int)Record.HasMaskPolicy << ",";
1195 OS << (int)Record.HasFRMRoundModeOp << ",";
1196 OS << (int)Record.IsTuple << ",";
1197 OS << (int)Record.UnMaskedPolicyScheme << ",";
1198 OS << (int)Record.MaskedPolicyScheme << ",";
1199 OS << "},\n";
1200 return OS;
1203 } // end namespace RISCV
1204 } // end namespace clang