[DFAJumpThreading] Remove incoming StartBlock from all phis when unfolding select...
[llvm-project.git] / clang / lib / Support / RISCVVIntrinsicUtils.cpp
blob751d0aedacc9a1f50b387734d20ad2126c174697
1 //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "clang/Support/RISCVVIntrinsicUtils.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/SmallSet.h"
12 #include "llvm/ADT/StringExtras.h"
13 #include "llvm/ADT/StringSet.h"
14 #include "llvm/ADT/Twine.h"
15 #include "llvm/Support/ErrorHandling.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <numeric>
18 #include <optional>
20 using namespace llvm;
22 namespace clang {
23 namespace RISCV {
25 const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
26 BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
27 const PrototypeDescriptor PrototypeDescriptor::VL =
28 PrototypeDescriptor(BaseTypeModifier::SizeT);
29 const PrototypeDescriptor PrototypeDescriptor::Vector =
30 PrototypeDescriptor(BaseTypeModifier::Vector);
32 //===----------------------------------------------------------------------===//
33 // Type implementation
34 //===----------------------------------------------------------------------===//
36 LMULType::LMULType(int NewLog2LMUL) {
37 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
38 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
39 Log2LMUL = NewLog2LMUL;
42 std::string LMULType::str() const {
43 if (Log2LMUL < 0)
44 return "mf" + utostr(1ULL << (-Log2LMUL));
45 return "m" + utostr(1ULL << Log2LMUL);
48 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
49 int Log2ScaleResult = 0;
50 switch (ElementBitwidth) {
51 default:
52 break;
53 case 8:
54 Log2ScaleResult = Log2LMUL + 3;
55 break;
56 case 16:
57 Log2ScaleResult = Log2LMUL + 2;
58 break;
59 case 32:
60 Log2ScaleResult = Log2LMUL + 1;
61 break;
62 case 64:
63 Log2ScaleResult = Log2LMUL;
64 break;
66 // Illegal vscale result would be less than 1
67 if (Log2ScaleResult < 0)
68 return std::nullopt;
69 return 1 << Log2ScaleResult;
72 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
74 RVVType::RVVType(BasicType BT, int Log2LMUL,
75 const PrototypeDescriptor &prototype)
76 : BT(BT), LMUL(LMULType(Log2LMUL)) {
77 applyBasicType();
78 applyModifier(prototype);
79 Valid = verifyType();
80 if (Valid) {
81 initBuiltinStr();
82 initTypeStr();
83 if (isVector()) {
84 initClangBuiltinStr();
89 // clang-format off
90 // boolean type are encoded the ratio of n (SEW/LMUL)
91 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
92 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
93 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
95 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
96 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
97 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
98 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
99 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
100 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
101 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
102 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
103 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
104 // clang-format on
106 bool RVVType::verifyType() const {
107 if (ScalarType == Invalid)
108 return false;
109 if (isScalar())
110 return true;
111 if (!Scale)
112 return false;
113 if (isFloat() && ElementBitwidth == 8)
114 return false;
115 if (IsTuple && (NF == 1 || NF > 8))
116 return false;
117 if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
118 return false;
119 unsigned V = *Scale;
120 switch (ElementBitwidth) {
121 case 1:
122 case 8:
123 // Check Scale is 1,2,4,8,16,32,64
124 return (V <= 64 && isPowerOf2_32(V));
125 case 16:
126 // Check Scale is 1,2,4,8,16,32
127 return (V <= 32 && isPowerOf2_32(V));
128 case 32:
129 // Check Scale is 1,2,4,8,16
130 return (V <= 16 && isPowerOf2_32(V));
131 case 64:
132 // Check Scale is 1,2,4,8
133 return (V <= 8 && isPowerOf2_32(V));
135 return false;
138 void RVVType::initBuiltinStr() {
139 assert(isValid() && "RVVType is invalid");
140 switch (ScalarType) {
141 case ScalarTypeKind::Void:
142 BuiltinStr = "v";
143 return;
144 case ScalarTypeKind::Size_t:
145 BuiltinStr = "z";
146 if (IsImmediate)
147 BuiltinStr = "I" + BuiltinStr;
148 if (IsPointer)
149 BuiltinStr += "*";
150 return;
151 case ScalarTypeKind::Ptrdiff_t:
152 BuiltinStr = "Y";
153 return;
154 case ScalarTypeKind::UnsignedLong:
155 BuiltinStr = "ULi";
156 return;
157 case ScalarTypeKind::SignedLong:
158 BuiltinStr = "Li";
159 return;
160 case ScalarTypeKind::Boolean:
161 assert(ElementBitwidth == 1);
162 BuiltinStr += "b";
163 break;
164 case ScalarTypeKind::SignedInteger:
165 case ScalarTypeKind::UnsignedInteger:
166 switch (ElementBitwidth) {
167 case 8:
168 BuiltinStr += "c";
169 break;
170 case 16:
171 BuiltinStr += "s";
172 break;
173 case 32:
174 BuiltinStr += "i";
175 break;
176 case 64:
177 BuiltinStr += "Wi";
178 break;
179 default:
180 llvm_unreachable("Unhandled ElementBitwidth!");
182 if (isSignedInteger())
183 BuiltinStr = "S" + BuiltinStr;
184 else
185 BuiltinStr = "U" + BuiltinStr;
186 break;
187 case ScalarTypeKind::Float:
188 switch (ElementBitwidth) {
189 case 16:
190 BuiltinStr += "x";
191 break;
192 case 32:
193 BuiltinStr += "f";
194 break;
195 case 64:
196 BuiltinStr += "d";
197 break;
198 default:
199 llvm_unreachable("Unhandled ElementBitwidth!");
201 break;
202 default:
203 llvm_unreachable("ScalarType is invalid!");
205 if (IsImmediate)
206 BuiltinStr = "I" + BuiltinStr;
207 if (isScalar()) {
208 if (IsConstant)
209 BuiltinStr += "C";
210 if (IsPointer)
211 BuiltinStr += "*";
212 return;
214 BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
215 // Pointer to vector types. Defined for segment load intrinsics.
216 // segment load intrinsics have pointer type arguments to store the loaded
217 // vector values.
218 if (IsPointer)
219 BuiltinStr += "*";
221 if (IsTuple)
222 BuiltinStr = "T" + utostr(NF) + BuiltinStr;
225 void RVVType::initClangBuiltinStr() {
226 assert(isValid() && "RVVType is invalid");
227 assert(isVector() && "Handle Vector type only");
229 ClangBuiltinStr = "__rvv_";
230 switch (ScalarType) {
231 case ScalarTypeKind::Boolean:
232 ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
233 return;
234 case ScalarTypeKind::Float:
235 ClangBuiltinStr += "float";
236 break;
237 case ScalarTypeKind::SignedInteger:
238 ClangBuiltinStr += "int";
239 break;
240 case ScalarTypeKind::UnsignedInteger:
241 ClangBuiltinStr += "uint";
242 break;
243 default:
244 llvm_unreachable("ScalarTypeKind is invalid");
246 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() +
247 (IsTuple ? "x" + utostr(NF) : "") + "_t";
250 void RVVType::initTypeStr() {
251 assert(isValid() && "RVVType is invalid");
253 if (IsConstant)
254 Str += "const ";
256 auto getTypeString = [&](StringRef TypeStr) {
257 if (isScalar())
258 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
259 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() +
260 (IsTuple ? "x" + utostr(NF) : "") + "_t")
261 .str();
264 switch (ScalarType) {
265 case ScalarTypeKind::Void:
266 Str = "void";
267 return;
268 case ScalarTypeKind::Size_t:
269 Str = "size_t";
270 if (IsPointer)
271 Str += " *";
272 return;
273 case ScalarTypeKind::Ptrdiff_t:
274 Str = "ptrdiff_t";
275 return;
276 case ScalarTypeKind::UnsignedLong:
277 Str = "unsigned long";
278 return;
279 case ScalarTypeKind::SignedLong:
280 Str = "long";
281 return;
282 case ScalarTypeKind::Boolean:
283 if (isScalar())
284 Str += "bool";
285 else
286 // Vector bool is special case, the formulate is
287 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
288 Str += "vbool" + utostr(64 / *Scale) + "_t";
289 break;
290 case ScalarTypeKind::Float:
291 if (isScalar()) {
292 if (ElementBitwidth == 64)
293 Str += "double";
294 else if (ElementBitwidth == 32)
295 Str += "float";
296 else if (ElementBitwidth == 16)
297 Str += "_Float16";
298 else
299 llvm_unreachable("Unhandled floating type.");
300 } else
301 Str += getTypeString("float");
302 break;
303 case ScalarTypeKind::SignedInteger:
304 Str += getTypeString("int");
305 break;
306 case ScalarTypeKind::UnsignedInteger:
307 Str += getTypeString("uint");
308 break;
309 default:
310 llvm_unreachable("ScalarType is invalid!");
312 if (IsPointer)
313 Str += " *";
316 void RVVType::initShortStr() {
317 switch (ScalarType) {
318 case ScalarTypeKind::Boolean:
319 assert(isVector());
320 ShortStr = "b" + utostr(64 / *Scale);
321 return;
322 case ScalarTypeKind::Float:
323 ShortStr = "f" + utostr(ElementBitwidth);
324 break;
325 case ScalarTypeKind::SignedInteger:
326 ShortStr = "i" + utostr(ElementBitwidth);
327 break;
328 case ScalarTypeKind::UnsignedInteger:
329 ShortStr = "u" + utostr(ElementBitwidth);
330 break;
331 default:
332 llvm_unreachable("Unhandled case!");
334 if (isVector())
335 ShortStr += LMUL.str();
336 if (isTuple())
337 ShortStr += "x" + utostr(NF);
340 static VectorTypeModifier getTupleVTM(unsigned NF) {
341 assert(2 <= NF && NF <= 8 && "2 <= NF <= 8");
342 return static_cast<VectorTypeModifier>(
343 static_cast<uint8_t>(VectorTypeModifier::Tuple2) + (NF - 2));
346 void RVVType::applyBasicType() {
347 switch (BT) {
348 case BasicType::Int8:
349 ElementBitwidth = 8;
350 ScalarType = ScalarTypeKind::SignedInteger;
351 break;
352 case BasicType::Int16:
353 ElementBitwidth = 16;
354 ScalarType = ScalarTypeKind::SignedInteger;
355 break;
356 case BasicType::Int32:
357 ElementBitwidth = 32;
358 ScalarType = ScalarTypeKind::SignedInteger;
359 break;
360 case BasicType::Int64:
361 ElementBitwidth = 64;
362 ScalarType = ScalarTypeKind::SignedInteger;
363 break;
364 case BasicType::Float16:
365 ElementBitwidth = 16;
366 ScalarType = ScalarTypeKind::Float;
367 break;
368 case BasicType::Float32:
369 ElementBitwidth = 32;
370 ScalarType = ScalarTypeKind::Float;
371 break;
372 case BasicType::Float64:
373 ElementBitwidth = 64;
374 ScalarType = ScalarTypeKind::Float;
375 break;
376 default:
377 llvm_unreachable("Unhandled type code!");
379 assert(ElementBitwidth != 0 && "Bad element bitwidth!");
382 std::optional<PrototypeDescriptor>
383 PrototypeDescriptor::parsePrototypeDescriptor(
384 llvm::StringRef PrototypeDescriptorStr) {
385 PrototypeDescriptor PD;
386 BaseTypeModifier PT = BaseTypeModifier::Invalid;
387 VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
389 if (PrototypeDescriptorStr.empty())
390 return PD;
392 // Handle base type modifier
393 auto PType = PrototypeDescriptorStr.back();
394 switch (PType) {
395 case 'e':
396 PT = BaseTypeModifier::Scalar;
397 break;
398 case 'v':
399 PT = BaseTypeModifier::Vector;
400 break;
401 case 'w':
402 PT = BaseTypeModifier::Vector;
403 VTM = VectorTypeModifier::Widening2XVector;
404 break;
405 case 'q':
406 PT = BaseTypeModifier::Vector;
407 VTM = VectorTypeModifier::Widening4XVector;
408 break;
409 case 'o':
410 PT = BaseTypeModifier::Vector;
411 VTM = VectorTypeModifier::Widening8XVector;
412 break;
413 case 'm':
414 PT = BaseTypeModifier::Vector;
415 VTM = VectorTypeModifier::MaskVector;
416 break;
417 case '0':
418 PT = BaseTypeModifier::Void;
419 break;
420 case 'z':
421 PT = BaseTypeModifier::SizeT;
422 break;
423 case 't':
424 PT = BaseTypeModifier::Ptrdiff;
425 break;
426 case 'u':
427 PT = BaseTypeModifier::UnsignedLong;
428 break;
429 case 'l':
430 PT = BaseTypeModifier::SignedLong;
431 break;
432 case 'f':
433 PT = BaseTypeModifier::Float32;
434 break;
435 default:
436 llvm_unreachable("Illegal primitive type transformers!");
438 PD.PT = static_cast<uint8_t>(PT);
439 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
441 // Compute the vector type transformers, it can only appear one time.
442 if (PrototypeDescriptorStr.startswith("(")) {
443 assert(VTM == VectorTypeModifier::NoModifier &&
444 "VectorTypeModifier should only have one modifier");
445 size_t Idx = PrototypeDescriptorStr.find(')');
446 assert(Idx != StringRef::npos);
447 StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
448 PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
449 assert(!PrototypeDescriptorStr.contains('(') &&
450 "Only allow one vector type modifier");
452 auto ComplexTT = ComplexType.split(":");
453 if (ComplexTT.first == "Log2EEW") {
454 uint32_t Log2EEW;
455 if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
456 llvm_unreachable("Invalid Log2EEW value!");
457 return std::nullopt;
459 switch (Log2EEW) {
460 case 3:
461 VTM = VectorTypeModifier::Log2EEW3;
462 break;
463 case 4:
464 VTM = VectorTypeModifier::Log2EEW4;
465 break;
466 case 5:
467 VTM = VectorTypeModifier::Log2EEW5;
468 break;
469 case 6:
470 VTM = VectorTypeModifier::Log2EEW6;
471 break;
472 default:
473 llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
474 return std::nullopt;
476 } else if (ComplexTT.first == "FixedSEW") {
477 uint32_t NewSEW;
478 if (ComplexTT.second.getAsInteger(10, NewSEW)) {
479 llvm_unreachable("Invalid FixedSEW value!");
480 return std::nullopt;
482 switch (NewSEW) {
483 case 8:
484 VTM = VectorTypeModifier::FixedSEW8;
485 break;
486 case 16:
487 VTM = VectorTypeModifier::FixedSEW16;
488 break;
489 case 32:
490 VTM = VectorTypeModifier::FixedSEW32;
491 break;
492 case 64:
493 VTM = VectorTypeModifier::FixedSEW64;
494 break;
495 default:
496 llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
497 return std::nullopt;
499 } else if (ComplexTT.first == "LFixedLog2LMUL") {
500 int32_t Log2LMUL;
501 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
502 llvm_unreachable("Invalid LFixedLog2LMUL value!");
503 return std::nullopt;
505 switch (Log2LMUL) {
506 case -3:
507 VTM = VectorTypeModifier::LFixedLog2LMULN3;
508 break;
509 case -2:
510 VTM = VectorTypeModifier::LFixedLog2LMULN2;
511 break;
512 case -1:
513 VTM = VectorTypeModifier::LFixedLog2LMULN1;
514 break;
515 case 0:
516 VTM = VectorTypeModifier::LFixedLog2LMUL0;
517 break;
518 case 1:
519 VTM = VectorTypeModifier::LFixedLog2LMUL1;
520 break;
521 case 2:
522 VTM = VectorTypeModifier::LFixedLog2LMUL2;
523 break;
524 case 3:
525 VTM = VectorTypeModifier::LFixedLog2LMUL3;
526 break;
527 default:
528 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
529 return std::nullopt;
531 } else if (ComplexTT.first == "SFixedLog2LMUL") {
532 int32_t Log2LMUL;
533 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
534 llvm_unreachable("Invalid SFixedLog2LMUL value!");
535 return std::nullopt;
537 switch (Log2LMUL) {
538 case -3:
539 VTM = VectorTypeModifier::SFixedLog2LMULN3;
540 break;
541 case -2:
542 VTM = VectorTypeModifier::SFixedLog2LMULN2;
543 break;
544 case -1:
545 VTM = VectorTypeModifier::SFixedLog2LMULN1;
546 break;
547 case 0:
548 VTM = VectorTypeModifier::SFixedLog2LMUL0;
549 break;
550 case 1:
551 VTM = VectorTypeModifier::SFixedLog2LMUL1;
552 break;
553 case 2:
554 VTM = VectorTypeModifier::SFixedLog2LMUL2;
555 break;
556 case 3:
557 VTM = VectorTypeModifier::SFixedLog2LMUL3;
558 break;
559 default:
560 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
561 return std::nullopt;
564 } else if (ComplexTT.first == "SEFixedLog2LMUL") {
565 int32_t Log2LMUL;
566 if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
567 llvm_unreachable("Invalid SEFixedLog2LMUL value!");
568 return std::nullopt;
570 switch (Log2LMUL) {
571 case -3:
572 VTM = VectorTypeModifier::SEFixedLog2LMULN3;
573 break;
574 case -2:
575 VTM = VectorTypeModifier::SEFixedLog2LMULN2;
576 break;
577 case -1:
578 VTM = VectorTypeModifier::SEFixedLog2LMULN1;
579 break;
580 case 0:
581 VTM = VectorTypeModifier::SEFixedLog2LMUL0;
582 break;
583 case 1:
584 VTM = VectorTypeModifier::SEFixedLog2LMUL1;
585 break;
586 case 2:
587 VTM = VectorTypeModifier::SEFixedLog2LMUL2;
588 break;
589 case 3:
590 VTM = VectorTypeModifier::SEFixedLog2LMUL3;
591 break;
592 default:
593 llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
594 return std::nullopt;
596 } else if (ComplexTT.first == "Tuple") {
597 unsigned NF = 0;
598 if (ComplexTT.second.getAsInteger(10, NF)) {
599 llvm_unreachable("Invalid NF value!");
600 return std::nullopt;
602 VTM = getTupleVTM(NF);
603 } else {
604 llvm_unreachable("Illegal complex type transformers!");
607 PD.VTM = static_cast<uint8_t>(VTM);
609 // Compute the remain type transformers
610 TypeModifier TM = TypeModifier::NoModifier;
611 for (char I : PrototypeDescriptorStr) {
612 switch (I) {
613 case 'P':
614 if ((TM & TypeModifier::Const) == TypeModifier::Const)
615 llvm_unreachable("'P' transformer cannot be used after 'C'");
616 if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
617 llvm_unreachable("'P' transformer cannot be used twice");
618 TM |= TypeModifier::Pointer;
619 break;
620 case 'C':
621 TM |= TypeModifier::Const;
622 break;
623 case 'K':
624 TM |= TypeModifier::Immediate;
625 break;
626 case 'U':
627 TM |= TypeModifier::UnsignedInteger;
628 break;
629 case 'I':
630 TM |= TypeModifier::SignedInteger;
631 break;
632 case 'F':
633 TM |= TypeModifier::Float;
634 break;
635 case 'S':
636 TM |= TypeModifier::LMUL1;
637 break;
638 default:
639 llvm_unreachable("Illegal non-primitive type transformer!");
642 PD.TM = static_cast<uint8_t>(TM);
644 return PD;
647 void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
648 // Handle primitive type transformer
649 switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
650 case BaseTypeModifier::Scalar:
651 Scale = 0;
652 break;
653 case BaseTypeModifier::Vector:
654 Scale = LMUL.getScale(ElementBitwidth);
655 break;
656 case BaseTypeModifier::Void:
657 ScalarType = ScalarTypeKind::Void;
658 break;
659 case BaseTypeModifier::SizeT:
660 ScalarType = ScalarTypeKind::Size_t;
661 break;
662 case BaseTypeModifier::Ptrdiff:
663 ScalarType = ScalarTypeKind::Ptrdiff_t;
664 break;
665 case BaseTypeModifier::UnsignedLong:
666 ScalarType = ScalarTypeKind::UnsignedLong;
667 break;
668 case BaseTypeModifier::SignedLong:
669 ScalarType = ScalarTypeKind::SignedLong;
670 break;
671 case BaseTypeModifier::Float32:
672 ElementBitwidth = 32;
673 ScalarType = ScalarTypeKind::Float;
674 break;
675 case BaseTypeModifier::Invalid:
676 ScalarType = ScalarTypeKind::Invalid;
677 return;
680 switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
681 case VectorTypeModifier::Widening2XVector:
682 ElementBitwidth *= 2;
683 LMUL.MulLog2LMUL(1);
684 Scale = LMUL.getScale(ElementBitwidth);
685 break;
686 case VectorTypeModifier::Widening4XVector:
687 ElementBitwidth *= 4;
688 LMUL.MulLog2LMUL(2);
689 Scale = LMUL.getScale(ElementBitwidth);
690 break;
691 case VectorTypeModifier::Widening8XVector:
692 ElementBitwidth *= 8;
693 LMUL.MulLog2LMUL(3);
694 Scale = LMUL.getScale(ElementBitwidth);
695 break;
696 case VectorTypeModifier::MaskVector:
697 ScalarType = ScalarTypeKind::Boolean;
698 Scale = LMUL.getScale(ElementBitwidth);
699 ElementBitwidth = 1;
700 break;
701 case VectorTypeModifier::Log2EEW3:
702 applyLog2EEW(3);
703 break;
704 case VectorTypeModifier::Log2EEW4:
705 applyLog2EEW(4);
706 break;
707 case VectorTypeModifier::Log2EEW5:
708 applyLog2EEW(5);
709 break;
710 case VectorTypeModifier::Log2EEW6:
711 applyLog2EEW(6);
712 break;
713 case VectorTypeModifier::FixedSEW8:
714 applyFixedSEW(8);
715 break;
716 case VectorTypeModifier::FixedSEW16:
717 applyFixedSEW(16);
718 break;
719 case VectorTypeModifier::FixedSEW32:
720 applyFixedSEW(32);
721 break;
722 case VectorTypeModifier::FixedSEW64:
723 applyFixedSEW(64);
724 break;
725 case VectorTypeModifier::LFixedLog2LMULN3:
726 applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
727 break;
728 case VectorTypeModifier::LFixedLog2LMULN2:
729 applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
730 break;
731 case VectorTypeModifier::LFixedLog2LMULN1:
732 applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
733 break;
734 case VectorTypeModifier::LFixedLog2LMUL0:
735 applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
736 break;
737 case VectorTypeModifier::LFixedLog2LMUL1:
738 applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
739 break;
740 case VectorTypeModifier::LFixedLog2LMUL2:
741 applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
742 break;
743 case VectorTypeModifier::LFixedLog2LMUL3:
744 applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
745 break;
746 case VectorTypeModifier::SFixedLog2LMULN3:
747 applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
748 break;
749 case VectorTypeModifier::SFixedLog2LMULN2:
750 applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
751 break;
752 case VectorTypeModifier::SFixedLog2LMULN1:
753 applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
754 break;
755 case VectorTypeModifier::SFixedLog2LMUL0:
756 applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
757 break;
758 case VectorTypeModifier::SFixedLog2LMUL1:
759 applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
760 break;
761 case VectorTypeModifier::SFixedLog2LMUL2:
762 applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
763 break;
764 case VectorTypeModifier::SFixedLog2LMUL3:
765 applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
766 break;
767 case VectorTypeModifier::SEFixedLog2LMULN3:
768 applyFixedLog2LMUL(-3, FixedLMULType::SmallerOrEqual);
769 break;
770 case VectorTypeModifier::SEFixedLog2LMULN2:
771 applyFixedLog2LMUL(-2, FixedLMULType::SmallerOrEqual);
772 break;
773 case VectorTypeModifier::SEFixedLog2LMULN1:
774 applyFixedLog2LMUL(-1, FixedLMULType::SmallerOrEqual);
775 break;
776 case VectorTypeModifier::SEFixedLog2LMUL0:
777 applyFixedLog2LMUL(0, FixedLMULType::SmallerOrEqual);
778 break;
779 case VectorTypeModifier::SEFixedLog2LMUL1:
780 applyFixedLog2LMUL(1, FixedLMULType::SmallerOrEqual);
781 break;
782 case VectorTypeModifier::SEFixedLog2LMUL2:
783 applyFixedLog2LMUL(2, FixedLMULType::SmallerOrEqual);
784 break;
785 case VectorTypeModifier::SEFixedLog2LMUL3:
786 applyFixedLog2LMUL(3, FixedLMULType::SmallerOrEqual);
787 break;
788 case VectorTypeModifier::Tuple2:
789 case VectorTypeModifier::Tuple3:
790 case VectorTypeModifier::Tuple4:
791 case VectorTypeModifier::Tuple5:
792 case VectorTypeModifier::Tuple6:
793 case VectorTypeModifier::Tuple7:
794 case VectorTypeModifier::Tuple8: {
795 IsTuple = true;
796 NF = 2 + static_cast<uint8_t>(Transformer.VTM) -
797 static_cast<uint8_t>(VectorTypeModifier::Tuple2);
798 break;
800 case VectorTypeModifier::NoModifier:
801 break;
804 // Early return if the current type modifier is already invalid.
805 if (ScalarType == Invalid)
806 return;
808 for (unsigned TypeModifierMaskShift = 0;
809 TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
810 ++TypeModifierMaskShift) {
811 unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
812 if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
813 TypeModifierMask)
814 continue;
815 switch (static_cast<TypeModifier>(TypeModifierMask)) {
816 case TypeModifier::Pointer:
817 IsPointer = true;
818 break;
819 case TypeModifier::Const:
820 IsConstant = true;
821 break;
822 case TypeModifier::Immediate:
823 IsImmediate = true;
824 IsConstant = true;
825 break;
826 case TypeModifier::UnsignedInteger:
827 ScalarType = ScalarTypeKind::UnsignedInteger;
828 break;
829 case TypeModifier::SignedInteger:
830 ScalarType = ScalarTypeKind::SignedInteger;
831 break;
832 case TypeModifier::Float:
833 ScalarType = ScalarTypeKind::Float;
834 break;
835 case TypeModifier::LMUL1:
836 LMUL = LMULType(0);
837 // Update ElementBitwidth need to update Scale too.
838 Scale = LMUL.getScale(ElementBitwidth);
839 break;
840 default:
841 llvm_unreachable("Unknown type modifier mask!");
846 void RVVType::applyLog2EEW(unsigned Log2EEW) {
847 // update new elmul = (eew/sew) * lmul
848 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
849 // update new eew
850 ElementBitwidth = 1 << Log2EEW;
851 ScalarType = ScalarTypeKind::SignedInteger;
852 Scale = LMUL.getScale(ElementBitwidth);
855 void RVVType::applyFixedSEW(unsigned NewSEW) {
856 // Set invalid type if src and dst SEW are same.
857 if (ElementBitwidth == NewSEW) {
858 ScalarType = ScalarTypeKind::Invalid;
859 return;
861 // Update new SEW
862 ElementBitwidth = NewSEW;
863 Scale = LMUL.getScale(ElementBitwidth);
866 void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
867 switch (Type) {
868 case FixedLMULType::LargerThan:
869 if (Log2LMUL <= LMUL.Log2LMUL) {
870 ScalarType = ScalarTypeKind::Invalid;
871 return;
873 break;
874 case FixedLMULType::SmallerThan:
875 if (Log2LMUL >= LMUL.Log2LMUL) {
876 ScalarType = ScalarTypeKind::Invalid;
877 return;
879 break;
880 case FixedLMULType::SmallerOrEqual:
881 if (Log2LMUL > LMUL.Log2LMUL) {
882 ScalarType = ScalarTypeKind::Invalid;
883 return;
885 break;
888 // Update new LMUL
889 LMUL = LMULType(Log2LMUL);
890 Scale = LMUL.getScale(ElementBitwidth);
893 std::optional<RVVTypes>
894 RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
895 ArrayRef<PrototypeDescriptor> Prototype) {
896 RVVTypes Types;
897 for (const PrototypeDescriptor &Proto : Prototype) {
898 auto T = computeType(BT, Log2LMUL, Proto);
899 if (!T)
900 return std::nullopt;
901 // Record legal type index
902 Types.push_back(*T);
904 return Types;
907 // Compute the hash value of RVVType, used for cache the result of computeType.
908 static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
909 PrototypeDescriptor Proto) {
910 // Layout of hash value:
911 // 0 8 16 24 32 40
912 // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
913 assert(Log2LMUL >= -3 && Log2LMUL <= 3);
914 return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
915 ((uint64_t)(Proto.PT & 0xff) << 16) |
916 ((uint64_t)(Proto.TM & 0xff) << 24) |
917 ((uint64_t)(Proto.VTM & 0xff) << 32);
920 std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
921 PrototypeDescriptor Proto) {
922 uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
923 // Search first
924 auto It = LegalTypes.find(Idx);
925 if (It != LegalTypes.end())
926 return &(It->second);
928 if (IllegalTypes.count(Idx))
929 return std::nullopt;
931 // Compute type and record the result.
932 RVVType T(BT, Log2LMUL, Proto);
933 if (T.isValid()) {
934 // Record legal type index and value.
935 std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
936 InsertResult = LegalTypes.insert({Idx, T});
937 return &(InsertResult.first->second);
939 // Record illegal type index.
940 IllegalTypes.insert(Idx);
941 return std::nullopt;
944 //===----------------------------------------------------------------------===//
945 // RVVIntrinsic implementation
946 //===----------------------------------------------------------------------===//
947 RVVIntrinsic::RVVIntrinsic(
948 StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
949 StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
950 bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
951 bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen,
952 const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
953 const std::vector<StringRef> &RequiredFeatures, unsigned NF,
954 Policy NewPolicyAttrs, bool HasFRMRoundModeOp)
955 : IRName(IRName), IsMasked(IsMasked),
956 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
957 SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
958 ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
960 // Init BuiltinName, Name and OverloadedName
961 BuiltinName = NewName.str();
962 Name = BuiltinName;
963 if (NewOverloadedName.empty())
964 OverloadedName = NewName.split("_").first.str();
965 else
966 OverloadedName = NewOverloadedName.str();
967 if (!Suffix.empty())
968 Name += "_" + Suffix.str();
969 if (!OverloadedSuffix.empty())
970 OverloadedName += "_" + OverloadedSuffix.str();
972 updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
973 PolicyAttrs, HasFRMRoundModeOp);
975 // Init OutputType and InputTypes
976 OutputType = OutInTypes[0];
977 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
979 // IntrinsicTypes is unmasked TA version index. Need to update it
980 // if there is merge operand (It is always in first operand).
981 IntrinsicTypes = NewIntrinsicTypes;
982 if ((IsMasked && hasMaskedOffOperand()) ||
983 (!IsMasked && hasPassthruOperand())) {
984 for (auto &I : IntrinsicTypes) {
985 if (I >= 0)
986 I += NF;
991 std::string RVVIntrinsic::getBuiltinTypeStr() const {
992 std::string S;
993 S += OutputType->getBuiltinStr();
994 for (const auto &T : InputTypes) {
995 S += T->getBuiltinStr();
997 return S;
1000 std::string RVVIntrinsic::getSuffixStr(
1001 RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
1002 llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
1003 SmallVector<std::string> SuffixStrs;
1004 for (auto PD : PrototypeDescriptors) {
1005 auto T = TypeCache.computeType(Type, Log2LMUL, PD);
1006 SuffixStrs.push_back((*T)->getShortStr());
1008 return join(SuffixStrs, "_");
1011 llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
1012 llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
1013 bool HasMaskedOffOperand, bool HasVL, unsigned NF,
1014 PolicyScheme DefaultScheme, Policy PolicyAttrs, bool IsTuple) {
1015 SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
1016 Prototype.end());
1017 bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
1018 if (IsMasked) {
1019 // If HasMaskedOffOperand, insert result type as first input operand if
1020 // need.
1021 if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
1022 if (NF == 1) {
1023 NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
1024 } else if (NF > 1) {
1025 if (IsTuple) {
1026 PrototypeDescriptor BasePtrOperand = Prototype[1];
1027 PrototypeDescriptor MaskoffType = PrototypeDescriptor(
1028 static_cast<uint8_t>(BaseTypeModifier::Vector),
1029 static_cast<uint8_t>(getTupleVTM(NF)),
1030 BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1031 NewPrototype.insert(NewPrototype.begin() + 1, MaskoffType);
1032 } else {
1033 // Convert
1034 // (void, op0 address, op1 address, ...)
1035 // to
1036 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1037 PrototypeDescriptor MaskoffType = NewPrototype[1];
1038 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1039 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1043 if (HasMaskedOffOperand && NF > 1) {
1044 // Convert
1045 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1046 // to
1047 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1048 // ...)
1049 if (IsTuple)
1050 NewPrototype.insert(NewPrototype.begin() + 1,
1051 PrototypeDescriptor::Mask);
1052 else
1053 NewPrototype.insert(NewPrototype.begin() + NF + 1,
1054 PrototypeDescriptor::Mask);
1055 } else {
1056 // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
1057 NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
1059 } else {
1060 if (NF == 1) {
1061 if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
1062 NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
1063 } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
1064 if (IsTuple) {
1065 PrototypeDescriptor BasePtrOperand = Prototype[0];
1066 PrototypeDescriptor MaskoffType = PrototypeDescriptor(
1067 static_cast<uint8_t>(BaseTypeModifier::Vector),
1068 static_cast<uint8_t>(getTupleVTM(NF)),
1069 BasePtrOperand.TM & ~static_cast<uint8_t>(TypeModifier::Pointer));
1070 NewPrototype.insert(NewPrototype.begin(), MaskoffType);
1071 } else {
1072 // NF > 1 cases for segment load operations.
1073 // Convert
1074 // (void, op0 address, op1 address, ...)
1075 // to
1076 // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
1077 PrototypeDescriptor MaskoffType = Prototype[1];
1078 MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
1079 NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
1084 // If HasVL, append PrototypeDescriptor:VL to last operand
1085 if (HasVL)
1086 NewPrototype.push_back(PrototypeDescriptor::VL);
1088 return NewPrototype;
1091 llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() {
1092 return {Policy(Policy::PolicyType::Undisturbed)}; // TU
1095 llvm::SmallVector<Policy>
1096 RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
1097 bool HasMaskPolicy) {
1098 if (HasTailPolicy && HasMaskPolicy)
1099 return {Policy(Policy::PolicyType::Undisturbed,
1100 Policy::PolicyType::Agnostic), // TUM
1101 Policy(Policy::PolicyType::Undisturbed,
1102 Policy::PolicyType::Undisturbed), // TUMU
1103 Policy(Policy::PolicyType::Agnostic,
1104 Policy::PolicyType::Undisturbed)}; // MU
1105 if (HasTailPolicy && !HasMaskPolicy)
1106 return {Policy(Policy::PolicyType::Undisturbed,
1107 Policy::PolicyType::Agnostic)}; // TU
1108 if (!HasTailPolicy && HasMaskPolicy)
1109 return {Policy(Policy::PolicyType::Agnostic,
1110 Policy::PolicyType::Undisturbed)}; // MU
1111 llvm_unreachable("An RVV instruction should not be without both tail policy "
1112 "and mask policy");
1115 void RVVIntrinsic::updateNamesAndPolicy(
1116 bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName,
1117 std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) {
1119 auto appendPolicySuffix = [&](const std::string &suffix) {
1120 Name += suffix;
1121 BuiltinName += suffix;
1122 OverloadedName += suffix;
1125 // This follows the naming guideline under riscv-c-api-doc to add the
1126 // `__riscv_` suffix for all RVV intrinsics.
1127 Name = "__riscv_" + Name;
1128 OverloadedName = "__riscv_" + OverloadedName;
1130 if (HasFRMRoundModeOp) {
1131 Name += "_rm";
1132 BuiltinName += "_rm";
1135 if (IsMasked) {
1136 if (PolicyAttrs.isTUMUPolicy())
1137 appendPolicySuffix("_tumu");
1138 else if (PolicyAttrs.isTUMAPolicy())
1139 appendPolicySuffix("_tum");
1140 else if (PolicyAttrs.isTAMUPolicy())
1141 appendPolicySuffix("_mu");
1142 else if (PolicyAttrs.isTAMAPolicy()) {
1143 Name += "_m";
1144 BuiltinName += "_m";
1145 } else
1146 llvm_unreachable("Unhandled policy condition");
1147 } else {
1148 if (PolicyAttrs.isTUPolicy())
1149 appendPolicySuffix("_tu");
1150 else if (PolicyAttrs.isTAPolicy()) // no suffix needed
1151 return;
1152 else
1153 llvm_unreachable("Unhandled policy condition");
1157 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
1158 SmallVector<PrototypeDescriptor> PrototypeDescriptors;
1159 const StringRef Primaries("evwqom0ztulf");
1160 while (!Prototypes.empty()) {
1161 size_t Idx = 0;
1162 // Skip over complex prototype because it could contain primitive type
1163 // character.
1164 if (Prototypes[0] == '(')
1165 Idx = Prototypes.find_first_of(')');
1166 Idx = Prototypes.find_first_of(Primaries, Idx);
1167 assert(Idx != StringRef::npos);
1168 auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
1169 Prototypes.slice(0, Idx + 1));
1170 if (!PD)
1171 llvm_unreachable("Error during parsing prototype.");
1172 PrototypeDescriptors.push_back(*PD);
1173 Prototypes = Prototypes.drop_front(Idx + 1);
1175 return PrototypeDescriptors;
1178 raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
1179 OS << "{";
1180 OS << "\"" << Record.Name << "\",";
1181 if (Record.OverloadedName == nullptr ||
1182 StringRef(Record.OverloadedName).empty())
1183 OS << "nullptr,";
1184 else
1185 OS << "\"" << Record.OverloadedName << "\",";
1186 OS << Record.PrototypeIndex << ",";
1187 OS << Record.SuffixIndex << ",";
1188 OS << Record.OverloadedSuffixIndex << ",";
1189 OS << (int)Record.PrototypeLength << ",";
1190 OS << (int)Record.SuffixLength << ",";
1191 OS << (int)Record.OverloadedSuffixSize << ",";
1192 OS << (int)Record.RequiredExtensions << ",";
1193 OS << (int)Record.TypeRangeMask << ",";
1194 OS << (int)Record.Log2LMULMask << ",";
1195 OS << (int)Record.NF << ",";
1196 OS << (int)Record.HasMasked << ",";
1197 OS << (int)Record.HasVL << ",";
1198 OS << (int)Record.HasMaskedOffOperand << ",";
1199 OS << (int)Record.HasTailPolicy << ",";
1200 OS << (int)Record.HasMaskPolicy << ",";
1201 OS << (int)Record.HasFRMRoundModeOp << ",";
1202 OS << (int)Record.IsTuple << ",";
1203 OS << (int)Record.UnMaskedPolicyScheme << ",";
1204 OS << (int)Record.MaskedPolicyScheme << ",";
1205 OS << "},\n";
1206 return OS;
1209 } // end namespace RISCV
1210 } // end namespace clang