1 //===- AsmParserState.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/AsmParser/AsmParserState.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Types.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "mlir/Support/LogicalResult.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/iterator.h"
22 #include "llvm/Support/ErrorHandling.h"
30 //===----------------------------------------------------------------------===//
31 // AsmParserState::Impl
32 //===----------------------------------------------------------------------===//
34 struct AsmParserState::Impl
{
35 /// A map from a SymbolRefAttr to a range of uses.
37 DenseMap
<Attribute
, SmallVector
<SmallVector
<SMRange
>, 0>>;
40 explicit PartialOpDef(const OperationName
&opName
) {
41 if (opName
.hasTrait
<OpTrait::SymbolTable
>())
42 symbolTable
= std::make_unique
<SymbolUseMap
>();
45 /// Return if this operation is a symbol table.
46 bool isSymbolTable() const { return symbolTable
.get(); }
48 /// If this operation is a symbol table, the following contains symbol uses
49 /// within this operation.
50 std::unique_ptr
<SymbolUseMap
> symbolTable
;
53 /// Resolve any symbol table uses in the IR.
54 void resolveSymbolUses();
56 /// A mapping from operations in the input source file to their parser state.
57 SmallVector
<std::unique_ptr
<OperationDefinition
>> operations
;
58 DenseMap
<Operation
*, unsigned> operationToIdx
;
60 /// A mapping from blocks in the input source file to their parser state.
61 SmallVector
<std::unique_ptr
<BlockDefinition
>> blocks
;
62 DenseMap
<Block
*, unsigned> blocksToIdx
;
64 /// A mapping from aliases in the input source file to their parser state.
65 SmallVector
<std::unique_ptr
<AttributeAliasDefinition
>> attrAliases
;
66 SmallVector
<std::unique_ptr
<TypeAliasDefinition
>> typeAliases
;
67 llvm::StringMap
<unsigned> attrAliasToIdx
;
68 llvm::StringMap
<unsigned> typeAliasToIdx
;
70 /// A set of value definitions that are placeholders for forward references.
71 /// This map should be empty if the parser finishes successfully.
72 DenseMap
<Value
, SmallVector
<SMLoc
>> placeholderValueUses
;
74 /// The symbol table operations within the IR.
75 SmallVector
<std::pair
<Operation
*, std::unique_ptr
<SymbolUseMap
>>>
76 symbolTableOperations
;
78 /// A stack of partial operation definitions that have been started but not
80 SmallVector
<PartialOpDef
> partialOperations
;
82 /// A stack of symbol use scopes. This is used when collecting symbol table
83 /// uses during parsing.
84 SmallVector
<SymbolUseMap
*> symbolUseScopes
;
86 /// A symbol table containing all of the symbol table operations in the IR.
87 SymbolTableCollection symbolTable
;
90 void AsmParserState::Impl::resolveSymbolUses() {
91 SmallVector
<Operation
*> symbolOps
;
92 for (auto &opAndUseMapIt
: symbolTableOperations
) {
93 for (auto &it
: *opAndUseMapIt
.second
) {
95 if (failed(symbolTable
.lookupSymbolIn(
96 opAndUseMapIt
.first
, cast
<SymbolRefAttr
>(it
.first
), symbolOps
)))
99 for (ArrayRef
<SMRange
> useRange
: it
.second
) {
100 for (const auto &symIt
: llvm::zip(symbolOps
, useRange
)) {
101 auto opIt
= operationToIdx
.find(std::get
<0>(symIt
));
102 if (opIt
!= operationToIdx
.end())
103 operations
[opIt
->second
]->symbolUses
.push_back(std::get
<1>(symIt
));
110 //===----------------------------------------------------------------------===//
112 //===----------------------------------------------------------------------===//
114 AsmParserState::AsmParserState() : impl(std::make_unique
<Impl
>()) {}
115 AsmParserState::~AsmParserState() = default;
116 AsmParserState
&AsmParserState::operator=(AsmParserState
&&other
) {
117 impl
= std::move(other
.impl
);
121 //===----------------------------------------------------------------------===//
124 auto AsmParserState::getBlockDefs() const -> iterator_range
<BlockDefIterator
> {
125 return llvm::make_pointee_range(llvm::ArrayRef(impl
->blocks
));
128 auto AsmParserState::getBlockDef(Block
*block
) const
129 -> const BlockDefinition
* {
130 auto it
= impl
->blocksToIdx
.find(block
);
131 return it
== impl
->blocksToIdx
.end() ? nullptr : &*impl
->blocks
[it
->second
];
134 auto AsmParserState::getOpDefs() const -> iterator_range
<OperationDefIterator
> {
135 return llvm::make_pointee_range(llvm::ArrayRef(impl
->operations
));
138 auto AsmParserState::getOpDef(Operation
*op
) const
139 -> const OperationDefinition
* {
140 auto it
= impl
->operationToIdx
.find(op
);
141 return it
== impl
->operationToIdx
.end() ? nullptr
142 : &*impl
->operations
[it
->second
];
145 auto AsmParserState::getAttributeAliasDefs() const
146 -> iterator_range
<AttributeDefIterator
> {
147 return llvm::make_pointee_range(ArrayRef(impl
->attrAliases
));
150 auto AsmParserState::getAttributeAliasDef(StringRef name
) const
151 -> const AttributeAliasDefinition
* {
152 auto it
= impl
->attrAliasToIdx
.find(name
);
153 return it
== impl
->attrAliasToIdx
.end() ? nullptr
154 : &*impl
->attrAliases
[it
->second
];
157 auto AsmParserState::getTypeAliasDefs() const
158 -> iterator_range
<TypeDefIterator
> {
159 return llvm::make_pointee_range(ArrayRef(impl
->typeAliases
));
162 auto AsmParserState::getTypeAliasDef(StringRef name
) const
163 -> const TypeAliasDefinition
* {
164 auto it
= impl
->typeAliasToIdx
.find(name
);
165 return it
== impl
->typeAliasToIdx
.end() ? nullptr
166 : &*impl
->typeAliases
[it
->second
];
169 /// Lex a string token whose contents start at the given `curPtr`. Returns the
170 /// position at the end of the string, after a terminal or invalid character
171 /// (e.g. `"` or `\0`).
172 static const char *lexLocStringTok(const char *curPtr
) {
173 while (char c
= *curPtr
++) {
174 // Check for various terminal characters.
175 if (StringRef("\"\n\v\f").contains(c
))
178 // Check for escape sequences.
180 // Check a few known escapes and \xx hex digits.
181 if (*curPtr
== '"' || *curPtr
== '\\' || *curPtr
== 'n' || *curPtr
== 't')
183 else if (llvm::isHexDigit(*curPtr
) && llvm::isHexDigit(curPtr
[1]))
190 // If we hit this point, we've reached the end of the buffer. Update the end
191 // pointer to not point past the buffer.
195 SMRange
AsmParserState::convertIdLocToRange(SMLoc loc
) {
198 const char *curPtr
= loc
.getPointer();
200 // Check if this is a string token.
201 if (*curPtr
== '"') {
202 curPtr
= lexLocStringTok(curPtr
+ 1);
204 // Otherwise, default to handling an identifier.
206 // Return if the given character is a valid identifier character.
207 auto isIdentifierChar
= [](char c
) {
208 return isalnum(c
) || c
== '$' || c
== '.' || c
== '_' || c
== '-';
211 while (*curPtr
&& isIdentifierChar(*(++curPtr
)))
215 return SMRange(loc
, SMLoc::getFromPointer(curPtr
));
218 //===----------------------------------------------------------------------===//
221 void AsmParserState::initialize(Operation
*topLevelOp
) {
222 startOperationDefinition(topLevelOp
->getName());
224 // If the top-level operation is a symbol table, push a new symbol scope.
225 Impl::PartialOpDef
&partialOpDef
= impl
->partialOperations
.back();
226 if (partialOpDef
.isSymbolTable())
227 impl
->symbolUseScopes
.push_back(partialOpDef
.symbolTable
.get());
230 void AsmParserState::finalize(Operation
*topLevelOp
) {
231 assert(!impl
->partialOperations
.empty() &&
232 "expected valid partial operation definition");
233 Impl::PartialOpDef partialOpDef
= impl
->partialOperations
.pop_back_val();
235 // If this operation is a symbol table, resolve any symbol uses.
236 if (partialOpDef
.isSymbolTable()) {
237 impl
->symbolTableOperations
.emplace_back(
238 topLevelOp
, std::move(partialOpDef
.symbolTable
));
240 impl
->resolveSymbolUses();
243 void AsmParserState::startOperationDefinition(const OperationName
&opName
) {
244 impl
->partialOperations
.emplace_back(opName
);
247 void AsmParserState::finalizeOperationDefinition(
248 Operation
*op
, SMRange nameLoc
, SMLoc endLoc
,
249 ArrayRef
<std::pair
<unsigned, SMLoc
>> resultGroups
) {
250 assert(!impl
->partialOperations
.empty() &&
251 "expected valid partial operation definition");
252 Impl::PartialOpDef partialOpDef
= impl
->partialOperations
.pop_back_val();
254 // Build the full operation definition.
255 std::unique_ptr
<OperationDefinition
> def
=
256 std::make_unique
<OperationDefinition
>(op
, nameLoc
, endLoc
);
257 for (auto &resultGroup
: resultGroups
)
258 def
->resultGroups
.emplace_back(resultGroup
.first
,
259 convertIdLocToRange(resultGroup
.second
));
260 impl
->operationToIdx
.try_emplace(op
, impl
->operations
.size());
261 impl
->operations
.emplace_back(std::move(def
));
263 // If this operation is a symbol table, resolve any symbol uses.
264 if (partialOpDef
.isSymbolTable()) {
265 impl
->symbolTableOperations
.emplace_back(
266 op
, std::move(partialOpDef
.symbolTable
));
270 void AsmParserState::startRegionDefinition() {
271 assert(!impl
->partialOperations
.empty() &&
272 "expected valid partial operation definition");
274 // If the parent operation of this region is a symbol table, we also push a
276 Impl::PartialOpDef
&partialOpDef
= impl
->partialOperations
.back();
277 if (partialOpDef
.isSymbolTable())
278 impl
->symbolUseScopes
.push_back(partialOpDef
.symbolTable
.get());
281 void AsmParserState::finalizeRegionDefinition() {
282 assert(!impl
->partialOperations
.empty() &&
283 "expected valid partial operation definition");
285 // If the parent operation of this region is a symbol table, pop the symbol
286 // scope for this region.
287 Impl::PartialOpDef
&partialOpDef
= impl
->partialOperations
.back();
288 if (partialOpDef
.isSymbolTable())
289 impl
->symbolUseScopes
.pop_back();
292 void AsmParserState::addDefinition(Block
*block
, SMLoc location
) {
293 auto it
= impl
->blocksToIdx
.find(block
);
294 if (it
== impl
->blocksToIdx
.end()) {
295 impl
->blocksToIdx
.try_emplace(block
, impl
->blocks
.size());
296 impl
->blocks
.emplace_back(std::make_unique
<BlockDefinition
>(
297 block
, convertIdLocToRange(location
)));
301 // If an entry already exists, this was a forward declaration that now has a
302 // proper definition.
303 impl
->blocks
[it
->second
]->definition
.loc
= convertIdLocToRange(location
);
306 void AsmParserState::addDefinition(BlockArgument blockArg
, SMLoc location
) {
307 auto it
= impl
->blocksToIdx
.find(blockArg
.getOwner());
308 assert(it
!= impl
->blocksToIdx
.end() &&
309 "expected owner block to have an entry");
310 BlockDefinition
&def
= *impl
->blocks
[it
->second
];
311 unsigned argIdx
= blockArg
.getArgNumber();
313 if (def
.arguments
.size() <= argIdx
)
314 def
.arguments
.resize(argIdx
+ 1);
315 def
.arguments
[argIdx
] = SMDefinition(convertIdLocToRange(location
));
318 void AsmParserState::addAttrAliasDefinition(StringRef name
, SMRange location
,
320 auto [it
, inserted
] =
321 impl
->attrAliasToIdx
.try_emplace(name
, impl
->attrAliases
.size());
322 // Location aliases may be referenced before they are defined.
324 impl
->attrAliases
.push_back(
325 std::make_unique
<AttributeAliasDefinition
>(name
, location
, value
));
327 AttributeAliasDefinition
&attr
= *impl
->attrAliases
[it
->second
];
328 attr
.definition
.loc
= location
;
333 void AsmParserState::addTypeAliasDefinition(StringRef name
, SMRange location
,
335 [[maybe_unused
]] auto [it
, inserted
] =
336 impl
->typeAliasToIdx
.try_emplace(name
, impl
->typeAliases
.size());
337 assert(inserted
&& "unexpected attribute alias redefinition");
338 impl
->typeAliases
.push_back(
339 std::make_unique
<TypeAliasDefinition
>(name
, location
, value
));
342 void AsmParserState::addUses(Value value
, ArrayRef
<SMLoc
> locations
) {
343 // Handle the case where the value is an operation result.
344 if (OpResult result
= dyn_cast
<OpResult
>(value
)) {
345 // Check to see if a definition for the parent operation has been recorded.
346 // If one hasn't, we treat the provided value as a placeholder value that
347 // will be refined further later.
348 Operation
*parentOp
= result
.getOwner();
349 auto existingIt
= impl
->operationToIdx
.find(parentOp
);
350 if (existingIt
== impl
->operationToIdx
.end()) {
351 impl
->placeholderValueUses
[value
].append(locations
.begin(),
356 // If a definition does exist, locate the value's result group and add the
357 // use. The result groups are ordered by increasing start index, so we just
358 // need to find the last group that has a smaller/equal start index.
359 unsigned resultNo
= result
.getResultNumber();
360 OperationDefinition
&def
= *impl
->operations
[existingIt
->second
];
361 for (auto &resultGroup
: llvm::reverse(def
.resultGroups
)) {
362 if (resultNo
>= resultGroup
.startIndex
) {
363 for (SMLoc loc
: locations
)
364 resultGroup
.definition
.uses
.push_back(convertIdLocToRange(loc
));
368 llvm_unreachable("expected valid result group for value use");
371 // Otherwise, this is a block argument.
372 BlockArgument arg
= cast
<BlockArgument
>(value
);
373 auto existingIt
= impl
->blocksToIdx
.find(arg
.getOwner());
374 assert(existingIt
!= impl
->blocksToIdx
.end() &&
375 "expected valid block definition for block argument");
376 BlockDefinition
&blockDef
= *impl
->blocks
[existingIt
->second
];
377 SMDefinition
&argDef
= blockDef
.arguments
[arg
.getArgNumber()];
378 for (SMLoc loc
: locations
)
379 argDef
.uses
.emplace_back(convertIdLocToRange(loc
));
382 void AsmParserState::addUses(Block
*block
, ArrayRef
<SMLoc
> locations
) {
383 auto it
= impl
->blocksToIdx
.find(block
);
384 if (it
== impl
->blocksToIdx
.end()) {
385 it
= impl
->blocksToIdx
.try_emplace(block
, impl
->blocks
.size()).first
;
386 impl
->blocks
.emplace_back(std::make_unique
<BlockDefinition
>(block
));
389 BlockDefinition
&def
= *impl
->blocks
[it
->second
];
390 for (SMLoc loc
: locations
)
391 def
.definition
.uses
.push_back(convertIdLocToRange(loc
));
394 void AsmParserState::addUses(SymbolRefAttr refAttr
,
395 ArrayRef
<SMRange
> locations
) {
396 // Ignore this symbol if no scopes are active.
397 if (impl
->symbolUseScopes
.empty())
400 assert((refAttr
.getNestedReferences().size() + 1) == locations
.size() &&
401 "expected the same number of references as provided locations");
402 (*impl
->symbolUseScopes
.back())[refAttr
].emplace_back(locations
.begin(),
406 void AsmParserState::addAttrAliasUses(StringRef name
, SMRange location
) {
407 auto it
= impl
->attrAliasToIdx
.find(name
);
408 // Location aliases may be referenced before they are defined.
409 if (it
== impl
->attrAliasToIdx
.end()) {
410 it
= impl
->attrAliasToIdx
.try_emplace(name
, impl
->attrAliases
.size()).first
;
411 impl
->attrAliases
.push_back(
412 std::make_unique
<AttributeAliasDefinition
>(name
));
414 AttributeAliasDefinition
&def
= *impl
->attrAliases
[it
->second
];
415 def
.definition
.uses
.push_back(location
);
418 void AsmParserState::addTypeAliasUses(StringRef name
, SMRange location
) {
419 auto it
= impl
->typeAliasToIdx
.find(name
);
420 // Location aliases may be referenced before they are defined.
421 assert(it
!= impl
->typeAliasToIdx
.end() &&
422 "expected valid type alias definition");
423 TypeAliasDefinition
&def
= *impl
->typeAliases
[it
->second
];
424 def
.definition
.uses
.push_back(location
);
427 void AsmParserState::refineDefinition(Value oldValue
, Value newValue
) {
428 auto it
= impl
->placeholderValueUses
.find(oldValue
);
429 assert(it
!= impl
->placeholderValueUses
.end() &&
430 "expected `oldValue` to be a placeholder");
431 addUses(newValue
, it
->second
);
432 impl
->placeholderValueUses
.erase(oldValue
);