[gn build] Port 8ebc35f8d041
[llvm-project.git] / flang / lib / Optimizer / Dialect / FIRAttr.cpp
blob4c78e223b4178505df6257d45d455de550d08dbb
1 //===-- FIRAttr.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Optimizer/Dialect/FIRAttr.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
16 #include "mlir/IR/AttributeSupport.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
24 #include "flang/Optimizer/Dialect/FIREnumAttr.cpp.inc"
25 #define GET_ATTRDEF_CLASSES
26 #include "flang/Optimizer/Dialect/FIRAttr.cpp.inc"
28 using namespace fir;
30 namespace fir::detail {
32 struct RealAttributeStorage : public mlir::AttributeStorage {
33 using KeyTy = std::pair<int, llvm::APFloat>;
35 RealAttributeStorage(int kind, const llvm::APFloat &value)
36 : kind(kind), value(value) {}
37 RealAttributeStorage(const KeyTy &key)
38 : RealAttributeStorage(key.first, key.second) {}
40 static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); }
42 bool operator==(const KeyTy &key) const {
43 return key.first == kind &&
44 key.second.compare(value) == llvm::APFloatBase::cmpEqual;
47 static RealAttributeStorage *
48 construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
49 return new (allocator.allocate<RealAttributeStorage>())
50 RealAttributeStorage(key);
53 KindTy getFKind() const { return kind; }
54 llvm::APFloat getValue() const { return value; }
56 private:
57 int kind;
58 llvm::APFloat value;
61 /// An attribute representing a reference to a type.
62 struct TypeAttributeStorage : public mlir::AttributeStorage {
63 using KeyTy = mlir::Type;
65 TypeAttributeStorage(mlir::Type value) : value(value) {
66 assert(value && "must not be of Type null");
69 /// Key equality function.
70 bool operator==(const KeyTy &key) const { return key == value; }
72 /// Construct a new storage instance.
73 static TypeAttributeStorage *
74 construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
75 return new (allocator.allocate<TypeAttributeStorage>())
76 TypeAttributeStorage(key);
79 mlir::Type getType() const { return value; }
81 private:
82 mlir::Type value;
84 } // namespace fir::detail
86 //===----------------------------------------------------------------------===//
87 // Attributes for SELECT TYPE
88 //===----------------------------------------------------------------------===//
90 ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) {
91 return Base::get(value.getContext(), value);
94 mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); }
96 SubclassAttr fir::SubclassAttr::get(mlir::Type value) {
97 return Base::get(value.getContext(), value);
100 mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); }
102 //===----------------------------------------------------------------------===//
103 // Attributes for SELECT CASE
104 //===----------------------------------------------------------------------===//
106 using AttributeUniquer = mlir::detail::AttributeUniquer;
108 ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
109 return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
112 UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
113 return AttributeUniquer::get<UpperBoundAttr>(ctxt);
116 LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
117 return AttributeUniquer::get<LowerBoundAttr>(ctxt);
120 PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
121 return AttributeUniquer::get<PointIntervalAttr>(ctxt);
124 //===----------------------------------------------------------------------===//
125 // RealAttr
126 //===----------------------------------------------------------------------===//
128 RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt,
129 const RealAttr::ValueType &key) {
130 return Base::get(ctxt, key);
133 KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); }
135 llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); }
137 //===----------------------------------------------------------------------===//
138 // FIR attribute parsing
139 //===----------------------------------------------------------------------===//
141 static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
142 mlir::DialectAsmParser &parser,
143 mlir::Type type) {
144 int kind = 0;
145 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) {
146 parser.emitError(parser.getNameLoc(), "expected '<' kind ','");
147 return {};
149 KindMapping kindMap(dialect->getContext());
150 llvm::APFloat value(0.);
151 if (parser.parseOptionalKeyword("i")) {
152 // `i` not present, so literal float must be present
153 double dontCare;
154 if (parser.parseFloat(dontCare) || parser.parseGreater()) {
155 parser.emitError(parser.getNameLoc(), "expected real constant '>'");
156 return {};
158 auto fltStr = parser.getFullSymbolSpec()
159 .drop_until([](char c) { return c == ','; })
160 .drop_front()
161 .drop_while([](char c) { return c == ' ' || c == '\t'; })
162 .take_until([](char c) {
163 return c == '>' || c == ' ' || c == '\t';
165 value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr);
166 } else {
167 // `i` is present, so literal bitstring (hex) must be present
168 llvm::StringRef hex;
169 if (parser.parseKeyword(&hex) || parser.parseGreater()) {
170 parser.emitError(parser.getNameLoc(), "expected real constant '>'");
171 return {};
173 const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind);
174 unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem);
175 auto bits = llvm::APInt(numBits, hex.drop_front(), 16);
176 value = llvm::APFloat(sem, bits);
178 return RealAttr::get(dialect->getContext(), {kind, value});
181 mlir::Attribute fir::FortranVariableFlagsAttr::parse(mlir::AsmParser &parser,
182 mlir::Type type) {
183 if (mlir::failed(parser.parseLess()))
184 return {};
186 fir::FortranVariableFlagsEnum flags = {};
187 if (mlir::failed(parser.parseOptionalGreater())) {
188 auto parseFlags = [&]() -> mlir::ParseResult {
189 llvm::StringRef elemName;
190 if (mlir::failed(parser.parseKeyword(&elemName)))
191 return mlir::failure();
193 auto elem = fir::symbolizeFortranVariableFlagsEnum(elemName);
194 if (!elem)
195 return parser.emitError(parser.getNameLoc(),
196 "Unknown fortran variable attribute: ")
197 << elemName;
199 flags = flags | *elem;
200 return mlir::success();
202 if (mlir::failed(parser.parseCommaSeparatedList(parseFlags)) ||
203 parser.parseGreater())
204 return {};
207 return FortranVariableFlagsAttr::get(parser.getContext(), flags);
210 mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect,
211 mlir::DialectAsmParser &parser,
212 mlir::Type type) {
213 auto loc = parser.getNameLoc();
214 llvm::StringRef attrName;
215 mlir::Attribute attr;
216 mlir::OptionalParseResult result =
217 generatedAttributeParser(parser, &attrName, type, attr);
218 if (result.has_value())
219 return attr;
220 if (attrName.empty())
221 return {}; // error reported by generatedAttributeParser
223 if (attrName == ExactTypeAttr::getAttrName()) {
224 mlir::Type type;
225 if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
226 parser.emitError(loc, "expected a type");
227 return {};
229 return ExactTypeAttr::get(type);
231 if (attrName == SubclassAttr::getAttrName()) {
232 mlir::Type type;
233 if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
234 parser.emitError(loc, "expected a subtype");
235 return {};
237 return SubclassAttr::get(type);
239 if (attrName == PointIntervalAttr::getAttrName())
240 return PointIntervalAttr::get(dialect->getContext());
241 if (attrName == LowerBoundAttr::getAttrName())
242 return LowerBoundAttr::get(dialect->getContext());
243 if (attrName == UpperBoundAttr::getAttrName())
244 return UpperBoundAttr::get(dialect->getContext());
245 if (attrName == ClosedIntervalAttr::getAttrName())
246 return ClosedIntervalAttr::get(dialect->getContext());
247 if (attrName == RealAttr::getAttrName())
248 return parseFirRealAttr(dialect, parser, type);
250 parser.emitError(loc, "unknown FIR attribute: ") << attrName;
251 return {};
254 //===----------------------------------------------------------------------===//
255 // FIR attribute pretty printer
256 //===----------------------------------------------------------------------===//
258 void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter &printer) const {
259 printer << "<";
260 printer << fir::stringifyFortranVariableFlagsEnum(this->getFlags());
261 printer << ">";
264 void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
265 mlir::DialectAsmPrinter &p) {
266 auto &os = p.getStream();
267 if (auto exact = mlir::dyn_cast<fir::ExactTypeAttr>(attr)) {
268 os << fir::ExactTypeAttr::getAttrName() << '<';
269 p.printType(exact.getType());
270 os << '>';
271 } else if (auto sub = mlir::dyn_cast<fir::SubclassAttr>(attr)) {
272 os << fir::SubclassAttr::getAttrName() << '<';
273 p.printType(sub.getType());
274 os << '>';
275 } else if (mlir::dyn_cast_or_null<fir::PointIntervalAttr>(attr)) {
276 os << fir::PointIntervalAttr::getAttrName();
277 } else if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) {
278 os << fir::ClosedIntervalAttr::getAttrName();
279 } else if (mlir::dyn_cast_or_null<fir::LowerBoundAttr>(attr)) {
280 os << fir::LowerBoundAttr::getAttrName();
281 } else if (mlir::dyn_cast_or_null<fir::UpperBoundAttr>(attr)) {
282 os << fir::UpperBoundAttr::getAttrName();
283 } else if (auto a = mlir::dyn_cast_or_null<fir::RealAttr>(attr)) {
284 os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
285 llvm::SmallString<40> ss;
286 a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
287 os << ss << '>';
288 } else if (mlir::failed(generatedAttributePrinter(attr, p))) {
289 // don't know how to print the attribute, so use a default
290 os << "<(unknown attribute)>";
294 //===----------------------------------------------------------------------===//
295 // FIROpsDialect
296 //===----------------------------------------------------------------------===//
298 void FIROpsDialect::registerAttributes() {
299 addAttributes<ClosedIntervalAttr, ExactTypeAttr,
300 FortranProcedureFlagsEnumAttr, FortranVariableFlagsAttr,
301 LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
302 SubclassAttr, UpperBoundAttr, LocationKindAttr,
303 LocationKindArrayAttr>();