Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / PDLToPDLInterp / Predicate.h
blob5ad2c477573a5b637ade7887c7e637b486678197
1 //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions for "predicates" used when converting PDL into
10 // a matcher tree. Predicates are composed of three different parts:
12 // * Positions
13 // - A position refers to a specific location on the input DAG, i.e. an
14 // existing MLIR entity being matched. These can be attributes, operands,
15 // operations, results, and types. Each position also defines a relation to
16 // its parent. For example, the operand `[0] -> 1` has a parent operation
17 // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation
18 // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge
19 // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position
20 // without a parent is `[0]`, which refers to the root operation.
21 // * Questions
22 // - A question refers to a query on a specific positional value. For
23 // example, an operation name question checks the name of an operation
24 // position.
25 // * Answers
26 // - An answer is the expected result of a question. For example, when
27 // matching an operation with the name "foo.op". The question would be an
28 // operation name question, with an expected answer of "foo.op".
30 //===----------------------------------------------------------------------===//
32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/OperationSupport.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/Types.h"
40 namespace mlir {
41 namespace pdl_to_pdl_interp {
42 namespace Predicates {
43 /// An enumeration of the kinds of predicates.
44 enum Kind : unsigned {
45 /// Positions, ordered by decreasing priority.
46 OperationPos,
47 OperandPos,
48 OperandGroupPos,
49 AttributePos,
50 ConstraintResultPos,
51 ResultPos,
52 ResultGroupPos,
53 TypePos,
54 AttributeLiteralPos,
55 TypeLiteralPos,
56 UsersPos,
57 ForEachPos,
59 // Questions, ordered by dependency and decreasing priority.
60 IsNotNullQuestion,
61 OperationNameQuestion,
62 TypeQuestion,
63 AttributeQuestion,
64 OperandCountAtLeastQuestion,
65 OperandCountQuestion,
66 ResultCountAtLeastQuestion,
67 ResultCountQuestion,
68 EqualToQuestion,
69 ConstraintQuestion,
71 // Answers.
72 AttributeAnswer,
73 FalseAnswer,
74 OperationNameAnswer,
75 TrueAnswer,
76 TypeAnswer,
77 UnsignedAnswer,
79 } // namespace Predicates
81 /// Base class for all predicates, used to allow efficient pointer comparison.
82 template <typename ConcreteT, typename BaseT, typename Key,
83 Predicates::Kind Kind>
84 class PredicateBase : public BaseT {
85 public:
86 using KeyTy = Key;
87 using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>;
89 template <typename KeyT>
90 explicit PredicateBase(KeyT &&key)
91 : BaseT(Kind), key(std::forward<KeyT>(key)) {}
93 /// Get an instance of this position.
94 template <typename... Args>
95 static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
96 return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
99 /// Construct an instance with the given storage allocator.
100 template <typename KeyT>
101 static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
102 KeyT &&key) {
103 return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key));
106 /// Utility methods required by the storage allocator.
107 bool operator==(const KeyTy &key) const { return this->key == key; }
108 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
110 /// Return the key value of this predicate.
111 const KeyTy &getValue() const { return key; }
113 protected:
114 KeyTy key;
117 /// Base storage for simple predicates that only unique with the kind.
118 template <typename ConcreteT, typename BaseT, Predicates::Kind Kind>
119 class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT {
120 public:
121 using Base = PredicateBase<ConcreteT, BaseT, void, Kind>;
123 explicit PredicateBase() : BaseT(Kind) {}
125 static ConcreteT *get(StorageUniquer &uniquer) {
126 return uniquer.get<ConcreteT>();
128 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
131 //===----------------------------------------------------------------------===//
132 // Positions
133 //===----------------------------------------------------------------------===//
135 struct OperationPosition;
137 /// A position describes a value on the input IR on which a predicate may be
138 /// applied, such as an operation or attribute. This enables re-use between
139 /// predicates, and assists generating bytecode and memory management.
141 /// Operation positions form the base of other positions, which are formed
142 /// relative to a parent operation. Operations are anchored at Operand nodes,
143 /// except for the root operation which is parentless.
144 class Position : public StorageUniquer::BaseStorage {
145 public:
146 explicit Position(Predicates::Kind kind) : kind(kind) {}
147 virtual ~Position();
149 /// Returns the depth of the first ancestor operation position.
150 unsigned getOperationDepth() const;
152 /// Returns the parent position. The root operation position has no parent.
153 Position *getParent() const { return parent; }
155 /// Returns the kind of this position.
156 Predicates::Kind getKind() const { return kind; }
158 protected:
159 /// Link to the parent position.
160 Position *parent = nullptr;
162 private:
163 /// The kind of this position.
164 Predicates::Kind kind;
167 //===----------------------------------------------------------------------===//
168 // AttributePosition
170 /// A position describing an attribute of an operation.
171 struct AttributePosition
172 : public PredicateBase<AttributePosition, Position,
173 std::pair<OperationPosition *, StringAttr>,
174 Predicates::AttributePos> {
175 explicit AttributePosition(const KeyTy &key);
177 /// Returns the attribute name of this position.
178 StringAttr getName() const { return key.second; }
181 //===----------------------------------------------------------------------===//
182 // AttributeLiteralPosition
184 /// A position describing a literal attribute.
185 struct AttributeLiteralPosition
186 : public PredicateBase<AttributeLiteralPosition, Position, Attribute,
187 Predicates::AttributeLiteralPos> {
188 using PredicateBase::PredicateBase;
191 //===----------------------------------------------------------------------===//
192 // ForEachPosition
194 /// A position describing an iterative choice of an operation.
195 struct ForEachPosition : public PredicateBase<ForEachPosition, Position,
196 std::pair<Position *, unsigned>,
197 Predicates::ForEachPos> {
198 explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; }
200 /// Returns the ID, for differentiating various loops.
201 /// For upward traversals, this is the index of the root.
202 unsigned getID() const { return key.second; }
205 //===----------------------------------------------------------------------===//
206 // OperandPosition
208 /// A position describing an operand of an operation.
209 struct OperandPosition
210 : public PredicateBase<OperandPosition, Position,
211 std::pair<OperationPosition *, unsigned>,
212 Predicates::OperandPos> {
213 explicit OperandPosition(const KeyTy &key);
215 /// Returns the operand number of this position.
216 unsigned getOperandNumber() const { return key.second; }
219 //===----------------------------------------------------------------------===//
220 // OperandGroupPosition
222 /// A position describing an operand group of an operation.
223 struct OperandGroupPosition
224 : public PredicateBase<
225 OperandGroupPosition, Position,
226 std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
227 Predicates::OperandGroupPos> {
228 explicit OperandGroupPosition(const KeyTy &key);
230 /// Returns a hash suitable for the given keytype.
231 static llvm::hash_code hashKey(const KeyTy &key) {
232 return llvm::hash_value(key);
235 /// Returns the group number of this position. If std::nullopt, this group
236 /// refers to all operands.
237 std::optional<unsigned> getOperandGroupNumber() const {
238 return std::get<1>(key);
241 /// Returns if the operand group has unknown size. If false, the operand group
242 /// has at max one element.
243 bool isVariadic() const { return std::get<2>(key); }
246 //===----------------------------------------------------------------------===//
247 // OperationPosition
249 /// An operation position describes an operation node in the IR. Other position
250 /// kinds are formed with respect to an operation position.
251 struct OperationPosition : public PredicateBase<OperationPosition, Position,
252 std::pair<Position *, unsigned>,
253 Predicates::OperationPos> {
254 explicit OperationPosition(const KeyTy &key) : Base(key) {
255 parent = key.first;
258 /// Returns a hash suitable for the given keytype.
259 static llvm::hash_code hashKey(const KeyTy &key) {
260 return llvm::hash_value(key);
263 /// Gets the root position.
264 static OperationPosition *getRoot(StorageUniquer &uniquer) {
265 return Base::get(uniquer, nullptr, 0);
268 /// Gets an operation position with the given parent.
269 static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
270 return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
273 /// Returns the depth of this position.
274 unsigned getDepth() const { return key.second; }
276 /// Returns if this operation position corresponds to the root.
277 bool isRoot() const { return getDepth() == 0; }
279 /// Returns if this operation represents an operand defining op.
280 bool isOperandDefiningOp() const;
283 //===----------------------------------------------------------------------===//
284 // ConstraintPosition
286 struct ConstraintQuestion;
288 /// A position describing the result of a native constraint. It saves the
289 /// corresponding ConstraintQuestion and result index to enable referring
290 /// back to them
291 struct ConstraintPosition
292 : public PredicateBase<ConstraintPosition, Position,
293 std::pair<ConstraintQuestion *, unsigned>,
294 Predicates::ConstraintResultPos> {
295 using PredicateBase::PredicateBase;
297 /// Returns the ConstraintQuestion to enable keeping track of the native
298 /// constraint this position stems from.
299 ConstraintQuestion *getQuestion() const { return key.first; }
301 // Returns the result index of this position
302 unsigned getIndex() const { return key.second; }
305 //===----------------------------------------------------------------------===//
306 // ResultPosition
308 /// A position describing a result of an operation.
309 struct ResultPosition
310 : public PredicateBase<ResultPosition, Position,
311 std::pair<OperationPosition *, unsigned>,
312 Predicates::ResultPos> {
313 explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }
315 /// Returns the result number of this position.
316 unsigned getResultNumber() const { return key.second; }
319 //===----------------------------------------------------------------------===//
320 // ResultGroupPosition
322 /// A position describing a result group of an operation.
323 struct ResultGroupPosition
324 : public PredicateBase<
325 ResultGroupPosition, Position,
326 std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
327 Predicates::ResultGroupPos> {
328 explicit ResultGroupPosition(const KeyTy &key) : Base(key) {
329 parent = std::get<0>(key);
332 /// Returns a hash suitable for the given keytype.
333 static llvm::hash_code hashKey(const KeyTy &key) {
334 return llvm::hash_value(key);
337 /// Returns the group number of this position. If std::nullopt, this group
338 /// refers to all results.
339 std::optional<unsigned> getResultGroupNumber() const {
340 return std::get<1>(key);
343 /// Returns if the result group has unknown size. If false, the result group
344 /// has at max one element.
345 bool isVariadic() const { return std::get<2>(key); }
348 //===----------------------------------------------------------------------===//
349 // TypePosition
351 /// A position describing the result type of an entity, i.e. an Attribute,
352 /// Operand, Result, etc.
353 struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
354 Predicates::TypePos> {
355 explicit TypePosition(const KeyTy &key) : Base(key) {
356 assert((isa<AttributePosition, OperandPosition, OperandGroupPosition,
357 ResultPosition, ResultGroupPosition>(key)) &&
358 "expected parent to be an attribute, operand, or result");
359 parent = key;
363 //===----------------------------------------------------------------------===//
364 // TypeLiteralPosition
366 /// A position describing a literal type or type range. The value is stored as
367 /// either a TypeAttr, or an ArrayAttr of TypeAttr.
368 struct TypeLiteralPosition
369 : public PredicateBase<TypeLiteralPosition, Position, Attribute,
370 Predicates::TypeLiteralPos> {
371 using PredicateBase::PredicateBase;
374 //===----------------------------------------------------------------------===//
375 // UsersPosition
377 /// A position describing the users of a value or a range of values. The second
378 /// value in the key indicates whether we choose users of a representative for
379 /// a range (this is true, e.g., in the upward traversals).
380 struct UsersPosition
381 : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
382 Predicates::UsersPos> {
383 explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; }
385 /// Returns a hash suitable for the given keytype.
386 static llvm::hash_code hashKey(const KeyTy &key) {
387 return llvm::hash_value(key);
390 /// Indicates whether to compute a range of a representative.
391 bool useRepresentative() const { return key.second; }
394 //===----------------------------------------------------------------------===//
395 // Qualifiers
396 //===----------------------------------------------------------------------===//
398 /// An ordinal predicate consists of a "Question" and a set of acceptable
399 /// "Answers" (later converted to ordinal values). A predicate will query some
400 /// property of a positional value and decide what to do based on the result.
402 /// This makes top-level predicate representations ordinal (SwitchOp). Later,
403 /// predicates that end up with only one acceptable answer (including all
404 /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the
405 /// matcher.
407 /// For simplicity, both are represented as "qualifiers", with a base kind and
408 /// perhaps additional properties. For example, all OperationName predicates ask
409 /// the same question, but GenericConstraint predicates may ask different ones.
410 class Qualifier : public StorageUniquer::BaseStorage {
411 public:
412 explicit Qualifier(Predicates::Kind kind) : kind(kind) {}
414 /// Returns the kind of this qualifier.
415 Predicates::Kind getKind() const { return kind; }
417 private:
418 /// The kind of this position.
419 Predicates::Kind kind;
422 //===----------------------------------------------------------------------===//
423 // Answers
425 /// An Answer representing an `Attribute` value.
426 struct AttributeAnswer
427 : public PredicateBase<AttributeAnswer, Qualifier, Attribute,
428 Predicates::AttributeAnswer> {
429 using Base::Base;
432 /// An Answer representing an `OperationName` value.
433 struct OperationNameAnswer
434 : public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
435 Predicates::OperationNameAnswer> {
436 using Base::Base;
439 /// An Answer representing a boolean `true` value.
440 struct TrueAnswer
441 : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
442 using Base::Base;
445 /// An Answer representing a boolean 'false' value.
446 struct FalseAnswer
447 : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> {
448 using Base::Base;
451 /// An Answer representing a `Type` value. The value is stored as either a
452 /// TypeAttr, or an ArrayAttr of TypeAttr.
453 struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
454 Predicates::TypeAnswer> {
455 using Base::Base;
458 /// An Answer representing an unsigned value.
459 struct UnsignedAnswer
460 : public PredicateBase<UnsignedAnswer, Qualifier, unsigned,
461 Predicates::UnsignedAnswer> {
462 using Base::Base;
465 //===----------------------------------------------------------------------===//
466 // Questions
468 /// Compare an `Attribute` to a constant value.
469 struct AttributeQuestion
470 : public PredicateBase<AttributeQuestion, Qualifier, void,
471 Predicates::AttributeQuestion> {};
473 /// Apply a parameterized constraint to multiple position values and possibly
474 /// produce results.
475 struct ConstraintQuestion
476 : public PredicateBase<
477 ConstraintQuestion, Qualifier,
478 std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
479 Predicates::ConstraintQuestion> {
480 using Base::Base;
482 /// Return the name of the constraint.
483 StringRef getName() const { return std::get<0>(key); }
485 /// Return the arguments of the constraint.
486 ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
488 /// Return the result types of the constraint.
489 ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
491 /// Return the negation status of the constraint.
492 bool getIsNegated() const { return std::get<3>(key); }
494 /// Construct an instance with the given storage allocator.
495 static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
496 KeyTy key) {
497 return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
498 alloc.copyInto(std::get<1>(key)),
499 alloc.copyInto(std::get<2>(key)),
500 std::get<3>(key)});
503 /// Returns a hash suitable for the given keytype.
504 static llvm::hash_code hashKey(const KeyTy &key) {
505 return llvm::hash_value(key);
509 /// Compare the equality of two values.
510 struct EqualToQuestion
511 : public PredicateBase<EqualToQuestion, Qualifier, Position *,
512 Predicates::EqualToQuestion> {
513 using Base::Base;
516 /// Compare a positional value with null, i.e. check if it exists.
517 struct IsNotNullQuestion
518 : public PredicateBase<IsNotNullQuestion, Qualifier, void,
519 Predicates::IsNotNullQuestion> {};
521 /// Compare the number of operands of an operation with a known value.
522 struct OperandCountQuestion
523 : public PredicateBase<OperandCountQuestion, Qualifier, void,
524 Predicates::OperandCountQuestion> {};
525 struct OperandCountAtLeastQuestion
526 : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
527 Predicates::OperandCountAtLeastQuestion> {};
529 /// Compare the name of an operation with a known value.
530 struct OperationNameQuestion
531 : public PredicateBase<OperationNameQuestion, Qualifier, void,
532 Predicates::OperationNameQuestion> {};
534 /// Compare the number of results of an operation with a known value.
535 struct ResultCountQuestion
536 : public PredicateBase<ResultCountQuestion, Qualifier, void,
537 Predicates::ResultCountQuestion> {};
538 struct ResultCountAtLeastQuestion
539 : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
540 Predicates::ResultCountAtLeastQuestion> {};
542 /// Compare the type of an attribute or value with a known type.
543 struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
544 Predicates::TypeQuestion> {};
546 //===----------------------------------------------------------------------===//
547 // PredicateUniquer
548 //===----------------------------------------------------------------------===//
550 /// This class provides a storage uniquer that is used to allocate predicate
551 /// instances.
552 class PredicateUniquer : public StorageUniquer {
553 public:
554 PredicateUniquer() {
555 // Register the types of Positions with the uniquer.
556 registerParametricStorageType<AttributePosition>();
557 registerParametricStorageType<AttributeLiteralPosition>();
558 registerParametricStorageType<ConstraintPosition>();
559 registerParametricStorageType<ForEachPosition>();
560 registerParametricStorageType<OperandPosition>();
561 registerParametricStorageType<OperandGroupPosition>();
562 registerParametricStorageType<OperationPosition>();
563 registerParametricStorageType<ResultPosition>();
564 registerParametricStorageType<ResultGroupPosition>();
565 registerParametricStorageType<TypePosition>();
566 registerParametricStorageType<TypeLiteralPosition>();
567 registerParametricStorageType<UsersPosition>();
569 // Register the types of Questions with the uniquer.
570 registerParametricStorageType<AttributeAnswer>();
571 registerParametricStorageType<OperationNameAnswer>();
572 registerParametricStorageType<TypeAnswer>();
573 registerParametricStorageType<UnsignedAnswer>();
574 registerSingletonStorageType<FalseAnswer>();
575 registerSingletonStorageType<TrueAnswer>();
577 // Register the types of Answers with the uniquer.
578 registerParametricStorageType<ConstraintQuestion>();
579 registerParametricStorageType<EqualToQuestion>();
580 registerSingletonStorageType<AttributeQuestion>();
581 registerSingletonStorageType<IsNotNullQuestion>();
582 registerSingletonStorageType<OperandCountQuestion>();
583 registerSingletonStorageType<OperandCountAtLeastQuestion>();
584 registerSingletonStorageType<OperationNameQuestion>();
585 registerSingletonStorageType<ResultCountQuestion>();
586 registerSingletonStorageType<ResultCountAtLeastQuestion>();
587 registerSingletonStorageType<TypeQuestion>();
591 //===----------------------------------------------------------------------===//
592 // PredicateBuilder
593 //===----------------------------------------------------------------------===//
595 /// This class provides utilities for constructing predicates.
596 class PredicateBuilder {
597 public:
598 PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
599 : uniquer(uniquer), ctx(ctx) {}
601 //===--------------------------------------------------------------------===//
602 // Positions
603 //===--------------------------------------------------------------------===//
605 /// Returns the root operation position.
606 Position *getRoot() { return OperationPosition::getRoot(uniquer); }
608 /// Returns the parent position defining the value held by the given operand.
609 OperationPosition *getOperandDefiningOp(Position *p) {
610 assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
611 "expected operand position");
612 return OperationPosition::get(uniquer, p);
615 /// Returns the operation position equivalent to the given position.
616 OperationPosition *getPassthroughOp(Position *p) {
617 assert((isa<ForEachPosition>(p)) && "expected users position");
618 return OperationPosition::get(uniquer, p);
621 // Returns a position for a new value created by a constraint.
622 ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
623 unsigned index) {
624 return ConstraintPosition::get(uniquer, std::make_pair(q, index));
627 /// Returns an attribute position for an attribute of the given operation.
628 Position *getAttribute(OperationPosition *p, StringRef name) {
629 return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
632 /// Returns an attribute position for the given attribute.
633 Position *getAttributeLiteral(Attribute attr) {
634 return AttributeLiteralPosition::get(uniquer, attr);
637 Position *getForEach(Position *p, unsigned id) {
638 return ForEachPosition::get(uniquer, p, id);
641 /// Returns an operand position for an operand of the given operation.
642 Position *getOperand(OperationPosition *p, unsigned operand) {
643 return OperandPosition::get(uniquer, p, operand);
646 /// Returns a position for a group of operands of the given operation.
647 Position *getOperandGroup(OperationPosition *p, std::optional<unsigned> group,
648 bool isVariadic) {
649 return OperandGroupPosition::get(uniquer, p, group, isVariadic);
651 Position *getAllOperands(OperationPosition *p) {
652 return getOperandGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true);
655 /// Returns a result position for a result of the given operation.
656 Position *getResult(OperationPosition *p, unsigned result) {
657 return ResultPosition::get(uniquer, p, result);
660 /// Returns a position for a group of results of the given operation.
661 Position *getResultGroup(OperationPosition *p, std::optional<unsigned> group,
662 bool isVariadic) {
663 return ResultGroupPosition::get(uniquer, p, group, isVariadic);
665 Position *getAllResults(OperationPosition *p) {
666 return getResultGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true);
669 /// Returns a type position for the given entity.
670 Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
672 /// Returns a type position for the given type value. The value is stored
673 /// as either a TypeAttr, or an ArrayAttr of TypeAttr.
674 Position *getTypeLiteral(Attribute attr) {
675 return TypeLiteralPosition::get(uniquer, attr);
678 /// Returns the users of a position using the value at the given operand.
679 UsersPosition *getUsers(Position *p, bool useRepresentative) {
680 assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
681 ResultGroupPosition>(p)) &&
682 "expected result position");
683 return UsersPosition::get(uniquer, p, useRepresentative);
686 //===--------------------------------------------------------------------===//
687 // Qualifiers
688 //===--------------------------------------------------------------------===//
690 /// An ordinal predicate consists of a "Question" and a set of acceptable
691 /// "Answers" (later converted to ordinal values). A predicate will query some
692 /// property of a positional value and decide what to do based on the result.
693 using Predicate = std::pair<Qualifier *, Qualifier *>;
695 /// Create a predicate comparing an attribute to a known value.
696 Predicate getAttributeConstraint(Attribute attr) {
697 return {AttributeQuestion::get(uniquer),
698 AttributeAnswer::get(uniquer, attr)};
701 /// Create a predicate checking if two values are equal.
702 Predicate getEqualTo(Position *pos) {
703 return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)};
706 /// Create a predicate checking if two values are not equal.
707 Predicate getNotEqualTo(Position *pos) {
708 return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)};
711 /// Create a predicate that applies a generic constraint.
712 Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
713 ArrayRef<Type> resultTypes, bool isNegated) {
714 return {ConstraintQuestion::get(
715 uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
716 TrueAnswer::get(uniquer)};
719 /// Create a predicate comparing a value with null.
720 Predicate getIsNotNull() {
721 return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)};
724 /// Create a predicate comparing the number of operands of an operation to a
725 /// known value.
726 Predicate getOperandCount(unsigned count) {
727 return {OperandCountQuestion::get(uniquer),
728 UnsignedAnswer::get(uniquer, count)};
730 Predicate getOperandCountAtLeast(unsigned count) {
731 return {OperandCountAtLeastQuestion::get(uniquer),
732 UnsignedAnswer::get(uniquer, count)};
735 /// Create a predicate comparing the name of an operation to a known value.
736 Predicate getOperationName(StringRef name) {
737 return {OperationNameQuestion::get(uniquer),
738 OperationNameAnswer::get(uniquer, OperationName(name, ctx))};
741 /// Create a predicate comparing the number of results of an operation to a
742 /// known value.
743 Predicate getResultCount(unsigned count) {
744 return {ResultCountQuestion::get(uniquer),
745 UnsignedAnswer::get(uniquer, count)};
747 Predicate getResultCountAtLeast(unsigned count) {
748 return {ResultCountAtLeastQuestion::get(uniquer),
749 UnsignedAnswer::get(uniquer, count)};
752 /// Create a predicate comparing the type of an attribute or value to a known
753 /// type. The value is stored as either a TypeAttr, or an ArrayAttr of
754 /// TypeAttr.
755 Predicate getTypeConstraint(Attribute type) {
756 return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
759 private:
760 /// The uniquer used when allocating predicate nodes.
761 PredicateUniquer &uniquer;
763 /// The current MLIR context.
764 MLIRContext *ctx;
767 } // namespace pdl_to_pdl_interp
768 } // namespace mlir
770 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_