1 //===- AsmParserState.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/AsmParser/AsmParserState.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Types.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/ADT/iterator.h"
21 #include "llvm/Support/ErrorHandling.h"
29 //===----------------------------------------------------------------------===//
30 // AsmParserState::Impl
31 //===----------------------------------------------------------------------===//
33 struct AsmParserState::Impl
{
34 /// A map from a SymbolRefAttr to a range of uses.
36 DenseMap
<Attribute
, SmallVector
<SmallVector
<SMRange
>, 0>>;
39 explicit PartialOpDef(const OperationName
&opName
) {
40 if (opName
.hasTrait
<OpTrait::SymbolTable
>())
41 symbolTable
= std::make_unique
<SymbolUseMap
>();
44 /// Return if this operation is a symbol table.
45 bool isSymbolTable() const { return symbolTable
.get(); }
47 /// If this operation is a symbol table, the following contains symbol uses
48 /// within this operation.
49 std::unique_ptr
<SymbolUseMap
> symbolTable
;
52 /// Resolve any symbol table uses in the IR.
53 void resolveSymbolUses();
55 /// A mapping from operations in the input source file to their parser state.
56 SmallVector
<std::unique_ptr
<OperationDefinition
>> operations
;
57 DenseMap
<Operation
*, unsigned> operationToIdx
;
59 /// A mapping from blocks in the input source file to their parser state.
60 SmallVector
<std::unique_ptr
<BlockDefinition
>> blocks
;
61 DenseMap
<Block
*, unsigned> blocksToIdx
;
63 /// A mapping from aliases in the input source file to their parser state.
64 SmallVector
<std::unique_ptr
<AttributeAliasDefinition
>> attrAliases
;
65 SmallVector
<std::unique_ptr
<TypeAliasDefinition
>> typeAliases
;
66 llvm::StringMap
<unsigned> attrAliasToIdx
;
67 llvm::StringMap
<unsigned> typeAliasToIdx
;
69 /// A set of value definitions that are placeholders for forward references.
70 /// This map should be empty if the parser finishes successfully.
71 DenseMap
<Value
, SmallVector
<SMLoc
>> placeholderValueUses
;
73 /// The symbol table operations within the IR.
74 SmallVector
<std::pair
<Operation
*, std::unique_ptr
<SymbolUseMap
>>>
75 symbolTableOperations
;
77 /// A stack of partial operation definitions that have been started but not
79 SmallVector
<PartialOpDef
> partialOperations
;
81 /// A stack of symbol use scopes. This is used when collecting symbol table
82 /// uses during parsing.
83 SmallVector
<SymbolUseMap
*> symbolUseScopes
;
85 /// A symbol table containing all of the symbol table operations in the IR.
86 SymbolTableCollection symbolTable
;
89 void AsmParserState::Impl::resolveSymbolUses() {
90 SmallVector
<Operation
*> symbolOps
;
91 for (auto &opAndUseMapIt
: symbolTableOperations
) {
92 for (auto &it
: *opAndUseMapIt
.second
) {
94 if (failed(symbolTable
.lookupSymbolIn(
95 opAndUseMapIt
.first
, cast
<SymbolRefAttr
>(it
.first
), symbolOps
)))
98 for (ArrayRef
<SMRange
> useRange
: it
.second
) {
99 for (const auto &symIt
: llvm::zip(symbolOps
, useRange
)) {
100 auto opIt
= operationToIdx
.find(std::get
<0>(symIt
));
101 if (opIt
!= operationToIdx
.end())
102 operations
[opIt
->second
]->symbolUses
.push_back(std::get
<1>(symIt
));
109 //===----------------------------------------------------------------------===//
111 //===----------------------------------------------------------------------===//
113 AsmParserState::AsmParserState() : impl(std::make_unique
<Impl
>()) {}
114 AsmParserState::~AsmParserState() = default;
115 AsmParserState
&AsmParserState::operator=(AsmParserState
&&other
) {
116 impl
= std::move(other
.impl
);
120 //===----------------------------------------------------------------------===//
123 auto AsmParserState::getBlockDefs() const -> iterator_range
<BlockDefIterator
> {
124 return llvm::make_pointee_range(llvm::ArrayRef(impl
->blocks
));
127 auto AsmParserState::getBlockDef(Block
*block
) const
128 -> const BlockDefinition
* {
129 auto it
= impl
->blocksToIdx
.find(block
);
130 return it
== impl
->blocksToIdx
.end() ? nullptr : &*impl
->blocks
[it
->second
];
133 auto AsmParserState::getOpDefs() const -> iterator_range
<OperationDefIterator
> {
134 return llvm::make_pointee_range(llvm::ArrayRef(impl
->operations
));
137 auto AsmParserState::getOpDef(Operation
*op
) const
138 -> const OperationDefinition
* {
139 auto it
= impl
->operationToIdx
.find(op
);
140 return it
== impl
->operationToIdx
.end() ? nullptr
141 : &*impl
->operations
[it
->second
];
144 auto AsmParserState::getAttributeAliasDefs() const
145 -> iterator_range
<AttributeDefIterator
> {
146 return llvm::make_pointee_range(ArrayRef(impl
->attrAliases
));
149 auto AsmParserState::getAttributeAliasDef(StringRef name
) const
150 -> const AttributeAliasDefinition
* {
151 auto it
= impl
->attrAliasToIdx
.find(name
);
152 return it
== impl
->attrAliasToIdx
.end() ? nullptr
153 : &*impl
->attrAliases
[it
->second
];
156 auto AsmParserState::getTypeAliasDefs() const
157 -> iterator_range
<TypeDefIterator
> {
158 return llvm::make_pointee_range(ArrayRef(impl
->typeAliases
));
161 auto AsmParserState::getTypeAliasDef(StringRef name
) const
162 -> const TypeAliasDefinition
* {
163 auto it
= impl
->typeAliasToIdx
.find(name
);
164 return it
== impl
->typeAliasToIdx
.end() ? nullptr
165 : &*impl
->typeAliases
[it
->second
];
168 /// Lex a string token whose contents start at the given `curPtr`. Returns the
169 /// position at the end of the string, after a terminal or invalid character
170 /// (e.g. `"` or `\0`).
171 static const char *lexLocStringTok(const char *curPtr
) {
172 while (char c
= *curPtr
++) {
173 // Check for various terminal characters.
174 if (StringRef("\"\n\v\f").contains(c
))
177 // Check for escape sequences.
179 // Check a few known escapes and \xx hex digits.
180 if (*curPtr
== '"' || *curPtr
== '\\' || *curPtr
== 'n' || *curPtr
== 't')
182 else if (llvm::isHexDigit(*curPtr
) && llvm::isHexDigit(curPtr
[1]))
189 // If we hit this point, we've reached the end of the buffer. Update the end
190 // pointer to not point past the buffer.
194 SMRange
AsmParserState::convertIdLocToRange(SMLoc loc
) {
197 const char *curPtr
= loc
.getPointer();
199 // Check if this is a string token.
200 if (*curPtr
== '"') {
201 curPtr
= lexLocStringTok(curPtr
+ 1);
203 // Otherwise, default to handling an identifier.
205 // Return if the given character is a valid identifier character.
206 auto isIdentifierChar
= [](char c
) {
207 return isalnum(c
) || c
== '$' || c
== '.' || c
== '_' || c
== '-';
210 while (*curPtr
&& isIdentifierChar(*(++curPtr
)))
214 return SMRange(loc
, SMLoc::getFromPointer(curPtr
));
217 //===----------------------------------------------------------------------===//
220 void AsmParserState::initialize(Operation
*topLevelOp
) {
221 startOperationDefinition(topLevelOp
->getName());
223 // If the top-level operation is a symbol table, push a new symbol scope.
224 Impl::PartialOpDef
&partialOpDef
= impl
->partialOperations
.back();
225 if (partialOpDef
.isSymbolTable())
226 impl
->symbolUseScopes
.push_back(partialOpDef
.symbolTable
.get());
229 void AsmParserState::finalize(Operation
*topLevelOp
) {
230 assert(!impl
->partialOperations
.empty() &&
231 "expected valid partial operation definition");
232 Impl::PartialOpDef partialOpDef
= impl
->partialOperations
.pop_back_val();
234 // If this operation is a symbol table, resolve any symbol uses.
235 if (partialOpDef
.isSymbolTable()) {
236 impl
->symbolTableOperations
.emplace_back(
237 topLevelOp
, std::move(partialOpDef
.symbolTable
));
239 impl
->resolveSymbolUses();
242 void AsmParserState::startOperationDefinition(const OperationName
&opName
) {
243 impl
->partialOperations
.emplace_back(opName
);
246 void AsmParserState::finalizeOperationDefinition(
247 Operation
*op
, SMRange nameLoc
, SMLoc endLoc
,
248 ArrayRef
<std::pair
<unsigned, SMLoc
>> resultGroups
) {
249 assert(!impl
->partialOperations
.empty() &&
250 "expected valid partial operation definition");
251 Impl::PartialOpDef partialOpDef
= impl
->partialOperations
.pop_back_val();
253 // Build the full operation definition.
254 std::unique_ptr
<OperationDefinition
> def
=
255 std::make_unique
<OperationDefinition
>(op
, nameLoc
, endLoc
);
256 for (auto &resultGroup
: resultGroups
)
257 def
->resultGroups
.emplace_back(resultGroup
.first
,
258 convertIdLocToRange(resultGroup
.second
));
259 impl
->operationToIdx
.try_emplace(op
, impl
->operations
.size());
260 impl
->operations
.emplace_back(std::move(def
));
262 // If this operation is a symbol table, resolve any symbol uses.
263 if (partialOpDef
.isSymbolTable()) {
264 impl
->symbolTableOperations
.emplace_back(
265 op
, std::move(partialOpDef
.symbolTable
));
269 void AsmParserState::startRegionDefinition() {
270 assert(!impl
->partialOperations
.empty() &&
271 "expected valid partial operation definition");
273 // If the parent operation of this region is a symbol table, we also push a
275 Impl::PartialOpDef
&partialOpDef
= impl
->partialOperations
.back();
276 if (partialOpDef
.isSymbolTable())
277 impl
->symbolUseScopes
.push_back(partialOpDef
.symbolTable
.get());
280 void AsmParserState::finalizeRegionDefinition() {
281 assert(!impl
->partialOperations
.empty() &&
282 "expected valid partial operation definition");
284 // If the parent operation of this region is a symbol table, pop the symbol
285 // scope for this region.
286 Impl::PartialOpDef
&partialOpDef
= impl
->partialOperations
.back();
287 if (partialOpDef
.isSymbolTable())
288 impl
->symbolUseScopes
.pop_back();
291 void AsmParserState::addDefinition(Block
*block
, SMLoc location
) {
292 auto [it
, inserted
] =
293 impl
->blocksToIdx
.try_emplace(block
, impl
->blocks
.size());
295 impl
->blocks
.emplace_back(std::make_unique
<BlockDefinition
>(
296 block
, convertIdLocToRange(location
)));
300 // If an entry already exists, this was a forward declaration that now has a
301 // proper definition.
302 impl
->blocks
[it
->second
]->definition
.loc
= convertIdLocToRange(location
);
305 void AsmParserState::addDefinition(BlockArgument blockArg
, SMLoc location
) {
306 auto it
= impl
->blocksToIdx
.find(blockArg
.getOwner());
307 assert(it
!= impl
->blocksToIdx
.end() &&
308 "expected owner block to have an entry");
309 BlockDefinition
&def
= *impl
->blocks
[it
->second
];
310 unsigned argIdx
= blockArg
.getArgNumber();
312 if (def
.arguments
.size() <= argIdx
)
313 def
.arguments
.resize(argIdx
+ 1);
314 def
.arguments
[argIdx
] = SMDefinition(convertIdLocToRange(location
));
317 void AsmParserState::addAttrAliasDefinition(StringRef name
, SMRange location
,
319 auto [it
, inserted
] =
320 impl
->attrAliasToIdx
.try_emplace(name
, impl
->attrAliases
.size());
321 // Location aliases may be referenced before they are defined.
323 impl
->attrAliases
.push_back(
324 std::make_unique
<AttributeAliasDefinition
>(name
, location
, value
));
326 AttributeAliasDefinition
&attr
= *impl
->attrAliases
[it
->second
];
327 attr
.definition
.loc
= location
;
332 void AsmParserState::addTypeAliasDefinition(StringRef name
, SMRange location
,
334 [[maybe_unused
]] auto [it
, inserted
] =
335 impl
->typeAliasToIdx
.try_emplace(name
, impl
->typeAliases
.size());
336 assert(inserted
&& "unexpected attribute alias redefinition");
337 impl
->typeAliases
.push_back(
338 std::make_unique
<TypeAliasDefinition
>(name
, location
, value
));
341 void AsmParserState::addUses(Value value
, ArrayRef
<SMLoc
> locations
) {
342 // Handle the case where the value is an operation result.
343 if (OpResult result
= dyn_cast
<OpResult
>(value
)) {
344 // Check to see if a definition for the parent operation has been recorded.
345 // If one hasn't, we treat the provided value as a placeholder value that
346 // will be refined further later.
347 Operation
*parentOp
= result
.getOwner();
348 auto existingIt
= impl
->operationToIdx
.find(parentOp
);
349 if (existingIt
== impl
->operationToIdx
.end()) {
350 impl
->placeholderValueUses
[value
].append(locations
.begin(),
355 // If a definition does exist, locate the value's result group and add the
356 // use. The result groups are ordered by increasing start index, so we just
357 // need to find the last group that has a smaller/equal start index.
358 unsigned resultNo
= result
.getResultNumber();
359 OperationDefinition
&def
= *impl
->operations
[existingIt
->second
];
360 for (auto &resultGroup
: llvm::reverse(def
.resultGroups
)) {
361 if (resultNo
>= resultGroup
.startIndex
) {
362 for (SMLoc loc
: locations
)
363 resultGroup
.definition
.uses
.push_back(convertIdLocToRange(loc
));
367 llvm_unreachable("expected valid result group for value use");
370 // Otherwise, this is a block argument.
371 BlockArgument arg
= cast
<BlockArgument
>(value
);
372 auto existingIt
= impl
->blocksToIdx
.find(arg
.getOwner());
373 assert(existingIt
!= impl
->blocksToIdx
.end() &&
374 "expected valid block definition for block argument");
375 BlockDefinition
&blockDef
= *impl
->blocks
[existingIt
->second
];
376 SMDefinition
&argDef
= blockDef
.arguments
[arg
.getArgNumber()];
377 for (SMLoc loc
: locations
)
378 argDef
.uses
.emplace_back(convertIdLocToRange(loc
));
381 void AsmParserState::addUses(Block
*block
, ArrayRef
<SMLoc
> locations
) {
382 auto [it
, inserted
] =
383 impl
->blocksToIdx
.try_emplace(block
, impl
->blocks
.size());
385 impl
->blocks
.emplace_back(std::make_unique
<BlockDefinition
>(block
));
387 BlockDefinition
&def
= *impl
->blocks
[it
->second
];
388 for (SMLoc loc
: locations
)
389 def
.definition
.uses
.push_back(convertIdLocToRange(loc
));
392 void AsmParserState::addUses(SymbolRefAttr refAttr
,
393 ArrayRef
<SMRange
> locations
) {
394 // Ignore this symbol if no scopes are active.
395 if (impl
->symbolUseScopes
.empty())
398 assert((refAttr
.getNestedReferences().size() + 1) == locations
.size() &&
399 "expected the same number of references as provided locations");
400 (*impl
->symbolUseScopes
.back())[refAttr
].emplace_back(locations
.begin(),
404 void AsmParserState::addAttrAliasUses(StringRef name
, SMRange location
) {
405 auto it
= impl
->attrAliasToIdx
.find(name
);
406 // Location aliases may be referenced before they are defined.
407 if (it
== impl
->attrAliasToIdx
.end()) {
408 it
= impl
->attrAliasToIdx
.try_emplace(name
, impl
->attrAliases
.size()).first
;
409 impl
->attrAliases
.push_back(
410 std::make_unique
<AttributeAliasDefinition
>(name
));
412 AttributeAliasDefinition
&def
= *impl
->attrAliases
[it
->second
];
413 def
.definition
.uses
.push_back(location
);
416 void AsmParserState::addTypeAliasUses(StringRef name
, SMRange location
) {
417 auto it
= impl
->typeAliasToIdx
.find(name
);
418 // Location aliases may be referenced before they are defined.
419 assert(it
!= impl
->typeAliasToIdx
.end() &&
420 "expected valid type alias definition");
421 TypeAliasDefinition
&def
= *impl
->typeAliases
[it
->second
];
422 def
.definition
.uses
.push_back(location
);
425 void AsmParserState::refineDefinition(Value oldValue
, Value newValue
) {
426 auto it
= impl
->placeholderValueUses
.find(oldValue
);
427 assert(it
!= impl
->placeholderValueUses
.end() &&
428 "expected `oldValue` to be a placeholder");
429 addUses(newValue
, it
->second
);
430 impl
->placeholderValueUses
.erase(oldValue
);