1 //===-- FIRAttr.cpp -------------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Optimizer/Dialect/FIRAttr.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
16 #include "mlir/IR/AttributeSupport.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
24 #include "flang/Optimizer/Dialect/FIREnumAttr.cpp.inc"
25 #define GET_ATTRDEF_CLASSES
26 #include "flang/Optimizer/Dialect/FIRAttr.cpp.inc"
30 namespace fir::detail
{
32 struct RealAttributeStorage
: public mlir::AttributeStorage
{
33 using KeyTy
= std::pair
<int, llvm::APFloat
>;
35 RealAttributeStorage(int kind
, const llvm::APFloat
&value
)
36 : kind(kind
), value(value
) {}
37 RealAttributeStorage(const KeyTy
&key
)
38 : RealAttributeStorage(key
.first
, key
.second
) {}
40 static unsigned hashKey(const KeyTy
&key
) { return llvm::hash_value(key
); }
42 bool operator==(const KeyTy
&key
) const {
43 return key
.first
== kind
&&
44 key
.second
.compare(value
) == llvm::APFloatBase::cmpEqual
;
47 static RealAttributeStorage
*
48 construct(mlir::AttributeStorageAllocator
&allocator
, const KeyTy
&key
) {
49 return new (allocator
.allocate
<RealAttributeStorage
>())
50 RealAttributeStorage(key
);
53 KindTy
getFKind() const { return kind
; }
54 llvm::APFloat
getValue() const { return value
; }
61 /// An attribute representing a reference to a type.
62 struct TypeAttributeStorage
: public mlir::AttributeStorage
{
63 using KeyTy
= mlir::Type
;
65 TypeAttributeStorage(mlir::Type value
) : value(value
) {
66 assert(value
&& "must not be of Type null");
69 /// Key equality function.
70 bool operator==(const KeyTy
&key
) const { return key
== value
; }
72 /// Construct a new storage instance.
73 static TypeAttributeStorage
*
74 construct(mlir::AttributeStorageAllocator
&allocator
, KeyTy key
) {
75 return new (allocator
.allocate
<TypeAttributeStorage
>())
76 TypeAttributeStorage(key
);
79 mlir::Type
getType() const { return value
; }
84 } // namespace fir::detail
86 //===----------------------------------------------------------------------===//
87 // Attributes for SELECT TYPE
88 //===----------------------------------------------------------------------===//
90 ExactTypeAttr
fir::ExactTypeAttr::get(mlir::Type value
) {
91 return Base::get(value
.getContext(), value
);
94 mlir::Type
fir::ExactTypeAttr::getType() const { return getImpl()->getType(); }
96 SubclassAttr
fir::SubclassAttr::get(mlir::Type value
) {
97 return Base::get(value
.getContext(), value
);
100 mlir::Type
fir::SubclassAttr::getType() const { return getImpl()->getType(); }
102 //===----------------------------------------------------------------------===//
103 // Attributes for SELECT CASE
104 //===----------------------------------------------------------------------===//
106 using AttributeUniquer
= mlir::detail::AttributeUniquer
;
108 ClosedIntervalAttr
fir::ClosedIntervalAttr::get(mlir::MLIRContext
*ctxt
) {
109 return AttributeUniquer::get
<ClosedIntervalAttr
>(ctxt
);
112 UpperBoundAttr
fir::UpperBoundAttr::get(mlir::MLIRContext
*ctxt
) {
113 return AttributeUniquer::get
<UpperBoundAttr
>(ctxt
);
116 LowerBoundAttr
fir::LowerBoundAttr::get(mlir::MLIRContext
*ctxt
) {
117 return AttributeUniquer::get
<LowerBoundAttr
>(ctxt
);
120 PointIntervalAttr
fir::PointIntervalAttr::get(mlir::MLIRContext
*ctxt
) {
121 return AttributeUniquer::get
<PointIntervalAttr
>(ctxt
);
124 //===----------------------------------------------------------------------===//
126 //===----------------------------------------------------------------------===//
128 RealAttr
fir::RealAttr::get(mlir::MLIRContext
*ctxt
,
129 const RealAttr::ValueType
&key
) {
130 return Base::get(ctxt
, key
);
133 KindTy
fir::RealAttr::getFKind() const { return getImpl()->getFKind(); }
135 llvm::APFloat
fir::RealAttr::getValue() const { return getImpl()->getValue(); }
137 //===----------------------------------------------------------------------===//
138 // FIR attribute parsing
139 //===----------------------------------------------------------------------===//
141 static mlir::Attribute
parseFirRealAttr(FIROpsDialect
*dialect
,
142 mlir::DialectAsmParser
&parser
,
145 if (parser
.parseLess() || parser
.parseInteger(kind
) || parser
.parseComma()) {
146 parser
.emitError(parser
.getNameLoc(), "expected '<' kind ','");
149 KindMapping
kindMap(dialect
->getContext());
150 llvm::APFloat
value(0.);
151 if (parser
.parseOptionalKeyword("i")) {
152 // `i` not present, so literal float must be present
154 if (parser
.parseFloat(dontCare
) || parser
.parseGreater()) {
155 parser
.emitError(parser
.getNameLoc(), "expected real constant '>'");
158 auto fltStr
= parser
.getFullSymbolSpec()
159 .drop_until([](char c
) { return c
== ','; })
161 .drop_while([](char c
) { return c
== ' ' || c
== '\t'; })
162 .take_until([](char c
) {
163 return c
== '>' || c
== ' ' || c
== '\t';
165 value
= llvm::APFloat(kindMap
.getFloatSemantics(kind
), fltStr
);
167 // `i` is present, so literal bitstring (hex) must be present
169 if (parser
.parseKeyword(&hex
) || parser
.parseGreater()) {
170 parser
.emitError(parser
.getNameLoc(), "expected real constant '>'");
173 const llvm::fltSemantics
&sem
= kindMap
.getFloatSemantics(kind
);
174 unsigned int numBits
= llvm::APFloat::semanticsSizeInBits(sem
);
175 auto bits
= llvm::APInt(numBits
, hex
.drop_front(), 16);
176 value
= llvm::APFloat(sem
, bits
);
178 return RealAttr::get(dialect
->getContext(), {kind
, value
});
181 mlir::Attribute
fir::FortranVariableFlagsAttr::parse(mlir::AsmParser
&parser
,
183 if (mlir::failed(parser
.parseLess()))
186 fir::FortranVariableFlagsEnum flags
= {};
187 if (mlir::failed(parser
.parseOptionalGreater())) {
188 auto parseFlags
= [&]() -> mlir::ParseResult
{
189 llvm::StringRef elemName
;
190 if (mlir::failed(parser
.parseKeyword(&elemName
)))
191 return mlir::failure();
193 auto elem
= fir::symbolizeFortranVariableFlagsEnum(elemName
);
195 return parser
.emitError(parser
.getNameLoc(),
196 "Unknown fortran variable attribute: ")
199 flags
= flags
| *elem
;
200 return mlir::success();
202 if (mlir::failed(parser
.parseCommaSeparatedList(parseFlags
)) ||
203 parser
.parseGreater())
207 return FortranVariableFlagsAttr::get(parser
.getContext(), flags
);
210 mlir::Attribute
fir::parseFirAttribute(FIROpsDialect
*dialect
,
211 mlir::DialectAsmParser
&parser
,
213 auto loc
= parser
.getNameLoc();
214 llvm::StringRef attrName
;
215 mlir::Attribute attr
;
216 mlir::OptionalParseResult result
=
217 generatedAttributeParser(parser
, &attrName
, type
, attr
);
218 if (result
.has_value())
220 if (attrName
.empty())
221 return {}; // error reported by generatedAttributeParser
223 if (attrName
== ExactTypeAttr::getAttrName()) {
225 if (parser
.parseLess() || parser
.parseType(type
) || parser
.parseGreater()) {
226 parser
.emitError(loc
, "expected a type");
229 return ExactTypeAttr::get(type
);
231 if (attrName
== SubclassAttr::getAttrName()) {
233 if (parser
.parseLess() || parser
.parseType(type
) || parser
.parseGreater()) {
234 parser
.emitError(loc
, "expected a subtype");
237 return SubclassAttr::get(type
);
239 if (attrName
== PointIntervalAttr::getAttrName())
240 return PointIntervalAttr::get(dialect
->getContext());
241 if (attrName
== LowerBoundAttr::getAttrName())
242 return LowerBoundAttr::get(dialect
->getContext());
243 if (attrName
== UpperBoundAttr::getAttrName())
244 return UpperBoundAttr::get(dialect
->getContext());
245 if (attrName
== ClosedIntervalAttr::getAttrName())
246 return ClosedIntervalAttr::get(dialect
->getContext());
247 if (attrName
== RealAttr::getAttrName())
248 return parseFirRealAttr(dialect
, parser
, type
);
250 parser
.emitError(loc
, "unknown FIR attribute: ") << attrName
;
254 //===----------------------------------------------------------------------===//
255 // FIR attribute pretty printer
256 //===----------------------------------------------------------------------===//
258 void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter
&printer
) const {
260 printer
<< fir::stringifyFortranVariableFlagsEnum(this->getFlags());
264 void fir::printFirAttribute(FIROpsDialect
*dialect
, mlir::Attribute attr
,
265 mlir::DialectAsmPrinter
&p
) {
266 auto &os
= p
.getStream();
267 if (auto exact
= mlir::dyn_cast
<fir::ExactTypeAttr
>(attr
)) {
268 os
<< fir::ExactTypeAttr::getAttrName() << '<';
269 p
.printType(exact
.getType());
271 } else if (auto sub
= mlir::dyn_cast
<fir::SubclassAttr
>(attr
)) {
272 os
<< fir::SubclassAttr::getAttrName() << '<';
273 p
.printType(sub
.getType());
275 } else if (mlir::dyn_cast_or_null
<fir::PointIntervalAttr
>(attr
)) {
276 os
<< fir::PointIntervalAttr::getAttrName();
277 } else if (mlir::dyn_cast_or_null
<fir::ClosedIntervalAttr
>(attr
)) {
278 os
<< fir::ClosedIntervalAttr::getAttrName();
279 } else if (mlir::dyn_cast_or_null
<fir::LowerBoundAttr
>(attr
)) {
280 os
<< fir::LowerBoundAttr::getAttrName();
281 } else if (mlir::dyn_cast_or_null
<fir::UpperBoundAttr
>(attr
)) {
282 os
<< fir::UpperBoundAttr::getAttrName();
283 } else if (auto a
= mlir::dyn_cast_or_null
<fir::RealAttr
>(attr
)) {
284 os
<< fir::RealAttr::getAttrName() << '<' << a
.getFKind() << ", i x";
285 llvm::SmallString
<40> ss
;
286 a
.getValue().bitcastToAPInt().toStringUnsigned(ss
, 16);
288 } else if (mlir::failed(generatedAttributePrinter(attr
, p
))) {
289 // don't know how to print the attribute, so use a default
290 os
<< "<(unknown attribute)>";
294 //===----------------------------------------------------------------------===//
296 //===----------------------------------------------------------------------===//
298 void FIROpsDialect::registerAttributes() {
299 addAttributes
<ClosedIntervalAttr
, ExactTypeAttr
,
300 FortranProcedureFlagsEnumAttr
, FortranVariableFlagsAttr
,
301 LowerBoundAttr
, PointIntervalAttr
, RealAttr
, ReduceAttr
,
302 SubclassAttr
, UpperBoundAttr
, LocationKindAttr
,
303 LocationKindArrayAttr
>();