[RISCV] Refactor RVV Policy by structure
[llvm-project.git] / clang / lib / Sema / SemaRISCVVectorLookup.cpp
blobc7709d3041bf653fde65c1cac23f4f8dda3b6940
1 //==- SemaRISCVVectorLookup.cpp - Name Lookup for RISC-V Vector Intrinsic -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements name lookup for RISC-V vector intrinsic.
11 //===----------------------------------------------------------------------===//
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Decl.h"
15 #include "clang/Basic/Builtins.h"
16 #include "clang/Basic/TargetInfo.h"
17 #include "clang/Lex/Preprocessor.h"
18 #include "clang/Sema/Lookup.h"
19 #include "clang/Sema/RISCVIntrinsicManager.h"
20 #include "clang/Sema/Sema.h"
21 #include "clang/Support/RISCVVIntrinsicUtils.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include <string>
24 #include <vector>
26 using namespace llvm;
27 using namespace clang;
28 using namespace clang::RISCV;
30 namespace {
32 // Function definition of a RVV intrinsic.
33 struct RVVIntrinsicDef {
34 /// Full function name with suffix, e.g. vadd_vv_i32m1.
35 std::string Name;
37 /// Overloaded function name, e.g. vadd.
38 std::string OverloadName;
40 /// Mapping to which clang built-in function, e.g. __builtin_rvv_vadd.
41 std::string BuiltinName;
43 /// Function signature, first element is return type.
44 RVVTypes Signature;
47 struct RVVOverloadIntrinsicDef {
48 // Indexes of RISCVIntrinsicManagerImpl::IntrinsicList.
49 SmallVector<size_t, 8> Indexes;
52 } // namespace
54 static const PrototypeDescriptor RVVSignatureTable[] = {
55 #define DECL_SIGNATURE_TABLE
56 #include "clang/Basic/riscv_vector_builtin_sema.inc"
57 #undef DECL_SIGNATURE_TABLE
60 static const RVVIntrinsicRecord RVVIntrinsicRecords[] = {
61 #define DECL_INTRINSIC_RECORDS
62 #include "clang/Basic/riscv_vector_builtin_sema.inc"
63 #undef DECL_INTRINSIC_RECORDS
66 // Get subsequence of signature table.
67 static ArrayRef<PrototypeDescriptor> ProtoSeq2ArrayRef(uint16_t Index,
68 uint8_t Length) {
69 return makeArrayRef(&RVVSignatureTable[Index], Length);
72 static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
73 QualType QT;
74 switch (Type->getScalarType()) {
75 case ScalarTypeKind::Void:
76 QT = Context.VoidTy;
77 break;
78 case ScalarTypeKind::Size_t:
79 QT = Context.getSizeType();
80 break;
81 case ScalarTypeKind::Ptrdiff_t:
82 QT = Context.getPointerDiffType();
83 break;
84 case ScalarTypeKind::UnsignedLong:
85 QT = Context.UnsignedLongTy;
86 break;
87 case ScalarTypeKind::SignedLong:
88 QT = Context.LongTy;
89 break;
90 case ScalarTypeKind::Boolean:
91 QT = Context.BoolTy;
92 break;
93 case ScalarTypeKind::SignedInteger:
94 QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), true);
95 break;
96 case ScalarTypeKind::UnsignedInteger:
97 QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
98 break;
99 case ScalarTypeKind::Float:
100 switch (Type->getElementBitwidth()) {
101 case 64:
102 QT = Context.DoubleTy;
103 break;
104 case 32:
105 QT = Context.FloatTy;
106 break;
107 case 16:
108 QT = Context.Float16Ty;
109 break;
110 default:
111 llvm_unreachable("Unsupported floating point width.");
113 break;
114 case Invalid:
115 llvm_unreachable("Unhandled type.");
117 if (Type->isVector())
118 QT = Context.getScalableVectorType(QT, *Type->getScale());
120 if (Type->isConstant())
121 QT = Context.getConstType(QT);
123 // Transform the type to a pointer as the last step, if necessary.
124 if (Type->isPointer())
125 QT = Context.getPointerType(QT);
127 return QT;
130 namespace {
131 class RISCVIntrinsicManagerImpl : public sema::RISCVIntrinsicManager {
132 private:
133 Sema &S;
134 ASTContext &Context;
135 RVVTypeCache TypeCache;
137 // List of all RVV intrinsic.
138 std::vector<RVVIntrinsicDef> IntrinsicList;
139 // Mapping function name to index of IntrinsicList.
140 StringMap<size_t> Intrinsics;
141 // Mapping function name to RVVOverloadIntrinsicDef.
142 StringMap<RVVOverloadIntrinsicDef> OverloadIntrinsics;
144 // Create IntrinsicList
145 void InitIntrinsicList();
147 // Create RVVIntrinsicDef.
148 void InitRVVIntrinsic(const RVVIntrinsicRecord &Record, StringRef SuffixStr,
149 StringRef OverloadedSuffixStr, bool IsMask,
150 RVVTypes &Types, bool HasPolicy, Policy DefaultPolicy,
151 bool IsPrototypeDefaultTU);
153 // Create FunctionDecl for a vector intrinsic.
154 void CreateRVVIntrinsicDecl(LookupResult &LR, IdentifierInfo *II,
155 Preprocessor &PP, unsigned Index,
156 bool IsOverload);
158 public:
159 RISCVIntrinsicManagerImpl(clang::Sema &S) : S(S), Context(S.Context) {
160 InitIntrinsicList();
163 // Create RISC-V vector intrinsic and insert into symbol table if found, and
164 // return true, otherwise return false.
165 bool CreateIntrinsicIfFound(LookupResult &LR, IdentifierInfo *II,
166 Preprocessor &PP) override;
168 } // namespace
170 void RISCVIntrinsicManagerImpl::InitIntrinsicList() {
171 const TargetInfo &TI = Context.getTargetInfo();
172 bool HasVectorFloat32 = TI.hasFeature("zve32f");
173 bool HasVectorFloat64 = TI.hasFeature("zve64d");
174 bool HasZvfh = TI.hasFeature("experimental-zvfh");
175 bool HasRV64 = TI.hasFeature("64bit");
176 bool HasFullMultiply = TI.hasFeature("v");
178 // Construction of RVVIntrinsicRecords need to sync with createRVVIntrinsics
179 // in RISCVVEmitter.cpp.
180 for (auto &Record : RVVIntrinsicRecords) {
181 // Create Intrinsics for each type and LMUL.
182 BasicType BaseType = BasicType::Unknown;
183 ArrayRef<PrototypeDescriptor> BasicProtoSeq =
184 ProtoSeq2ArrayRef(Record.PrototypeIndex, Record.PrototypeLength);
185 ArrayRef<PrototypeDescriptor> SuffixProto =
186 ProtoSeq2ArrayRef(Record.SuffixIndex, Record.SuffixLength);
187 ArrayRef<PrototypeDescriptor> OverloadedSuffixProto = ProtoSeq2ArrayRef(
188 Record.OverloadedSuffixIndex, Record.OverloadedSuffixSize);
190 PolicyScheme UnMaskedPolicyScheme =
191 static_cast<PolicyScheme>(Record.UnMaskedPolicyScheme);
192 PolicyScheme MaskedPolicyScheme =
193 static_cast<PolicyScheme>(Record.MaskedPolicyScheme);
195 llvm::SmallVector<PrototypeDescriptor> ProtoSeq =
196 RVVIntrinsic::computeBuiltinTypes(
197 BasicProtoSeq, /*IsMasked=*/false,
198 /*HasMaskedOffOperand=*/false, Record.HasVL, Record.NF,
199 Record.IsPrototypeDefaultTU, UnMaskedPolicyScheme, Policy());
201 llvm::SmallVector<PrototypeDescriptor> ProtoMaskSeq =
202 RVVIntrinsic::computeBuiltinTypes(
203 BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand,
204 Record.HasVL, Record.NF, Record.IsPrototypeDefaultTU,
205 MaskedPolicyScheme, Policy());
207 bool UnMaskedHasPolicy = UnMaskedPolicyScheme != PolicyScheme::SchemeNone;
208 bool MaskedHasPolicy = MaskedPolicyScheme != PolicyScheme::SchemeNone;
209 // If unmasked builtin supports policy, they should be TU or TA.
210 llvm::SmallVector<Policy> SupportedUnMaskedPolicies;
211 SupportedUnMaskedPolicies.emplace_back(Policy(
212 Policy::PolicyType::Undisturbed, Policy::PolicyType::Omit)); // TU
213 SupportedUnMaskedPolicies.emplace_back(
214 Policy(Policy::PolicyType::Agnostic, Policy::PolicyType::Omit)); // TA
215 llvm::SmallVector<Policy> SupportedMaskedPolicies =
216 RVVIntrinsic::getSupportedMaskedPolicies(Record.HasTailPolicy,
217 Record.HasMaskPolicy);
219 for (unsigned int TypeRangeMaskShift = 0;
220 TypeRangeMaskShift <= static_cast<unsigned int>(BasicType::MaxOffset);
221 ++TypeRangeMaskShift) {
222 unsigned int BaseTypeI = 1 << TypeRangeMaskShift;
223 BaseType = static_cast<BasicType>(BaseTypeI);
225 if ((BaseTypeI & Record.TypeRangeMask) != BaseTypeI)
226 continue;
228 // Check requirement.
229 if (BaseType == BasicType::Float16 && !HasZvfh)
230 continue;
232 if (BaseType == BasicType::Float32 && !HasVectorFloat32)
233 continue;
235 if (BaseType == BasicType::Float64 && !HasVectorFloat64)
236 continue;
238 if (((Record.RequiredExtensions & RVV_REQ_RV64) == RVV_REQ_RV64) &&
239 !HasRV64)
240 continue;
242 if ((BaseType == BasicType::Int64) &&
243 ((Record.RequiredExtensions & RVV_REQ_FullMultiply) ==
244 RVV_REQ_FullMultiply) &&
245 !HasFullMultiply)
246 continue;
248 // Expanded with different LMUL.
249 for (int Log2LMUL = -3; Log2LMUL <= 3; Log2LMUL++) {
250 if (!(Record.Log2LMULMask & (1 << (Log2LMUL + 3))))
251 continue;
253 Optional<RVVTypes> Types =
254 TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoSeq);
256 // Ignored to create new intrinsic if there are any illegal types.
257 if (!Types.has_value())
258 continue;
260 std::string SuffixStr = RVVIntrinsic::getSuffixStr(
261 TypeCache, BaseType, Log2LMUL, SuffixProto);
262 std::string OverloadedSuffixStr = RVVIntrinsic::getSuffixStr(
263 TypeCache, BaseType, Log2LMUL, OverloadedSuffixProto);
265 // Create non-masked intrinsic.
266 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, false, *Types,
267 UnMaskedHasPolicy, Policy(),
268 Record.IsPrototypeDefaultTU);
270 // Create non-masked policy intrinsic.
271 if (Record.UnMaskedPolicyScheme != PolicyScheme::SchemeNone) {
272 for (auto P : SupportedUnMaskedPolicies) {
273 llvm::SmallVector<PrototypeDescriptor> PolicyPrototype =
274 RVVIntrinsic::computeBuiltinTypes(
275 BasicProtoSeq, /*IsMasked=*/false,
276 /*HasMaskedOffOperand=*/false, Record.HasVL, Record.NF,
277 Record.IsPrototypeDefaultTU, UnMaskedPolicyScheme, P);
278 Optional<RVVTypes> PolicyTypes = TypeCache.computeTypes(
279 BaseType, Log2LMUL, Record.NF, PolicyPrototype);
280 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr,
281 /*IsMask=*/false, *PolicyTypes, UnMaskedHasPolicy,
282 P, Record.IsPrototypeDefaultTU);
285 if (!Record.HasMasked)
286 continue;
287 // Create masked intrinsic.
288 Optional<RVVTypes> MaskTypes =
289 TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoMaskSeq);
290 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, true,
291 *MaskTypes, MaskedHasPolicy, Policy(),
292 Record.IsPrototypeDefaultTU);
293 if (Record.MaskedPolicyScheme == PolicyScheme::SchemeNone)
294 continue;
295 // Create masked policy intrinsic.
296 for (auto P : SupportedMaskedPolicies) {
297 llvm::SmallVector<PrototypeDescriptor> PolicyPrototype =
298 RVVIntrinsic::computeBuiltinTypes(
299 BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand,
300 Record.HasVL, Record.NF, Record.IsPrototypeDefaultTU,
301 MaskedPolicyScheme, P);
302 Optional<RVVTypes> PolicyTypes = TypeCache.computeTypes(
303 BaseType, Log2LMUL, Record.NF, PolicyPrototype);
304 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr,
305 /*IsMask=*/true, *PolicyTypes, MaskedHasPolicy, P,
306 Record.IsPrototypeDefaultTU);
308 } // End for different LMUL
309 } // End for different TypeRange
313 // Compute name and signatures for intrinsic with practical types.
314 void RISCVIntrinsicManagerImpl::InitRVVIntrinsic(
315 const RVVIntrinsicRecord &Record, StringRef SuffixStr,
316 StringRef OverloadedSuffixStr, bool IsMasked, RVVTypes &Signature,
317 bool HasPolicy, Policy DefaultPolicy, bool IsPrototypeDefaultTU) {
318 // Function name, e.g. vadd_vv_i32m1.
319 std::string Name = Record.Name;
320 if (!SuffixStr.empty())
321 Name += "_" + SuffixStr.str();
323 // Overloaded function name, e.g. vadd.
324 std::string OverloadedName;
325 if (!Record.OverloadedName)
326 OverloadedName = StringRef(Record.Name).split("_").first.str();
327 else
328 OverloadedName = Record.OverloadedName;
329 if (!OverloadedSuffixStr.empty())
330 OverloadedName += "_" + OverloadedSuffixStr.str();
332 // clang built-in function name, e.g. __builtin_rvv_vadd.
333 std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name);
335 RVVIntrinsic::updateNamesAndPolicy(IsMasked, HasPolicy, IsPrototypeDefaultTU,
336 Name, BuiltinName, OverloadedName,
337 DefaultPolicy);
339 // Put into IntrinsicList.
340 size_t Index = IntrinsicList.size();
341 IntrinsicList.push_back({Name, OverloadedName, BuiltinName, Signature});
343 // Creating mapping to Intrinsics.
344 Intrinsics.insert({Name, Index});
346 // Get the RVVOverloadIntrinsicDef.
347 RVVOverloadIntrinsicDef &OverloadIntrinsicDef =
348 OverloadIntrinsics[OverloadedName];
350 // And added the index.
351 OverloadIntrinsicDef.Indexes.push_back(Index);
354 void RISCVIntrinsicManagerImpl::CreateRVVIntrinsicDecl(LookupResult &LR,
355 IdentifierInfo *II,
356 Preprocessor &PP,
357 unsigned Index,
358 bool IsOverload) {
359 ASTContext &Context = S.Context;
360 RVVIntrinsicDef &IDef = IntrinsicList[Index];
361 RVVTypes Sigs = IDef.Signature;
362 size_t SigLength = Sigs.size();
363 RVVType *ReturnType = Sigs[0];
364 QualType RetType = RVVType2Qual(Context, ReturnType);
365 SmallVector<QualType, 8> ArgTypes;
366 QualType BuiltinFuncType;
368 // Skip return type, and convert RVVType to QualType for arguments.
369 for (size_t i = 1; i < SigLength; ++i)
370 ArgTypes.push_back(RVVType2Qual(Context, Sigs[i]));
372 FunctionProtoType::ExtProtoInfo PI(
373 Context.getDefaultCallingConvention(false, false, true));
375 PI.Variadic = false;
377 SourceLocation Loc = LR.getNameLoc();
378 BuiltinFuncType = Context.getFunctionType(RetType, ArgTypes, PI);
379 DeclContext *Parent = Context.getTranslationUnitDecl();
381 FunctionDecl *RVVIntrinsicDecl = FunctionDecl::Create(
382 Context, Parent, Loc, Loc, II, BuiltinFuncType, /*TInfo=*/nullptr,
383 SC_Extern, S.getCurFPFeatures().isFPConstrained(),
384 /*isInlineSpecified*/ false,
385 /*hasWrittenPrototype*/ true);
387 // Create Decl objects for each parameter, adding them to the
388 // FunctionDecl.
389 const auto *FP = cast<FunctionProtoType>(BuiltinFuncType);
390 SmallVector<ParmVarDecl *, 8> ParmList;
391 for (unsigned IParm = 0, E = FP->getNumParams(); IParm != E; ++IParm) {
392 ParmVarDecl *Parm =
393 ParmVarDecl::Create(Context, RVVIntrinsicDecl, Loc, Loc, nullptr,
394 FP->getParamType(IParm), nullptr, SC_None, nullptr);
395 Parm->setScopeInfo(0, IParm);
396 ParmList.push_back(Parm);
398 RVVIntrinsicDecl->setParams(ParmList);
400 // Add function attributes.
401 if (IsOverload)
402 RVVIntrinsicDecl->addAttr(OverloadableAttr::CreateImplicit(Context));
404 // Setup alias to __builtin_rvv_*
405 IdentifierInfo &IntrinsicII = PP.getIdentifierTable().get(IDef.BuiltinName);
406 RVVIntrinsicDecl->addAttr(
407 BuiltinAliasAttr::CreateImplicit(S.Context, &IntrinsicII));
409 // Add to symbol table.
410 LR.addDecl(RVVIntrinsicDecl);
413 bool RISCVIntrinsicManagerImpl::CreateIntrinsicIfFound(LookupResult &LR,
414 IdentifierInfo *II,
415 Preprocessor &PP) {
416 StringRef Name = II->getName();
418 // Lookup the function name from the overload intrinsics first.
419 auto OvIItr = OverloadIntrinsics.find(Name);
420 if (OvIItr != OverloadIntrinsics.end()) {
421 const RVVOverloadIntrinsicDef &OvIntrinsicDef = OvIItr->second;
422 for (auto Index : OvIntrinsicDef.Indexes)
423 CreateRVVIntrinsicDecl(LR, II, PP, Index,
424 /*IsOverload*/ true);
426 // If we added overloads, need to resolve the lookup result.
427 LR.resolveKind();
428 return true;
431 // Lookup the function name from the intrinsics.
432 auto Itr = Intrinsics.find(Name);
433 if (Itr != Intrinsics.end()) {
434 CreateRVVIntrinsicDecl(LR, II, PP, Itr->second,
435 /*IsOverload*/ false);
436 return true;
439 // It's not an RVV intrinsics.
440 return false;
443 namespace clang {
444 std::unique_ptr<clang::sema::RISCVIntrinsicManager>
445 CreateRISCVIntrinsicManager(Sema &S) {
446 return std::make_unique<RISCVIntrinsicManagerImpl>(S);
448 } // namespace clang