[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / AsmParser / Lexer.cpp
blobb4189181a849590a3568eb2add6dddf9da9a9435
1 //===- Lexer.cpp - MLIR Lexer Implementation ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lexer for the MLIR textual form.
11 //===----------------------------------------------------------------------===//
13 #include "Lexer.h"
14 #include "Token.h"
15 #include "mlir/AsmParser/CodeComplete.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/Location.h"
18 #include "mlir/IR/MLIRContext.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSwitch.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include <cassert>
26 #include <cctype>
28 using namespace mlir;
30 // Returns true if 'c' is an allowable punctuation character: [$._-]
31 // Returns false otherwise.
32 static bool isPunct(char c) {
33 return c == '$' || c == '.' || c == '_' || c == '-';
36 Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
37 AsmParserCodeCompleteContext *codeCompleteContext)
38 : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) {
39 auto bufferID = sourceMgr.getMainFileID();
40 curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
41 curPtr = curBuffer.begin();
43 // Set the code completion location if it was provided.
44 if (codeCompleteContext)
45 codeCompleteLoc = codeCompleteContext->getCodeCompleteLoc().getPointer();
48 /// Encode the specified source location information into an attribute for
49 /// attachment to the IR.
50 Location Lexer::getEncodedSourceLocation(SMLoc loc) {
51 auto &sourceMgr = getSourceMgr();
52 unsigned mainFileID = sourceMgr.getMainFileID();
54 // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
55 // use it here.
56 auto &bufferInfo = sourceMgr.getBufferInfo(mainFileID);
57 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
58 unsigned column =
59 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
60 auto *buffer = sourceMgr.getMemoryBuffer(mainFileID);
62 return FileLineColLoc::get(context, buffer->getBufferIdentifier(), lineNo,
63 column);
66 /// emitError - Emit an error message and return an Token::error token.
67 Token Lexer::emitError(const char *loc, const Twine &message) {
68 mlir::emitError(getEncodedSourceLocation(SMLoc::getFromPointer(loc)),
69 message);
70 return formToken(Token::error, loc);
73 Token Lexer::lexToken() {
74 while (true) {
75 const char *tokStart = curPtr;
77 // Check to see if the current token is at the code completion location.
78 if (tokStart == codeCompleteLoc)
79 return formToken(Token::code_complete, tokStart);
81 // Lex the next token.
82 switch (*curPtr++) {
83 default:
84 // Handle bare identifiers.
85 if (isalpha(curPtr[-1]))
86 return lexBareIdentifierOrKeyword(tokStart);
88 // Unknown character, emit an error.
89 return emitError(tokStart, "unexpected character");
91 case ' ':
92 case '\t':
93 case '\n':
94 case '\r':
95 // Handle whitespace.
96 continue;
98 case '_':
99 // Handle bare identifiers.
100 return lexBareIdentifierOrKeyword(tokStart);
102 case 0:
103 // This may either be a nul character in the source file or may be the EOF
104 // marker that llvm::MemoryBuffer guarantees will be there.
105 if (curPtr - 1 == curBuffer.end())
106 return formToken(Token::eof, tokStart);
107 continue;
109 case ':':
110 return formToken(Token::colon, tokStart);
111 case ',':
112 return formToken(Token::comma, tokStart);
113 case '.':
114 return lexEllipsis(tokStart);
115 case '(':
116 return formToken(Token::l_paren, tokStart);
117 case ')':
118 return formToken(Token::r_paren, tokStart);
119 case '{':
120 if (*curPtr == '-' && *(curPtr + 1) == '#') {
121 curPtr += 2;
122 return formToken(Token::file_metadata_begin, tokStart);
124 return formToken(Token::l_brace, tokStart);
125 case '}':
126 return formToken(Token::r_brace, tokStart);
127 case '[':
128 return formToken(Token::l_square, tokStart);
129 case ']':
130 return formToken(Token::r_square, tokStart);
131 case '<':
132 return formToken(Token::less, tokStart);
133 case '>':
134 return formToken(Token::greater, tokStart);
135 case '=':
136 return formToken(Token::equal, tokStart);
138 case '+':
139 return formToken(Token::plus, tokStart);
140 case '*':
141 return formToken(Token::star, tokStart);
142 case '-':
143 if (*curPtr == '>') {
144 ++curPtr;
145 return formToken(Token::arrow, tokStart);
147 return formToken(Token::minus, tokStart);
149 case '?':
150 return formToken(Token::question, tokStart);
152 case '|':
153 return formToken(Token::vertical_bar, tokStart);
155 case '/':
156 if (*curPtr == '/') {
157 skipComment();
158 continue;
160 return emitError(tokStart, "unexpected character");
162 case '@':
163 return lexAtIdentifier(tokStart);
165 case '#':
166 if (*curPtr == '-' && *(curPtr + 1) == '}') {
167 curPtr += 2;
168 return formToken(Token::file_metadata_end, tokStart);
170 [[fallthrough]];
171 case '!':
172 case '^':
173 case '%':
174 return lexPrefixedIdentifier(tokStart);
175 case '"':
176 return lexString(tokStart);
178 case '0':
179 case '1':
180 case '2':
181 case '3':
182 case '4':
183 case '5':
184 case '6':
185 case '7':
186 case '8':
187 case '9':
188 return lexNumber(tokStart);
193 /// Lex an '@foo' identifier.
195 /// symbol-ref-id ::= `@` (bare-id | string-literal)
197 Token Lexer::lexAtIdentifier(const char *tokStart) {
198 char cur = *curPtr++;
200 // Try to parse a string literal, if present.
201 if (cur == '"') {
202 Token stringIdentifier = lexString(curPtr);
203 if (stringIdentifier.is(Token::error))
204 return stringIdentifier;
205 return formToken(Token::at_identifier, tokStart);
208 // Otherwise, these always start with a letter or underscore.
209 if (!isalpha(cur) && cur != '_')
210 return emitError(curPtr - 1,
211 "@ identifier expected to start with letter or '_'");
213 while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
214 *curPtr == '$' || *curPtr == '.')
215 ++curPtr;
216 return formToken(Token::at_identifier, tokStart);
219 /// Lex a bare identifier or keyword that starts with a letter.
221 /// bare-id ::= (letter|[_]) (letter|digit|[_$.])*
222 /// integer-type ::= `[su]?i[1-9][0-9]*`
224 Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
225 // Match the rest of the identifier regex: [0-9a-zA-Z_.$]*
226 while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' ||
227 *curPtr == '$' || *curPtr == '.')
228 ++curPtr;
230 // Check to see if this identifier is a keyword.
231 StringRef spelling(tokStart, curPtr - tokStart);
233 auto isAllDigit = [](StringRef str) {
234 return llvm::all_of(str, llvm::isDigit);
237 // Check for i123, si456, ui789.
238 if ((spelling.size() > 1 && tokStart[0] == 'i' &&
239 isAllDigit(spelling.drop_front())) ||
240 ((spelling.size() > 2 && tokStart[1] == 'i' &&
241 (tokStart[0] == 's' || tokStart[0] == 'u')) &&
242 isAllDigit(spelling.drop_front(2))))
243 return Token(Token::inttype, spelling);
245 Token::Kind kind = StringSwitch<Token::Kind>(spelling)
246 #define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING)
247 #include "TokenKinds.def"
248 .Default(Token::bare_identifier);
250 return Token(kind, spelling);
253 /// Skip a comment line, starting with a '//'.
255 /// TODO: add a regex for comments here and to the spec.
257 void Lexer::skipComment() {
258 // Advance over the second '/' in a '//' comment.
259 assert(*curPtr == '/');
260 ++curPtr;
262 while (true) {
263 switch (*curPtr++) {
264 case '\n':
265 case '\r':
266 // Newline is end of comment.
267 return;
268 case 0:
269 // If this is the end of the buffer, end the comment.
270 if (curPtr - 1 == curBuffer.end()) {
271 --curPtr;
272 return;
274 [[fallthrough]];
275 default:
276 // Skip over other characters.
277 break;
282 /// Lex an ellipsis.
284 /// ellipsis ::= '...'
286 Token Lexer::lexEllipsis(const char *tokStart) {
287 assert(curPtr[-1] == '.');
289 if (curPtr == curBuffer.end() || *curPtr != '.' || *(curPtr + 1) != '.')
290 return emitError(curPtr, "expected three consecutive dots for an ellipsis");
292 curPtr += 2;
293 return formToken(Token::ellipsis, tokStart);
296 /// Lex a number literal.
298 /// integer-literal ::= digit+ | `0x` hex_digit+
299 /// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
301 Token Lexer::lexNumber(const char *tokStart) {
302 assert(isdigit(curPtr[-1]));
304 // Handle the hexadecimal case.
305 if (curPtr[-1] == '0' && *curPtr == 'x') {
306 // If we see stuff like 0xi32, this is a literal `0` followed by an
307 // identifier `xi32`, stop after `0`.
308 if (!isxdigit(curPtr[1]))
309 return formToken(Token::integer, tokStart);
311 curPtr += 2;
312 while (isxdigit(*curPtr))
313 ++curPtr;
315 return formToken(Token::integer, tokStart);
318 // Handle the normal decimal case.
319 while (isdigit(*curPtr))
320 ++curPtr;
322 if (*curPtr != '.')
323 return formToken(Token::integer, tokStart);
324 ++curPtr;
326 // Skip over [0-9]*([eE][-+]?[0-9]+)?
327 while (isdigit(*curPtr))
328 ++curPtr;
330 if (*curPtr == 'e' || *curPtr == 'E') {
331 if (isdigit(static_cast<unsigned char>(curPtr[1])) ||
332 ((curPtr[1] == '-' || curPtr[1] == '+') &&
333 isdigit(static_cast<unsigned char>(curPtr[2])))) {
334 curPtr += 2;
335 while (isdigit(*curPtr))
336 ++curPtr;
339 return formToken(Token::floatliteral, tokStart);
342 /// Lex an identifier that starts with a prefix followed by suffix-id.
344 /// attribute-id ::= `#` suffix-id
345 /// ssa-id ::= '%' suffix-id
346 /// block-id ::= '^' suffix-id
347 /// type-id ::= '!' suffix-id
348 /// suffix-id ::= digit+ | (letter|id-punct) (letter|id-punct|digit)*
349 /// id-punct ::= `$` | `.` | `_` | `-`
351 Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
352 Token::Kind kind;
353 StringRef errorKind;
354 switch (*tokStart) {
355 case '#':
356 kind = Token::hash_identifier;
357 errorKind = "invalid attribute name";
358 break;
359 case '%':
360 kind = Token::percent_identifier;
361 errorKind = "invalid SSA name";
362 break;
363 case '^':
364 kind = Token::caret_identifier;
365 errorKind = "invalid block name";
366 break;
367 case '!':
368 kind = Token::exclamation_identifier;
369 errorKind = "invalid type identifier";
370 break;
371 default:
372 llvm_unreachable("invalid caller");
375 // Parse suffix-id.
376 if (isdigit(*curPtr)) {
377 // If suffix-id starts with a digit, the rest must be digits.
378 while (isdigit(*curPtr))
379 ++curPtr;
380 } else if (isalpha(*curPtr) || isPunct(*curPtr)) {
381 do {
382 ++curPtr;
383 } while (isalpha(*curPtr) || isdigit(*curPtr) || isPunct(*curPtr));
384 } else if (curPtr == codeCompleteLoc) {
385 return formToken(Token::code_complete, tokStart);
386 } else {
387 return emitError(curPtr - 1, errorKind);
390 // Check for a code completion within the identifier.
391 if (codeCompleteLoc && codeCompleteLoc >= tokStart &&
392 codeCompleteLoc <= curPtr) {
393 return Token(Token::code_complete,
394 StringRef(tokStart, codeCompleteLoc - tokStart));
397 return formToken(kind, tokStart);
400 /// Lex a string literal.
402 /// string-literal ::= '"' [^"\n\f\v\r]* '"'
404 /// TODO: define escaping rules.
405 Token Lexer::lexString(const char *tokStart) {
406 assert(curPtr[-1] == '"');
408 while (true) {
409 // Check to see if there is a code completion location within the string. In
410 // these cases we generate a completion location and place the currently
411 // lexed string within the token. This allows for the parser to use the
412 // partially lexed string when computing the completion results.
413 if (curPtr == codeCompleteLoc)
414 return formToken(Token::code_complete, tokStart);
416 switch (*curPtr++) {
417 case '"':
418 return formToken(Token::string, tokStart);
419 case 0:
420 // If this is a random nul character in the middle of a string, just
421 // include it. If it is the end of file, then it is an error.
422 if (curPtr - 1 != curBuffer.end())
423 continue;
424 [[fallthrough]];
425 case '\n':
426 case '\v':
427 case '\f':
428 return emitError(curPtr - 1, "expected '\"' in string literal");
429 case '\\':
430 // Handle explicitly a few escapes.
431 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
432 ++curPtr;
433 else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
434 // Support \xx for two hex digits.
435 curPtr += 2;
436 else
437 return emitError(curPtr - 1, "unknown escape in string literal");
438 continue;
440 default:
441 continue;