[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / AsmParser / AsmParserState.cpp
blob47cfb5288629093357c5fe1869f6f7b2f4fbaf36
1 //===- AsmParserState.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/AsmParser/AsmParserState.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Types.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "mlir/Support/LogicalResult.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/iterator.h"
22 #include "llvm/Support/ErrorHandling.h"
23 #include <cassert>
24 #include <cctype>
25 #include <memory>
26 #include <utility>
28 using namespace mlir;
30 //===----------------------------------------------------------------------===//
31 // AsmParserState::Impl
32 //===----------------------------------------------------------------------===//
34 struct AsmParserState::Impl {
35 /// A map from a SymbolRefAttr to a range of uses.
36 using SymbolUseMap =
37 DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>;
39 struct PartialOpDef {
40 explicit PartialOpDef(const OperationName &opName) {
41 if (opName.hasTrait<OpTrait::SymbolTable>())
42 symbolTable = std::make_unique<SymbolUseMap>();
45 /// Return if this operation is a symbol table.
46 bool isSymbolTable() const { return symbolTable.get(); }
48 /// If this operation is a symbol table, the following contains symbol uses
49 /// within this operation.
50 std::unique_ptr<SymbolUseMap> symbolTable;
53 /// Resolve any symbol table uses in the IR.
54 void resolveSymbolUses();
56 /// A mapping from operations in the input source file to their parser state.
57 SmallVector<std::unique_ptr<OperationDefinition>> operations;
58 DenseMap<Operation *, unsigned> operationToIdx;
60 /// A mapping from blocks in the input source file to their parser state.
61 SmallVector<std::unique_ptr<BlockDefinition>> blocks;
62 DenseMap<Block *, unsigned> blocksToIdx;
64 /// A mapping from aliases in the input source file to their parser state.
65 SmallVector<std::unique_ptr<AttributeAliasDefinition>> attrAliases;
66 SmallVector<std::unique_ptr<TypeAliasDefinition>> typeAliases;
67 llvm::StringMap<unsigned> attrAliasToIdx;
68 llvm::StringMap<unsigned> typeAliasToIdx;
70 /// A set of value definitions that are placeholders for forward references.
71 /// This map should be empty if the parser finishes successfully.
72 DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses;
74 /// The symbol table operations within the IR.
75 SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
76 symbolTableOperations;
78 /// A stack of partial operation definitions that have been started but not
79 /// yet finalized.
80 SmallVector<PartialOpDef> partialOperations;
82 /// A stack of symbol use scopes. This is used when collecting symbol table
83 /// uses during parsing.
84 SmallVector<SymbolUseMap *> symbolUseScopes;
86 /// A symbol table containing all of the symbol table operations in the IR.
87 SymbolTableCollection symbolTable;
90 void AsmParserState::Impl::resolveSymbolUses() {
91 SmallVector<Operation *> symbolOps;
92 for (auto &opAndUseMapIt : symbolTableOperations) {
93 for (auto &it : *opAndUseMapIt.second) {
94 symbolOps.clear();
95 if (failed(symbolTable.lookupSymbolIn(
96 opAndUseMapIt.first, cast<SymbolRefAttr>(it.first), symbolOps)))
97 continue;
99 for (ArrayRef<SMRange> useRange : it.second) {
100 for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
101 auto opIt = operationToIdx.find(std::get<0>(symIt));
102 if (opIt != operationToIdx.end())
103 operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
110 //===----------------------------------------------------------------------===//
111 // AsmParserState
112 //===----------------------------------------------------------------------===//
114 AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
115 AsmParserState::~AsmParserState() = default;
116 AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
117 impl = std::move(other.impl);
118 return *this;
121 //===----------------------------------------------------------------------===//
122 // Access State
124 auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
125 return llvm::make_pointee_range(llvm::ArrayRef(impl->blocks));
128 auto AsmParserState::getBlockDef(Block *block) const
129 -> const BlockDefinition * {
130 auto it = impl->blocksToIdx.find(block);
131 return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
134 auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
135 return llvm::make_pointee_range(llvm::ArrayRef(impl->operations));
138 auto AsmParserState::getOpDef(Operation *op) const
139 -> const OperationDefinition * {
140 auto it = impl->operationToIdx.find(op);
141 return it == impl->operationToIdx.end() ? nullptr
142 : &*impl->operations[it->second];
145 auto AsmParserState::getAttributeAliasDefs() const
146 -> iterator_range<AttributeDefIterator> {
147 return llvm::make_pointee_range(ArrayRef(impl->attrAliases));
150 auto AsmParserState::getAttributeAliasDef(StringRef name) const
151 -> const AttributeAliasDefinition * {
152 auto it = impl->attrAliasToIdx.find(name);
153 return it == impl->attrAliasToIdx.end() ? nullptr
154 : &*impl->attrAliases[it->second];
157 auto AsmParserState::getTypeAliasDefs() const
158 -> iterator_range<TypeDefIterator> {
159 return llvm::make_pointee_range(ArrayRef(impl->typeAliases));
162 auto AsmParserState::getTypeAliasDef(StringRef name) const
163 -> const TypeAliasDefinition * {
164 auto it = impl->typeAliasToIdx.find(name);
165 return it == impl->typeAliasToIdx.end() ? nullptr
166 : &*impl->typeAliases[it->second];
169 /// Lex a string token whose contents start at the given `curPtr`. Returns the
170 /// position at the end of the string, after a terminal or invalid character
171 /// (e.g. `"` or `\0`).
172 static const char *lexLocStringTok(const char *curPtr) {
173 while (char c = *curPtr++) {
174 // Check for various terminal characters.
175 if (StringRef("\"\n\v\f").contains(c))
176 return curPtr;
178 // Check for escape sequences.
179 if (c == '\\') {
180 // Check a few known escapes and \xx hex digits.
181 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
182 ++curPtr;
183 else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
184 curPtr += 2;
185 else
186 return curPtr;
190 // If we hit this point, we've reached the end of the buffer. Update the end
191 // pointer to not point past the buffer.
192 return curPtr - 1;
195 SMRange AsmParserState::convertIdLocToRange(SMLoc loc) {
196 if (!loc.isValid())
197 return SMRange();
198 const char *curPtr = loc.getPointer();
200 // Check if this is a string token.
201 if (*curPtr == '"') {
202 curPtr = lexLocStringTok(curPtr + 1);
204 // Otherwise, default to handling an identifier.
205 } else {
206 // Return if the given character is a valid identifier character.
207 auto isIdentifierChar = [](char c) {
208 return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
211 while (*curPtr && isIdentifierChar(*(++curPtr)))
212 continue;
215 return SMRange(loc, SMLoc::getFromPointer(curPtr));
218 //===----------------------------------------------------------------------===//
219 // Populate State
221 void AsmParserState::initialize(Operation *topLevelOp) {
222 startOperationDefinition(topLevelOp->getName());
224 // If the top-level operation is a symbol table, push a new symbol scope.
225 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
226 if (partialOpDef.isSymbolTable())
227 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
230 void AsmParserState::finalize(Operation *topLevelOp) {
231 assert(!impl->partialOperations.empty() &&
232 "expected valid partial operation definition");
233 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
235 // If this operation is a symbol table, resolve any symbol uses.
236 if (partialOpDef.isSymbolTable()) {
237 impl->symbolTableOperations.emplace_back(
238 topLevelOp, std::move(partialOpDef.symbolTable));
240 impl->resolveSymbolUses();
243 void AsmParserState::startOperationDefinition(const OperationName &opName) {
244 impl->partialOperations.emplace_back(opName);
247 void AsmParserState::finalizeOperationDefinition(
248 Operation *op, SMRange nameLoc, SMLoc endLoc,
249 ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) {
250 assert(!impl->partialOperations.empty() &&
251 "expected valid partial operation definition");
252 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
254 // Build the full operation definition.
255 std::unique_ptr<OperationDefinition> def =
256 std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
257 for (auto &resultGroup : resultGroups)
258 def->resultGroups.emplace_back(resultGroup.first,
259 convertIdLocToRange(resultGroup.second));
260 impl->operationToIdx.try_emplace(op, impl->operations.size());
261 impl->operations.emplace_back(std::move(def));
263 // If this operation is a symbol table, resolve any symbol uses.
264 if (partialOpDef.isSymbolTable()) {
265 impl->symbolTableOperations.emplace_back(
266 op, std::move(partialOpDef.symbolTable));
270 void AsmParserState::startRegionDefinition() {
271 assert(!impl->partialOperations.empty() &&
272 "expected valid partial operation definition");
274 // If the parent operation of this region is a symbol table, we also push a
275 // new symbol scope.
276 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
277 if (partialOpDef.isSymbolTable())
278 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
281 void AsmParserState::finalizeRegionDefinition() {
282 assert(!impl->partialOperations.empty() &&
283 "expected valid partial operation definition");
285 // If the parent operation of this region is a symbol table, pop the symbol
286 // scope for this region.
287 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
288 if (partialOpDef.isSymbolTable())
289 impl->symbolUseScopes.pop_back();
292 void AsmParserState::addDefinition(Block *block, SMLoc location) {
293 auto it = impl->blocksToIdx.find(block);
294 if (it == impl->blocksToIdx.end()) {
295 impl->blocksToIdx.try_emplace(block, impl->blocks.size());
296 impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
297 block, convertIdLocToRange(location)));
298 return;
301 // If an entry already exists, this was a forward declaration that now has a
302 // proper definition.
303 impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
306 void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) {
307 auto it = impl->blocksToIdx.find(blockArg.getOwner());
308 assert(it != impl->blocksToIdx.end() &&
309 "expected owner block to have an entry");
310 BlockDefinition &def = *impl->blocks[it->second];
311 unsigned argIdx = blockArg.getArgNumber();
313 if (def.arguments.size() <= argIdx)
314 def.arguments.resize(argIdx + 1);
315 def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
318 void AsmParserState::addAttrAliasDefinition(StringRef name, SMRange location,
319 Attribute value) {
320 auto [it, inserted] =
321 impl->attrAliasToIdx.try_emplace(name, impl->attrAliases.size());
322 // Location aliases may be referenced before they are defined.
323 if (inserted) {
324 impl->attrAliases.push_back(
325 std::make_unique<AttributeAliasDefinition>(name, location, value));
326 } else {
327 AttributeAliasDefinition &attr = *impl->attrAliases[it->second];
328 attr.definition.loc = location;
329 attr.value = value;
333 void AsmParserState::addTypeAliasDefinition(StringRef name, SMRange location,
334 Type value) {
335 [[maybe_unused]] auto [it, inserted] =
336 impl->typeAliasToIdx.try_emplace(name, impl->typeAliases.size());
337 assert(inserted && "unexpected attribute alias redefinition");
338 impl->typeAliases.push_back(
339 std::make_unique<TypeAliasDefinition>(name, location, value));
342 void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
343 // Handle the case where the value is an operation result.
344 if (OpResult result = dyn_cast<OpResult>(value)) {
345 // Check to see if a definition for the parent operation has been recorded.
346 // If one hasn't, we treat the provided value as a placeholder value that
347 // will be refined further later.
348 Operation *parentOp = result.getOwner();
349 auto existingIt = impl->operationToIdx.find(parentOp);
350 if (existingIt == impl->operationToIdx.end()) {
351 impl->placeholderValueUses[value].append(locations.begin(),
352 locations.end());
353 return;
356 // If a definition does exist, locate the value's result group and add the
357 // use. The result groups are ordered by increasing start index, so we just
358 // need to find the last group that has a smaller/equal start index.
359 unsigned resultNo = result.getResultNumber();
360 OperationDefinition &def = *impl->operations[existingIt->second];
361 for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
362 if (resultNo >= resultGroup.startIndex) {
363 for (SMLoc loc : locations)
364 resultGroup.definition.uses.push_back(convertIdLocToRange(loc));
365 return;
368 llvm_unreachable("expected valid result group for value use");
371 // Otherwise, this is a block argument.
372 BlockArgument arg = cast<BlockArgument>(value);
373 auto existingIt = impl->blocksToIdx.find(arg.getOwner());
374 assert(existingIt != impl->blocksToIdx.end() &&
375 "expected valid block definition for block argument");
376 BlockDefinition &blockDef = *impl->blocks[existingIt->second];
377 SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
378 for (SMLoc loc : locations)
379 argDef.uses.emplace_back(convertIdLocToRange(loc));
382 void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) {
383 auto it = impl->blocksToIdx.find(block);
384 if (it == impl->blocksToIdx.end()) {
385 it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first;
386 impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
389 BlockDefinition &def = *impl->blocks[it->second];
390 for (SMLoc loc : locations)
391 def.definition.uses.push_back(convertIdLocToRange(loc));
394 void AsmParserState::addUses(SymbolRefAttr refAttr,
395 ArrayRef<SMRange> locations) {
396 // Ignore this symbol if no scopes are active.
397 if (impl->symbolUseScopes.empty())
398 return;
400 assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
401 "expected the same number of references as provided locations");
402 (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
403 locations.end());
406 void AsmParserState::addAttrAliasUses(StringRef name, SMRange location) {
407 auto it = impl->attrAliasToIdx.find(name);
408 // Location aliases may be referenced before they are defined.
409 if (it == impl->attrAliasToIdx.end()) {
410 it = impl->attrAliasToIdx.try_emplace(name, impl->attrAliases.size()).first;
411 impl->attrAliases.push_back(
412 std::make_unique<AttributeAliasDefinition>(name));
414 AttributeAliasDefinition &def = *impl->attrAliases[it->second];
415 def.definition.uses.push_back(location);
418 void AsmParserState::addTypeAliasUses(StringRef name, SMRange location) {
419 auto it = impl->typeAliasToIdx.find(name);
420 // Location aliases may be referenced before they are defined.
421 assert(it != impl->typeAliasToIdx.end() &&
422 "expected valid type alias definition");
423 TypeAliasDefinition &def = *impl->typeAliases[it->second];
424 def.definition.uses.push_back(location);
427 void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
428 auto it = impl->placeholderValueUses.find(oldValue);
429 assert(it != impl->placeholderValueUses.end() &&
430 "expected `oldValue` to be a placeholder");
431 addUses(newValue, it->second);
432 impl->placeholderValueUses.erase(oldValue);