1 //===- Parser.cpp ---------------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Tools/PDLL/Parser/Parser.h"
11 #include "mlir/Support/IndentedOstream.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/Argument.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Constraint.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/Tools/PDLL/AST/Context.h"
19 #include "mlir/Tools/PDLL/AST/Diagnostic.h"
20 #include "mlir/Tools/PDLL/AST/Nodes.h"
21 #include "mlir/Tools/PDLL/AST/Types.h"
22 #include "mlir/Tools/PDLL/ODS/Constraint.h"
23 #include "mlir/Tools/PDLL/ODS/Context.h"
24 #include "mlir/Tools/PDLL/ODS/Operation.h"
25 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
38 using namespace mlir::pdll
;
40 //===----------------------------------------------------------------------===//
42 //===----------------------------------------------------------------------===//
47 Parser(ast::Context
&ctx
, llvm::SourceMgr
&sourceMgr
,
48 bool enableDocumentation
, CodeCompleteContext
*codeCompleteContext
)
49 : ctx(ctx
), lexer(sourceMgr
, ctx
.getDiagEngine(), codeCompleteContext
),
50 curToken(lexer
.lexToken()), enableDocumentation(enableDocumentation
),
51 typeTy(ast::TypeType::get(ctx
)), valueTy(ast::ValueType::get(ctx
)),
52 typeRangeTy(ast::TypeRangeType::get(ctx
)),
53 valueRangeTy(ast::ValueRangeType::get(ctx
)),
54 attrTy(ast::AttributeType::get(ctx
)),
55 codeCompleteContext(codeCompleteContext
) {}
57 /// Try to parse a new module. Returns nullptr in the case of failure.
58 FailureOr
<ast::Module
*> parseModule();
61 /// The current context of the parser. It allows for the parser to know a bit
62 /// about the construct it is nested within during parsing. This is used
63 /// specifically to provide additional verification during parsing, e.g. to
64 /// prevent using rewrites within a match context, matcher constraints within
65 /// a rewrite section, etc.
66 enum class ParserContext
{
67 /// The parser is in the global context.
69 /// The parser is currently within a Constraint, which disallows all types
70 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
72 /// The parser is currently within the matcher portion of a Pattern, which
73 /// is allows a terminal operation rewrite statement but no other rewrite
76 /// The parser is currently within a Rewrite, which disallows calls to
77 /// constraints, requires operation expressions to have names, etc.
81 /// The current specification context of an operations result type. This
82 /// indicates how the result types of an operation may be inferred.
83 enum class OpResultTypeContext
{
84 /// The result types of the operation are not known to be inferred.
86 /// The result types of the operation are inferred from the root input of a
87 /// `replace` statement.
89 /// The result types of the operation are inferred by using the
90 /// `InferTypeOpInterface` interface provided by the operation.
94 //===--------------------------------------------------------------------===//
96 //===--------------------------------------------------------------------===//
98 /// Push a new decl scope onto the lexer.
99 ast::DeclScope
*pushDeclScope() {
100 ast::DeclScope
*newScope
=
101 new (scopeAllocator
.Allocate()) ast::DeclScope(curDeclScope
);
102 return (curDeclScope
= newScope
);
104 void pushDeclScope(ast::DeclScope
*scope
) { curDeclScope
= scope
; }
106 /// Pop the last decl scope from the lexer.
107 void popDeclScope() { curDeclScope
= curDeclScope
->getParentScope(); }
109 /// Parse the body of an AST module.
110 LogicalResult
parseModuleBody(SmallVectorImpl
<ast::Decl
*> &decls
);
112 /// Try to convert the given expression to `type`. Returns failure and emits
113 /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114 /// invoked to attach notes to the emitted error diagnostic. On success,
115 /// `expr` is updated to the expression used to convert to `type`.
116 LogicalResult
convertExpressionTo(
117 ast::Expr
*&expr
, ast::Type type
,
118 function_ref
<void(ast::Diagnostic
&diag
)> noteAttachFn
= {});
120 convertOpExpressionTo(ast::Expr
*&expr
, ast::OperationType exprType
,
122 function_ref
<ast::InFlightDiagnostic()> emitErrorFn
);
123 LogicalResult
convertTupleExpressionTo(
124 ast::Expr
*&expr
, ast::TupleType exprType
, ast::Type type
,
125 function_ref
<ast::InFlightDiagnostic()> emitErrorFn
,
126 function_ref
<void(ast::Diagnostic
&diag
)> noteAttachFn
);
128 /// Given an operation expression, convert it to a Value or ValueRange
129 /// typed expression.
130 ast::Expr
*convertOpToValue(const ast::Expr
*opExpr
);
132 /// Lookup ODS information for the given operation, returns nullptr if no
133 /// information is found.
134 const ods::Operation
*lookupODSOperation(std::optional
<StringRef
> opName
) {
135 return opName
? ctx
.getODSContext().lookupOperation(*opName
) : nullptr;
138 /// Process the given documentation string, or return an empty string if
139 /// documentation isn't enabled.
140 StringRef
processDoc(StringRef doc
) {
141 return enableDocumentation
? doc
: StringRef();
144 /// Process the given documentation string and format it, or return an empty
145 /// string if documentation isn't enabled.
146 std::string
processAndFormatDoc(const Twine
&doc
) {
147 if (!enableDocumentation
)
151 llvm::raw_string_ostream
docOS(docStr
);
152 std::string tmpDocStr
= doc
.str();
153 raw_indented_ostream(docOS
).printReindented(
154 StringRef(tmpDocStr
).rtrim(" \t"));
159 //===--------------------------------------------------------------------===//
162 LogicalResult
parseDirective(SmallVectorImpl
<ast::Decl
*> &decls
);
163 LogicalResult
parseInclude(SmallVectorImpl
<ast::Decl
*> &decls
);
164 LogicalResult
parseTdInclude(StringRef filename
, SMRange fileLoc
,
165 SmallVectorImpl
<ast::Decl
*> &decls
);
167 /// Process the records of a parsed tablegen include file.
168 void processTdIncludeRecords(llvm::RecordKeeper
&tdRecords
,
169 SmallVectorImpl
<ast::Decl
*> &decls
);
171 /// Create a user defined native constraint for a constraint imported from
173 template <typename ConstraintT
>
175 createODSNativePDLLConstraintDecl(StringRef name
, StringRef codeBlock
,
176 SMRange loc
, ast::Type type
,
177 StringRef nativeType
, StringRef docString
);
178 template <typename ConstraintT
>
180 createODSNativePDLLConstraintDecl(const tblgen::Constraint
&constraint
,
181 SMRange loc
, ast::Type type
,
182 StringRef nativeType
);
184 //===--------------------------------------------------------------------===//
187 /// This structure contains the set of pattern metadata that may be parsed.
188 struct ParsedPatternMetadata
{
189 std::optional
<uint16_t> benefit
;
190 bool hasBoundedRecursion
= false;
193 FailureOr
<ast::Decl
*> parseTopLevelDecl();
194 FailureOr
<ast::NamedAttributeDecl
*>
195 parseNamedAttributeDecl(std::optional
<StringRef
> parentOpName
);
197 /// Parse an argument variable as part of the signature of a
198 /// UserConstraintDecl or UserRewriteDecl.
199 FailureOr
<ast::VariableDecl
*> parseArgumentDecl();
201 /// Parse a result variable as part of the signature of a UserConstraintDecl
202 /// or UserRewriteDecl.
203 FailureOr
<ast::VariableDecl
*> parseResultDecl(unsigned resultNum
);
205 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
206 /// defined in a non-global context.
207 FailureOr
<ast::UserConstraintDecl
*>
208 parseUserConstraintDecl(bool isInline
= false);
210 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
211 /// non-global context, such as within a Pattern/Constraint/etc.
212 FailureOr
<ast::UserConstraintDecl
*> parseInlineUserConstraintDecl();
214 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
216 FailureOr
<ast::UserConstraintDecl
*> parseUserPDLLConstraintDecl(
217 const ast::Name
&name
, bool isInline
,
218 ArrayRef
<ast::VariableDecl
*> arguments
, ast::DeclScope
*argumentScope
,
219 ArrayRef
<ast::VariableDecl
*> results
, ast::Type resultType
);
221 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
222 /// defined in a non-global context.
223 FailureOr
<ast::UserRewriteDecl
*> parseUserRewriteDecl(bool isInline
= false);
225 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
226 /// non-global context, such as within a Pattern/Rewrite/etc.
227 FailureOr
<ast::UserRewriteDecl
*> parseInlineUserRewriteDecl();
229 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
231 FailureOr
<ast::UserRewriteDecl
*> parseUserPDLLRewriteDecl(
232 const ast::Name
&name
, bool isInline
,
233 ArrayRef
<ast::VariableDecl
*> arguments
, ast::DeclScope
*argumentScope
,
234 ArrayRef
<ast::VariableDecl
*> results
, ast::Type resultType
);
236 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
237 /// effectively the same syntax, and only differ on slight semantics (given
238 /// the different parsing contexts).
239 template <typename T
, typename ParseUserPDLLDeclFnT
>
240 FailureOr
<T
*> parseUserConstraintOrRewriteDecl(
241 ParseUserPDLLDeclFnT
&&parseUserPDLLFn
, ParserContext declContext
,
242 StringRef anonymousNamePrefix
, bool isInline
);
244 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
245 /// These decls have effectively the same syntax.
246 template <typename T
>
247 FailureOr
<T
*> parseUserNativeConstraintOrRewriteDecl(
248 const ast::Name
&name
, bool isInline
,
249 ArrayRef
<ast::VariableDecl
*> arguments
,
250 ArrayRef
<ast::VariableDecl
*> results
, ast::Type resultType
);
252 /// Parse the functional signature (i.e. the arguments and results) of a
253 /// UserConstraintDecl or UserRewriteDecl.
254 LogicalResult
parseUserConstraintOrRewriteSignature(
255 SmallVectorImpl
<ast::VariableDecl
*> &arguments
,
256 SmallVectorImpl
<ast::VariableDecl
*> &results
,
257 ast::DeclScope
*&argumentScope
, ast::Type
&resultType
);
259 /// Validate the return (which if present is specified by bodyIt) of a
260 /// UserConstraintDecl or UserRewriteDecl.
261 LogicalResult
validateUserConstraintOrRewriteReturn(
262 StringRef declType
, ast::CompoundStmt
*body
,
263 ArrayRef
<ast::Stmt
*>::iterator bodyIt
,
264 ArrayRef
<ast::Stmt
*>::iterator bodyE
,
265 ArrayRef
<ast::VariableDecl
*> results
, ast::Type
&resultType
);
267 FailureOr
<ast::CompoundStmt
*>
268 parseLambdaBody(function_ref
<LogicalResult(ast::Stmt
*&)> processStatementFn
,
269 bool expectTerminalSemicolon
= true);
270 FailureOr
<ast::CompoundStmt
*> parsePatternLambdaBody();
271 FailureOr
<ast::Decl
*> parsePatternDecl();
272 LogicalResult
parsePatternDeclMetadata(ParsedPatternMetadata
&metadata
);
274 /// Check to see if a decl has already been defined with the given name, if
275 /// one has emit and error and return failure. Returns success otherwise.
276 LogicalResult
checkDefineNamedDecl(const ast::Name
&name
);
278 /// Try to define a variable decl with the given components, returns the
279 /// variable on success.
280 FailureOr
<ast::VariableDecl
*>
281 defineVariableDecl(StringRef name
, SMRange nameLoc
, ast::Type type
,
283 ArrayRef
<ast::ConstraintRef
> constraints
);
284 FailureOr
<ast::VariableDecl
*>
285 defineVariableDecl(StringRef name
, SMRange nameLoc
, ast::Type type
,
286 ArrayRef
<ast::ConstraintRef
> constraints
);
288 /// Parse the constraint reference list for a variable decl.
289 LogicalResult
parseVariableDeclConstraintList(
290 SmallVectorImpl
<ast::ConstraintRef
> &constraints
);
292 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
293 FailureOr
<ast::Expr
*> parseTypeConstraintExpr();
295 /// Try to parse a single reference to a constraint. `typeConstraint` is the
296 /// location of a previously parsed type constraint for the entity that will
297 /// be constrained by the parsed constraint. `existingConstraints` are any
298 /// existing constraints that have already been parsed for the same entity
299 /// that will be constrained by this constraint. `allowInlineTypeConstraints`
300 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
301 FailureOr
<ast::ConstraintRef
>
302 parseConstraint(std::optional
<SMRange
> &typeConstraint
,
303 ArrayRef
<ast::ConstraintRef
> existingConstraints
,
304 bool allowInlineTypeConstraints
);
306 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
307 /// argument or result variable. The constraints for these variables do not
308 /// allow inline type constraints, and only permit a single constraint.
309 FailureOr
<ast::ConstraintRef
> parseArgOrResultConstraint();
311 //===--------------------------------------------------------------------===//
314 FailureOr
<ast::Expr
*> parseExpr();
316 /// Identifier expressions.
317 FailureOr
<ast::Expr
*> parseAttributeExpr();
318 FailureOr
<ast::Expr
*> parseCallExpr(ast::Expr
*parentExpr
,
319 bool isNegated
= false);
320 FailureOr
<ast::Expr
*> parseDeclRefExpr(StringRef name
, SMRange loc
);
321 FailureOr
<ast::Expr
*> parseIdentifierExpr();
322 FailureOr
<ast::Expr
*> parseInlineConstraintLambdaExpr();
323 FailureOr
<ast::Expr
*> parseInlineRewriteLambdaExpr();
324 FailureOr
<ast::Expr
*> parseMemberAccessExpr(ast::Expr
*parentExpr
);
325 FailureOr
<ast::Expr
*> parseNegatedExpr();
326 FailureOr
<ast::OpNameDecl
*> parseOperationName(bool allowEmptyName
= false);
327 FailureOr
<ast::OpNameDecl
*> parseWrappedOperationName(bool allowEmptyName
);
328 FailureOr
<ast::Expr
*>
329 parseOperationExpr(OpResultTypeContext inputResultTypeContext
=
330 OpResultTypeContext::Explicit
);
331 FailureOr
<ast::Expr
*> parseTupleExpr();
332 FailureOr
<ast::Expr
*> parseTypeExpr();
333 FailureOr
<ast::Expr
*> parseUnderscoreExpr();
335 //===--------------------------------------------------------------------===//
338 FailureOr
<ast::Stmt
*> parseStmt(bool expectTerminalSemicolon
= true);
339 FailureOr
<ast::CompoundStmt
*> parseCompoundStmt();
340 FailureOr
<ast::EraseStmt
*> parseEraseStmt();
341 FailureOr
<ast::LetStmt
*> parseLetStmt();
342 FailureOr
<ast::ReplaceStmt
*> parseReplaceStmt();
343 FailureOr
<ast::ReturnStmt
*> parseReturnStmt();
344 FailureOr
<ast::RewriteStmt
*> parseRewriteStmt();
346 //===--------------------------------------------------------------------===//
348 //===--------------------------------------------------------------------===//
350 //===--------------------------------------------------------------------===//
353 /// Try to extract a callable from the given AST node. Returns nullptr on
355 ast::CallableDecl
*tryExtractCallableDecl(ast::Node
*node
);
357 /// Try to create a pattern decl with the given components, returning the
358 /// Pattern on success.
359 FailureOr
<ast::PatternDecl
*>
360 createPatternDecl(SMRange loc
, const ast::Name
*name
,
361 const ParsedPatternMetadata
&metadata
,
362 ast::CompoundStmt
*body
);
364 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
365 /// of results, defined as part of the signature.
367 createUserConstraintRewriteResultType(ArrayRef
<ast::VariableDecl
*> results
);
369 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
370 template <typename T
>
371 FailureOr
<T
*> createUserPDLLConstraintOrRewriteDecl(
372 const ast::Name
&name
, ArrayRef
<ast::VariableDecl
*> arguments
,
373 ArrayRef
<ast::VariableDecl
*> results
, ast::Type resultType
,
374 ast::CompoundStmt
*body
);
376 /// Try to create a variable decl with the given components, returning the
377 /// Variable on success.
378 FailureOr
<ast::VariableDecl
*>
379 createVariableDecl(StringRef name
, SMRange loc
, ast::Expr
*initializer
,
380 ArrayRef
<ast::ConstraintRef
> constraints
);
382 /// Create a variable for an argument or result defined as part of the
383 /// signature of a UserConstraintDecl/UserRewriteDecl.
384 FailureOr
<ast::VariableDecl
*>
385 createArgOrResultVariableDecl(StringRef name
, SMRange loc
,
386 const ast::ConstraintRef
&constraint
);
388 /// Validate the constraints used to constraint a variable decl.
389 /// `inferredType` is the type of the variable inferred by the constraints
390 /// within the list, and is updated to the most refined type as determined by
391 /// the constraints. Returns success if the constraint list is valid, failure
394 validateVariableConstraints(ArrayRef
<ast::ConstraintRef
> constraints
,
395 ast::Type
&inferredType
);
396 /// Validate a single reference to a constraint. `inferredType` contains the
397 /// currently inferred variabled type and is refined within the type defined
398 /// by the constraint. Returns success if the constraint is valid, failure
400 LogicalResult
validateVariableConstraint(const ast::ConstraintRef
&ref
,
401 ast::Type
&inferredType
);
402 LogicalResult
validateTypeConstraintExpr(const ast::Expr
*typeExpr
);
403 LogicalResult
validateTypeRangeConstraintExpr(const ast::Expr
*typeExpr
);
405 //===--------------------------------------------------------------------===//
408 FailureOr
<ast::CallExpr
*>
409 createCallExpr(SMRange loc
, ast::Expr
*parentExpr
,
410 MutableArrayRef
<ast::Expr
*> arguments
,
411 bool isNegated
= false);
412 FailureOr
<ast::DeclRefExpr
*> createDeclRefExpr(SMRange loc
, ast::Decl
*decl
);
413 FailureOr
<ast::DeclRefExpr
*>
414 createInlineVariableExpr(ast::Type type
, StringRef name
, SMRange loc
,
415 ArrayRef
<ast::ConstraintRef
> constraints
);
416 FailureOr
<ast::MemberAccessExpr
*>
417 createMemberAccessExpr(ast::Expr
*parentExpr
, StringRef name
, SMRange loc
);
419 /// Validate the member access `name` into the given parent expression. On
420 /// success, this also returns the type of the member accessed.
421 FailureOr
<ast::Type
> validateMemberAccess(ast::Expr
*parentExpr
,
422 StringRef name
, SMRange loc
);
423 FailureOr
<ast::OperationExpr
*>
424 createOperationExpr(SMRange loc
, const ast::OpNameDecl
*name
,
425 OpResultTypeContext resultTypeContext
,
426 SmallVectorImpl
<ast::Expr
*> &operands
,
427 MutableArrayRef
<ast::NamedAttributeDecl
*> attributes
,
428 SmallVectorImpl
<ast::Expr
*> &results
);
430 validateOperationOperands(SMRange loc
, std::optional
<StringRef
> name
,
431 const ods::Operation
*odsOp
,
432 SmallVectorImpl
<ast::Expr
*> &operands
);
433 LogicalResult
validateOperationResults(SMRange loc
,
434 std::optional
<StringRef
> name
,
435 const ods::Operation
*odsOp
,
436 SmallVectorImpl
<ast::Expr
*> &results
);
437 void checkOperationResultTypeInferrence(SMRange loc
, StringRef name
,
438 const ods::Operation
*odsOp
);
439 LogicalResult
validateOperationOperandsOrResults(
440 StringRef groupName
, SMRange loc
, std::optional
<SMRange
> odsOpLoc
,
441 std::optional
<StringRef
> name
, SmallVectorImpl
<ast::Expr
*> &values
,
442 ArrayRef
<ods::OperandOrResult
> odsValues
, ast::Type singleTy
,
443 ast::RangeType rangeTy
);
444 FailureOr
<ast::TupleExpr
*> createTupleExpr(SMRange loc
,
445 ArrayRef
<ast::Expr
*> elements
,
446 ArrayRef
<StringRef
> elementNames
);
448 //===--------------------------------------------------------------------===//
451 FailureOr
<ast::EraseStmt
*> createEraseStmt(SMRange loc
, ast::Expr
*rootOp
);
452 FailureOr
<ast::ReplaceStmt
*>
453 createReplaceStmt(SMRange loc
, ast::Expr
*rootOp
,
454 MutableArrayRef
<ast::Expr
*> replValues
);
455 FailureOr
<ast::RewriteStmt
*>
456 createRewriteStmt(SMRange loc
, ast::Expr
*rootOp
,
457 ast::CompoundStmt
*rewriteBody
);
459 //===--------------------------------------------------------------------===//
461 //===--------------------------------------------------------------------===//
463 /// The set of various code completion methods. Every completion method
464 /// returns `failure` to stop the parsing process after providing completion
467 LogicalResult
codeCompleteMemberAccess(ast::Expr
*parentExpr
);
468 LogicalResult
codeCompleteAttributeName(std::optional
<StringRef
> opName
);
469 LogicalResult
codeCompleteConstraintName(ast::Type inferredType
,
470 bool allowInlineTypeConstraints
);
471 LogicalResult
codeCompleteDialectName();
472 LogicalResult
codeCompleteOperationName(StringRef dialectName
);
473 LogicalResult
codeCompletePatternMetadata();
474 LogicalResult
codeCompleteIncludeFilename(StringRef curPath
);
476 void codeCompleteCallSignature(ast::Node
*parent
, unsigned currentNumArgs
);
477 void codeCompleteOperationOperandsSignature(std::optional
<StringRef
> opName
,
478 unsigned currentNumOperands
);
479 void codeCompleteOperationResultsSignature(std::optional
<StringRef
> opName
,
480 unsigned currentNumResults
);
482 //===--------------------------------------------------------------------===//
484 //===--------------------------------------------------------------------===//
486 /// If the current token has the specified kind, consume it and return true.
487 /// If not, return false.
488 bool consumeIf(Token::Kind kind
) {
489 if (curToken
.isNot(kind
))
495 /// Advance the current lexer onto the next token.
496 void consumeToken() {
497 assert(curToken
.isNot(Token::eof
, Token::error
) &&
498 "shouldn't advance past EOF or errors");
499 curToken
= lexer
.lexToken();
502 /// Advance the current lexer onto the next token, asserting what the expected
503 /// current token is. This is preferred to the above method because it leads
504 /// to more self-documenting code with better checking.
505 void consumeToken(Token::Kind kind
) {
506 assert(curToken
.is(kind
) && "consumed an unexpected token");
510 /// Reset the lexer to the location at the given position.
511 void resetToken(SMRange tokLoc
) {
512 lexer
.resetPointer(tokLoc
.Start
.getPointer());
513 curToken
= lexer
.lexToken();
516 /// Consume the specified token if present and return success. On failure,
517 /// output a diagnostic and return failure.
518 LogicalResult
parseToken(Token::Kind kind
, const Twine
&msg
) {
519 if (curToken
.getKind() != kind
)
520 return emitError(curToken
.getLoc(), msg
);
524 LogicalResult
emitError(SMRange loc
, const Twine
&msg
) {
525 lexer
.emitError(loc
, msg
);
528 LogicalResult
emitError(const Twine
&msg
) {
529 return emitError(curToken
.getLoc(), msg
);
531 LogicalResult
emitErrorAndNote(SMRange loc
, const Twine
&msg
, SMRange noteLoc
,
533 lexer
.emitErrorAndNote(loc
, msg
, noteLoc
, note
);
537 //===--------------------------------------------------------------------===//
539 //===--------------------------------------------------------------------===//
541 /// The owning AST context.
544 /// The lexer of this parser.
547 /// The current token within the lexer.
550 /// A flag indicating if the parser should add documentation to AST nodes when
552 bool enableDocumentation
;
554 /// The most recently defined decl scope.
555 ast::DeclScope
*curDeclScope
= nullptr;
556 llvm::SpecificBumpPtrAllocator
<ast::DeclScope
> scopeAllocator
;
558 /// The current context of the parser.
559 ParserContext parserContext
= ParserContext::Global
;
561 /// Cached types to simplify verification and expression creation.
562 ast::Type typeTy
, valueTy
;
563 ast::RangeType typeRangeTy
, valueRangeTy
;
566 /// A counter used when naming anonymous constraints and rewrites.
567 unsigned anonymousDeclNameCounter
= 0;
569 /// The optional code completion context.
570 CodeCompleteContext
*codeCompleteContext
;
574 FailureOr
<ast::Module
*> Parser::parseModule() {
575 SMLoc moduleLoc
= curToken
.getStartLoc();
578 // Parse the top-level decls of the module.
579 SmallVector
<ast::Decl
*> decls
;
580 if (failed(parseModuleBody(decls
)))
581 return popDeclScope(), failure();
584 return ast::Module::create(ctx
, moduleLoc
, decls
);
587 LogicalResult
Parser::parseModuleBody(SmallVectorImpl
<ast::Decl
*> &decls
) {
588 while (curToken
.isNot(Token::eof
)) {
589 if (curToken
.is(Token::directive
)) {
590 if (failed(parseDirective(decls
)))
595 FailureOr
<ast::Decl
*> decl
= parseTopLevelDecl();
598 decls
.push_back(*decl
);
603 ast::Expr
*Parser::convertOpToValue(const ast::Expr
*opExpr
) {
604 return ast::AllResultsMemberAccessExpr::create(ctx
, opExpr
->getLoc(), opExpr
,
608 LogicalResult
Parser::convertExpressionTo(
609 ast::Expr
*&expr
, ast::Type type
,
610 function_ref
<void(ast::Diagnostic
&diag
)> noteAttachFn
) {
611 ast::Type exprType
= expr
->getType();
612 if (exprType
== type
)
615 auto emitConvertError
= [&]() -> ast::InFlightDiagnostic
{
616 ast::InFlightDiagnostic diag
= ctx
.getDiagEngine().emitError(
617 expr
->getLoc(), llvm::formatv("unable to convert expression of type "
618 "`{0}` to the expected type of "
626 if (auto exprOpType
= exprType
.dyn_cast
<ast::OperationType
>())
627 return convertOpExpressionTo(expr
, exprOpType
, type
, emitConvertError
);
629 // FIXME: Decide how to allow/support converting a single result to multiple,
630 // and multiple to a single result. For now, we just allow Single->Range,
631 // but this isn't something really supported in the PDL dialect. We should
632 // figure out some way to support both.
633 if ((exprType
== valueTy
|| exprType
== valueRangeTy
) &&
634 (type
== valueTy
|| type
== valueRangeTy
))
636 if ((exprType
== typeTy
|| exprType
== typeRangeTy
) &&
637 (type
== typeTy
|| type
== typeRangeTy
))
640 // Handle tuple types.
641 if (auto exprTupleType
= exprType
.dyn_cast
<ast::TupleType
>())
642 return convertTupleExpressionTo(expr
, exprTupleType
, type
, emitConvertError
,
645 return emitConvertError();
648 LogicalResult
Parser::convertOpExpressionTo(
649 ast::Expr
*&expr
, ast::OperationType exprType
, ast::Type type
,
650 function_ref
<ast::InFlightDiagnostic()> emitErrorFn
) {
651 // Two operation types are compatible if they have the same name, or if the
652 // expected type is more general.
653 if (auto opType
= type
.dyn_cast
<ast::OperationType
>()) {
654 if (opType
.getName())
655 return emitErrorFn();
659 // An operation can always convert to a ValueRange.
660 if (type
== valueRangeTy
) {
661 expr
= ast::AllResultsMemberAccessExpr::create(ctx
, expr
->getLoc(), expr
,
666 // Allow conversion to a single value by constraining the result range.
667 if (type
== valueTy
) {
668 // If the operation is registered, we can verify if it can ever have a
670 if (const ods::Operation
*odsOp
= exprType
.getODSOperation()) {
671 if (odsOp
->getResults().empty()) {
672 return emitErrorFn()->attachNote(
673 llvm::formatv("see the definition of `{0}`, which was defined "
679 unsigned numSingleResults
= llvm::count_if(
680 odsOp
->getResults(), [](const ods::OperandOrResult
&result
) {
681 return result
.getVariableLengthKind() ==
682 ods::VariableLengthKind::Single
;
684 if (numSingleResults
> 1) {
685 return emitErrorFn()->attachNote(
686 llvm::formatv("see the definition of `{0}`, which was defined "
687 "with at least {1} results",
688 odsOp
->getName(), numSingleResults
),
693 expr
= ast::AllResultsMemberAccessExpr::create(ctx
, expr
->getLoc(), expr
,
697 return emitErrorFn();
700 LogicalResult
Parser::convertTupleExpressionTo(
701 ast::Expr
*&expr
, ast::TupleType exprType
, ast::Type type
,
702 function_ref
<ast::InFlightDiagnostic()> emitErrorFn
,
703 function_ref
<void(ast::Diagnostic
&diag
)> noteAttachFn
) {
704 // Handle conversions between tuples.
705 if (auto tupleType
= type
.dyn_cast
<ast::TupleType
>()) {
706 if (tupleType
.size() != exprType
.size())
707 return emitErrorFn();
709 // Build a new tuple expression using each of the elements of the current
711 SmallVector
<ast::Expr
*> newExprs
;
712 for (unsigned i
= 0, e
= exprType
.size(); i
< e
; ++i
) {
713 newExprs
.push_back(ast::MemberAccessExpr::create(
714 ctx
, expr
->getLoc(), expr
, llvm::to_string(i
),
715 exprType
.getElementTypes()[i
]));
717 auto diagFn
= [&](ast::Diagnostic
&diag
) {
718 diag
.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
723 if (failed(convertExpressionTo(newExprs
.back(),
724 tupleType
.getElementTypes()[i
], diagFn
)))
727 expr
= ast::TupleExpr::create(ctx
, expr
->getLoc(), newExprs
,
728 tupleType
.getElementNames());
732 // Handle conversion to a range.
733 auto convertToRange
= [&](ArrayRef
<ast::Type
> allowedElementTypes
,
734 ast::RangeType resultTy
) -> LogicalResult
{
735 // TODO: We currently only allow range conversion within a rewrite context.
736 if (parserContext
!= ParserContext::Rewrite
) {
737 return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
738 "only allowed within a rewrite context");
741 // All of the tuple elements must be allowed types.
742 for (ast::Type elementType
: exprType
.getElementTypes())
743 if (!llvm::is_contained(allowedElementTypes
, elementType
))
744 return emitErrorFn();
746 // Build a new tuple expression using each of the elements of the current
748 SmallVector
<ast::Expr
*> newExprs
;
749 for (unsigned i
= 0, e
= exprType
.size(); i
< e
; ++i
) {
750 newExprs
.push_back(ast::MemberAccessExpr::create(
751 ctx
, expr
->getLoc(), expr
, llvm::to_string(i
),
752 exprType
.getElementTypes()[i
]));
754 expr
= ast::RangeExpr::create(ctx
, expr
->getLoc(), newExprs
, resultTy
);
757 if (type
== valueRangeTy
)
758 return convertToRange({valueTy
, valueRangeTy
}, valueRangeTy
);
759 if (type
== typeRangeTy
)
760 return convertToRange({typeTy
, typeRangeTy
}, typeRangeTy
);
762 return emitErrorFn();
765 //===----------------------------------------------------------------------===//
768 LogicalResult
Parser::parseDirective(SmallVectorImpl
<ast::Decl
*> &decls
) {
769 StringRef directive
= curToken
.getSpelling();
770 if (directive
== "#include")
771 return parseInclude(decls
);
773 return emitError("unknown directive `" + directive
+ "`");
776 LogicalResult
Parser::parseInclude(SmallVectorImpl
<ast::Decl
*> &decls
) {
777 SMRange loc
= curToken
.getLoc();
778 consumeToken(Token::directive
);
780 // Handle code completion of the include file path.
781 if (curToken
.is(Token::code_complete_string
))
782 return codeCompleteIncludeFilename(curToken
.getStringValue());
784 // Parse the file being included.
785 if (!curToken
.isString())
786 return emitError(loc
,
787 "expected string file name after `include` directive");
788 SMRange fileLoc
= curToken
.getLoc();
789 std::string filenameStr
= curToken
.getStringValue();
790 StringRef filename
= filenameStr
;
793 // Check the type of include. If ending with `.pdll`, this is another pdl file
794 // to be parsed along with the current module.
795 if (filename
.ends_with(".pdll")) {
796 if (failed(lexer
.pushInclude(filename
, fileLoc
)))
797 return emitError(fileLoc
,
798 "unable to open include file `" + filename
+ "`");
800 // If we added the include successfully, parse it into the current module.
801 // Make sure to update to the next token after we finish parsing the nested
803 curToken
= lexer
.lexToken();
804 LogicalResult result
= parseModuleBody(decls
);
805 curToken
= lexer
.lexToken();
809 // Otherwise, this must be a `.td` include.
810 if (filename
.ends_with(".td"))
811 return parseTdInclude(filename
, fileLoc
, decls
);
813 return emitError(fileLoc
,
814 "expected include filename to end with `.pdll` or `.td`");
817 LogicalResult
Parser::parseTdInclude(StringRef filename
, llvm::SMRange fileLoc
,
818 SmallVectorImpl
<ast::Decl
*> &decls
) {
819 llvm::SourceMgr
&parserSrcMgr
= lexer
.getSourceMgr();
821 // Use the source manager to open the file, but don't yet add it.
822 std::string includedFile
;
823 llvm::ErrorOr
<std::unique_ptr
<llvm::MemoryBuffer
>> includeBuffer
=
824 parserSrcMgr
.OpenIncludeFile(filename
.str(), includedFile
);
826 return emitError(fileLoc
, "unable to open include file `" + filename
+ "`");
828 // Setup the source manager for parsing the tablegen file.
829 llvm::SourceMgr tdSrcMgr
;
830 tdSrcMgr
.AddNewSourceBuffer(std::move(*includeBuffer
), SMLoc());
831 tdSrcMgr
.setIncludeDirs(parserSrcMgr
.getIncludeDirs());
833 // This class provides a context argument for the llvm::SourceMgr diagnostic
835 struct DiagHandlerContext
{
839 } handlerContext
{*this, filename
, fileLoc
};
841 // Set the diagnostic handler for the tablegen source manager.
842 tdSrcMgr
.setDiagHandler(
843 [](const llvm::SMDiagnostic
&diag
, void *rawHandlerContext
) {
844 auto *ctx
= reinterpret_cast<DiagHandlerContext
*>(rawHandlerContext
);
845 (void)ctx
->parser
.emitError(
847 llvm::formatv("error while processing include file `{0}`: {1}",
848 ctx
->filename
, diag
.getMessage()));
852 // Parse the tablegen file.
853 llvm::RecordKeeper tdRecords
;
854 if (llvm::TableGenParseFile(tdSrcMgr
, tdRecords
))
857 // Process the parsed records.
858 processTdIncludeRecords(tdRecords
, decls
);
860 // After we are done processing, move all of the tablegen source buffers to
861 // the main parser source mgr. This allows for directly using source locations
862 // from the .td files without needing to remap them.
863 parserSrcMgr
.takeSourceBuffersFrom(tdSrcMgr
, fileLoc
.End
);
867 void Parser::processTdIncludeRecords(llvm::RecordKeeper
&tdRecords
,
868 SmallVectorImpl
<ast::Decl
*> &decls
) {
869 // Return the length kind of the given value.
870 auto getLengthKind
= [](const auto &value
) {
871 if (value
.isOptional())
872 return ods::VariableLengthKind::Optional
;
873 return value
.isVariadic() ? ods::VariableLengthKind::Variadic
874 : ods::VariableLengthKind::Single
;
877 // Insert a type constraint into the ODS context.
878 ods::Context
&odsContext
= ctx
.getODSContext();
879 auto addTypeConstraint
= [&](const tblgen::NamedTypeConstraint
&cst
)
880 -> const ods::TypeConstraint
& {
881 return odsContext
.insertTypeConstraint(
882 cst
.constraint
.getUniqueDefName(),
883 processDoc(cst
.constraint
.getSummary()),
884 cst
.constraint
.getCPPClassName());
886 auto convertLocToRange
= [&](llvm::SMLoc loc
) -> llvm::SMRange
{
887 return {loc
, llvm::SMLoc::getFromPointer(loc
.getPointer() + 1)};
890 // Process the parsed tablegen records to build ODS information.
892 for (llvm::Record
*def
: tdRecords
.getAllDerivedDefinitions("Op")) {
893 tblgen::Operator
op(def
);
895 // Check to see if this operation is known to support type inferrence.
896 bool supportsResultTypeInferrence
=
897 op
.getTrait("::mlir::InferTypeOpInterface::Trait");
899 auto [odsOp
, inserted
] = odsContext
.insertOperation(
900 op
.getOperationName(), processDoc(op
.getSummary()),
901 processAndFormatDoc(op
.getDescription()), op
.getQualCppClassName(),
902 supportsResultTypeInferrence
, op
.getLoc().front());
904 // Ignore operations that have already been added.
908 for (const tblgen::NamedAttribute
&attr
: op
.getAttributes()) {
909 odsOp
->appendAttribute(attr
.name
, attr
.attr
.isOptional(),
910 odsContext
.insertAttributeConstraint(
911 attr
.attr
.getUniqueDefName(),
912 processDoc(attr
.attr
.getSummary()),
913 attr
.attr
.getStorageType()));
915 for (const tblgen::NamedTypeConstraint
&operand
: op
.getOperands()) {
916 odsOp
->appendOperand(operand
.name
, getLengthKind(operand
),
917 addTypeConstraint(operand
));
919 for (const tblgen::NamedTypeConstraint
&result
: op
.getResults()) {
920 odsOp
->appendResult(result
.name
, getLengthKind(result
),
921 addTypeConstraint(result
));
925 auto shouldBeSkipped
= [this](llvm::Record
*def
) {
926 return def
->isAnonymous() || curDeclScope
->lookup(def
->getName()) ||
927 def
->isSubClassOf("DeclareInterfaceMethods");
930 /// Attr constraints.
931 for (llvm::Record
*def
: tdRecords
.getAllDerivedDefinitions("Attr")) {
932 if (shouldBeSkipped(def
))
935 tblgen::Attribute
constraint(def
);
936 decls
.push_back(createODSNativePDLLConstraintDecl
<ast::AttrConstraintDecl
>(
937 constraint
, convertLocToRange(def
->getLoc().front()), attrTy
,
938 constraint
.getStorageType()));
940 /// Type constraints.
941 for (llvm::Record
*def
: tdRecords
.getAllDerivedDefinitions("Type")) {
942 if (shouldBeSkipped(def
))
945 tblgen::TypeConstraint
constraint(def
);
946 decls
.push_back(createODSNativePDLLConstraintDecl
<ast::TypeConstraintDecl
>(
947 constraint
, convertLocToRange(def
->getLoc().front()), typeTy
,
948 constraint
.getCPPClassName()));
951 ast::Type opTy
= ast::OperationType::get(ctx
);
952 for (llvm::Record
*def
: tdRecords
.getAllDerivedDefinitions("OpInterface")) {
953 if (shouldBeSkipped(def
))
956 SMRange loc
= convertLocToRange(def
->getLoc().front());
958 std::string cppClassName
=
959 llvm::formatv("{0}::{1}", def
->getValueAsString("cppNamespace"),
960 def
->getValueAsString("cppInterfaceName"))
962 std::string codeBlock
=
963 llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
968 processAndFormatDoc(def
->getValueAsString("description"));
969 decls
.push_back(createODSNativePDLLConstraintDecl
<ast::OpConstraintDecl
>(
970 def
->getName(), codeBlock
, loc
, opTy
, cppClassName
, desc
));
974 template <typename ConstraintT
>
975 ast::Decl
*Parser::createODSNativePDLLConstraintDecl(
976 StringRef name
, StringRef codeBlock
, SMRange loc
, ast::Type type
,
977 StringRef nativeType
, StringRef docString
) {
978 // Build the single input parameter.
979 ast::DeclScope
*argScope
= pushDeclScope();
980 auto *paramVar
= ast::VariableDecl::create(
981 ctx
, ast::Name::create(ctx
, "self", loc
), type
,
982 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx
, loc
)));
983 argScope
->add(paramVar
);
986 // Build the native constraint.
987 auto *constraintDecl
= ast::UserConstraintDecl::createNative(
988 ctx
, ast::Name::create(ctx
, name
, loc
), paramVar
,
989 /*results=*/std::nullopt
, codeBlock
, ast::TupleType::get(ctx
),
991 constraintDecl
->setDocComment(ctx
, docString
);
992 curDeclScope
->add(constraintDecl
);
993 return constraintDecl
;
996 template <typename ConstraintT
>
998 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint
&constraint
,
999 SMRange loc
, ast::Type type
,
1000 StringRef nativeType
) {
1001 // Format the condition template.
1002 tblgen::FmtContext fmtContext
;
1003 fmtContext
.withSelf("self");
1004 std::string codeBlock
= tblgen::tgfmt(
1005 "return ::mlir::success(" + constraint
.getConditionTemplate() + ");",
1008 // If documentation was enabled, build the doc string for the generated
1009 // constraint. It would be nice to do this lazily, but TableGen information is
1010 // destroyed after we finish parsing the file.
1011 std::string docString
;
1012 if (enableDocumentation
) {
1013 StringRef desc
= constraint
.getDescription();
1014 docString
= processAndFormatDoc(
1015 constraint
.getSummary() +
1016 (desc
.empty() ? "" : ("\n\n" + constraint
.getDescription())));
1019 return createODSNativePDLLConstraintDecl
<ConstraintT
>(
1020 constraint
.getUniqueDefName(), codeBlock
, loc
, type
, nativeType
,
1024 //===----------------------------------------------------------------------===//
1027 FailureOr
<ast::Decl
*> Parser::parseTopLevelDecl() {
1028 FailureOr
<ast::Decl
*> decl
;
1029 switch (curToken
.getKind()) {
1030 case Token::kw_Constraint
:
1031 decl
= parseUserConstraintDecl();
1033 case Token::kw_Pattern
:
1034 decl
= parsePatternDecl();
1036 case Token::kw_Rewrite
:
1037 decl
= parseUserRewriteDecl();
1040 return emitError("expected top-level declaration, such as a `Pattern`");
1045 // If the decl has a name, add it to the current scope.
1046 if (const ast::Name
*name
= (*decl
)->getName()) {
1047 if (failed(checkDefineNamedDecl(*name
)))
1049 curDeclScope
->add(*decl
);
1054 FailureOr
<ast::NamedAttributeDecl
*>
1055 Parser::parseNamedAttributeDecl(std::optional
<StringRef
> parentOpName
) {
1056 // Check for name code completion.
1057 if (curToken
.is(Token::code_complete
))
1058 return codeCompleteAttributeName(parentOpName
);
1060 std::string attrNameStr
;
1061 if (curToken
.isString())
1062 attrNameStr
= curToken
.getStringValue();
1063 else if (curToken
.is(Token::identifier
) || curToken
.isKeyword())
1064 attrNameStr
= curToken
.getSpelling().str();
1066 return emitError("expected identifier or string attribute name");
1067 const auto &name
= ast::Name::create(ctx
, attrNameStr
, curToken
.getLoc());
1070 // Check for a value of the attribute.
1071 ast::Expr
*attrValue
= nullptr;
1072 if (consumeIf(Token::equal
)) {
1073 FailureOr
<ast::Expr
*> attrExpr
= parseExpr();
1074 if (failed(attrExpr
))
1076 attrValue
= *attrExpr
;
1078 // If there isn't a concrete value, create an expression representing a
1080 attrValue
= ast::AttributeExpr::create(ctx
, name
.getLoc(), "unit");
1083 return ast::NamedAttributeDecl::create(ctx
, name
, attrValue
);
1086 FailureOr
<ast::CompoundStmt
*> Parser::parseLambdaBody(
1087 function_ref
<LogicalResult(ast::Stmt
*&)> processStatementFn
,
1088 bool expectTerminalSemicolon
) {
1089 consumeToken(Token::equal_arrow
);
1091 // Parse the single statement of the lambda body.
1092 SMLoc bodyStartLoc
= curToken
.getStartLoc();
1094 FailureOr
<ast::Stmt
*> singleStatement
= parseStmt(expectTerminalSemicolon
);
1095 bool failedToParse
=
1096 failed(singleStatement
) || failed(processStatementFn(*singleStatement
));
1101 SMRange
bodyLoc(bodyStartLoc
, curToken
.getStartLoc());
1102 return ast::CompoundStmt::create(ctx
, bodyLoc
, *singleStatement
);
1105 FailureOr
<ast::VariableDecl
*> Parser::parseArgumentDecl() {
1106 // Ensure that the argument is named.
1107 if (curToken
.isNot(Token::identifier
) && !curToken
.isDependentKeyword())
1108 return emitError("expected identifier argument name");
1110 // Parse the argument similarly to a normal variable.
1111 StringRef name
= curToken
.getSpelling();
1112 SMRange nameLoc
= curToken
.getLoc();
1116 parseToken(Token::colon
, "expected `:` before argument constraint")))
1119 FailureOr
<ast::ConstraintRef
> cst
= parseArgOrResultConstraint();
1123 return createArgOrResultVariableDecl(name
, nameLoc
, *cst
);
1126 FailureOr
<ast::VariableDecl
*> Parser::parseResultDecl(unsigned resultNum
) {
1127 // Check to see if this result is named.
1128 if (curToken
.is(Token::identifier
) || curToken
.isDependentKeyword()) {
1129 // Check to see if this name actually refers to a Constraint.
1130 if (!curDeclScope
->lookup
<ast::ConstraintDecl
>(curToken
.getSpelling())) {
1131 // If it wasn't a constraint, parse the result similarly to a variable. If
1132 // there is already an existing decl, we will emit an error when defining
1133 // this variable later.
1134 StringRef name
= curToken
.getSpelling();
1135 SMRange nameLoc
= curToken
.getLoc();
1138 if (failed(parseToken(Token::colon
,
1139 "expected `:` before result constraint")))
1142 FailureOr
<ast::ConstraintRef
> cst
= parseArgOrResultConstraint();
1146 return createArgOrResultVariableDecl(name
, nameLoc
, *cst
);
1150 // If it isn't named, we parse the constraint directly and create an unnamed
1152 FailureOr
<ast::ConstraintRef
> cst
= parseArgOrResultConstraint();
1156 return createArgOrResultVariableDecl("", cst
->referenceLoc
, *cst
);
1159 FailureOr
<ast::UserConstraintDecl
*>
1160 Parser::parseUserConstraintDecl(bool isInline
) {
1161 // Constraints and rewrites have very similar formats, dispatch to a shared
1162 // interface for parsing.
1163 return parseUserConstraintOrRewriteDecl
<ast::UserConstraintDecl
>(
1164 [&](auto &&...args
) {
1165 return this->parseUserPDLLConstraintDecl(args
...);
1167 ParserContext::Constraint
, "constraint", isInline
);
1170 FailureOr
<ast::UserConstraintDecl
*> Parser::parseInlineUserConstraintDecl() {
1171 FailureOr
<ast::UserConstraintDecl
*> decl
=
1172 parseUserConstraintDecl(/*isInline=*/true);
1173 if (failed(decl
) || failed(checkDefineNamedDecl((*decl
)->getName())))
1176 curDeclScope
->add(*decl
);
1180 FailureOr
<ast::UserConstraintDecl
*> Parser::parseUserPDLLConstraintDecl(
1181 const ast::Name
&name
, bool isInline
,
1182 ArrayRef
<ast::VariableDecl
*> arguments
, ast::DeclScope
*argumentScope
,
1183 ArrayRef
<ast::VariableDecl
*> results
, ast::Type resultType
) {
1184 // Push the argument scope back onto the list, so that the body can
1185 // reference arguments.
1186 pushDeclScope(argumentScope
);
1188 // Parse the body of the constraint. The body is either defined as a compound
1189 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1190 ast::CompoundStmt
*body
;
1191 if (curToken
.is(Token::equal_arrow
)) {
1192 FailureOr
<ast::CompoundStmt
*> bodyResult
= parseLambdaBody(
1193 [&](ast::Stmt
*&stmt
) -> LogicalResult
{
1194 ast::Expr
*stmtExpr
= dyn_cast
<ast::Expr
>(stmt
);
1196 return emitError(stmt
->getLoc(),
1197 "expected `Constraint` lambda body to contain a "
1198 "single expression");
1200 stmt
= ast::ReturnStmt::create(ctx
, stmt
->getLoc(), stmtExpr
);
1203 /*expectTerminalSemicolon=*/!isInline
);
1204 if (failed(bodyResult
))
1208 FailureOr
<ast::CompoundStmt
*> bodyResult
= parseCompoundStmt();
1209 if (failed(bodyResult
))
1213 // Verify the structure of the body.
1214 auto bodyIt
= body
->begin(), bodyE
= body
->end();
1215 for (; bodyIt
!= bodyE
; ++bodyIt
)
1216 if (isa
<ast::ReturnStmt
>(*bodyIt
))
1218 if (failed(validateUserConstraintOrRewriteReturn(
1219 "Constraint", body
, bodyIt
, bodyE
, results
, resultType
)))
1224 return createUserPDLLConstraintOrRewriteDecl
<ast::UserConstraintDecl
>(
1225 name
, arguments
, results
, resultType
, body
);
1228 FailureOr
<ast::UserRewriteDecl
*> Parser::parseUserRewriteDecl(bool isInline
) {
1229 // Constraints and rewrites have very similar formats, dispatch to a shared
1230 // interface for parsing.
1231 return parseUserConstraintOrRewriteDecl
<ast::UserRewriteDecl
>(
1232 [&](auto &&...args
) { return this->parseUserPDLLRewriteDecl(args
...); },
1233 ParserContext::Rewrite
, "rewrite", isInline
);
1236 FailureOr
<ast::UserRewriteDecl
*> Parser::parseInlineUserRewriteDecl() {
1237 FailureOr
<ast::UserRewriteDecl
*> decl
=
1238 parseUserRewriteDecl(/*isInline=*/true);
1239 if (failed(decl
) || failed(checkDefineNamedDecl((*decl
)->getName())))
1242 curDeclScope
->add(*decl
);
1246 FailureOr
<ast::UserRewriteDecl
*> Parser::parseUserPDLLRewriteDecl(
1247 const ast::Name
&name
, bool isInline
,
1248 ArrayRef
<ast::VariableDecl
*> arguments
, ast::DeclScope
*argumentScope
,
1249 ArrayRef
<ast::VariableDecl
*> results
, ast::Type resultType
) {
1250 // Push the argument scope back onto the list, so that the body can
1251 // reference arguments.
1252 curDeclScope
= argumentScope
;
1253 ast::CompoundStmt
*body
;
1254 if (curToken
.is(Token::equal_arrow
)) {
1255 FailureOr
<ast::CompoundStmt
*> bodyResult
= parseLambdaBody(
1256 [&](ast::Stmt
*&statement
) -> LogicalResult
{
1257 if (isa
<ast::OpRewriteStmt
>(statement
))
1260 ast::Expr
*statementExpr
= dyn_cast
<ast::Expr
>(statement
);
1261 if (!statementExpr
) {
1263 statement
->getLoc(),
1264 "expected `Rewrite` lambda body to contain a single expression "
1265 "or an operation rewrite statement; such as `erase`, "
1266 "`replace`, or `rewrite`");
1269 ast::ReturnStmt::create(ctx
, statement
->getLoc(), statementExpr
);
1272 /*expectTerminalSemicolon=*/!isInline
);
1273 if (failed(bodyResult
))
1277 FailureOr
<ast::CompoundStmt
*> bodyResult
= parseCompoundStmt();
1278 if (failed(bodyResult
))
1284 // Verify the structure of the body.
1285 auto bodyIt
= body
->begin(), bodyE
= body
->end();
1286 for (; bodyIt
!= bodyE
; ++bodyIt
)
1287 if (isa
<ast::ReturnStmt
>(*bodyIt
))
1289 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body
, bodyIt
,
1290 bodyE
, results
, resultType
)))
1292 return createUserPDLLConstraintOrRewriteDecl
<ast::UserRewriteDecl
>(
1293 name
, arguments
, results
, resultType
, body
);
1296 template <typename T
, typename ParseUserPDLLDeclFnT
>
1297 FailureOr
<T
*> Parser::parseUserConstraintOrRewriteDecl(
1298 ParseUserPDLLDeclFnT
&&parseUserPDLLFn
, ParserContext declContext
,
1299 StringRef anonymousNamePrefix
, bool isInline
) {
1300 SMRange loc
= curToken
.getLoc();
1302 llvm::SaveAndRestore
saveCtx(parserContext
, declContext
);
1304 // Parse the name of the decl.
1305 const ast::Name
*name
= nullptr;
1306 if (curToken
.isNot(Token::identifier
)) {
1307 // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1308 // in C++, so being unnamed is fine.
1310 return emitError("expected identifier name");
1312 // Create a unique anonymous name to use, as the name for this decl is not
1314 std::string anonName
=
1315 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix
,
1316 anonymousDeclNameCounter
++)
1318 name
= &ast::Name::create(ctx
, anonName
, loc
);
1320 // If a name was provided, we can use it directly.
1321 name
= &ast::Name::create(ctx
, curToken
.getSpelling(), curToken
.getLoc());
1322 consumeToken(Token::identifier
);
1325 // Parse the functional signature of the decl.
1326 SmallVector
<ast::VariableDecl
*> arguments
, results
;
1327 ast::DeclScope
*argumentScope
;
1328 ast::Type resultType
;
1329 if (failed(parseUserConstraintOrRewriteSignature(arguments
, results
,
1330 argumentScope
, resultType
)))
1333 // Check to see which type of constraint this is. If the constraint contains a
1334 // compound body, this is a PDLL decl.
1335 if (curToken
.isAny(Token::l_brace
, Token::equal_arrow
))
1336 return parseUserPDLLFn(*name
, isInline
, arguments
, argumentScope
, results
,
1339 // Otherwise, this is a native decl.
1340 return parseUserNativeConstraintOrRewriteDecl
<T
>(*name
, isInline
, arguments
,
1341 results
, resultType
);
1344 template <typename T
>
1345 FailureOr
<T
*> Parser::parseUserNativeConstraintOrRewriteDecl(
1346 const ast::Name
&name
, bool isInline
,
1347 ArrayRef
<ast::VariableDecl
*> arguments
,
1348 ArrayRef
<ast::VariableDecl
*> results
, ast::Type resultType
) {
1349 // If followed by a string, the native code body has also been specified.
1350 std::string codeStrStorage
;
1351 std::optional
<StringRef
> optCodeStr
;
1352 if (curToken
.isString()) {
1353 codeStrStorage
= curToken
.getStringValue();
1354 optCodeStr
= codeStrStorage
;
1356 } else if (isInline
) {
1357 return emitError(name
.getLoc(),
1358 "external declarations must be declared in global scope");
1359 } else if (curToken
.is(Token::error
)) {
1362 if (failed(parseToken(Token::semicolon
,
1363 "expected `;` after native declaration")))
1365 // TODO: PDL should be able to support constraint results in certain
1366 // situations, we should revise this.
1367 if (std::is_same
<ast::UserConstraintDecl
, T
>::value
&& !results
.empty()) {
1369 "native Constraints currently do not support returning results");
1371 return T::createNative(ctx
, name
, arguments
, results
, optCodeStr
, resultType
);
1374 LogicalResult
Parser::parseUserConstraintOrRewriteSignature(
1375 SmallVectorImpl
<ast::VariableDecl
*> &arguments
,
1376 SmallVectorImpl
<ast::VariableDecl
*> &results
,
1377 ast::DeclScope
*&argumentScope
, ast::Type
&resultType
) {
1378 // Parse the argument list of the decl.
1379 if (failed(parseToken(Token::l_paren
, "expected `(` to start argument list")))
1382 argumentScope
= pushDeclScope();
1383 if (curToken
.isNot(Token::r_paren
)) {
1385 FailureOr
<ast::VariableDecl
*> argument
= parseArgumentDecl();
1386 if (failed(argument
))
1388 arguments
.emplace_back(*argument
);
1389 } while (consumeIf(Token::comma
));
1392 if (failed(parseToken(Token::r_paren
, "expected `)` to end argument list")))
1395 // Parse the results of the decl.
1397 if (consumeIf(Token::arrow
)) {
1398 auto parseResultFn
= [&]() -> LogicalResult
{
1399 FailureOr
<ast::VariableDecl
*> result
= parseResultDecl(results
.size());
1402 results
.emplace_back(*result
);
1406 // Check for a list of results.
1407 if (consumeIf(Token::l_paren
)) {
1409 if (failed(parseResultFn()))
1411 } while (consumeIf(Token::comma
));
1412 if (failed(parseToken(Token::r_paren
, "expected `)` to end result list")))
1415 // Otherwise, there is only one result.
1416 } else if (failed(parseResultFn())) {
1422 // Compute the result type of the decl.
1423 resultType
= createUserConstraintRewriteResultType(results
);
1425 // Verify that results are only named if there are more than one.
1426 if (results
.size() == 1 && !results
.front()->getName().getName().empty()) {
1428 results
.front()->getLoc(),
1429 "cannot create a single-element tuple with an element label");
1434 LogicalResult
Parser::validateUserConstraintOrRewriteReturn(
1435 StringRef declType
, ast::CompoundStmt
*body
,
1436 ArrayRef
<ast::Stmt
*>::iterator bodyIt
,
1437 ArrayRef
<ast::Stmt
*>::iterator bodyE
,
1438 ArrayRef
<ast::VariableDecl
*> results
, ast::Type
&resultType
) {
1439 // Handle if a `return` was provided.
1440 if (bodyIt
!= bodyE
) {
1441 // Emit an error if we have trailing statements after the return.
1442 if (std::next(bodyIt
) != bodyE
) {
1444 (*std::next(bodyIt
))->getLoc(),
1445 llvm::formatv("`return` terminated the `{0}` body, but found "
1446 "trailing statements afterwards",
1450 // Otherwise if a return wasn't provided, check that no results are
1452 } else if (!results
.empty()) {
1454 {body
->getLoc().End
, body
->getLoc().End
},
1455 llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1456 declType
, resultType
));
1461 FailureOr
<ast::CompoundStmt
*> Parser::parsePatternLambdaBody() {
1462 return parseLambdaBody([&](ast::Stmt
*&statement
) -> LogicalResult
{
1463 if (isa
<ast::OpRewriteStmt
>(statement
))
1466 statement
->getLoc(),
1467 "expected Pattern lambda body to contain a single operation "
1468 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1472 FailureOr
<ast::Decl
*> Parser::parsePatternDecl() {
1473 SMRange loc
= curToken
.getLoc();
1474 consumeToken(Token::kw_Pattern
);
1475 llvm::SaveAndRestore
saveCtx(parserContext
, ParserContext::PatternMatch
);
1477 // Check for an optional identifier for the pattern name.
1478 const ast::Name
*name
= nullptr;
1479 if (curToken
.is(Token::identifier
)) {
1480 name
= &ast::Name::create(ctx
, curToken
.getSpelling(), curToken
.getLoc());
1481 consumeToken(Token::identifier
);
1484 // Parse any pattern metadata.
1485 ParsedPatternMetadata metadata
;
1486 if (consumeIf(Token::kw_with
) && failed(parsePatternDeclMetadata(metadata
)))
1489 // Parse the pattern body.
1490 ast::CompoundStmt
*body
;
1492 // Handle a lambda body.
1493 if (curToken
.is(Token::equal_arrow
)) {
1494 FailureOr
<ast::CompoundStmt
*> bodyResult
= parsePatternLambdaBody();
1495 if (failed(bodyResult
))
1499 if (curToken
.isNot(Token::l_brace
))
1500 return emitError("expected `{` or `=>` to start pattern body");
1501 FailureOr
<ast::CompoundStmt
*> bodyResult
= parseCompoundStmt();
1502 if (failed(bodyResult
))
1506 // Verify the body of the pattern.
1507 auto bodyIt
= body
->begin(), bodyE
= body
->end();
1508 for (; bodyIt
!= bodyE
; ++bodyIt
) {
1509 if (isa
<ast::ReturnStmt
>(*bodyIt
)) {
1510 return emitError((*bodyIt
)->getLoc(),
1511 "`return` statements are only permitted within a "
1512 "`Constraint` or `Rewrite` body");
1514 // Break when we've found the rewrite statement.
1515 if (isa
<ast::OpRewriteStmt
>(*bodyIt
))
1518 if (bodyIt
== bodyE
) {
1519 return emitError(loc
,
1520 "expected Pattern body to terminate with an operation "
1521 "rewrite statement, such as `erase`");
1523 if (std::next(bodyIt
) != bodyE
) {
1524 return emitError((*std::next(bodyIt
))->getLoc(),
1525 "Pattern body was terminated by an operation "
1526 "rewrite statement, but found trailing statements");
1530 return createPatternDecl(loc
, name
, metadata
, body
);
1534 Parser::parsePatternDeclMetadata(ParsedPatternMetadata
&metadata
) {
1535 std::optional
<SMRange
> benefitLoc
;
1536 std::optional
<SMRange
> hasBoundedRecursionLoc
;
1539 // Handle metadata code completion.
1540 if (curToken
.is(Token::code_complete
))
1541 return codeCompletePatternMetadata();
1543 if (curToken
.isNot(Token::identifier
))
1544 return emitError("expected pattern metadata identifier");
1545 StringRef metadataStr
= curToken
.getSpelling();
1546 SMRange metadataLoc
= curToken
.getLoc();
1547 consumeToken(Token::identifier
);
1549 // Parse the benefit metadata: benefit(<integer-value>)
1550 if (metadataStr
== "benefit") {
1552 return emitErrorAndNote(metadataLoc
,
1553 "pattern benefit has already been specified",
1554 *benefitLoc
, "see previous definition here");
1556 if (failed(parseToken(Token::l_paren
,
1557 "expected `(` before pattern benefit")))
1560 uint16_t benefitValue
= 0;
1561 if (curToken
.isNot(Token::integer
))
1562 return emitError("expected integral pattern benefit");
1563 if (curToken
.getSpelling().getAsInteger(/*Radix=*/10, benefitValue
))
1565 "expected pattern benefit to fit within a 16-bit integer");
1566 consumeToken(Token::integer
);
1568 metadata
.benefit
= benefitValue
;
1569 benefitLoc
= metadataLoc
;
1572 parseToken(Token::r_paren
, "expected `)` after pattern benefit")))
1577 // Parse the bounded recursion metadata: recursion
1578 if (metadataStr
== "recursion") {
1579 if (hasBoundedRecursionLoc
) {
1580 return emitErrorAndNote(
1582 "pattern recursion metadata has already been specified",
1583 *hasBoundedRecursionLoc
, "see previous definition here");
1585 metadata
.hasBoundedRecursion
= true;
1586 hasBoundedRecursionLoc
= metadataLoc
;
1590 return emitError(metadataLoc
, "unknown pattern metadata");
1591 } while (consumeIf(Token::comma
));
1596 FailureOr
<ast::Expr
*> Parser::parseTypeConstraintExpr() {
1597 consumeToken(Token::less
);
1599 FailureOr
<ast::Expr
*> typeExpr
= parseExpr();
1600 if (failed(typeExpr
) ||
1601 failed(parseToken(Token::greater
,
1602 "expected `>` after variable type constraint")))
1607 LogicalResult
Parser::checkDefineNamedDecl(const ast::Name
&name
) {
1608 assert(curDeclScope
&& "defining decl outside of a decl scope");
1609 if (ast::Decl
*lastDecl
= curDeclScope
->lookup(name
.getName())) {
1610 return emitErrorAndNote(
1611 name
.getLoc(), "`" + name
.getName() + "` has already been defined",
1612 lastDecl
->getName()->getLoc(), "see previous definition here");
1617 FailureOr
<ast::VariableDecl
*>
1618 Parser::defineVariableDecl(StringRef name
, SMRange nameLoc
, ast::Type type
,
1619 ast::Expr
*initExpr
,
1620 ArrayRef
<ast::ConstraintRef
> constraints
) {
1621 assert(curDeclScope
&& "defining variable outside of decl scope");
1622 const ast::Name
&nameDecl
= ast::Name::create(ctx
, name
, nameLoc
);
1624 // If the name of the variable indicates a special variable, we don't add it
1625 // to the scope. This variable is local to the definition point.
1626 if (name
.empty() || name
== "_") {
1627 return ast::VariableDecl::create(ctx
, nameDecl
, type
, initExpr
,
1630 if (failed(checkDefineNamedDecl(nameDecl
)))
1634 ast::VariableDecl::create(ctx
, nameDecl
, type
, initExpr
, constraints
);
1635 curDeclScope
->add(varDecl
);
1639 FailureOr
<ast::VariableDecl
*>
1640 Parser::defineVariableDecl(StringRef name
, SMRange nameLoc
, ast::Type type
,
1641 ArrayRef
<ast::ConstraintRef
> constraints
) {
1642 return defineVariableDecl(name
, nameLoc
, type
, /*initExpr=*/nullptr,
1646 LogicalResult
Parser::parseVariableDeclConstraintList(
1647 SmallVectorImpl
<ast::ConstraintRef
> &constraints
) {
1648 std::optional
<SMRange
> typeConstraint
;
1649 auto parseSingleConstraint
= [&] {
1650 FailureOr
<ast::ConstraintRef
> constraint
= parseConstraint(
1651 typeConstraint
, constraints
, /*allowInlineTypeConstraints=*/true);
1652 if (failed(constraint
))
1654 constraints
.push_back(*constraint
);
1658 // Check to see if this is a single constraint, or a list.
1659 if (!consumeIf(Token::l_square
))
1660 return parseSingleConstraint();
1663 if (failed(parseSingleConstraint()))
1665 } while (consumeIf(Token::comma
));
1666 return parseToken(Token::r_square
, "expected `]` after constraint list");
1669 FailureOr
<ast::ConstraintRef
>
1670 Parser::parseConstraint(std::optional
<SMRange
> &typeConstraint
,
1671 ArrayRef
<ast::ConstraintRef
> existingConstraints
,
1672 bool allowInlineTypeConstraints
) {
1673 auto parseTypeConstraint
= [&](ast::Expr
*&typeExpr
) -> LogicalResult
{
1674 if (!allowInlineTypeConstraints
) {
1677 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1678 "permitted on arguments or results");
1681 return emitErrorAndNote(
1683 "the type of this variable has already been constrained",
1684 *typeConstraint
, "see previous constraint location here");
1685 FailureOr
<ast::Expr
*> constraintExpr
= parseTypeConstraintExpr();
1686 if (failed(constraintExpr
))
1688 typeExpr
= *constraintExpr
;
1689 typeConstraint
= typeExpr
->getLoc();
1693 SMRange loc
= curToken
.getLoc();
1694 switch (curToken
.getKind()) {
1695 case Token::kw_Attr
: {
1696 consumeToken(Token::kw_Attr
);
1698 // Check for a type constraint.
1699 ast::Expr
*typeExpr
= nullptr;
1700 if (curToken
.is(Token::less
) && failed(parseTypeConstraint(typeExpr
)))
1702 return ast::ConstraintRef(
1703 ast::AttrConstraintDecl::create(ctx
, loc
, typeExpr
), loc
);
1705 case Token::kw_Op
: {
1706 consumeToken(Token::kw_Op
);
1708 // Parse an optional operation name. If the name isn't provided, this refers
1709 // to "any" operation.
1710 FailureOr
<ast::OpNameDecl
*> opName
=
1711 parseWrappedOperationName(/*allowEmptyName=*/true);
1715 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx
, loc
, *opName
),
1718 case Token::kw_Type
:
1719 consumeToken(Token::kw_Type
);
1720 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx
, loc
), loc
);
1721 case Token::kw_TypeRange
:
1722 consumeToken(Token::kw_TypeRange
);
1723 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx
, loc
),
1725 case Token::kw_Value
: {
1726 consumeToken(Token::kw_Value
);
1728 // Check for a type constraint.
1729 ast::Expr
*typeExpr
= nullptr;
1730 if (curToken
.is(Token::less
) && failed(parseTypeConstraint(typeExpr
)))
1733 return ast::ConstraintRef(
1734 ast::ValueConstraintDecl::create(ctx
, loc
, typeExpr
), loc
);
1736 case Token::kw_ValueRange
: {
1737 consumeToken(Token::kw_ValueRange
);
1739 // Check for a type constraint.
1740 ast::Expr
*typeExpr
= nullptr;
1741 if (curToken
.is(Token::less
) && failed(parseTypeConstraint(typeExpr
)))
1744 return ast::ConstraintRef(
1745 ast::ValueRangeConstraintDecl::create(ctx
, loc
, typeExpr
), loc
);
1748 case Token::kw_Constraint
: {
1749 // Handle an inline constraint.
1750 FailureOr
<ast::UserConstraintDecl
*> decl
= parseInlineUserConstraintDecl();
1753 return ast::ConstraintRef(*decl
, loc
);
1755 case Token::identifier
: {
1756 StringRef constraintName
= curToken
.getSpelling();
1757 consumeToken(Token::identifier
);
1759 // Lookup the referenced constraint.
1760 ast::Decl
*cstDecl
= curDeclScope
->lookup
<ast::Decl
>(constraintName
);
1762 return emitError(loc
, "unknown reference to constraint `" +
1763 constraintName
+ "`");
1766 // Handle a reference to a proper constraint.
1767 if (auto *cst
= dyn_cast
<ast::ConstraintDecl
>(cstDecl
))
1768 return ast::ConstraintRef(cst
, loc
);
1770 return emitErrorAndNote(
1771 loc
, "invalid reference to non-constraint", cstDecl
->getLoc(),
1772 "see the definition of `" + constraintName
+ "` here");
1774 // Handle single entity constraint code completion.
1775 case Token::code_complete
: {
1776 // Try to infer the current type for use by code completion.
1777 ast::Type inferredType
;
1778 if (failed(validateVariableConstraints(existingConstraints
, inferredType
)))
1781 return codeCompleteConstraintName(inferredType
, allowInlineTypeConstraints
);
1786 return emitError(loc
, "expected identifier constraint");
1789 FailureOr
<ast::ConstraintRef
> Parser::parseArgOrResultConstraint() {
1790 std::optional
<SMRange
> typeConstraint
;
1791 return parseConstraint(typeConstraint
, /*existingConstraints=*/std::nullopt
,
1792 /*allowInlineTypeConstraints=*/false);
1795 //===----------------------------------------------------------------------===//
1798 FailureOr
<ast::Expr
*> Parser::parseExpr() {
1799 if (curToken
.is(Token::underscore
))
1800 return parseUnderscoreExpr();
1802 // Parse the LHS expression.
1803 FailureOr
<ast::Expr
*> lhsExpr
;
1804 switch (curToken
.getKind()) {
1805 case Token::kw_attr
:
1806 lhsExpr
= parseAttributeExpr();
1808 case Token::kw_Constraint
:
1809 lhsExpr
= parseInlineConstraintLambdaExpr();
1812 lhsExpr
= parseNegatedExpr();
1814 case Token::identifier
:
1815 lhsExpr
= parseIdentifierExpr();
1818 lhsExpr
= parseOperationExpr();
1820 case Token::kw_Rewrite
:
1821 lhsExpr
= parseInlineRewriteLambdaExpr();
1823 case Token::kw_type
:
1824 lhsExpr
= parseTypeExpr();
1826 case Token::l_paren
:
1827 lhsExpr
= parseTupleExpr();
1830 return emitError("expected expression");
1832 if (failed(lhsExpr
))
1835 // Check for an operator expression.
1837 switch (curToken
.getKind()) {
1839 lhsExpr
= parseMemberAccessExpr(*lhsExpr
);
1841 case Token::l_paren
:
1842 lhsExpr
= parseCallExpr(*lhsExpr
);
1847 if (failed(lhsExpr
))
1852 FailureOr
<ast::Expr
*> Parser::parseAttributeExpr() {
1853 SMRange loc
= curToken
.getLoc();
1854 consumeToken(Token::kw_attr
);
1856 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1858 if (!consumeIf(Token::less
)) {
1860 return parseIdentifierExpr();
1863 if (!curToken
.isString())
1864 return emitError("expected string literal containing MLIR attribute");
1865 std::string attrExpr
= curToken
.getStringValue();
1868 loc
.End
= curToken
.getEndLoc();
1870 parseToken(Token::greater
, "expected `>` after attribute literal")))
1872 return ast::AttributeExpr::create(ctx
, loc
, attrExpr
);
1875 FailureOr
<ast::Expr
*> Parser::parseCallExpr(ast::Expr
*parentExpr
,
1877 consumeToken(Token::l_paren
);
1879 // Parse the arguments of the call.
1880 SmallVector
<ast::Expr
*> arguments
;
1881 if (curToken
.isNot(Token::r_paren
)) {
1883 // Handle code completion for the call arguments.
1884 if (curToken
.is(Token::code_complete
)) {
1885 codeCompleteCallSignature(parentExpr
, arguments
.size());
1889 FailureOr
<ast::Expr
*> argument
= parseExpr();
1890 if (failed(argument
))
1892 arguments
.push_back(*argument
);
1893 } while (consumeIf(Token::comma
));
1896 SMRange
loc(parentExpr
->getLoc().Start
, curToken
.getEndLoc());
1897 if (failed(parseToken(Token::r_paren
, "expected `)` after argument list")))
1900 return createCallExpr(loc
, parentExpr
, arguments
, isNegated
);
1903 FailureOr
<ast::Expr
*> Parser::parseDeclRefExpr(StringRef name
, SMRange loc
) {
1904 ast::Decl
*decl
= curDeclScope
->lookup(name
);
1906 return emitError(loc
, "undefined reference to `" + name
+ "`");
1908 return createDeclRefExpr(loc
, decl
);
1911 FailureOr
<ast::Expr
*> Parser::parseIdentifierExpr() {
1912 StringRef name
= curToken
.getSpelling();
1913 SMRange nameLoc
= curToken
.getLoc();
1916 // Check to see if this is a decl ref expression that defines a variable
1918 if (consumeIf(Token::colon
)) {
1919 SmallVector
<ast::ConstraintRef
> constraints
;
1920 if (failed(parseVariableDeclConstraintList(constraints
)))
1923 if (failed(validateVariableConstraints(constraints
, type
)))
1925 return createInlineVariableExpr(type
, name
, nameLoc
, constraints
);
1928 return parseDeclRefExpr(name
, nameLoc
);
1931 FailureOr
<ast::Expr
*> Parser::parseInlineConstraintLambdaExpr() {
1932 FailureOr
<ast::UserConstraintDecl
*> decl
= parseInlineUserConstraintDecl();
1936 return ast::DeclRefExpr::create(ctx
, (*decl
)->getLoc(), *decl
,
1937 ast::ConstraintType::get(ctx
));
1940 FailureOr
<ast::Expr
*> Parser::parseInlineRewriteLambdaExpr() {
1941 FailureOr
<ast::UserRewriteDecl
*> decl
= parseInlineUserRewriteDecl();
1945 return ast::DeclRefExpr::create(ctx
, (*decl
)->getLoc(), *decl
,
1946 ast::RewriteType::get(ctx
));
1949 FailureOr
<ast::Expr
*> Parser::parseMemberAccessExpr(ast::Expr
*parentExpr
) {
1950 SMRange dotLoc
= curToken
.getLoc();
1951 consumeToken(Token::dot
);
1953 // Check for code completion of the member name.
1954 if (curToken
.is(Token::code_complete
))
1955 return codeCompleteMemberAccess(parentExpr
);
1957 // Parse the member name.
1958 Token memberNameTok
= curToken
;
1959 if (memberNameTok
.isNot(Token::identifier
, Token::integer
) &&
1960 !memberNameTok
.isKeyword())
1961 return emitError(dotLoc
, "expected identifier or numeric member name");
1962 StringRef memberName
= memberNameTok
.getSpelling();
1963 SMRange
loc(parentExpr
->getLoc().Start
, curToken
.getEndLoc());
1966 return createMemberAccessExpr(parentExpr
, memberName
, loc
);
1969 FailureOr
<ast::Expr
*> Parser::parseNegatedExpr() {
1970 consumeToken(Token::kw_not
);
1971 // Only native constraints are supported after negation
1972 if (!curToken
.is(Token::identifier
))
1973 return emitError("expected native constraint");
1974 FailureOr
<ast::Expr
*> identifierExpr
= parseIdentifierExpr();
1975 if (failed(identifierExpr
))
1977 return parseCallExpr(*identifierExpr
, /*isNegated = */ true);
1980 FailureOr
<ast::OpNameDecl
*> Parser::parseOperationName(bool allowEmptyName
) {
1981 SMRange loc
= curToken
.getLoc();
1983 // Check for code completion for the dialect name.
1984 if (curToken
.is(Token::code_complete
))
1985 return codeCompleteDialectName();
1987 // Handle the case of an no operation name.
1988 if (curToken
.isNot(Token::identifier
) && !curToken
.isKeyword()) {
1990 return ast::OpNameDecl::create(ctx
, SMRange());
1991 return emitError("expected dialect namespace");
1993 StringRef name
= curToken
.getSpelling();
1996 // Otherwise, this is a literal operation name.
1997 if (failed(parseToken(Token::dot
, "expected `.` after dialect namespace")))
2000 // Check for code completion for the operation name.
2001 if (curToken
.is(Token::code_complete
))
2002 return codeCompleteOperationName(name
);
2004 if (curToken
.isNot(Token::identifier
) && !curToken
.isKeyword())
2005 return emitError("expected operation name after dialect namespace");
2007 name
= StringRef(name
.data(), name
.size() + 1);
2009 name
= StringRef(name
.data(), name
.size() + curToken
.getSpelling().size());
2010 loc
.End
= curToken
.getEndLoc();
2012 } while (curToken
.isAny(Token::identifier
, Token::dot
) ||
2013 curToken
.isKeyword());
2014 return ast::OpNameDecl::create(ctx
, ast::Name::create(ctx
, name
, loc
));
2017 FailureOr
<ast::OpNameDecl
*>
2018 Parser::parseWrappedOperationName(bool allowEmptyName
) {
2019 if (!consumeIf(Token::less
))
2020 return ast::OpNameDecl::create(ctx
, SMRange());
2022 FailureOr
<ast::OpNameDecl
*> opNameDecl
= parseOperationName(allowEmptyName
);
2023 if (failed(opNameDecl
))
2026 if (failed(parseToken(Token::greater
, "expected `>` after operation name")))
2031 FailureOr
<ast::Expr
*>
2032 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext
) {
2033 SMRange loc
= curToken
.getLoc();
2034 consumeToken(Token::kw_op
);
2036 // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2038 if (curToken
.isNot(Token::less
)) {
2040 return parseIdentifierExpr();
2043 // Parse the operation name. The name may be elided, in which case the
2044 // operation refers to "any" operation(i.e. a difference between `MyOp` and
2045 // `Operation*`). Operation names within a rewrite context must be named.
2046 bool allowEmptyName
= parserContext
!= ParserContext::Rewrite
;
2047 FailureOr
<ast::OpNameDecl
*> opNameDecl
=
2048 parseWrappedOperationName(allowEmptyName
);
2049 if (failed(opNameDecl
))
2051 std::optional
<StringRef
> opName
= (*opNameDecl
)->getName();
2053 // Functor used to create an implicit range variable, used for implicit "all"
2054 // operand or results variables.
2055 auto createImplicitRangeVar
= [&](ast::ConstraintDecl
*cst
, ast::Type type
) {
2056 FailureOr
<ast::VariableDecl
*> rangeVar
=
2057 defineVariableDecl("_", loc
, type
, ast::ConstraintRef(cst
, loc
));
2058 assert(succeeded(rangeVar
) && "expected range variable to be valid");
2059 return ast::DeclRefExpr::create(ctx
, loc
, *rangeVar
, type
);
2062 // Check for the optional list of operands.
2063 SmallVector
<ast::Expr
*> operands
;
2064 if (!consumeIf(Token::l_paren
)) {
2065 // If the operand list isn't specified and we are in a match context, define
2066 // an inplace unconstrained operand range corresponding to all of the
2067 // operands of the operation. This avoids treating zero operands the same
2068 // way as "unconstrained operands".
2069 if (parserContext
!= ParserContext::Rewrite
) {
2070 operands
.push_back(createImplicitRangeVar(
2071 ast::ValueRangeConstraintDecl::create(ctx
, loc
), valueRangeTy
));
2073 } else if (!consumeIf(Token::r_paren
)) {
2074 // If the operand list was specified and non-empty, parse the operands.
2076 // Check for operand signature code completion.
2077 if (curToken
.is(Token::code_complete
)) {
2078 codeCompleteOperationOperandsSignature(opName
, operands
.size());
2082 FailureOr
<ast::Expr
*> operand
= parseExpr();
2083 if (failed(operand
))
2085 operands
.push_back(*operand
);
2086 } while (consumeIf(Token::comma
));
2088 if (failed(parseToken(Token::r_paren
,
2089 "expected `)` after operation operand list")))
2093 // Check for the optional list of attributes.
2094 SmallVector
<ast::NamedAttributeDecl
*> attributes
;
2095 if (consumeIf(Token::l_brace
)) {
2097 FailureOr
<ast::NamedAttributeDecl
*> decl
=
2098 parseNamedAttributeDecl(opName
);
2101 attributes
.emplace_back(*decl
);
2102 } while (consumeIf(Token::comma
));
2104 if (failed(parseToken(Token::r_brace
,
2105 "expected `}` after operation attribute list")))
2109 // Handle the result types of the operation.
2110 SmallVector
<ast::Expr
*> resultTypes
;
2111 OpResultTypeContext resultTypeContext
= inputResultTypeContext
;
2113 // Check for an explicit list of result types.
2114 if (consumeIf(Token::arrow
)) {
2115 if (failed(parseToken(Token::l_paren
,
2116 "expected `(` before operation result type list")))
2119 // If result types are provided, initially assume that the operation does
2120 // not rely on type inferrence. We don't assert that it isn't, because we
2121 // may be inferring the value of some type/type range variables, but given
2122 // that these variables may be defined in calls we can't always discern when
2123 // this is the case.
2124 resultTypeContext
= OpResultTypeContext::Explicit
;
2126 // Handle the case of an empty result list.
2127 if (!consumeIf(Token::r_paren
)) {
2129 // Check for result signature code completion.
2130 if (curToken
.is(Token::code_complete
)) {
2131 codeCompleteOperationResultsSignature(opName
, resultTypes
.size());
2135 FailureOr
<ast::Expr
*> resultTypeExpr
= parseExpr();
2136 if (failed(resultTypeExpr
))
2138 resultTypes
.push_back(*resultTypeExpr
);
2139 } while (consumeIf(Token::comma
));
2141 if (failed(parseToken(Token::r_paren
,
2142 "expected `)` after operation result type list")))
2145 } else if (parserContext
!= ParserContext::Rewrite
) {
2146 // If the result list isn't specified and we are in a match context, define
2147 // an inplace unconstrained result range corresponding to all of the results
2148 // of the operation. This avoids treating zero results the same way as
2149 // "unconstrained results".
2150 resultTypes
.push_back(createImplicitRangeVar(
2151 ast::TypeRangeConstraintDecl::create(ctx
, loc
), typeRangeTy
));
2152 } else if (resultTypeContext
== OpResultTypeContext::Explicit
) {
2153 // If the result list isn't specified and we are in a rewrite, try to infer
2154 // them at runtime instead.
2155 resultTypeContext
= OpResultTypeContext::Interface
;
2158 return createOperationExpr(loc
, *opNameDecl
, resultTypeContext
, operands
,
2159 attributes
, resultTypes
);
2162 FailureOr
<ast::Expr
*> Parser::parseTupleExpr() {
2163 SMRange loc
= curToken
.getLoc();
2164 consumeToken(Token::l_paren
);
2166 DenseMap
<StringRef
, SMRange
> usedNames
;
2167 SmallVector
<StringRef
> elementNames
;
2168 SmallVector
<ast::Expr
*> elements
;
2169 if (curToken
.isNot(Token::r_paren
)) {
2171 // Check for the optional element name assignment before the value.
2172 StringRef elementName
;
2173 if (curToken
.is(Token::identifier
) || curToken
.isDependentKeyword()) {
2174 Token elementNameTok
= curToken
;
2177 // The element name is only present if followed by an `=`.
2178 if (consumeIf(Token::equal
)) {
2179 elementName
= elementNameTok
.getSpelling();
2181 // Check to see if this name is already used.
2182 auto elementNameIt
=
2183 usedNames
.try_emplace(elementName
, elementNameTok
.getLoc());
2184 if (!elementNameIt
.second
) {
2185 return emitErrorAndNote(
2186 elementNameTok
.getLoc(),
2187 llvm::formatv("duplicate tuple element label `{0}`",
2189 elementNameIt
.first
->getSecond(),
2190 "see previous label use here");
2193 // Otherwise, we treat this as part of an expression so reset the
2195 resetToken(elementNameTok
.getLoc());
2198 elementNames
.push_back(elementName
);
2200 // Parse the tuple element value.
2201 FailureOr
<ast::Expr
*> element
= parseExpr();
2202 if (failed(element
))
2204 elements
.push_back(*element
);
2205 } while (consumeIf(Token::comma
));
2207 loc
.End
= curToken
.getEndLoc();
2209 parseToken(Token::r_paren
, "expected `)` after tuple element list")))
2211 return createTupleExpr(loc
, elements
, elementNames
);
2214 FailureOr
<ast::Expr
*> Parser::parseTypeExpr() {
2215 SMRange loc
= curToken
.getLoc();
2216 consumeToken(Token::kw_type
);
2218 // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2220 if (!consumeIf(Token::less
)) {
2222 return parseIdentifierExpr();
2225 if (!curToken
.isString())
2226 return emitError("expected string literal containing MLIR type");
2227 std::string attrExpr
= curToken
.getStringValue();
2230 loc
.End
= curToken
.getEndLoc();
2231 if (failed(parseToken(Token::greater
, "expected `>` after type literal")))
2233 return ast::TypeExpr::create(ctx
, loc
, attrExpr
);
2236 FailureOr
<ast::Expr
*> Parser::parseUnderscoreExpr() {
2237 StringRef name
= curToken
.getSpelling();
2238 SMRange nameLoc
= curToken
.getLoc();
2239 consumeToken(Token::underscore
);
2241 // Underscore expressions require a constraint list.
2242 if (failed(parseToken(Token::colon
, "expected `:` after `_` variable")))
2245 // Parse the constraints for the expression.
2246 SmallVector
<ast::ConstraintRef
> constraints
;
2247 if (failed(parseVariableDeclConstraintList(constraints
)))
2251 if (failed(validateVariableConstraints(constraints
, type
)))
2253 return createInlineVariableExpr(type
, name
, nameLoc
, constraints
);
2256 //===----------------------------------------------------------------------===//
2259 FailureOr
<ast::Stmt
*> Parser::parseStmt(bool expectTerminalSemicolon
) {
2260 FailureOr
<ast::Stmt
*> stmt
;
2261 switch (curToken
.getKind()) {
2262 case Token::kw_erase
:
2263 stmt
= parseEraseStmt();
2266 stmt
= parseLetStmt();
2268 case Token::kw_replace
:
2269 stmt
= parseReplaceStmt();
2271 case Token::kw_return
:
2272 stmt
= parseReturnStmt();
2274 case Token::kw_rewrite
:
2275 stmt
= parseRewriteStmt();
2282 (expectTerminalSemicolon
&&
2283 failed(parseToken(Token::semicolon
, "expected `;` after statement"))))
2288 FailureOr
<ast::CompoundStmt
*> Parser::parseCompoundStmt() {
2289 SMLoc startLoc
= curToken
.getStartLoc();
2290 consumeToken(Token::l_brace
);
2292 // Push a new block scope and parse any nested statements.
2294 SmallVector
<ast::Stmt
*> statements
;
2295 while (curToken
.isNot(Token::r_brace
)) {
2296 FailureOr
<ast::Stmt
*> statement
= parseStmt();
2297 if (failed(statement
))
2298 return popDeclScope(), failure();
2299 statements
.push_back(*statement
);
2303 // Consume the end brace.
2304 SMRange
location(startLoc
, curToken
.getEndLoc());
2305 consumeToken(Token::r_brace
);
2307 return ast::CompoundStmt::create(ctx
, location
, statements
);
2310 FailureOr
<ast::EraseStmt
*> Parser::parseEraseStmt() {
2311 if (parserContext
== ParserContext::Constraint
)
2312 return emitError("`erase` cannot be used within a Constraint");
2313 SMRange loc
= curToken
.getLoc();
2314 consumeToken(Token::kw_erase
);
2316 // Parse the root operation expression.
2317 FailureOr
<ast::Expr
*> rootOp
= parseExpr();
2321 return createEraseStmt(loc
, *rootOp
);
2324 FailureOr
<ast::LetStmt
*> Parser::parseLetStmt() {
2325 SMRange loc
= curToken
.getLoc();
2326 consumeToken(Token::kw_let
);
2328 // Parse the name of the new variable.
2329 SMRange varLoc
= curToken
.getLoc();
2330 if (curToken
.isNot(Token::identifier
) && !curToken
.isDependentKeyword()) {
2331 // `_` is a reserved variable name.
2332 if (curToken
.is(Token::underscore
)) {
2333 return emitError(varLoc
,
2334 "`_` may only be used to define \"inline\" variables");
2336 return emitError(varLoc
,
2337 "expected identifier after `let` to name a new variable");
2339 StringRef varName
= curToken
.getSpelling();
2342 // Parse the optional set of constraints.
2343 SmallVector
<ast::ConstraintRef
> constraints
;
2344 if (consumeIf(Token::colon
) &&
2345 failed(parseVariableDeclConstraintList(constraints
)))
2348 // Parse the optional initializer expression.
2349 ast::Expr
*initializer
= nullptr;
2350 if (consumeIf(Token::equal
)) {
2351 FailureOr
<ast::Expr
*> initOrFailure
= parseExpr();
2352 if (failed(initOrFailure
))
2354 initializer
= *initOrFailure
;
2356 // Check that the constraints are compatible with having an initializer,
2357 // e.g. type constraints cannot be used with initializers.
2358 for (ast::ConstraintRef constraint
: constraints
) {
2359 LogicalResult result
=
2360 TypeSwitch
<const ast::Node
*, LogicalResult
>(constraint
.constraint
)
2361 .Case
<ast::AttrConstraintDecl
, ast::ValueConstraintDecl
,
2362 ast::ValueRangeConstraintDecl
>([&](const auto *cst
) {
2363 if (cst
->getTypeExpr()) {
2364 return this->emitError(
2365 constraint
.referenceLoc
,
2366 "type constraints are not permitted on variables with "
2371 .Default(success());
2377 FailureOr
<ast::VariableDecl
*> varDecl
=
2378 createVariableDecl(varName
, varLoc
, initializer
, constraints
);
2379 if (failed(varDecl
))
2381 return ast::LetStmt::create(ctx
, loc
, *varDecl
);
2384 FailureOr
<ast::ReplaceStmt
*> Parser::parseReplaceStmt() {
2385 if (parserContext
== ParserContext::Constraint
)
2386 return emitError("`replace` cannot be used within a Constraint");
2387 SMRange loc
= curToken
.getLoc();
2388 consumeToken(Token::kw_replace
);
2390 // Parse the root operation expression.
2391 FailureOr
<ast::Expr
*> rootOp
= parseExpr();
2396 parseToken(Token::kw_with
, "expected `with` after root operation")))
2399 // The replacement portion of this statement is within a rewrite context.
2400 llvm::SaveAndRestore
saveCtx(parserContext
, ParserContext::Rewrite
);
2402 // Parse the replacement values.
2403 SmallVector
<ast::Expr
*> replValues
;
2404 if (consumeIf(Token::l_paren
)) {
2405 if (consumeIf(Token::r_paren
)) {
2407 loc
, "expected at least one replacement value, consider using "
2408 "`erase` if no replacement values are desired");
2412 FailureOr
<ast::Expr
*> replExpr
= parseExpr();
2413 if (failed(replExpr
))
2415 replValues
.emplace_back(*replExpr
);
2416 } while (consumeIf(Token::comma
));
2418 if (failed(parseToken(Token::r_paren
,
2419 "expected `)` after replacement values")))
2422 // Handle replacement with an operation uniquely, as the replacement
2423 // operation supports type inferrence from the root operation.
2424 FailureOr
<ast::Expr
*> replExpr
;
2425 if (curToken
.is(Token::kw_op
))
2426 replExpr
= parseOperationExpr(OpResultTypeContext::Replacement
);
2428 replExpr
= parseExpr();
2429 if (failed(replExpr
))
2431 replValues
.emplace_back(*replExpr
);
2434 return createReplaceStmt(loc
, *rootOp
, replValues
);
2437 FailureOr
<ast::ReturnStmt
*> Parser::parseReturnStmt() {
2438 SMRange loc
= curToken
.getLoc();
2439 consumeToken(Token::kw_return
);
2441 // Parse the result value.
2442 FailureOr
<ast::Expr
*> resultExpr
= parseExpr();
2443 if (failed(resultExpr
))
2446 return ast::ReturnStmt::create(ctx
, loc
, *resultExpr
);
2449 FailureOr
<ast::RewriteStmt
*> Parser::parseRewriteStmt() {
2450 if (parserContext
== ParserContext::Constraint
)
2451 return emitError("`rewrite` cannot be used within a Constraint");
2452 SMRange loc
= curToken
.getLoc();
2453 consumeToken(Token::kw_rewrite
);
2455 // Parse the root operation.
2456 FailureOr
<ast::Expr
*> rootOp
= parseExpr();
2460 if (failed(parseToken(Token::kw_with
, "expected `with` before rewrite body")))
2463 if (curToken
.isNot(Token::l_brace
))
2464 return emitError("expected `{` to start rewrite body");
2466 // The rewrite body of this statement is within a rewrite context.
2467 llvm::SaveAndRestore
saveCtx(parserContext
, ParserContext::Rewrite
);
2469 FailureOr
<ast::CompoundStmt
*> rewriteBody
= parseCompoundStmt();
2470 if (failed(rewriteBody
))
2473 // Verify the rewrite body.
2474 for (const ast::Stmt
*stmt
: (*rewriteBody
)->getChildren()) {
2475 if (isa
<ast::ReturnStmt
>(stmt
)) {
2476 return emitError(stmt
->getLoc(),
2477 "`return` statements are only permitted within a "
2478 "`Constraint` or `Rewrite` body");
2482 return createRewriteStmt(loc
, *rootOp
, *rewriteBody
);
2485 //===----------------------------------------------------------------------===//
2486 // Creation+Analysis
2487 //===----------------------------------------------------------------------===//
2489 //===----------------------------------------------------------------------===//
2492 ast::CallableDecl
*Parser::tryExtractCallableDecl(ast::Node
*node
) {
2493 // Unwrap reference expressions.
2494 if (auto *init
= dyn_cast
<ast::DeclRefExpr
>(node
))
2495 node
= init
->getDecl();
2496 return dyn_cast
<ast::CallableDecl
>(node
);
2499 FailureOr
<ast::PatternDecl
*>
2500 Parser::createPatternDecl(SMRange loc
, const ast::Name
*name
,
2501 const ParsedPatternMetadata
&metadata
,
2502 ast::CompoundStmt
*body
) {
2503 return ast::PatternDecl::create(ctx
, loc
, name
, metadata
.benefit
,
2504 metadata
.hasBoundedRecursion
, body
);
2507 ast::Type
Parser::createUserConstraintRewriteResultType(
2508 ArrayRef
<ast::VariableDecl
*> results
) {
2509 // Single result decls use the type of the single result.
2510 if (results
.size() == 1)
2511 return results
[0]->getType();
2513 // Multiple results use a tuple type, with the types and names grabbed from
2514 // the result variable decls.
2515 auto resultTypes
= llvm::map_range(
2516 results
, [&](const auto *result
) { return result
->getType(); });
2517 auto resultNames
= llvm::map_range(
2518 results
, [&](const auto *result
) { return result
->getName().getName(); });
2519 return ast::TupleType::get(ctx
, llvm::to_vector(resultTypes
),
2520 llvm::to_vector(resultNames
));
2523 template <typename T
>
2524 FailureOr
<T
*> Parser::createUserPDLLConstraintOrRewriteDecl(
2525 const ast::Name
&name
, ArrayRef
<ast::VariableDecl
*> arguments
,
2526 ArrayRef
<ast::VariableDecl
*> results
, ast::Type resultType
,
2527 ast::CompoundStmt
*body
) {
2528 if (!body
->getChildren().empty()) {
2529 if (auto *retStmt
= dyn_cast
<ast::ReturnStmt
>(body
->getChildren().back())) {
2530 ast::Expr
*resultExpr
= retStmt
->getResultExpr();
2532 // Process the result of the decl. If no explicit signature results
2533 // were provided, check for return type inference. Otherwise, check that
2534 // the return expression can be converted to the expected type.
2535 if (results
.empty())
2536 resultType
= resultExpr
->getType();
2537 else if (failed(convertExpressionTo(resultExpr
, resultType
)))
2540 retStmt
->setResultExpr(resultExpr
);
2543 return T::createPDLL(ctx
, name
, arguments
, results
, body
, resultType
);
2546 FailureOr
<ast::VariableDecl
*>
2547 Parser::createVariableDecl(StringRef name
, SMRange loc
, ast::Expr
*initializer
,
2548 ArrayRef
<ast::ConstraintRef
> constraints
) {
2549 // The type of the variable, which is expected to be inferred by either a
2550 // constraint or an initializer expression.
2552 if (failed(validateVariableConstraints(constraints
, type
)))
2556 // Update the variable type based on the initializer, or try to convert the
2557 // initializer to the existing type.
2559 type
= initializer
->getType();
2560 else if (ast::Type mergedType
= type
.refineWith(initializer
->getType()))
2562 else if (failed(convertExpressionTo(initializer
, type
)))
2565 // Otherwise, if there is no initializer check that the type has already
2566 // been resolved from the constraint list.
2568 return emitErrorAndNote(
2569 loc
, "unable to infer type for variable `" + name
+ "`", loc
,
2570 "the type of a variable must be inferable from the constraint "
2571 "list or the initializer");
2574 // Constraint types cannot be used when defining variables.
2575 if (type
.isa
<ast::ConstraintType
, ast::RewriteType
>()) {
2577 loc
, llvm::formatv("unable to define variable of `{0}` type", type
));
2580 // Try to define a variable with the given name.
2581 FailureOr
<ast::VariableDecl
*> varDecl
=
2582 defineVariableDecl(name
, loc
, type
, initializer
, constraints
);
2583 if (failed(varDecl
))
2589 FailureOr
<ast::VariableDecl
*>
2590 Parser::createArgOrResultVariableDecl(StringRef name
, SMRange loc
,
2591 const ast::ConstraintRef
&constraint
) {
2593 if (failed(validateVariableConstraint(constraint
, argType
)))
2595 return defineVariableDecl(name
, loc
, argType
, constraint
);
2599 Parser::validateVariableConstraints(ArrayRef
<ast::ConstraintRef
> constraints
,
2600 ast::Type
&inferredType
) {
2601 for (const ast::ConstraintRef
&ref
: constraints
)
2602 if (failed(validateVariableConstraint(ref
, inferredType
)))
2607 LogicalResult
Parser::validateVariableConstraint(const ast::ConstraintRef
&ref
,
2608 ast::Type
&inferredType
) {
2609 ast::Type constraintType
;
2610 if (const auto *cst
= dyn_cast
<ast::AttrConstraintDecl
>(ref
.constraint
)) {
2611 if (const ast::Expr
*typeExpr
= cst
->getTypeExpr()) {
2612 if (failed(validateTypeConstraintExpr(typeExpr
)))
2615 constraintType
= ast::AttributeType::get(ctx
);
2616 } else if (const auto *cst
=
2617 dyn_cast
<ast::OpConstraintDecl
>(ref
.constraint
)) {
2618 constraintType
= ast::OperationType::get(
2619 ctx
, cst
->getName(), lookupODSOperation(cst
->getName()));
2620 } else if (isa
<ast::TypeConstraintDecl
>(ref
.constraint
)) {
2621 constraintType
= typeTy
;
2622 } else if (isa
<ast::TypeRangeConstraintDecl
>(ref
.constraint
)) {
2623 constraintType
= typeRangeTy
;
2624 } else if (const auto *cst
=
2625 dyn_cast
<ast::ValueConstraintDecl
>(ref
.constraint
)) {
2626 if (const ast::Expr
*typeExpr
= cst
->getTypeExpr()) {
2627 if (failed(validateTypeConstraintExpr(typeExpr
)))
2630 constraintType
= valueTy
;
2631 } else if (const auto *cst
=
2632 dyn_cast
<ast::ValueRangeConstraintDecl
>(ref
.constraint
)) {
2633 if (const ast::Expr
*typeExpr
= cst
->getTypeExpr()) {
2634 if (failed(validateTypeRangeConstraintExpr(typeExpr
)))
2637 constraintType
= valueRangeTy
;
2638 } else if (const auto *cst
=
2639 dyn_cast
<ast::UserConstraintDecl
>(ref
.constraint
)) {
2640 ArrayRef
<ast::VariableDecl
*> inputs
= cst
->getInputs();
2641 if (inputs
.size() != 1) {
2642 return emitErrorAndNote(ref
.referenceLoc
,
2643 "`Constraint`s applied via a variable constraint "
2644 "list must take a single input, but got " +
2645 Twine(inputs
.size()),
2647 "see definition of constraint here");
2649 constraintType
= inputs
.front()->getType();
2651 llvm_unreachable("unknown constraint type");
2654 // Check that the constraint type is compatible with the current inferred
2656 if (!inferredType
) {
2657 inferredType
= constraintType
;
2658 } else if (ast::Type mergedTy
= inferredType
.refineWith(constraintType
)) {
2659 inferredType
= mergedTy
;
2661 return emitError(ref
.referenceLoc
,
2662 llvm::formatv("constraint type `{0}` is incompatible "
2663 "with the previously inferred type `{1}`",
2664 constraintType
, inferredType
));
2669 LogicalResult
Parser::validateTypeConstraintExpr(const ast::Expr
*typeExpr
) {
2670 ast::Type typeExprType
= typeExpr
->getType();
2671 if (typeExprType
!= typeTy
) {
2672 return emitError(typeExpr
->getLoc(),
2673 "expected expression of `Type` in type constraint");
2679 Parser::validateTypeRangeConstraintExpr(const ast::Expr
*typeExpr
) {
2680 ast::Type typeExprType
= typeExpr
->getType();
2681 if (typeExprType
!= typeRangeTy
) {
2682 return emitError(typeExpr
->getLoc(),
2683 "expected expression of `TypeRange` in type constraint");
2688 //===----------------------------------------------------------------------===//
2691 FailureOr
<ast::CallExpr
*>
2692 Parser::createCallExpr(SMRange loc
, ast::Expr
*parentExpr
,
2693 MutableArrayRef
<ast::Expr
*> arguments
, bool isNegated
) {
2694 ast::Type parentType
= parentExpr
->getType();
2696 ast::CallableDecl
*callableDecl
= tryExtractCallableDecl(parentExpr
);
2697 if (!callableDecl
) {
2698 return emitError(loc
,
2699 llvm::formatv("expected a reference to a callable "
2700 "`Constraint` or `Rewrite`, but got: `{0}`",
2703 if (parserContext
== ParserContext::Rewrite
) {
2704 if (isa
<ast::UserConstraintDecl
>(callableDecl
))
2706 loc
, "unable to invoke `Constraint` within a rewrite section");
2708 return emitError(loc
, "unable to negate a Rewrite");
2710 if (isa
<ast::UserRewriteDecl
>(callableDecl
))
2711 return emitError(loc
,
2712 "unable to invoke `Rewrite` within a match section");
2713 if (isNegated
&& cast
<ast::UserConstraintDecl
>(callableDecl
)->getBody())
2714 return emitError(loc
, "unable to negate non native constraints");
2717 // Verify the arguments of the call.
2718 /// Handle size mismatch.
2719 ArrayRef
<ast::VariableDecl
*> callArgs
= callableDecl
->getInputs();
2720 if (callArgs
.size() != arguments
.size()) {
2721 return emitErrorAndNote(
2723 llvm::formatv("invalid number of arguments for {0} call; expected "
2725 callableDecl
->getCallableType(), callArgs
.size(),
2727 callableDecl
->getLoc(),
2728 llvm::formatv("see the definition of {0} here",
2729 callableDecl
->getName()->getName()));
2732 /// Handle argument type mismatch.
2733 auto attachDiagFn
= [&](ast::Diagnostic
&diag
) {
2734 diag
.attachNote(llvm::formatv("see the definition of `{0}` here",
2735 callableDecl
->getName()->getName()),
2736 callableDecl
->getLoc());
2738 for (auto it
: llvm::zip(callArgs
, arguments
)) {
2739 if (failed(convertExpressionTo(std::get
<1>(it
), std::get
<0>(it
)->getType(),
2744 return ast::CallExpr::create(ctx
, loc
, parentExpr
, arguments
,
2745 callableDecl
->getResultType(), isNegated
);
2748 FailureOr
<ast::DeclRefExpr
*> Parser::createDeclRefExpr(SMRange loc
,
2750 // Check the type of decl being referenced.
2752 if (isa
<ast::ConstraintDecl
>(decl
))
2753 declType
= ast::ConstraintType::get(ctx
);
2754 else if (isa
<ast::UserRewriteDecl
>(decl
))
2755 declType
= ast::RewriteType::get(ctx
);
2756 else if (auto *varDecl
= dyn_cast
<ast::VariableDecl
>(decl
))
2757 declType
= varDecl
->getType();
2759 return emitError(loc
, "invalid reference to `" +
2760 decl
->getName()->getName() + "`");
2762 return ast::DeclRefExpr::create(ctx
, loc
, decl
, declType
);
2765 FailureOr
<ast::DeclRefExpr
*>
2766 Parser::createInlineVariableExpr(ast::Type type
, StringRef name
, SMRange loc
,
2767 ArrayRef
<ast::ConstraintRef
> constraints
) {
2768 FailureOr
<ast::VariableDecl
*> decl
=
2769 defineVariableDecl(name
, loc
, type
, constraints
);
2772 return ast::DeclRefExpr::create(ctx
, loc
, *decl
, type
);
2775 FailureOr
<ast::MemberAccessExpr
*>
2776 Parser::createMemberAccessExpr(ast::Expr
*parentExpr
, StringRef name
,
2778 // Validate the member name for the given parent expression.
2779 FailureOr
<ast::Type
> memberType
= validateMemberAccess(parentExpr
, name
, loc
);
2780 if (failed(memberType
))
2783 return ast::MemberAccessExpr::create(ctx
, loc
, parentExpr
, name
, *memberType
);
2786 FailureOr
<ast::Type
> Parser::validateMemberAccess(ast::Expr
*parentExpr
,
2787 StringRef name
, SMRange loc
) {
2788 ast::Type parentType
= parentExpr
->getType();
2789 if (ast::OperationType opType
= parentType
.dyn_cast
<ast::OperationType
>()) {
2790 if (name
== ast::AllResultsMemberAccessExpr::getMemberName())
2791 return valueRangeTy
;
2793 // Verify member access based on the operation type.
2794 if (const ods::Operation
*odsOp
= opType
.getODSOperation()) {
2795 auto results
= odsOp
->getResults();
2797 // Handle indexed results.
2799 if (llvm::isDigit(name
[0]) && !name
.getAsInteger(/*Radix=*/10, index
) &&
2800 index
< results
.size()) {
2801 return results
[index
].isVariadic() ? valueRangeTy
: valueTy
;
2804 // Handle named results.
2805 const auto *it
= llvm::find_if(results
, [&](const auto &result
) {
2806 return result
.getName() == name
;
2808 if (it
!= results
.end())
2809 return it
->isVariadic() ? valueRangeTy
: valueTy
;
2810 } else if (llvm::isDigit(name
[0])) {
2811 // Allow unchecked numeric indexing of the results of unregistered
2812 // operations. It returns a single value.
2815 } else if (auto tupleType
= parentType
.dyn_cast
<ast::TupleType
>()) {
2816 // Handle indexed results.
2818 if (llvm::isDigit(name
[0]) && !name
.getAsInteger(/*Radix=*/10, index
) &&
2819 index
< tupleType
.size()) {
2820 return tupleType
.getElementTypes()[index
];
2823 // Handle named results.
2824 auto elementNames
= tupleType
.getElementNames();
2825 const auto *it
= llvm::find(elementNames
, name
);
2826 if (it
!= elementNames
.end())
2827 return tupleType
.getElementTypes()[it
- elementNames
.begin()];
2831 llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2835 FailureOr
<ast::OperationExpr
*> Parser::createOperationExpr(
2836 SMRange loc
, const ast::OpNameDecl
*name
,
2837 OpResultTypeContext resultTypeContext
,
2838 SmallVectorImpl
<ast::Expr
*> &operands
,
2839 MutableArrayRef
<ast::NamedAttributeDecl
*> attributes
,
2840 SmallVectorImpl
<ast::Expr
*> &results
) {
2841 std::optional
<StringRef
> opNameRef
= name
->getName();
2842 const ods::Operation
*odsOp
= lookupODSOperation(opNameRef
);
2844 // Verify the inputs operands.
2845 if (failed(validateOperationOperands(loc
, opNameRef
, odsOp
, operands
)))
2848 // Verify the attribute list.
2849 for (ast::NamedAttributeDecl
*attr
: attributes
) {
2850 // Check for an attribute type, or a type awaiting resolution.
2851 ast::Type attrType
= attr
->getValue()->getType();
2852 if (!attrType
.isa
<ast::AttributeType
>()) {
2854 attr
->getValue()->getLoc(),
2855 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType
));
2860 (resultTypeContext
== OpResultTypeContext::Explicit
|| results
.empty()) &&
2861 "unexpected inferrence when results were explicitly specified");
2863 // If we aren't relying on type inferrence, or explicit results were provided,
2865 if (resultTypeContext
== OpResultTypeContext::Explicit
) {
2866 if (failed(validateOperationResults(loc
, opNameRef
, odsOp
, results
)))
2869 // Validate the use of interface based type inferrence for this operation.
2870 } else if (resultTypeContext
== OpResultTypeContext::Interface
) {
2872 "expected valid operation name when inferring operation results");
2873 checkOperationResultTypeInferrence(loc
, *opNameRef
, odsOp
);
2876 return ast::OperationExpr::create(ctx
, loc
, odsOp
, name
, operands
, results
,
2881 Parser::validateOperationOperands(SMRange loc
, std::optional
<StringRef
> name
,
2882 const ods::Operation
*odsOp
,
2883 SmallVectorImpl
<ast::Expr
*> &operands
) {
2884 return validateOperationOperandsOrResults(
2885 "operand", loc
, odsOp
? odsOp
->getLoc() : std::optional
<SMRange
>(), name
,
2886 operands
, odsOp
? odsOp
->getOperands() : std::nullopt
, valueTy
,
2891 Parser::validateOperationResults(SMRange loc
, std::optional
<StringRef
> name
,
2892 const ods::Operation
*odsOp
,
2893 SmallVectorImpl
<ast::Expr
*> &results
) {
2894 return validateOperationOperandsOrResults(
2895 "result", loc
, odsOp
? odsOp
->getLoc() : std::optional
<SMRange
>(), name
,
2896 results
, odsOp
? odsOp
->getResults() : std::nullopt
, typeTy
, typeRangeTy
);
2899 void Parser::checkOperationResultTypeInferrence(SMRange loc
, StringRef opName
,
2900 const ods::Operation
*odsOp
) {
2901 // If the operation might not have inferrence support, emit a warning to the
2902 // user. We don't emit an error because the interface might be added to the
2903 // operation at runtime. It's rare, but it could still happen. We emit a
2904 // warning here instead.
2906 // Handle inferrence warnings for unknown operations.
2908 ctx
.getDiagEngine().emitWarning(
2910 "operation result types are marked to be inferred, but "
2911 "`{0}` is unknown. Ensure that `{0}` supports zero "
2912 "results or implements `InferTypeOpInterface`. Include "
2913 "the ODS definition of this operation to remove this warning.",
2918 // Handle inferrence warnings for known operations that expected at least one
2919 // result, but don't have inference support. An elided results list can mean
2920 // "zero-results", and we don't want to warn when that is the expected
2922 bool requiresInferrence
=
2923 llvm::any_of(odsOp
->getResults(), [](const ods::OperandOrResult
&result
) {
2924 return !result
.isVariableLength();
2926 if (requiresInferrence
&& !odsOp
->hasResultTypeInferrence()) {
2927 ast::InFlightDiagnostic diag
= ctx
.getDiagEngine().emitWarning(
2929 llvm::formatv("operation result types are marked to be inferred, but "
2930 "`{0}` does not provide an implementation of "
2931 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2932 "`InferTypeOpInterface` at runtime, or add support to "
2933 "the ODS definition to remove this warning.",
2935 diag
->attachNote(llvm::formatv("see the definition of `{0}` here", opName
),
2941 LogicalResult
Parser::validateOperationOperandsOrResults(
2942 StringRef groupName
, SMRange loc
, std::optional
<SMRange
> odsOpLoc
,
2943 std::optional
<StringRef
> name
, SmallVectorImpl
<ast::Expr
*> &values
,
2944 ArrayRef
<ods::OperandOrResult
> odsValues
, ast::Type singleTy
,
2945 ast::RangeType rangeTy
) {
2946 // All operation types accept a single range parameter.
2947 if (values
.size() == 1) {
2948 if (failed(convertExpressionTo(values
[0], rangeTy
)))
2953 /// If the operation has ODS information, we can more accurately verify the
2956 auto emitSizeMismatchError
= [&] {
2957 return emitErrorAndNote(
2959 llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2961 groupName
, *name
, odsValues
.size(), values
.size()),
2962 *odsOpLoc
, llvm::formatv("see the definition of `{0}` here", *name
));
2965 // Handle the case where no values were provided.
2966 if (values
.empty()) {
2967 // If we don't expect any on the ODS side, we are done.
2968 if (odsValues
.empty())
2971 // If we do, check if we actually need to provide values (i.e. if any of
2972 // the values are actually required).
2973 unsigned numVariadic
= 0;
2974 for (const auto &odsValue
: odsValues
) {
2975 if (!odsValue
.isVariableLength())
2976 return emitSizeMismatchError();
2980 // If we are in a non-rewrite context, we don't need to do anything more.
2981 // Zero-values is a valid constraint on the operation.
2982 if (parserContext
!= ParserContext::Rewrite
)
2985 // Otherwise, when in a rewrite we may need to provide values to match the
2986 // ODS signature of the operation to create.
2988 // If we only have one variadic value, just use an empty list.
2989 if (numVariadic
== 1)
2992 // Otherwise, create dummy values for each of the entries so that we
2993 // adhere to the ODS signature.
2994 for (unsigned i
= 0, e
= odsValues
.size(); i
< e
; ++i
) {
2995 values
.push_back(ast::RangeExpr::create(
2996 ctx
, loc
, /*elements=*/std::nullopt
, rangeTy
));
3001 // Verify that the number of values provided matches the number of value
3002 // groups ODS expects.
3003 if (odsValues
.size() != values
.size())
3004 return emitSizeMismatchError();
3006 auto diagFn
= [&](ast::Diagnostic
&diag
) {
3007 diag
.attachNote(llvm::formatv("see the definition of `{0}` here", *name
),
3010 for (unsigned i
= 0, e
= values
.size(); i
< e
; ++i
) {
3011 ast::Type expectedType
= odsValues
[i
].isVariadic() ? rangeTy
: singleTy
;
3012 if (failed(convertExpressionTo(values
[i
], expectedType
, diagFn
)))
3018 // Otherwise, accept the value groups as they have been defined and just
3019 // ensure they are one of the expected types.
3020 for (ast::Expr
*&valueExpr
: values
) {
3021 ast::Type valueExprType
= valueExpr
->getType();
3023 // Check if this is one of the expected types.
3024 if (valueExprType
== rangeTy
|| valueExprType
== singleTy
)
3027 // If the operand is an Operation, allow converting to a Value or
3028 // ValueRange. This situations arises quite often with nested operation
3029 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3030 if (singleTy
== valueTy
) {
3031 if (valueExprType
.isa
<ast::OperationType
>()) {
3032 valueExpr
= convertOpToValue(valueExpr
);
3037 // Otherwise, try to convert the expression to a range.
3038 if (succeeded(convertExpressionTo(valueExpr
, rangeTy
)))
3042 valueExpr
->getLoc(),
3044 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3045 singleTy
, rangeTy
, valueExprType
));
3050 FailureOr
<ast::TupleExpr
*>
3051 Parser::createTupleExpr(SMRange loc
, ArrayRef
<ast::Expr
*> elements
,
3052 ArrayRef
<StringRef
> elementNames
) {
3053 for (const ast::Expr
*element
: elements
) {
3054 ast::Type eleTy
= element
->getType();
3055 if (eleTy
.isa
<ast::ConstraintType
, ast::RewriteType
, ast::TupleType
>()) {
3058 llvm::formatv("unable to build a tuple with `{0}` element", eleTy
));
3061 return ast::TupleExpr::create(ctx
, loc
, elements
, elementNames
);
3064 //===----------------------------------------------------------------------===//
3067 FailureOr
<ast::EraseStmt
*> Parser::createEraseStmt(SMRange loc
,
3068 ast::Expr
*rootOp
) {
3069 // Check that root is an Operation.
3070 ast::Type rootType
= rootOp
->getType();
3071 if (!rootType
.isa
<ast::OperationType
>())
3072 return emitError(rootOp
->getLoc(), "expected `Op` expression");
3074 return ast::EraseStmt::create(ctx
, loc
, rootOp
);
3077 FailureOr
<ast::ReplaceStmt
*>
3078 Parser::createReplaceStmt(SMRange loc
, ast::Expr
*rootOp
,
3079 MutableArrayRef
<ast::Expr
*> replValues
) {
3080 // Check that root is an Operation.
3081 ast::Type rootType
= rootOp
->getType();
3082 if (!rootType
.isa
<ast::OperationType
>()) {
3085 llvm::formatv("expected `Op` expression, but got `{0}`", rootType
));
3088 // If there are multiple replacement values, we implicitly convert any Op
3089 // expressions to the value form.
3090 bool shouldConvertOpToValues
= replValues
.size() > 1;
3091 for (ast::Expr
*&replExpr
: replValues
) {
3092 ast::Type replType
= replExpr
->getType();
3094 // Check that replExpr is an Operation, Value, or ValueRange.
3095 if (replType
.isa
<ast::OperationType
>()) {
3096 if (shouldConvertOpToValues
)
3097 replExpr
= convertOpToValue(replExpr
);
3101 if (replType
!= valueTy
&& replType
!= valueRangeTy
) {
3102 return emitError(replExpr
->getLoc(),
3103 llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3104 "expression, but got `{0}`",
3109 return ast::ReplaceStmt::create(ctx
, loc
, rootOp
, replValues
);
3112 FailureOr
<ast::RewriteStmt
*>
3113 Parser::createRewriteStmt(SMRange loc
, ast::Expr
*rootOp
,
3114 ast::CompoundStmt
*rewriteBody
) {
3115 // Check that root is an Operation.
3116 ast::Type rootType
= rootOp
->getType();
3117 if (!rootType
.isa
<ast::OperationType
>()) {
3120 llvm::formatv("expected `Op` expression, but got `{0}`", rootType
));
3123 return ast::RewriteStmt::create(ctx
, loc
, rootOp
, rewriteBody
);
3126 //===----------------------------------------------------------------------===//
3128 //===----------------------------------------------------------------------===//
3130 LogicalResult
Parser::codeCompleteMemberAccess(ast::Expr
*parentExpr
) {
3131 ast::Type parentType
= parentExpr
->getType();
3132 if (ast::OperationType opType
= parentType
.dyn_cast
<ast::OperationType
>())
3133 codeCompleteContext
->codeCompleteOperationMemberAccess(opType
);
3134 else if (ast::TupleType tupleType
= parentType
.dyn_cast
<ast::TupleType
>())
3135 codeCompleteContext
->codeCompleteTupleMemberAccess(tupleType
);
3140 Parser::codeCompleteAttributeName(std::optional
<StringRef
> opName
) {
3142 codeCompleteContext
->codeCompleteOperationAttributeName(*opName
);
3147 Parser::codeCompleteConstraintName(ast::Type inferredType
,
3148 bool allowInlineTypeConstraints
) {
3149 codeCompleteContext
->codeCompleteConstraintName(
3150 inferredType
, allowInlineTypeConstraints
, curDeclScope
);
3154 LogicalResult
Parser::codeCompleteDialectName() {
3155 codeCompleteContext
->codeCompleteDialectName();
3159 LogicalResult
Parser::codeCompleteOperationName(StringRef dialectName
) {
3160 codeCompleteContext
->codeCompleteOperationName(dialectName
);
3164 LogicalResult
Parser::codeCompletePatternMetadata() {
3165 codeCompleteContext
->codeCompletePatternMetadata();
3169 LogicalResult
Parser::codeCompleteIncludeFilename(StringRef curPath
) {
3170 codeCompleteContext
->codeCompleteIncludeFilename(curPath
);
3174 void Parser::codeCompleteCallSignature(ast::Node
*parent
,
3175 unsigned currentNumArgs
) {
3176 ast::CallableDecl
*callableDecl
= tryExtractCallableDecl(parent
);
3180 codeCompleteContext
->codeCompleteCallSignature(callableDecl
, currentNumArgs
);
3183 void Parser::codeCompleteOperationOperandsSignature(
3184 std::optional
<StringRef
> opName
, unsigned currentNumOperands
) {
3185 codeCompleteContext
->codeCompleteOperationOperandsSignature(
3186 opName
, currentNumOperands
);
3189 void Parser::codeCompleteOperationResultsSignature(
3190 std::optional
<StringRef
> opName
, unsigned currentNumResults
) {
3191 codeCompleteContext
->codeCompleteOperationResultsSignature(opName
,
3195 //===----------------------------------------------------------------------===//
3197 //===----------------------------------------------------------------------===//
3199 FailureOr
<ast::Module
*>
3200 mlir::pdll::parsePDLLAST(ast::Context
&ctx
, llvm::SourceMgr
&sourceMgr
,
3201 bool enableDocumentation
,
3202 CodeCompleteContext
*codeCompleteContext
) {
3203 Parser
parser(ctx
, sourceMgr
, enableDocumentation
, codeCompleteContext
);
3204 return parser
.parseModule();