[mlir] Use StringRef::{starts,ends}_with (NFC)
[llvm-project.git] / mlir / lib / Tools / PDLL / Parser / Parser.cpp
blobcfbc4e4536fe8c454c2a1478c5619e7e8d712fcc
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
10 #include "Lexer.h"
11 #include "mlir/Support/IndentedOstream.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/Argument.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Constraint.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/Tools/PDLL/AST/Context.h"
19 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
20 #include "mlir/Tools/PDLL/AST/Nodes.h"
21 #include "mlir/Tools/PDLL/AST/Types.h"
22 #include "mlir/Tools/PDLL/ODS/Constraint.h"
23 #include "mlir/Tools/PDLL/ODS/Context.h"
24 #include "mlir/Tools/PDLL/ODS/Operation.h"
25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
34 #include <string>
35 #include <optional>
37 using namespace mlir;
38 using namespace mlir::pdll;
40 //===----------------------------------------------------------------------===//
41 // Parser
42 //===----------------------------------------------------------------------===//
44 namespace {
45 class Parser {
46 public:
47 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
48 bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
49 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
50 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
51 typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
52 typeRangeTy(ast::TypeRangeType::get(ctx)),
53 valueRangeTy(ast::ValueRangeType::get(ctx)),
54 attrTy(ast::AttributeType::get(ctx)),
55 codeCompleteContext(codeCompleteContext) {}
57 /// Try to parse a new module. Returns nullptr in the case of failure.
58 FailureOr<ast::Module *> parseModule();
60 private:
61 /// The current context of the parser. It allows for the parser to know a bit
62 /// about the construct it is nested within during parsing. This is used
63 /// specifically to provide additional verification during parsing, e.g. to
64 /// prevent using rewrites within a match context, matcher constraints within
65 /// a rewrite section, etc.
66 enum class ParserContext {
67 /// The parser is in the global context.
68 Global,
69 /// The parser is currently within a Constraint, which disallows all types
70 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71 Constraint,
72 /// The parser is currently within the matcher portion of a Pattern, which
73 /// is allows a terminal operation rewrite statement but no other rewrite
74 /// transformations.
75 PatternMatch,
76 /// The parser is currently within a Rewrite, which disallows calls to
77 /// constraints, requires operation expressions to have names, etc.
78 Rewrite,
81 /// The current specification context of an operations result type. This
82 /// indicates how the result types of an operation may be inferred.
83 enum class OpResultTypeContext {
84 /// The result types of the operation are not known to be inferred.
85 Explicit,
86 /// The result types of the operation are inferred from the root input of a
87 /// `replace` statement.
88 Replacement,
89 /// The result types of the operation are inferred by using the
90 /// `InferTypeOpInterface` interface provided by the operation.
91 Interface,
94 //===--------------------------------------------------------------------===//
95 // Parsing
96 //===--------------------------------------------------------------------===//
98 /// Push a new decl scope onto the lexer.
99 ast::DeclScope *pushDeclScope() {
100 ast::DeclScope *newScope =
101 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102 return (curDeclScope = newScope);
104 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
106 /// Pop the last decl scope from the lexer.
107 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
109 /// Parse the body of an AST module.
110 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
112 /// Try to convert the given expression to `type`. Returns failure and emits
113 /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114 /// invoked to attach notes to the emitted error diagnostic. On success,
115 /// `expr` is updated to the expression used to convert to `type`.
116 LogicalResult convertExpressionTo(
117 ast::Expr *&expr, ast::Type type,
118 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
119 LogicalResult
120 convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
121 ast::Type type,
122 function_ref<ast::InFlightDiagnostic()> emitErrorFn);
123 LogicalResult convertTupleExpressionTo(
124 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
125 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
126 function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
128 /// Given an operation expression, convert it to a Value or ValueRange
129 /// typed expression.
130 ast::Expr *convertOpToValue(const ast::Expr *opExpr);
132 /// Lookup ODS information for the given operation, returns nullptr if no
133 /// information is found.
134 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
135 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
138 /// Process the given documentation string, or return an empty string if
139 /// documentation isn't enabled.
140 StringRef processDoc(StringRef doc) {
141 return enableDocumentation ? doc : StringRef();
144 /// Process the given documentation string and format it, or return an empty
145 /// string if documentation isn't enabled.
146 std::string processAndFormatDoc(const Twine &doc) {
147 if (!enableDocumentation)
148 return "";
149 std::string docStr;
151 llvm::raw_string_ostream docOS(docStr);
152 std::string tmpDocStr = doc.str();
153 raw_indented_ostream(docOS).printReindented(
154 StringRef(tmpDocStr).rtrim(" \t"));
156 return docStr;
159 //===--------------------------------------------------------------------===//
160 // Directives
162 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
163 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
164 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
165 SmallVectorImpl<ast::Decl *> &decls);
167 /// Process the records of a parsed tablegen include file.
168 void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
169 SmallVectorImpl<ast::Decl *> &decls);
171 /// Create a user defined native constraint for a constraint imported from
172 /// ODS.
173 template <typename ConstraintT>
174 ast::Decl *
175 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
176 SMRange loc, ast::Type type,
177 StringRef nativeType, StringRef docString);
178 template <typename ConstraintT>
179 ast::Decl *
180 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
181 SMRange loc, ast::Type type,
182 StringRef nativeType);
184 //===--------------------------------------------------------------------===//
185 // Decls
187 /// This structure contains the set of pattern metadata that may be parsed.
188 struct ParsedPatternMetadata {
189 std::optional<uint16_t> benefit;
190 bool hasBoundedRecursion = false;
193 FailureOr<ast::Decl *> parseTopLevelDecl();
194 FailureOr<ast::NamedAttributeDecl *>
195 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
197 /// Parse an argument variable as part of the signature of a
198 /// UserConstraintDecl or UserRewriteDecl.
199 FailureOr<ast::VariableDecl *> parseArgumentDecl();
201 /// Parse a result variable as part of the signature of a UserConstraintDecl
202 /// or UserRewriteDecl.
203 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
205 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
206 /// defined in a non-global context.
207 FailureOr<ast::UserConstraintDecl *>
208 parseUserConstraintDecl(bool isInline = false);
210 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
211 /// non-global context, such as within a Pattern/Constraint/etc.
212 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
214 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
215 /// PDLL constructs.
216 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
217 const ast::Name &name, bool isInline,
218 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
219 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
221 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
222 /// defined in a non-global context.
223 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
225 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
226 /// non-global context, such as within a Pattern/Rewrite/etc.
227 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
229 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
230 /// PDLL constructs.
231 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
232 const ast::Name &name, bool isInline,
233 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
234 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
236 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
237 /// effectively the same syntax, and only differ on slight semantics (given
238 /// the different parsing contexts).
239 template <typename T, typename ParseUserPDLLDeclFnT>
240 FailureOr<T *> parseUserConstraintOrRewriteDecl(
241 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
242 StringRef anonymousNamePrefix, bool isInline);
244 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
245 /// These decls have effectively the same syntax.
246 template <typename T>
247 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
248 const ast::Name &name, bool isInline,
249 ArrayRef<ast::VariableDecl *> arguments,
250 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
252 /// Parse the functional signature (i.e. the arguments and results) of a
253 /// UserConstraintDecl or UserRewriteDecl.
254 LogicalResult parseUserConstraintOrRewriteSignature(
255 SmallVectorImpl<ast::VariableDecl *> &arguments,
256 SmallVectorImpl<ast::VariableDecl *> &results,
257 ast::DeclScope *&argumentScope, ast::Type &resultType);
259 /// Validate the return (which if present is specified by bodyIt) of a
260 /// UserConstraintDecl or UserRewriteDecl.
261 LogicalResult validateUserConstraintOrRewriteReturn(
262 StringRef declType, ast::CompoundStmt *body,
263 ArrayRef<ast::Stmt *>::iterator bodyIt,
264 ArrayRef<ast::Stmt *>::iterator bodyE,
265 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
267 FailureOr<ast::CompoundStmt *>
268 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
269 bool expectTerminalSemicolon = true);
270 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
271 FailureOr<ast::Decl *> parsePatternDecl();
272 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
274 /// Check to see if a decl has already been defined with the given name, if
275 /// one has emit and error and return failure. Returns success otherwise.
276 LogicalResult checkDefineNamedDecl(const ast::Name &name);
278 /// Try to define a variable decl with the given components, returns the
279 /// variable on success.
280 FailureOr<ast::VariableDecl *>
281 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
282 ast::Expr *initExpr,
283 ArrayRef<ast::ConstraintRef> constraints);
284 FailureOr<ast::VariableDecl *>
285 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
286 ArrayRef<ast::ConstraintRef> constraints);
288 /// Parse the constraint reference list for a variable decl.
289 LogicalResult parseVariableDeclConstraintList(
290 SmallVectorImpl<ast::ConstraintRef> &constraints);
292 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
293 FailureOr<ast::Expr *> parseTypeConstraintExpr();
295 /// Try to parse a single reference to a constraint. `typeConstraint` is the
296 /// location of a previously parsed type constraint for the entity that will
297 /// be constrained by the parsed constraint. `existingConstraints` are any
298 /// existing constraints that have already been parsed for the same entity
299 /// that will be constrained by this constraint. `allowInlineTypeConstraints`
300 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
301 FailureOr<ast::ConstraintRef>
302 parseConstraint(std::optional<SMRange> &typeConstraint,
303 ArrayRef<ast::ConstraintRef> existingConstraints,
304 bool allowInlineTypeConstraints);
306 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
307 /// argument or result variable. The constraints for these variables do not
308 /// allow inline type constraints, and only permit a single constraint.
309 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
311 //===--------------------------------------------------------------------===//
312 // Exprs
314 FailureOr<ast::Expr *> parseExpr();
316 /// Identifier expressions.
317 FailureOr<ast::Expr *> parseAttributeExpr();
318 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
319 bool isNegated = false);
320 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
321 FailureOr<ast::Expr *> parseIdentifierExpr();
322 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
323 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
324 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
325 FailureOr<ast::Expr *> parseNegatedExpr();
326 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
327 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
328 FailureOr<ast::Expr *>
329 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
330 OpResultTypeContext::Explicit);
331 FailureOr<ast::Expr *> parseTupleExpr();
332 FailureOr<ast::Expr *> parseTypeExpr();
333 FailureOr<ast::Expr *> parseUnderscoreExpr();
335 //===--------------------------------------------------------------------===//
336 // Stmts
338 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
339 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
340 FailureOr<ast::EraseStmt *> parseEraseStmt();
341 FailureOr<ast::LetStmt *> parseLetStmt();
342 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
343 FailureOr<ast::ReturnStmt *> parseReturnStmt();
344 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
346 //===--------------------------------------------------------------------===//
347 // Creation+Analysis
348 //===--------------------------------------------------------------------===//
350 //===--------------------------------------------------------------------===//
351 // Decls
353 /// Try to extract a callable from the given AST node. Returns nullptr on
354 /// failure.
355 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
357 /// Try to create a pattern decl with the given components, returning the
358 /// Pattern on success.
359 FailureOr<ast::PatternDecl *>
360 createPatternDecl(SMRange loc, const ast::Name *name,
361 const ParsedPatternMetadata &metadata,
362 ast::CompoundStmt *body);
364 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
365 /// of results, defined as part of the signature.
366 ast::Type
367 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
369 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
370 template <typename T>
371 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
372 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
373 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
374 ast::CompoundStmt *body);
376 /// Try to create a variable decl with the given components, returning the
377 /// Variable on success.
378 FailureOr<ast::VariableDecl *>
379 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
380 ArrayRef<ast::ConstraintRef> constraints);
382 /// Create a variable for an argument or result defined as part of the
383 /// signature of a UserConstraintDecl/UserRewriteDecl.
384 FailureOr<ast::VariableDecl *>
385 createArgOrResultVariableDecl(StringRef name, SMRange loc,
386 const ast::ConstraintRef &constraint);
388 /// Validate the constraints used to constraint a variable decl.
389 /// `inferredType` is the type of the variable inferred by the constraints
390 /// within the list, and is updated to the most refined type as determined by
391 /// the constraints. Returns success if the constraint list is valid, failure
392 /// otherwise.
393 LogicalResult
394 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
395 ast::Type &inferredType);
396 /// Validate a single reference to a constraint. `inferredType` contains the
397 /// currently inferred variabled type and is refined within the type defined
398 /// by the constraint. Returns success if the constraint is valid, failure
399 /// otherwise.
400 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
401 ast::Type &inferredType);
402 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
403 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
405 //===--------------------------------------------------------------------===//
406 // Exprs
408 FailureOr<ast::CallExpr *>
409 createCallExpr(SMRange loc, ast::Expr *parentExpr,
410 MutableArrayRef<ast::Expr *> arguments,
411 bool isNegated = false);
412 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
413 FailureOr<ast::DeclRefExpr *>
414 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
415 ArrayRef<ast::ConstraintRef> constraints);
416 FailureOr<ast::MemberAccessExpr *>
417 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
419 /// Validate the member access `name` into the given parent expression. On
420 /// success, this also returns the type of the member accessed.
421 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
422 StringRef name, SMRange loc);
423 FailureOr<ast::OperationExpr *>
424 createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
425 OpResultTypeContext resultTypeContext,
426 SmallVectorImpl<ast::Expr *> &operands,
427 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
428 SmallVectorImpl<ast::Expr *> &results);
429 LogicalResult
430 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
431 const ods::Operation *odsOp,
432 SmallVectorImpl<ast::Expr *> &operands);
433 LogicalResult validateOperationResults(SMRange loc,
434 std::optional<StringRef> name,
435 const ods::Operation *odsOp,
436 SmallVectorImpl<ast::Expr *> &results);
437 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
438 const ods::Operation *odsOp);
439 LogicalResult validateOperationOperandsOrResults(
440 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
441 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
442 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
443 ast::RangeType rangeTy);
444 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
445 ArrayRef<ast::Expr *> elements,
446 ArrayRef<StringRef> elementNames);
448 //===--------------------------------------------------------------------===//
449 // Stmts
451 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
452 FailureOr<ast::ReplaceStmt *>
453 createReplaceStmt(SMRange loc, ast::Expr *rootOp,
454 MutableArrayRef<ast::Expr *> replValues);
455 FailureOr<ast::RewriteStmt *>
456 createRewriteStmt(SMRange loc, ast::Expr *rootOp,
457 ast::CompoundStmt *rewriteBody);
459 //===--------------------------------------------------------------------===//
460 // Code Completion
461 //===--------------------------------------------------------------------===//
463 /// The set of various code completion methods. Every completion method
464 /// returns `failure` to stop the parsing process after providing completion
465 /// results.
467 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
468 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
469 LogicalResult codeCompleteConstraintName(ast::Type inferredType,
470 bool allowInlineTypeConstraints);
471 LogicalResult codeCompleteDialectName();
472 LogicalResult codeCompleteOperationName(StringRef dialectName);
473 LogicalResult codeCompletePatternMetadata();
474 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
476 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
477 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
478 unsigned currentNumOperands);
479 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
480 unsigned currentNumResults);
482 //===--------------------------------------------------------------------===//
483 // Lexer Utilities
484 //===--------------------------------------------------------------------===//
486 /// If the current token has the specified kind, consume it and return true.
487 /// If not, return false.
488 bool consumeIf(Token::Kind kind) {
489 if (curToken.isNot(kind))
490 return false;
491 consumeToken(kind);
492 return true;
495 /// Advance the current lexer onto the next token.
496 void consumeToken() {
497 assert(curToken.isNot(Token::eof, Token::error) &&
498 "shouldn't advance past EOF or errors");
499 curToken = lexer.lexToken();
502 /// Advance the current lexer onto the next token, asserting what the expected
503 /// current token is. This is preferred to the above method because it leads
504 /// to more self-documenting code with better checking.
505 void consumeToken(Token::Kind kind) {
506 assert(curToken.is(kind) && "consumed an unexpected token");
507 consumeToken();
510 /// Reset the lexer to the location at the given position.
511 void resetToken(SMRange tokLoc) {
512 lexer.resetPointer(tokLoc.Start.getPointer());
513 curToken = lexer.lexToken();
516 /// Consume the specified token if present and return success. On failure,
517 /// output a diagnostic and return failure.
518 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
519 if (curToken.getKind() != kind)
520 return emitError(curToken.getLoc(), msg);
521 consumeToken();
522 return success();
524 LogicalResult emitError(SMRange loc, const Twine &msg) {
525 lexer.emitError(loc, msg);
526 return failure();
528 LogicalResult emitError(const Twine &msg) {
529 return emitError(curToken.getLoc(), msg);
531 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
532 const Twine &note) {
533 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
534 return failure();
537 //===--------------------------------------------------------------------===//
538 // Fields
539 //===--------------------------------------------------------------------===//
541 /// The owning AST context.
542 ast::Context &ctx;
544 /// The lexer of this parser.
545 Lexer lexer;
547 /// The current token within the lexer.
548 Token curToken;
550 /// A flag indicating if the parser should add documentation to AST nodes when
551 /// viable.
552 bool enableDocumentation;
554 /// The most recently defined decl scope.
555 ast::DeclScope *curDeclScope = nullptr;
556 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
558 /// The current context of the parser.
559 ParserContext parserContext = ParserContext::Global;
561 /// Cached types to simplify verification and expression creation.
562 ast::Type typeTy, valueTy;
563 ast::RangeType typeRangeTy, valueRangeTy;
564 ast::Type attrTy;
566 /// A counter used when naming anonymous constraints and rewrites.
567 unsigned anonymousDeclNameCounter = 0;
569 /// The optional code completion context.
570 CodeCompleteContext *codeCompleteContext;
572 } // namespace
574 FailureOr<ast::Module *> Parser::parseModule() {
575 SMLoc moduleLoc = curToken.getStartLoc();
576 pushDeclScope();
578 // Parse the top-level decls of the module.
579 SmallVector<ast::Decl *> decls;
580 if (failed(parseModuleBody(decls)))
581 return popDeclScope(), failure();
583 popDeclScope();
584 return ast::Module::create(ctx, moduleLoc, decls);
587 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
588 while (curToken.isNot(Token::eof)) {
589 if (curToken.is(Token::directive)) {
590 if (failed(parseDirective(decls)))
591 return failure();
592 continue;
595 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
596 if (failed(decl))
597 return failure();
598 decls.push_back(*decl);
600 return success();
603 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
604 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
605 valueRangeTy);
608 LogicalResult Parser::convertExpressionTo(
609 ast::Expr *&expr, ast::Type type,
610 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
611 ast::Type exprType = expr->getType();
612 if (exprType == type)
613 return success();
615 auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
616 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
617 expr->getLoc(), llvm::formatv("unable to convert expression of type "
618 "`{0}` to the expected type of "
619 "`{1}`",
620 exprType, type));
621 if (noteAttachFn)
622 noteAttachFn(*diag);
623 return diag;
626 if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
627 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
629 // FIXME: Decide how to allow/support converting a single result to multiple,
630 // and multiple to a single result. For now, we just allow Single->Range,
631 // but this isn't something really supported in the PDL dialect. We should
632 // figure out some way to support both.
633 if ((exprType == valueTy || exprType == valueRangeTy) &&
634 (type == valueTy || type == valueRangeTy))
635 return success();
636 if ((exprType == typeTy || exprType == typeRangeTy) &&
637 (type == typeTy || type == typeRangeTy))
638 return success();
640 // Handle tuple types.
641 if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
642 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
643 noteAttachFn);
645 return emitConvertError();
648 LogicalResult Parser::convertOpExpressionTo(
649 ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
650 function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
651 // Two operation types are compatible if they have the same name, or if the
652 // expected type is more general.
653 if (auto opType = type.dyn_cast<ast::OperationType>()) {
654 if (opType.getName())
655 return emitErrorFn();
656 return success();
659 // An operation can always convert to a ValueRange.
660 if (type == valueRangeTy) {
661 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
662 valueRangeTy);
663 return success();
666 // Allow conversion to a single value by constraining the result range.
667 if (type == valueTy) {
668 // If the operation is registered, we can verify if it can ever have a
669 // single result.
670 if (const ods::Operation *odsOp = exprType.getODSOperation()) {
671 if (odsOp->getResults().empty()) {
672 return emitErrorFn()->attachNote(
673 llvm::formatv("see the definition of `{0}`, which was defined "
674 "with zero results",
675 odsOp->getName()),
676 odsOp->getLoc());
679 unsigned numSingleResults = llvm::count_if(
680 odsOp->getResults(), [](const ods::OperandOrResult &result) {
681 return result.getVariableLengthKind() ==
682 ods::VariableLengthKind::Single;
684 if (numSingleResults > 1) {
685 return emitErrorFn()->attachNote(
686 llvm::formatv("see the definition of `{0}`, which was defined "
687 "with at least {1} results",
688 odsOp->getName(), numSingleResults),
689 odsOp->getLoc());
693 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
694 valueTy);
695 return success();
697 return emitErrorFn();
700 LogicalResult Parser::convertTupleExpressionTo(
701 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
702 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
703 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
704 // Handle conversions between tuples.
705 if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
706 if (tupleType.size() != exprType.size())
707 return emitErrorFn();
709 // Build a new tuple expression using each of the elements of the current
710 // tuple.
711 SmallVector<ast::Expr *> newExprs;
712 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
713 newExprs.push_back(ast::MemberAccessExpr::create(
714 ctx, expr->getLoc(), expr, llvm::to_string(i),
715 exprType.getElementTypes()[i]));
717 auto diagFn = [&](ast::Diagnostic &diag) {
718 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
719 i, exprType));
720 if (noteAttachFn)
721 noteAttachFn(diag);
723 if (failed(convertExpressionTo(newExprs.back(),
724 tupleType.getElementTypes()[i], diagFn)))
725 return failure();
727 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
728 tupleType.getElementNames());
729 return success();
732 // Handle conversion to a range.
733 auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
734 ast::RangeType resultTy) -> LogicalResult {
735 // TODO: We currently only allow range conversion within a rewrite context.
736 if (parserContext != ParserContext::Rewrite) {
737 return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
738 "only allowed within a rewrite context");
741 // All of the tuple elements must be allowed types.
742 for (ast::Type elementType : exprType.getElementTypes())
743 if (!llvm::is_contained(allowedElementTypes, elementType))
744 return emitErrorFn();
746 // Build a new tuple expression using each of the elements of the current
747 // tuple.
748 SmallVector<ast::Expr *> newExprs;
749 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
750 newExprs.push_back(ast::MemberAccessExpr::create(
751 ctx, expr->getLoc(), expr, llvm::to_string(i),
752 exprType.getElementTypes()[i]));
754 expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
755 return success();
757 if (type == valueRangeTy)
758 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
759 if (type == typeRangeTy)
760 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
762 return emitErrorFn();
765 //===----------------------------------------------------------------------===//
766 // Directives
768 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
769 StringRef directive = curToken.getSpelling();
770 if (directive == "#include")
771 return parseInclude(decls);
773 return emitError("unknown directive `" + directive + "`");
776 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
777 SMRange loc = curToken.getLoc();
778 consumeToken(Token::directive);
780 // Handle code completion of the include file path.
781 if (curToken.is(Token::code_complete_string))
782 return codeCompleteIncludeFilename(curToken.getStringValue());
784 // Parse the file being included.
785 if (!curToken.isString())
786 return emitError(loc,
787 "expected string file name after `include` directive");
788 SMRange fileLoc = curToken.getLoc();
789 std::string filenameStr = curToken.getStringValue();
790 StringRef filename = filenameStr;
791 consumeToken();
793 // Check the type of include. If ending with `.pdll`, this is another pdl file
794 // to be parsed along with the current module.
795 if (filename.ends_with(".pdll")) {
796 if (failed(lexer.pushInclude(filename, fileLoc)))
797 return emitError(fileLoc,
798 "unable to open include file `" + filename + "`");
800 // If we added the include successfully, parse it into the current module.
801 // Make sure to update to the next token after we finish parsing the nested
802 // file.
803 curToken = lexer.lexToken();
804 LogicalResult result = parseModuleBody(decls);
805 curToken = lexer.lexToken();
806 return result;
809 // Otherwise, this must be a `.td` include.
810 if (filename.ends_with(".td"))
811 return parseTdInclude(filename, fileLoc, decls);
813 return emitError(fileLoc,
814 "expected include filename to end with `.pdll` or `.td`");
817 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818 SmallVectorImpl<ast::Decl *> &decls) {
819 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
821 // Use the source manager to open the file, but don't yet add it.
822 std::string includedFile;
823 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
824 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
825 if (!includeBuffer)
826 return emitError(fileLoc, "unable to open include file `" + filename + "`");
828 // Setup the source manager for parsing the tablegen file.
829 llvm::SourceMgr tdSrcMgr;
830 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
831 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
833 // This class provides a context argument for the llvm::SourceMgr diagnostic
834 // handler.
835 struct DiagHandlerContext {
836 Parser &parser;
837 StringRef filename;
838 llvm::SMRange loc;
839 } handlerContext{*this, filename, fileLoc};
841 // Set the diagnostic handler for the tablegen source manager.
842 tdSrcMgr.setDiagHandler(
843 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
844 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
845 (void)ctx->parser.emitError(
846 ctx->loc,
847 llvm::formatv("error while processing include file `{0}`: {1}",
848 ctx->filename, diag.getMessage()));
850 &handlerContext);
852 // Parse the tablegen file.
853 llvm::RecordKeeper tdRecords;
854 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
855 return failure();
857 // Process the parsed records.
858 processTdIncludeRecords(tdRecords, decls);
860 // After we are done processing, move all of the tablegen source buffers to
861 // the main parser source mgr. This allows for directly using source locations
862 // from the .td files without needing to remap them.
863 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
864 return success();
867 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
868 SmallVectorImpl<ast::Decl *> &decls) {
869 // Return the length kind of the given value.
870 auto getLengthKind = [](const auto &value) {
871 if (value.isOptional())
872 return ods::VariableLengthKind::Optional;
873 return value.isVariadic() ? ods::VariableLengthKind::Variadic
874 : ods::VariableLengthKind::Single;
877 // Insert a type constraint into the ODS context.
878 ods::Context &odsContext = ctx.getODSContext();
879 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
880 -> const ods::TypeConstraint & {
881 return odsContext.insertTypeConstraint(
882 cst.constraint.getUniqueDefName(),
883 processDoc(cst.constraint.getSummary()),
884 cst.constraint.getCPPClassName());
886 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
887 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
890 // Process the parsed tablegen records to build ODS information.
891 /// Operations.
892 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
893 tblgen::Operator op(def);
895 // Check to see if this operation is known to support type inferrence.
896 bool supportsResultTypeInferrence =
897 op.getTrait("::mlir::InferTypeOpInterface::Trait");
899 auto [odsOp, inserted] = odsContext.insertOperation(
900 op.getOperationName(), processDoc(op.getSummary()),
901 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
902 supportsResultTypeInferrence, op.getLoc().front());
904 // Ignore operations that have already been added.
905 if (!inserted)
906 continue;
908 for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
909 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
910 odsContext.insertAttributeConstraint(
911 attr.attr.getUniqueDefName(),
912 processDoc(attr.attr.getSummary()),
913 attr.attr.getStorageType()));
915 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
916 odsOp->appendOperand(operand.name, getLengthKind(operand),
917 addTypeConstraint(operand));
919 for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
920 odsOp->appendResult(result.name, getLengthKind(result),
921 addTypeConstraint(result));
925 auto shouldBeSkipped = [this](llvm::Record *def) {
926 return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
927 def->isSubClassOf("DeclareInterfaceMethods");
930 /// Attr constraints.
931 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
932 if (shouldBeSkipped(def))
933 continue;
935 tblgen::Attribute constraint(def);
936 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
937 constraint, convertLocToRange(def->getLoc().front()), attrTy,
938 constraint.getStorageType()));
940 /// Type constraints.
941 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
942 if (shouldBeSkipped(def))
943 continue;
945 tblgen::TypeConstraint constraint(def);
946 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
947 constraint, convertLocToRange(def->getLoc().front()), typeTy,
948 constraint.getCPPClassName()));
950 /// OpInterfaces.
951 ast::Type opTy = ast::OperationType::get(ctx);
952 for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
953 if (shouldBeSkipped(def))
954 continue;
956 SMRange loc = convertLocToRange(def->getLoc().front());
958 std::string cppClassName =
959 llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
960 def->getValueAsString("cppInterfaceName"))
961 .str();
962 std::string codeBlock =
963 llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
964 cppClassName)
965 .str();
967 std::string desc =
968 processAndFormatDoc(def->getValueAsString("description"));
969 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
970 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
974 template <typename ConstraintT>
975 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
976 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
977 StringRef nativeType, StringRef docString) {
978 // Build the single input parameter.
979 ast::DeclScope *argScope = pushDeclScope();
980 auto *paramVar = ast::VariableDecl::create(
981 ctx, ast::Name::create(ctx, "self", loc), type,
982 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
983 argScope->add(paramVar);
984 popDeclScope();
986 // Build the native constraint.
987 auto *constraintDecl = ast::UserConstraintDecl::createNative(
988 ctx, ast::Name::create(ctx, name, loc), paramVar,
989 /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
990 nativeType);
991 constraintDecl->setDocComment(ctx, docString);
992 curDeclScope->add(constraintDecl);
993 return constraintDecl;
996 template <typename ConstraintT>
997 ast::Decl *
998 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
999 SMRange loc, ast::Type type,
1000 StringRef nativeType) {
1001 // Format the condition template.
1002 tblgen::FmtContext fmtContext;
1003 fmtContext.withSelf("self");
1004 std::string codeBlock = tblgen::tgfmt(
1005 "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1006 &fmtContext);
1008 // If documentation was enabled, build the doc string for the generated
1009 // constraint. It would be nice to do this lazily, but TableGen information is
1010 // destroyed after we finish parsing the file.
1011 std::string docString;
1012 if (enableDocumentation) {
1013 StringRef desc = constraint.getDescription();
1014 docString = processAndFormatDoc(
1015 constraint.getSummary() +
1016 (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1019 return createODSNativePDLLConstraintDecl<ConstraintT>(
1020 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1021 docString);
1024 //===----------------------------------------------------------------------===//
1025 // Decls
1027 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1028 FailureOr<ast::Decl *> decl;
1029 switch (curToken.getKind()) {
1030 case Token::kw_Constraint:
1031 decl = parseUserConstraintDecl();
1032 break;
1033 case Token::kw_Pattern:
1034 decl = parsePatternDecl();
1035 break;
1036 case Token::kw_Rewrite:
1037 decl = parseUserRewriteDecl();
1038 break;
1039 default:
1040 return emitError("expected top-level declaration, such as a `Pattern`");
1042 if (failed(decl))
1043 return failure();
1045 // If the decl has a name, add it to the current scope.
1046 if (const ast::Name *name = (*decl)->getName()) {
1047 if (failed(checkDefineNamedDecl(*name)))
1048 return failure();
1049 curDeclScope->add(*decl);
1051 return decl;
1054 FailureOr<ast::NamedAttributeDecl *>
1055 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1056 // Check for name code completion.
1057 if (curToken.is(Token::code_complete))
1058 return codeCompleteAttributeName(parentOpName);
1060 std::string attrNameStr;
1061 if (curToken.isString())
1062 attrNameStr = curToken.getStringValue();
1063 else if (curToken.is(Token::identifier) || curToken.isKeyword())
1064 attrNameStr = curToken.getSpelling().str();
1065 else
1066 return emitError("expected identifier or string attribute name");
1067 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1068 consumeToken();
1070 // Check for a value of the attribute.
1071 ast::Expr *attrValue = nullptr;
1072 if (consumeIf(Token::equal)) {
1073 FailureOr<ast::Expr *> attrExpr = parseExpr();
1074 if (failed(attrExpr))
1075 return failure();
1076 attrValue = *attrExpr;
1077 } else {
1078 // If there isn't a concrete value, create an expression representing a
1079 // UnitAttr.
1080 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1083 return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1086 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1087 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1088 bool expectTerminalSemicolon) {
1089 consumeToken(Token::equal_arrow);
1091 // Parse the single statement of the lambda body.
1092 SMLoc bodyStartLoc = curToken.getStartLoc();
1093 pushDeclScope();
1094 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1095 bool failedToParse =
1096 failed(singleStatement) || failed(processStatementFn(*singleStatement));
1097 popDeclScope();
1098 if (failedToParse)
1099 return failure();
1101 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1102 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1105 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1106 // Ensure that the argument is named.
1107 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1108 return emitError("expected identifier argument name");
1110 // Parse the argument similarly to a normal variable.
1111 StringRef name = curToken.getSpelling();
1112 SMRange nameLoc = curToken.getLoc();
1113 consumeToken();
1115 if (failed(
1116 parseToken(Token::colon, "expected `:` before argument constraint")))
1117 return failure();
1119 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1120 if (failed(cst))
1121 return failure();
1123 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1126 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1127 // Check to see if this result is named.
1128 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1129 // Check to see if this name actually refers to a Constraint.
1130 if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1131 // If it wasn't a constraint, parse the result similarly to a variable. If
1132 // there is already an existing decl, we will emit an error when defining
1133 // this variable later.
1134 StringRef name = curToken.getSpelling();
1135 SMRange nameLoc = curToken.getLoc();
1136 consumeToken();
1138 if (failed(parseToken(Token::colon,
1139 "expected `:` before result constraint")))
1140 return failure();
1142 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1143 if (failed(cst))
1144 return failure();
1146 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1150 // If it isn't named, we parse the constraint directly and create an unnamed
1151 // result variable.
1152 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1153 if (failed(cst))
1154 return failure();
1156 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1159 FailureOr<ast::UserConstraintDecl *>
1160 Parser::parseUserConstraintDecl(bool isInline) {
1161 // Constraints and rewrites have very similar formats, dispatch to a shared
1162 // interface for parsing.
1163 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1164 [&](auto &&...args) {
1165 return this->parseUserPDLLConstraintDecl(args...);
1167 ParserContext::Constraint, "constraint", isInline);
1170 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1171 FailureOr<ast::UserConstraintDecl *> decl =
1172 parseUserConstraintDecl(/*isInline=*/true);
1173 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1174 return failure();
1176 curDeclScope->add(*decl);
1177 return decl;
1180 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1181 const ast::Name &name, bool isInline,
1182 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1183 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1184 // Push the argument scope back onto the list, so that the body can
1185 // reference arguments.
1186 pushDeclScope(argumentScope);
1188 // Parse the body of the constraint. The body is either defined as a compound
1189 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1190 ast::CompoundStmt *body;
1191 if (curToken.is(Token::equal_arrow)) {
1192 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1193 [&](ast::Stmt *&stmt) -> LogicalResult {
1194 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1195 if (!stmtExpr) {
1196 return emitError(stmt->getLoc(),
1197 "expected `Constraint` lambda body to contain a "
1198 "single expression");
1200 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1201 return success();
1203 /*expectTerminalSemicolon=*/!isInline);
1204 if (failed(bodyResult))
1205 return failure();
1206 body = *bodyResult;
1207 } else {
1208 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1209 if (failed(bodyResult))
1210 return failure();
1211 body = *bodyResult;
1213 // Verify the structure of the body.
1214 auto bodyIt = body->begin(), bodyE = body->end();
1215 for (; bodyIt != bodyE; ++bodyIt)
1216 if (isa<ast::ReturnStmt>(*bodyIt))
1217 break;
1218 if (failed(validateUserConstraintOrRewriteReturn(
1219 "Constraint", body, bodyIt, bodyE, results, resultType)))
1220 return failure();
1222 popDeclScope();
1224 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1225 name, arguments, results, resultType, body);
1228 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1229 // Constraints and rewrites have very similar formats, dispatch to a shared
1230 // interface for parsing.
1231 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1232 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1233 ParserContext::Rewrite, "rewrite", isInline);
1236 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1237 FailureOr<ast::UserRewriteDecl *> decl =
1238 parseUserRewriteDecl(/*isInline=*/true);
1239 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1240 return failure();
1242 curDeclScope->add(*decl);
1243 return decl;
1246 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1247 const ast::Name &name, bool isInline,
1248 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1249 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1250 // Push the argument scope back onto the list, so that the body can
1251 // reference arguments.
1252 curDeclScope = argumentScope;
1253 ast::CompoundStmt *body;
1254 if (curToken.is(Token::equal_arrow)) {
1255 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1256 [&](ast::Stmt *&statement) -> LogicalResult {
1257 if (isa<ast::OpRewriteStmt>(statement))
1258 return success();
1260 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1261 if (!statementExpr) {
1262 return emitError(
1263 statement->getLoc(),
1264 "expected `Rewrite` lambda body to contain a single expression "
1265 "or an operation rewrite statement; such as `erase`, "
1266 "`replace`, or `rewrite`");
1268 statement =
1269 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1270 return success();
1272 /*expectTerminalSemicolon=*/!isInline);
1273 if (failed(bodyResult))
1274 return failure();
1275 body = *bodyResult;
1276 } else {
1277 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1278 if (failed(bodyResult))
1279 return failure();
1280 body = *bodyResult;
1282 popDeclScope();
1284 // Verify the structure of the body.
1285 auto bodyIt = body->begin(), bodyE = body->end();
1286 for (; bodyIt != bodyE; ++bodyIt)
1287 if (isa<ast::ReturnStmt>(*bodyIt))
1288 break;
1289 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1290 bodyE, results, resultType)))
1291 return failure();
1292 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1293 name, arguments, results, resultType, body);
1296 template <typename T, typename ParseUserPDLLDeclFnT>
1297 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1298 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1299 StringRef anonymousNamePrefix, bool isInline) {
1300 SMRange loc = curToken.getLoc();
1301 consumeToken();
1302 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1304 // Parse the name of the decl.
1305 const ast::Name *name = nullptr;
1306 if (curToken.isNot(Token::identifier)) {
1307 // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1308 // in C++, so being unnamed is fine.
1309 if (!isInline)
1310 return emitError("expected identifier name");
1312 // Create a unique anonymous name to use, as the name for this decl is not
1313 // important.
1314 std::string anonName =
1315 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1316 anonymousDeclNameCounter++)
1317 .str();
1318 name = &ast::Name::create(ctx, anonName, loc);
1319 } else {
1320 // If a name was provided, we can use it directly.
1321 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1322 consumeToken(Token::identifier);
1325 // Parse the functional signature of the decl.
1326 SmallVector<ast::VariableDecl *> arguments, results;
1327 ast::DeclScope *argumentScope;
1328 ast::Type resultType;
1329 if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1330 argumentScope, resultType)))
1331 return failure();
1333 // Check to see which type of constraint this is. If the constraint contains a
1334 // compound body, this is a PDLL decl.
1335 if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1336 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1337 resultType);
1339 // Otherwise, this is a native decl.
1340 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1341 results, resultType);
1344 template <typename T>
1345 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1346 const ast::Name &name, bool isInline,
1347 ArrayRef<ast::VariableDecl *> arguments,
1348 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1349 // If followed by a string, the native code body has also been specified.
1350 std::string codeStrStorage;
1351 std::optional<StringRef> optCodeStr;
1352 if (curToken.isString()) {
1353 codeStrStorage = curToken.getStringValue();
1354 optCodeStr = codeStrStorage;
1355 consumeToken();
1356 } else if (isInline) {
1357 return emitError(name.getLoc(),
1358 "external declarations must be declared in global scope");
1359 } else if (curToken.is(Token::error)) {
1360 return failure();
1362 if (failed(parseToken(Token::semicolon,
1363 "expected `;` after native declaration")))
1364 return failure();
1365 // TODO: PDL should be able to support constraint results in certain
1366 // situations, we should revise this.
1367 if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1368 return emitError(
1369 "native Constraints currently do not support returning results");
1371 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1374 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1375 SmallVectorImpl<ast::VariableDecl *> &arguments,
1376 SmallVectorImpl<ast::VariableDecl *> &results,
1377 ast::DeclScope *&argumentScope, ast::Type &resultType) {
1378 // Parse the argument list of the decl.
1379 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1380 return failure();
1382 argumentScope = pushDeclScope();
1383 if (curToken.isNot(Token::r_paren)) {
1384 do {
1385 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1386 if (failed(argument))
1387 return failure();
1388 arguments.emplace_back(*argument);
1389 } while (consumeIf(Token::comma));
1391 popDeclScope();
1392 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1393 return failure();
1395 // Parse the results of the decl.
1396 pushDeclScope();
1397 if (consumeIf(Token::arrow)) {
1398 auto parseResultFn = [&]() -> LogicalResult {
1399 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1400 if (failed(result))
1401 return failure();
1402 results.emplace_back(*result);
1403 return success();
1406 // Check for a list of results.
1407 if (consumeIf(Token::l_paren)) {
1408 do {
1409 if (failed(parseResultFn()))
1410 return failure();
1411 } while (consumeIf(Token::comma));
1412 if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1413 return failure();
1415 // Otherwise, there is only one result.
1416 } else if (failed(parseResultFn())) {
1417 return failure();
1420 popDeclScope();
1422 // Compute the result type of the decl.
1423 resultType = createUserConstraintRewriteResultType(results);
1425 // Verify that results are only named if there are more than one.
1426 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1427 return emitError(
1428 results.front()->getLoc(),
1429 "cannot create a single-element tuple with an element label");
1431 return success();
1434 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1435 StringRef declType, ast::CompoundStmt *body,
1436 ArrayRef<ast::Stmt *>::iterator bodyIt,
1437 ArrayRef<ast::Stmt *>::iterator bodyE,
1438 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1439 // Handle if a `return` was provided.
1440 if (bodyIt != bodyE) {
1441 // Emit an error if we have trailing statements after the return.
1442 if (std::next(bodyIt) != bodyE) {
1443 return emitError(
1444 (*std::next(bodyIt))->getLoc(),
1445 llvm::formatv("`return` terminated the `{0}` body, but found "
1446 "trailing statements afterwards",
1447 declType));
1450 // Otherwise if a return wasn't provided, check that no results are
1451 // expected.
1452 } else if (!results.empty()) {
1453 return emitError(
1454 {body->getLoc().End, body->getLoc().End},
1455 llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1456 declType, resultType));
1458 return success();
1461 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1462 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1463 if (isa<ast::OpRewriteStmt>(statement))
1464 return success();
1465 return emitError(
1466 statement->getLoc(),
1467 "expected Pattern lambda body to contain a single operation "
1468 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1472 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1473 SMRange loc = curToken.getLoc();
1474 consumeToken(Token::kw_Pattern);
1475 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1477 // Check for an optional identifier for the pattern name.
1478 const ast::Name *name = nullptr;
1479 if (curToken.is(Token::identifier)) {
1480 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1481 consumeToken(Token::identifier);
1484 // Parse any pattern metadata.
1485 ParsedPatternMetadata metadata;
1486 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1487 return failure();
1489 // Parse the pattern body.
1490 ast::CompoundStmt *body;
1492 // Handle a lambda body.
1493 if (curToken.is(Token::equal_arrow)) {
1494 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1495 if (failed(bodyResult))
1496 return failure();
1497 body = *bodyResult;
1498 } else {
1499 if (curToken.isNot(Token::l_brace))
1500 return emitError("expected `{` or `=>` to start pattern body");
1501 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1502 if (failed(bodyResult))
1503 return failure();
1504 body = *bodyResult;
1506 // Verify the body of the pattern.
1507 auto bodyIt = body->begin(), bodyE = body->end();
1508 for (; bodyIt != bodyE; ++bodyIt) {
1509 if (isa<ast::ReturnStmt>(*bodyIt)) {
1510 return emitError((*bodyIt)->getLoc(),
1511 "`return` statements are only permitted within a "
1512 "`Constraint` or `Rewrite` body");
1514 // Break when we've found the rewrite statement.
1515 if (isa<ast::OpRewriteStmt>(*bodyIt))
1516 break;
1518 if (bodyIt == bodyE) {
1519 return emitError(loc,
1520 "expected Pattern body to terminate with an operation "
1521 "rewrite statement, such as `erase`");
1523 if (std::next(bodyIt) != bodyE) {
1524 return emitError((*std::next(bodyIt))->getLoc(),
1525 "Pattern body was terminated by an operation "
1526 "rewrite statement, but found trailing statements");
1530 return createPatternDecl(loc, name, metadata, body);
1533 LogicalResult
1534 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1535 std::optional<SMRange> benefitLoc;
1536 std::optional<SMRange> hasBoundedRecursionLoc;
1538 do {
1539 // Handle metadata code completion.
1540 if (curToken.is(Token::code_complete))
1541 return codeCompletePatternMetadata();
1543 if (curToken.isNot(Token::identifier))
1544 return emitError("expected pattern metadata identifier");
1545 StringRef metadataStr = curToken.getSpelling();
1546 SMRange metadataLoc = curToken.getLoc();
1547 consumeToken(Token::identifier);
1549 // Parse the benefit metadata: benefit(<integer-value>)
1550 if (metadataStr == "benefit") {
1551 if (benefitLoc) {
1552 return emitErrorAndNote(metadataLoc,
1553 "pattern benefit has already been specified",
1554 *benefitLoc, "see previous definition here");
1556 if (failed(parseToken(Token::l_paren,
1557 "expected `(` before pattern benefit")))
1558 return failure();
1560 uint16_t benefitValue = 0;
1561 if (curToken.isNot(Token::integer))
1562 return emitError("expected integral pattern benefit");
1563 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1564 return emitError(
1565 "expected pattern benefit to fit within a 16-bit integer");
1566 consumeToken(Token::integer);
1568 metadata.benefit = benefitValue;
1569 benefitLoc = metadataLoc;
1571 if (failed(
1572 parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1573 return failure();
1574 continue;
1577 // Parse the bounded recursion metadata: recursion
1578 if (metadataStr == "recursion") {
1579 if (hasBoundedRecursionLoc) {
1580 return emitErrorAndNote(
1581 metadataLoc,
1582 "pattern recursion metadata has already been specified",
1583 *hasBoundedRecursionLoc, "see previous definition here");
1585 metadata.hasBoundedRecursion = true;
1586 hasBoundedRecursionLoc = metadataLoc;
1587 continue;
1590 return emitError(metadataLoc, "unknown pattern metadata");
1591 } while (consumeIf(Token::comma));
1593 return success();
1596 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1597 consumeToken(Token::less);
1599 FailureOr<ast::Expr *> typeExpr = parseExpr();
1600 if (failed(typeExpr) ||
1601 failed(parseToken(Token::greater,
1602 "expected `>` after variable type constraint")))
1603 return failure();
1604 return typeExpr;
1607 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1608 assert(curDeclScope && "defining decl outside of a decl scope");
1609 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1610 return emitErrorAndNote(
1611 name.getLoc(), "`" + name.getName() + "` has already been defined",
1612 lastDecl->getName()->getLoc(), "see previous definition here");
1614 return success();
1617 FailureOr<ast::VariableDecl *>
1618 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1619 ast::Expr *initExpr,
1620 ArrayRef<ast::ConstraintRef> constraints) {
1621 assert(curDeclScope && "defining variable outside of decl scope");
1622 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1624 // If the name of the variable indicates a special variable, we don't add it
1625 // to the scope. This variable is local to the definition point.
1626 if (name.empty() || name == "_") {
1627 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1628 constraints);
1630 if (failed(checkDefineNamedDecl(nameDecl)))
1631 return failure();
1633 auto *varDecl =
1634 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1635 curDeclScope->add(varDecl);
1636 return varDecl;
1639 FailureOr<ast::VariableDecl *>
1640 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1641 ArrayRef<ast::ConstraintRef> constraints) {
1642 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1643 constraints);
1646 LogicalResult Parser::parseVariableDeclConstraintList(
1647 SmallVectorImpl<ast::ConstraintRef> &constraints) {
1648 std::optional<SMRange> typeConstraint;
1649 auto parseSingleConstraint = [&] {
1650 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1651 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1652 if (failed(constraint))
1653 return failure();
1654 constraints.push_back(*constraint);
1655 return success();
1658 // Check to see if this is a single constraint, or a list.
1659 if (!consumeIf(Token::l_square))
1660 return parseSingleConstraint();
1662 do {
1663 if (failed(parseSingleConstraint()))
1664 return failure();
1665 } while (consumeIf(Token::comma));
1666 return parseToken(Token::r_square, "expected `]` after constraint list");
1669 FailureOr<ast::ConstraintRef>
1670 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1671 ArrayRef<ast::ConstraintRef> existingConstraints,
1672 bool allowInlineTypeConstraints) {
1673 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1674 if (!allowInlineTypeConstraints) {
1675 return emitError(
1676 curToken.getLoc(),
1677 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1678 "permitted on arguments or results");
1680 if (typeConstraint)
1681 return emitErrorAndNote(
1682 curToken.getLoc(),
1683 "the type of this variable has already been constrained",
1684 *typeConstraint, "see previous constraint location here");
1685 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1686 if (failed(constraintExpr))
1687 return failure();
1688 typeExpr = *constraintExpr;
1689 typeConstraint = typeExpr->getLoc();
1690 return success();
1693 SMRange loc = curToken.getLoc();
1694 switch (curToken.getKind()) {
1695 case Token::kw_Attr: {
1696 consumeToken(Token::kw_Attr);
1698 // Check for a type constraint.
1699 ast::Expr *typeExpr = nullptr;
1700 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1701 return failure();
1702 return ast::ConstraintRef(
1703 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1705 case Token::kw_Op: {
1706 consumeToken(Token::kw_Op);
1708 // Parse an optional operation name. If the name isn't provided, this refers
1709 // to "any" operation.
1710 FailureOr<ast::OpNameDecl *> opName =
1711 parseWrappedOperationName(/*allowEmptyName=*/true);
1712 if (failed(opName))
1713 return failure();
1715 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1716 loc);
1718 case Token::kw_Type:
1719 consumeToken(Token::kw_Type);
1720 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1721 case Token::kw_TypeRange:
1722 consumeToken(Token::kw_TypeRange);
1723 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1724 loc);
1725 case Token::kw_Value: {
1726 consumeToken(Token::kw_Value);
1728 // Check for a type constraint.
1729 ast::Expr *typeExpr = nullptr;
1730 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1731 return failure();
1733 return ast::ConstraintRef(
1734 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1736 case Token::kw_ValueRange: {
1737 consumeToken(Token::kw_ValueRange);
1739 // Check for a type constraint.
1740 ast::Expr *typeExpr = nullptr;
1741 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1742 return failure();
1744 return ast::ConstraintRef(
1745 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1748 case Token::kw_Constraint: {
1749 // Handle an inline constraint.
1750 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1751 if (failed(decl))
1752 return failure();
1753 return ast::ConstraintRef(*decl, loc);
1755 case Token::identifier: {
1756 StringRef constraintName = curToken.getSpelling();
1757 consumeToken(Token::identifier);
1759 // Lookup the referenced constraint.
1760 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1761 if (!cstDecl) {
1762 return emitError(loc, "unknown reference to constraint `" +
1763 constraintName + "`");
1766 // Handle a reference to a proper constraint.
1767 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1768 return ast::ConstraintRef(cst, loc);
1770 return emitErrorAndNote(
1771 loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1772 "see the definition of `" + constraintName + "` here");
1774 // Handle single entity constraint code completion.
1775 case Token::code_complete: {
1776 // Try to infer the current type for use by code completion.
1777 ast::Type inferredType;
1778 if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1779 return failure();
1781 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1783 default:
1784 break;
1786 return emitError(loc, "expected identifier constraint");
1789 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1790 std::optional<SMRange> typeConstraint;
1791 return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1792 /*allowInlineTypeConstraints=*/false);
1795 //===----------------------------------------------------------------------===//
1796 // Exprs
1798 FailureOr<ast::Expr *> Parser::parseExpr() {
1799 if (curToken.is(Token::underscore))
1800 return parseUnderscoreExpr();
1802 // Parse the LHS expression.
1803 FailureOr<ast::Expr *> lhsExpr;
1804 switch (curToken.getKind()) {
1805 case Token::kw_attr:
1806 lhsExpr = parseAttributeExpr();
1807 break;
1808 case Token::kw_Constraint:
1809 lhsExpr = parseInlineConstraintLambdaExpr();
1810 break;
1811 case Token::kw_not:
1812 lhsExpr = parseNegatedExpr();
1813 break;
1814 case Token::identifier:
1815 lhsExpr = parseIdentifierExpr();
1816 break;
1817 case Token::kw_op:
1818 lhsExpr = parseOperationExpr();
1819 break;
1820 case Token::kw_Rewrite:
1821 lhsExpr = parseInlineRewriteLambdaExpr();
1822 break;
1823 case Token::kw_type:
1824 lhsExpr = parseTypeExpr();
1825 break;
1826 case Token::l_paren:
1827 lhsExpr = parseTupleExpr();
1828 break;
1829 default:
1830 return emitError("expected expression");
1832 if (failed(lhsExpr))
1833 return failure();
1835 // Check for an operator expression.
1836 while (true) {
1837 switch (curToken.getKind()) {
1838 case Token::dot:
1839 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1840 break;
1841 case Token::l_paren:
1842 lhsExpr = parseCallExpr(*lhsExpr);
1843 break;
1844 default:
1845 return lhsExpr;
1847 if (failed(lhsExpr))
1848 return failure();
1852 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1853 SMRange loc = curToken.getLoc();
1854 consumeToken(Token::kw_attr);
1856 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1857 // identifier.
1858 if (!consumeIf(Token::less)) {
1859 resetToken(loc);
1860 return parseIdentifierExpr();
1863 if (!curToken.isString())
1864 return emitError("expected string literal containing MLIR attribute");
1865 std::string attrExpr = curToken.getStringValue();
1866 consumeToken();
1868 loc.End = curToken.getEndLoc();
1869 if (failed(
1870 parseToken(Token::greater, "expected `>` after attribute literal")))
1871 return failure();
1872 return ast::AttributeExpr::create(ctx, loc, attrExpr);
1875 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1876 bool isNegated) {
1877 consumeToken(Token::l_paren);
1879 // Parse the arguments of the call.
1880 SmallVector<ast::Expr *> arguments;
1881 if (curToken.isNot(Token::r_paren)) {
1882 do {
1883 // Handle code completion for the call arguments.
1884 if (curToken.is(Token::code_complete)) {
1885 codeCompleteCallSignature(parentExpr, arguments.size());
1886 return failure();
1889 FailureOr<ast::Expr *> argument = parseExpr();
1890 if (failed(argument))
1891 return failure();
1892 arguments.push_back(*argument);
1893 } while (consumeIf(Token::comma));
1896 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1897 if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1898 return failure();
1900 return createCallExpr(loc, parentExpr, arguments, isNegated);
1903 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1904 ast::Decl *decl = curDeclScope->lookup(name);
1905 if (!decl)
1906 return emitError(loc, "undefined reference to `" + name + "`");
1908 return createDeclRefExpr(loc, decl);
1911 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1912 StringRef name = curToken.getSpelling();
1913 SMRange nameLoc = curToken.getLoc();
1914 consumeToken();
1916 // Check to see if this is a decl ref expression that defines a variable
1917 // inline.
1918 if (consumeIf(Token::colon)) {
1919 SmallVector<ast::ConstraintRef> constraints;
1920 if (failed(parseVariableDeclConstraintList(constraints)))
1921 return failure();
1922 ast::Type type;
1923 if (failed(validateVariableConstraints(constraints, type)))
1924 return failure();
1925 return createInlineVariableExpr(type, name, nameLoc, constraints);
1928 return parseDeclRefExpr(name, nameLoc);
1931 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1932 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1933 if (failed(decl))
1934 return failure();
1936 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1937 ast::ConstraintType::get(ctx));
1940 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1941 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1942 if (failed(decl))
1943 return failure();
1945 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1946 ast::RewriteType::get(ctx));
1949 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1950 SMRange dotLoc = curToken.getLoc();
1951 consumeToken(Token::dot);
1953 // Check for code completion of the member name.
1954 if (curToken.is(Token::code_complete))
1955 return codeCompleteMemberAccess(parentExpr);
1957 // Parse the member name.
1958 Token memberNameTok = curToken;
1959 if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1960 !memberNameTok.isKeyword())
1961 return emitError(dotLoc, "expected identifier or numeric member name");
1962 StringRef memberName = memberNameTok.getSpelling();
1963 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1964 consumeToken();
1966 return createMemberAccessExpr(parentExpr, memberName, loc);
1969 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1970 consumeToken(Token::kw_not);
1971 // Only native constraints are supported after negation
1972 if (!curToken.is(Token::identifier))
1973 return emitError("expected native constraint");
1974 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1975 if (failed(identifierExpr))
1976 return failure();
1977 return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1980 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1981 SMRange loc = curToken.getLoc();
1983 // Check for code completion for the dialect name.
1984 if (curToken.is(Token::code_complete))
1985 return codeCompleteDialectName();
1987 // Handle the case of an no operation name.
1988 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1989 if (allowEmptyName)
1990 return ast::OpNameDecl::create(ctx, SMRange());
1991 return emitError("expected dialect namespace");
1993 StringRef name = curToken.getSpelling();
1994 consumeToken();
1996 // Otherwise, this is a literal operation name.
1997 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1998 return failure();
2000 // Check for code completion for the operation name.
2001 if (curToken.is(Token::code_complete))
2002 return codeCompleteOperationName(name);
2004 if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
2005 return emitError("expected operation name after dialect namespace");
2007 name = StringRef(name.data(), name.size() + 1);
2008 do {
2009 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2010 loc.End = curToken.getEndLoc();
2011 consumeToken();
2012 } while (curToken.isAny(Token::identifier, Token::dot) ||
2013 curToken.isKeyword());
2014 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2017 FailureOr<ast::OpNameDecl *>
2018 Parser::parseWrappedOperationName(bool allowEmptyName) {
2019 if (!consumeIf(Token::less))
2020 return ast::OpNameDecl::create(ctx, SMRange());
2022 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2023 if (failed(opNameDecl))
2024 return failure();
2026 if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2027 return failure();
2028 return opNameDecl;
2031 FailureOr<ast::Expr *>
2032 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2033 SMRange loc = curToken.getLoc();
2034 consumeToken(Token::kw_op);
2036 // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2037 // identifier.
2038 if (curToken.isNot(Token::less)) {
2039 resetToken(loc);
2040 return parseIdentifierExpr();
2043 // Parse the operation name. The name may be elided, in which case the
2044 // operation refers to "any" operation(i.e. a difference between `MyOp` and
2045 // `Operation*`). Operation names within a rewrite context must be named.
2046 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2047 FailureOr<ast::OpNameDecl *> opNameDecl =
2048 parseWrappedOperationName(allowEmptyName);
2049 if (failed(opNameDecl))
2050 return failure();
2051 std::optional<StringRef> opName = (*opNameDecl)->getName();
2053 // Functor used to create an implicit range variable, used for implicit "all"
2054 // operand or results variables.
2055 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2056 FailureOr<ast::VariableDecl *> rangeVar =
2057 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2058 assert(succeeded(rangeVar) && "expected range variable to be valid");
2059 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2062 // Check for the optional list of operands.
2063 SmallVector<ast::Expr *> operands;
2064 if (!consumeIf(Token::l_paren)) {
2065 // If the operand list isn't specified and we are in a match context, define
2066 // an inplace unconstrained operand range corresponding to all of the
2067 // operands of the operation. This avoids treating zero operands the same
2068 // way as "unconstrained operands".
2069 if (parserContext != ParserContext::Rewrite) {
2070 operands.push_back(createImplicitRangeVar(
2071 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2073 } else if (!consumeIf(Token::r_paren)) {
2074 // If the operand list was specified and non-empty, parse the operands.
2075 do {
2076 // Check for operand signature code completion.
2077 if (curToken.is(Token::code_complete)) {
2078 codeCompleteOperationOperandsSignature(opName, operands.size());
2079 return failure();
2082 FailureOr<ast::Expr *> operand = parseExpr();
2083 if (failed(operand))
2084 return failure();
2085 operands.push_back(*operand);
2086 } while (consumeIf(Token::comma));
2088 if (failed(parseToken(Token::r_paren,
2089 "expected `)` after operation operand list")))
2090 return failure();
2093 // Check for the optional list of attributes.
2094 SmallVector<ast::NamedAttributeDecl *> attributes;
2095 if (consumeIf(Token::l_brace)) {
2096 do {
2097 FailureOr<ast::NamedAttributeDecl *> decl =
2098 parseNamedAttributeDecl(opName);
2099 if (failed(decl))
2100 return failure();
2101 attributes.emplace_back(*decl);
2102 } while (consumeIf(Token::comma));
2104 if (failed(parseToken(Token::r_brace,
2105 "expected `}` after operation attribute list")))
2106 return failure();
2109 // Handle the result types of the operation.
2110 SmallVector<ast::Expr *> resultTypes;
2111 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2113 // Check for an explicit list of result types.
2114 if (consumeIf(Token::arrow)) {
2115 if (failed(parseToken(Token::l_paren,
2116 "expected `(` before operation result type list")))
2117 return failure();
2119 // If result types are provided, initially assume that the operation does
2120 // not rely on type inferrence. We don't assert that it isn't, because we
2121 // may be inferring the value of some type/type range variables, but given
2122 // that these variables may be defined in calls we can't always discern when
2123 // this is the case.
2124 resultTypeContext = OpResultTypeContext::Explicit;
2126 // Handle the case of an empty result list.
2127 if (!consumeIf(Token::r_paren)) {
2128 do {
2129 // Check for result signature code completion.
2130 if (curToken.is(Token::code_complete)) {
2131 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2132 return failure();
2135 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2136 if (failed(resultTypeExpr))
2137 return failure();
2138 resultTypes.push_back(*resultTypeExpr);
2139 } while (consumeIf(Token::comma));
2141 if (failed(parseToken(Token::r_paren,
2142 "expected `)` after operation result type list")))
2143 return failure();
2145 } else if (parserContext != ParserContext::Rewrite) {
2146 // If the result list isn't specified and we are in a match context, define
2147 // an inplace unconstrained result range corresponding to all of the results
2148 // of the operation. This avoids treating zero results the same way as
2149 // "unconstrained results".
2150 resultTypes.push_back(createImplicitRangeVar(
2151 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2152 } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2153 // If the result list isn't specified and we are in a rewrite, try to infer
2154 // them at runtime instead.
2155 resultTypeContext = OpResultTypeContext::Interface;
2158 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2159 attributes, resultTypes);
2162 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2163 SMRange loc = curToken.getLoc();
2164 consumeToken(Token::l_paren);
2166 DenseMap<StringRef, SMRange> usedNames;
2167 SmallVector<StringRef> elementNames;
2168 SmallVector<ast::Expr *> elements;
2169 if (curToken.isNot(Token::r_paren)) {
2170 do {
2171 // Check for the optional element name assignment before the value.
2172 StringRef elementName;
2173 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2174 Token elementNameTok = curToken;
2175 consumeToken();
2177 // The element name is only present if followed by an `=`.
2178 if (consumeIf(Token::equal)) {
2179 elementName = elementNameTok.getSpelling();
2181 // Check to see if this name is already used.
2182 auto elementNameIt =
2183 usedNames.try_emplace(elementName, elementNameTok.getLoc());
2184 if (!elementNameIt.second) {
2185 return emitErrorAndNote(
2186 elementNameTok.getLoc(),
2187 llvm::formatv("duplicate tuple element label `{0}`",
2188 elementName),
2189 elementNameIt.first->getSecond(),
2190 "see previous label use here");
2192 } else {
2193 // Otherwise, we treat this as part of an expression so reset the
2194 // lexer.
2195 resetToken(elementNameTok.getLoc());
2198 elementNames.push_back(elementName);
2200 // Parse the tuple element value.
2201 FailureOr<ast::Expr *> element = parseExpr();
2202 if (failed(element))
2203 return failure();
2204 elements.push_back(*element);
2205 } while (consumeIf(Token::comma));
2207 loc.End = curToken.getEndLoc();
2208 if (failed(
2209 parseToken(Token::r_paren, "expected `)` after tuple element list")))
2210 return failure();
2211 return createTupleExpr(loc, elements, elementNames);
2214 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2215 SMRange loc = curToken.getLoc();
2216 consumeToken(Token::kw_type);
2218 // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2219 // identifier.
2220 if (!consumeIf(Token::less)) {
2221 resetToken(loc);
2222 return parseIdentifierExpr();
2225 if (!curToken.isString())
2226 return emitError("expected string literal containing MLIR type");
2227 std::string attrExpr = curToken.getStringValue();
2228 consumeToken();
2230 loc.End = curToken.getEndLoc();
2231 if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2232 return failure();
2233 return ast::TypeExpr::create(ctx, loc, attrExpr);
2236 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2237 StringRef name = curToken.getSpelling();
2238 SMRange nameLoc = curToken.getLoc();
2239 consumeToken(Token::underscore);
2241 // Underscore expressions require a constraint list.
2242 if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2243 return failure();
2245 // Parse the constraints for the expression.
2246 SmallVector<ast::ConstraintRef> constraints;
2247 if (failed(parseVariableDeclConstraintList(constraints)))
2248 return failure();
2250 ast::Type type;
2251 if (failed(validateVariableConstraints(constraints, type)))
2252 return failure();
2253 return createInlineVariableExpr(type, name, nameLoc, constraints);
2256 //===----------------------------------------------------------------------===//
2257 // Stmts
2259 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2260 FailureOr<ast::Stmt *> stmt;
2261 switch (curToken.getKind()) {
2262 case Token::kw_erase:
2263 stmt = parseEraseStmt();
2264 break;
2265 case Token::kw_let:
2266 stmt = parseLetStmt();
2267 break;
2268 case Token::kw_replace:
2269 stmt = parseReplaceStmt();
2270 break;
2271 case Token::kw_return:
2272 stmt = parseReturnStmt();
2273 break;
2274 case Token::kw_rewrite:
2275 stmt = parseRewriteStmt();
2276 break;
2277 default:
2278 stmt = parseExpr();
2279 break;
2281 if (failed(stmt) ||
2282 (expectTerminalSemicolon &&
2283 failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2284 return failure();
2285 return stmt;
2288 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2289 SMLoc startLoc = curToken.getStartLoc();
2290 consumeToken(Token::l_brace);
2292 // Push a new block scope and parse any nested statements.
2293 pushDeclScope();
2294 SmallVector<ast::Stmt *> statements;
2295 while (curToken.isNot(Token::r_brace)) {
2296 FailureOr<ast::Stmt *> statement = parseStmt();
2297 if (failed(statement))
2298 return popDeclScope(), failure();
2299 statements.push_back(*statement);
2301 popDeclScope();
2303 // Consume the end brace.
2304 SMRange location(startLoc, curToken.getEndLoc());
2305 consumeToken(Token::r_brace);
2307 return ast::CompoundStmt::create(ctx, location, statements);
2310 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2311 if (parserContext == ParserContext::Constraint)
2312 return emitError("`erase` cannot be used within a Constraint");
2313 SMRange loc = curToken.getLoc();
2314 consumeToken(Token::kw_erase);
2316 // Parse the root operation expression.
2317 FailureOr<ast::Expr *> rootOp = parseExpr();
2318 if (failed(rootOp))
2319 return failure();
2321 return createEraseStmt(loc, *rootOp);
2324 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2325 SMRange loc = curToken.getLoc();
2326 consumeToken(Token::kw_let);
2328 // Parse the name of the new variable.
2329 SMRange varLoc = curToken.getLoc();
2330 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2331 // `_` is a reserved variable name.
2332 if (curToken.is(Token::underscore)) {
2333 return emitError(varLoc,
2334 "`_` may only be used to define \"inline\" variables");
2336 return emitError(varLoc,
2337 "expected identifier after `let` to name a new variable");
2339 StringRef varName = curToken.getSpelling();
2340 consumeToken();
2342 // Parse the optional set of constraints.
2343 SmallVector<ast::ConstraintRef> constraints;
2344 if (consumeIf(Token::colon) &&
2345 failed(parseVariableDeclConstraintList(constraints)))
2346 return failure();
2348 // Parse the optional initializer expression.
2349 ast::Expr *initializer = nullptr;
2350 if (consumeIf(Token::equal)) {
2351 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2352 if (failed(initOrFailure))
2353 return failure();
2354 initializer = *initOrFailure;
2356 // Check that the constraints are compatible with having an initializer,
2357 // e.g. type constraints cannot be used with initializers.
2358 for (ast::ConstraintRef constraint : constraints) {
2359 LogicalResult result =
2360 TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2361 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2362 ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2363 if (cst->getTypeExpr()) {
2364 return this->emitError(
2365 constraint.referenceLoc,
2366 "type constraints are not permitted on variables with "
2367 "initializers");
2369 return success();
2371 .Default(success());
2372 if (failed(result))
2373 return failure();
2377 FailureOr<ast::VariableDecl *> varDecl =
2378 createVariableDecl(varName, varLoc, initializer, constraints);
2379 if (failed(varDecl))
2380 return failure();
2381 return ast::LetStmt::create(ctx, loc, *varDecl);
2384 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2385 if (parserContext == ParserContext::Constraint)
2386 return emitError("`replace` cannot be used within a Constraint");
2387 SMRange loc = curToken.getLoc();
2388 consumeToken(Token::kw_replace);
2390 // Parse the root operation expression.
2391 FailureOr<ast::Expr *> rootOp = parseExpr();
2392 if (failed(rootOp))
2393 return failure();
2395 if (failed(
2396 parseToken(Token::kw_with, "expected `with` after root operation")))
2397 return failure();
2399 // The replacement portion of this statement is within a rewrite context.
2400 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2402 // Parse the replacement values.
2403 SmallVector<ast::Expr *> replValues;
2404 if (consumeIf(Token::l_paren)) {
2405 if (consumeIf(Token::r_paren)) {
2406 return emitError(
2407 loc, "expected at least one replacement value, consider using "
2408 "`erase` if no replacement values are desired");
2411 do {
2412 FailureOr<ast::Expr *> replExpr = parseExpr();
2413 if (failed(replExpr))
2414 return failure();
2415 replValues.emplace_back(*replExpr);
2416 } while (consumeIf(Token::comma));
2418 if (failed(parseToken(Token::r_paren,
2419 "expected `)` after replacement values")))
2420 return failure();
2421 } else {
2422 // Handle replacement with an operation uniquely, as the replacement
2423 // operation supports type inferrence from the root operation.
2424 FailureOr<ast::Expr *> replExpr;
2425 if (curToken.is(Token::kw_op))
2426 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2427 else
2428 replExpr = parseExpr();
2429 if (failed(replExpr))
2430 return failure();
2431 replValues.emplace_back(*replExpr);
2434 return createReplaceStmt(loc, *rootOp, replValues);
2437 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2438 SMRange loc = curToken.getLoc();
2439 consumeToken(Token::kw_return);
2441 // Parse the result value.
2442 FailureOr<ast::Expr *> resultExpr = parseExpr();
2443 if (failed(resultExpr))
2444 return failure();
2446 return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2449 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2450 if (parserContext == ParserContext::Constraint)
2451 return emitError("`rewrite` cannot be used within a Constraint");
2452 SMRange loc = curToken.getLoc();
2453 consumeToken(Token::kw_rewrite);
2455 // Parse the root operation.
2456 FailureOr<ast::Expr *> rootOp = parseExpr();
2457 if (failed(rootOp))
2458 return failure();
2460 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2461 return failure();
2463 if (curToken.isNot(Token::l_brace))
2464 return emitError("expected `{` to start rewrite body");
2466 // The rewrite body of this statement is within a rewrite context.
2467 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2469 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2470 if (failed(rewriteBody))
2471 return failure();
2473 // Verify the rewrite body.
2474 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2475 if (isa<ast::ReturnStmt>(stmt)) {
2476 return emitError(stmt->getLoc(),
2477 "`return` statements are only permitted within a "
2478 "`Constraint` or `Rewrite` body");
2482 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2485 //===----------------------------------------------------------------------===//
2486 // Creation+Analysis
2487 //===----------------------------------------------------------------------===//
2489 //===----------------------------------------------------------------------===//
2490 // Decls
2492 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2493 // Unwrap reference expressions.
2494 if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2495 node = init->getDecl();
2496 return dyn_cast<ast::CallableDecl>(node);
2499 FailureOr<ast::PatternDecl *>
2500 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2501 const ParsedPatternMetadata &metadata,
2502 ast::CompoundStmt *body) {
2503 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2504 metadata.hasBoundedRecursion, body);
2507 ast::Type Parser::createUserConstraintRewriteResultType(
2508 ArrayRef<ast::VariableDecl *> results) {
2509 // Single result decls use the type of the single result.
2510 if (results.size() == 1)
2511 return results[0]->getType();
2513 // Multiple results use a tuple type, with the types and names grabbed from
2514 // the result variable decls.
2515 auto resultTypes = llvm::map_range(
2516 results, [&](const auto *result) { return result->getType(); });
2517 auto resultNames = llvm::map_range(
2518 results, [&](const auto *result) { return result->getName().getName(); });
2519 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2520 llvm::to_vector(resultNames));
2523 template <typename T>
2524 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2525 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2526 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2527 ast::CompoundStmt *body) {
2528 if (!body->getChildren().empty()) {
2529 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2530 ast::Expr *resultExpr = retStmt->getResultExpr();
2532 // Process the result of the decl. If no explicit signature results
2533 // were provided, check for return type inference. Otherwise, check that
2534 // the return expression can be converted to the expected type.
2535 if (results.empty())
2536 resultType = resultExpr->getType();
2537 else if (failed(convertExpressionTo(resultExpr, resultType)))
2538 return failure();
2539 else
2540 retStmt->setResultExpr(resultExpr);
2543 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2546 FailureOr<ast::VariableDecl *>
2547 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2548 ArrayRef<ast::ConstraintRef> constraints) {
2549 // The type of the variable, which is expected to be inferred by either a
2550 // constraint or an initializer expression.
2551 ast::Type type;
2552 if (failed(validateVariableConstraints(constraints, type)))
2553 return failure();
2555 if (initializer) {
2556 // Update the variable type based on the initializer, or try to convert the
2557 // initializer to the existing type.
2558 if (!type)
2559 type = initializer->getType();
2560 else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2561 type = mergedType;
2562 else if (failed(convertExpressionTo(initializer, type)))
2563 return failure();
2565 // Otherwise, if there is no initializer check that the type has already
2566 // been resolved from the constraint list.
2567 } else if (!type) {
2568 return emitErrorAndNote(
2569 loc, "unable to infer type for variable `" + name + "`", loc,
2570 "the type of a variable must be inferable from the constraint "
2571 "list or the initializer");
2574 // Constraint types cannot be used when defining variables.
2575 if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2576 return emitError(
2577 loc, llvm::formatv("unable to define variable of `{0}` type", type));
2580 // Try to define a variable with the given name.
2581 FailureOr<ast::VariableDecl *> varDecl =
2582 defineVariableDecl(name, loc, type, initializer, constraints);
2583 if (failed(varDecl))
2584 return failure();
2586 return *varDecl;
2589 FailureOr<ast::VariableDecl *>
2590 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2591 const ast::ConstraintRef &constraint) {
2592 ast::Type argType;
2593 if (failed(validateVariableConstraint(constraint, argType)))
2594 return failure();
2595 return defineVariableDecl(name, loc, argType, constraint);
2598 LogicalResult
2599 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2600 ast::Type &inferredType) {
2601 for (const ast::ConstraintRef &ref : constraints)
2602 if (failed(validateVariableConstraint(ref, inferredType)))
2603 return failure();
2604 return success();
2607 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2608 ast::Type &inferredType) {
2609 ast::Type constraintType;
2610 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2611 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2612 if (failed(validateTypeConstraintExpr(typeExpr)))
2613 return failure();
2615 constraintType = ast::AttributeType::get(ctx);
2616 } else if (const auto *cst =
2617 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2618 constraintType = ast::OperationType::get(
2619 ctx, cst->getName(), lookupODSOperation(cst->getName()));
2620 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2621 constraintType = typeTy;
2622 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2623 constraintType = typeRangeTy;
2624 } else if (const auto *cst =
2625 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2626 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2627 if (failed(validateTypeConstraintExpr(typeExpr)))
2628 return failure();
2630 constraintType = valueTy;
2631 } else if (const auto *cst =
2632 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2633 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2634 if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2635 return failure();
2637 constraintType = valueRangeTy;
2638 } else if (const auto *cst =
2639 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2640 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2641 if (inputs.size() != 1) {
2642 return emitErrorAndNote(ref.referenceLoc,
2643 "`Constraint`s applied via a variable constraint "
2644 "list must take a single input, but got " +
2645 Twine(inputs.size()),
2646 cst->getLoc(),
2647 "see definition of constraint here");
2649 constraintType = inputs.front()->getType();
2650 } else {
2651 llvm_unreachable("unknown constraint type");
2654 // Check that the constraint type is compatible with the current inferred
2655 // type.
2656 if (!inferredType) {
2657 inferredType = constraintType;
2658 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2659 inferredType = mergedTy;
2660 } else {
2661 return emitError(ref.referenceLoc,
2662 llvm::formatv("constraint type `{0}` is incompatible "
2663 "with the previously inferred type `{1}`",
2664 constraintType, inferredType));
2666 return success();
2669 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2670 ast::Type typeExprType = typeExpr->getType();
2671 if (typeExprType != typeTy) {
2672 return emitError(typeExpr->getLoc(),
2673 "expected expression of `Type` in type constraint");
2675 return success();
2678 LogicalResult
2679 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2680 ast::Type typeExprType = typeExpr->getType();
2681 if (typeExprType != typeRangeTy) {
2682 return emitError(typeExpr->getLoc(),
2683 "expected expression of `TypeRange` in type constraint");
2685 return success();
2688 //===----------------------------------------------------------------------===//
2689 // Exprs
2691 FailureOr<ast::CallExpr *>
2692 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2693 MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2694 ast::Type parentType = parentExpr->getType();
2696 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2697 if (!callableDecl) {
2698 return emitError(loc,
2699 llvm::formatv("expected a reference to a callable "
2700 "`Constraint` or `Rewrite`, but got: `{0}`",
2701 parentType));
2703 if (parserContext == ParserContext::Rewrite) {
2704 if (isa<ast::UserConstraintDecl>(callableDecl))
2705 return emitError(
2706 loc, "unable to invoke `Constraint` within a rewrite section");
2707 if (isNegated)
2708 return emitError(loc, "unable to negate a Rewrite");
2709 } else {
2710 if (isa<ast::UserRewriteDecl>(callableDecl))
2711 return emitError(loc,
2712 "unable to invoke `Rewrite` within a match section");
2713 if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2714 return emitError(loc, "unable to negate non native constraints");
2717 // Verify the arguments of the call.
2718 /// Handle size mismatch.
2719 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2720 if (callArgs.size() != arguments.size()) {
2721 return emitErrorAndNote(
2722 loc,
2723 llvm::formatv("invalid number of arguments for {0} call; expected "
2724 "{1}, but got {2}",
2725 callableDecl->getCallableType(), callArgs.size(),
2726 arguments.size()),
2727 callableDecl->getLoc(),
2728 llvm::formatv("see the definition of {0} here",
2729 callableDecl->getName()->getName()));
2732 /// Handle argument type mismatch.
2733 auto attachDiagFn = [&](ast::Diagnostic &diag) {
2734 diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2735 callableDecl->getName()->getName()),
2736 callableDecl->getLoc());
2738 for (auto it : llvm::zip(callArgs, arguments)) {
2739 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2740 attachDiagFn)))
2741 return failure();
2744 return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2745 callableDecl->getResultType(), isNegated);
2748 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2749 ast::Decl *decl) {
2750 // Check the type of decl being referenced.
2751 ast::Type declType;
2752 if (isa<ast::ConstraintDecl>(decl))
2753 declType = ast::ConstraintType::get(ctx);
2754 else if (isa<ast::UserRewriteDecl>(decl))
2755 declType = ast::RewriteType::get(ctx);
2756 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2757 declType = varDecl->getType();
2758 else
2759 return emitError(loc, "invalid reference to `" +
2760 decl->getName()->getName() + "`");
2762 return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2765 FailureOr<ast::DeclRefExpr *>
2766 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2767 ArrayRef<ast::ConstraintRef> constraints) {
2768 FailureOr<ast::VariableDecl *> decl =
2769 defineVariableDecl(name, loc, type, constraints);
2770 if (failed(decl))
2771 return failure();
2772 return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2775 FailureOr<ast::MemberAccessExpr *>
2776 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2777 SMRange loc) {
2778 // Validate the member name for the given parent expression.
2779 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2780 if (failed(memberType))
2781 return failure();
2783 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2786 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2787 StringRef name, SMRange loc) {
2788 ast::Type parentType = parentExpr->getType();
2789 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2790 if (name == ast::AllResultsMemberAccessExpr::getMemberName())
2791 return valueRangeTy;
2793 // Verify member access based on the operation type.
2794 if (const ods::Operation *odsOp = opType.getODSOperation()) {
2795 auto results = odsOp->getResults();
2797 // Handle indexed results.
2798 unsigned index = 0;
2799 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2800 index < results.size()) {
2801 return results[index].isVariadic() ? valueRangeTy : valueTy;
2804 // Handle named results.
2805 const auto *it = llvm::find_if(results, [&](const auto &result) {
2806 return result.getName() == name;
2808 if (it != results.end())
2809 return it->isVariadic() ? valueRangeTy : valueTy;
2810 } else if (llvm::isDigit(name[0])) {
2811 // Allow unchecked numeric indexing of the results of unregistered
2812 // operations. It returns a single value.
2813 return valueTy;
2815 } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2816 // Handle indexed results.
2817 unsigned index = 0;
2818 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2819 index < tupleType.size()) {
2820 return tupleType.getElementTypes()[index];
2823 // Handle named results.
2824 auto elementNames = tupleType.getElementNames();
2825 const auto *it = llvm::find(elementNames, name);
2826 if (it != elementNames.end())
2827 return tupleType.getElementTypes()[it - elementNames.begin()];
2829 return emitError(
2830 loc,
2831 llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2832 name, parentType));
2835 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2836 SMRange loc, const ast::OpNameDecl *name,
2837 OpResultTypeContext resultTypeContext,
2838 SmallVectorImpl<ast::Expr *> &operands,
2839 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2840 SmallVectorImpl<ast::Expr *> &results) {
2841 std::optional<StringRef> opNameRef = name->getName();
2842 const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2844 // Verify the inputs operands.
2845 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2846 return failure();
2848 // Verify the attribute list.
2849 for (ast::NamedAttributeDecl *attr : attributes) {
2850 // Check for an attribute type, or a type awaiting resolution.
2851 ast::Type attrType = attr->getValue()->getType();
2852 if (!attrType.isa<ast::AttributeType>()) {
2853 return emitError(
2854 attr->getValue()->getLoc(),
2855 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2859 assert(
2860 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2861 "unexpected inferrence when results were explicitly specified");
2863 // If we aren't relying on type inferrence, or explicit results were provided,
2864 // validate them.
2865 if (resultTypeContext == OpResultTypeContext::Explicit) {
2866 if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2867 return failure();
2869 // Validate the use of interface based type inferrence for this operation.
2870 } else if (resultTypeContext == OpResultTypeContext::Interface) {
2871 assert(opNameRef &&
2872 "expected valid operation name when inferring operation results");
2873 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2876 return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2877 attributes);
2880 LogicalResult
2881 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2882 const ods::Operation *odsOp,
2883 SmallVectorImpl<ast::Expr *> &operands) {
2884 return validateOperationOperandsOrResults(
2885 "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2886 operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
2887 valueRangeTy);
2890 LogicalResult
2891 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2892 const ods::Operation *odsOp,
2893 SmallVectorImpl<ast::Expr *> &results) {
2894 return validateOperationOperandsOrResults(
2895 "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2896 results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
2899 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2900 const ods::Operation *odsOp) {
2901 // If the operation might not have inferrence support, emit a warning to the
2902 // user. We don't emit an error because the interface might be added to the
2903 // operation at runtime. It's rare, but it could still happen. We emit a
2904 // warning here instead.
2906 // Handle inferrence warnings for unknown operations.
2907 if (!odsOp) {
2908 ctx.getDiagEngine().emitWarning(
2909 loc, llvm::formatv(
2910 "operation result types are marked to be inferred, but "
2911 "`{0}` is unknown. Ensure that `{0}` supports zero "
2912 "results or implements `InferTypeOpInterface`. Include "
2913 "the ODS definition of this operation to remove this warning.",
2914 opName));
2915 return;
2918 // Handle inferrence warnings for known operations that expected at least one
2919 // result, but don't have inference support. An elided results list can mean
2920 // "zero-results", and we don't want to warn when that is the expected
2921 // behavior.
2922 bool requiresInferrence =
2923 llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2924 return !result.isVariableLength();
2926 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2927 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2928 loc,
2929 llvm::formatv("operation result types are marked to be inferred, but "
2930 "`{0}` does not provide an implementation of "
2931 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2932 "`InferTypeOpInterface` at runtime, or add support to "
2933 "the ODS definition to remove this warning.",
2934 opName));
2935 diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2936 odsOp->getLoc());
2937 return;
2941 LogicalResult Parser::validateOperationOperandsOrResults(
2942 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2943 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2944 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2945 ast::RangeType rangeTy) {
2946 // All operation types accept a single range parameter.
2947 if (values.size() == 1) {
2948 if (failed(convertExpressionTo(values[0], rangeTy)))
2949 return failure();
2950 return success();
2953 /// If the operation has ODS information, we can more accurately verify the
2954 /// values.
2955 if (odsOpLoc) {
2956 auto emitSizeMismatchError = [&] {
2957 return emitErrorAndNote(
2958 loc,
2959 llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2960 "{2}, but got {3}",
2961 groupName, *name, odsValues.size(), values.size()),
2962 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2965 // Handle the case where no values were provided.
2966 if (values.empty()) {
2967 // If we don't expect any on the ODS side, we are done.
2968 if (odsValues.empty())
2969 return success();
2971 // If we do, check if we actually need to provide values (i.e. if any of
2972 // the values are actually required).
2973 unsigned numVariadic = 0;
2974 for (const auto &odsValue : odsValues) {
2975 if (!odsValue.isVariableLength())
2976 return emitSizeMismatchError();
2977 ++numVariadic;
2980 // If we are in a non-rewrite context, we don't need to do anything more.
2981 // Zero-values is a valid constraint on the operation.
2982 if (parserContext != ParserContext::Rewrite)
2983 return success();
2985 // Otherwise, when in a rewrite we may need to provide values to match the
2986 // ODS signature of the operation to create.
2988 // If we only have one variadic value, just use an empty list.
2989 if (numVariadic == 1)
2990 return success();
2992 // Otherwise, create dummy values for each of the entries so that we
2993 // adhere to the ODS signature.
2994 for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2995 values.push_back(ast::RangeExpr::create(
2996 ctx, loc, /*elements=*/std::nullopt, rangeTy));
2998 return success();
3001 // Verify that the number of values provided matches the number of value
3002 // groups ODS expects.
3003 if (odsValues.size() != values.size())
3004 return emitSizeMismatchError();
3006 auto diagFn = [&](ast::Diagnostic &diag) {
3007 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3008 *odsOpLoc);
3010 for (unsigned i = 0, e = values.size(); i < e; ++i) {
3011 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3012 if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3013 return failure();
3015 return success();
3018 // Otherwise, accept the value groups as they have been defined and just
3019 // ensure they are one of the expected types.
3020 for (ast::Expr *&valueExpr : values) {
3021 ast::Type valueExprType = valueExpr->getType();
3023 // Check if this is one of the expected types.
3024 if (valueExprType == rangeTy || valueExprType == singleTy)
3025 continue;
3027 // If the operand is an Operation, allow converting to a Value or
3028 // ValueRange. This situations arises quite often with nested operation
3029 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3030 if (singleTy == valueTy) {
3031 if (valueExprType.isa<ast::OperationType>()) {
3032 valueExpr = convertOpToValue(valueExpr);
3033 continue;
3037 // Otherwise, try to convert the expression to a range.
3038 if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3039 continue;
3041 return emitError(
3042 valueExpr->getLoc(),
3043 llvm::formatv(
3044 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3045 singleTy, rangeTy, valueExprType));
3047 return success();
3050 FailureOr<ast::TupleExpr *>
3051 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3052 ArrayRef<StringRef> elementNames) {
3053 for (const ast::Expr *element : elements) {
3054 ast::Type eleTy = element->getType();
3055 if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
3056 return emitError(
3057 element->getLoc(),
3058 llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3061 return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3064 //===----------------------------------------------------------------------===//
3065 // Stmts
3067 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3068 ast::Expr *rootOp) {
3069 // Check that root is an Operation.
3070 ast::Type rootType = rootOp->getType();
3071 if (!rootType.isa<ast::OperationType>())
3072 return emitError(rootOp->getLoc(), "expected `Op` expression");
3074 return ast::EraseStmt::create(ctx, loc, rootOp);
3077 FailureOr<ast::ReplaceStmt *>
3078 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3079 MutableArrayRef<ast::Expr *> replValues) {
3080 // Check that root is an Operation.
3081 ast::Type rootType = rootOp->getType();
3082 if (!rootType.isa<ast::OperationType>()) {
3083 return emitError(
3084 rootOp->getLoc(),
3085 llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3088 // If there are multiple replacement values, we implicitly convert any Op
3089 // expressions to the value form.
3090 bool shouldConvertOpToValues = replValues.size() > 1;
3091 for (ast::Expr *&replExpr : replValues) {
3092 ast::Type replType = replExpr->getType();
3094 // Check that replExpr is an Operation, Value, or ValueRange.
3095 if (replType.isa<ast::OperationType>()) {
3096 if (shouldConvertOpToValues)
3097 replExpr = convertOpToValue(replExpr);
3098 continue;
3101 if (replType != valueTy && replType != valueRangeTy) {
3102 return emitError(replExpr->getLoc(),
3103 llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3104 "expression, but got `{0}`",
3105 replType));
3109 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3112 FailureOr<ast::RewriteStmt *>
3113 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3114 ast::CompoundStmt *rewriteBody) {
3115 // Check that root is an Operation.
3116 ast::Type rootType = rootOp->getType();
3117 if (!rootType.isa<ast::OperationType>()) {
3118 return emitError(
3119 rootOp->getLoc(),
3120 llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3123 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3126 //===----------------------------------------------------------------------===//
3127 // Code Completion
3128 //===----------------------------------------------------------------------===//
3130 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3131 ast::Type parentType = parentExpr->getType();
3132 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3133 codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3134 else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3135 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3136 return failure();
3139 LogicalResult
3140 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3141 if (opName)
3142 codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3143 return failure();
3146 LogicalResult
3147 Parser::codeCompleteConstraintName(ast::Type inferredType,
3148 bool allowInlineTypeConstraints) {
3149 codeCompleteContext->codeCompleteConstraintName(
3150 inferredType, allowInlineTypeConstraints, curDeclScope);
3151 return failure();
3154 LogicalResult Parser::codeCompleteDialectName() {
3155 codeCompleteContext->codeCompleteDialectName();
3156 return failure();
3159 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3160 codeCompleteContext->codeCompleteOperationName(dialectName);
3161 return failure();
3164 LogicalResult Parser::codeCompletePatternMetadata() {
3165 codeCompleteContext->codeCompletePatternMetadata();
3166 return failure();
3169 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3170 codeCompleteContext->codeCompleteIncludeFilename(curPath);
3171 return failure();
3174 void Parser::codeCompleteCallSignature(ast::Node *parent,
3175 unsigned currentNumArgs) {
3176 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3177 if (!callableDecl)
3178 return;
3180 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3183 void Parser::codeCompleteOperationOperandsSignature(
3184 std::optional<StringRef> opName, unsigned currentNumOperands) {
3185 codeCompleteContext->codeCompleteOperationOperandsSignature(
3186 opName, currentNumOperands);
3189 void Parser::codeCompleteOperationResultsSignature(
3190 std::optional<StringRef> opName, unsigned currentNumResults) {
3191 codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3192 currentNumResults);
3195 //===----------------------------------------------------------------------===//
3196 // Parser
3197 //===----------------------------------------------------------------------===//
3199 FailureOr<ast::Module *>
3200 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3201 bool enableDocumentation,
3202 CodeCompleteContext *codeCompleteContext) {
3203 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3204 return parser.parseModule();