1 //===- Predicate.cpp - Predicate class ------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Wrapper around predicates defined in TableGen.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/TableGen/Predicate.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
22 using namespace tblgen
;
25 using llvm::SpecificBumpPtrAllocator
;
27 // Construct a Predicate from a record.
28 Pred::Pred(const Record
*record
) : def(record
) {
29 assert(def
->isSubClassOf("Pred") &&
30 "must be a subclass of TableGen 'Pred' class");
33 // Construct a Predicate from an initializer.
34 Pred::Pred(const Init
*init
) {
35 if (const auto *defInit
= dyn_cast_or_null
<llvm::DefInit
>(init
))
36 def
= defInit
->getDef();
39 std::string
Pred::getCondition() const {
40 // Static dispatch to subclasses.
41 if (def
->isSubClassOf("CombinedPred"))
42 return static_cast<const CombinedPred
*>(this)->getConditionImpl();
43 if (def
->isSubClassOf("CPred"))
44 return static_cast<const CPred
*>(this)->getConditionImpl();
45 llvm_unreachable("Pred::getCondition must be overridden in subclasses");
48 bool Pred::isCombined() const {
49 return def
&& def
->isSubClassOf("CombinedPred");
52 ArrayRef
<SMLoc
> Pred::getLoc() const { return def
->getLoc(); }
54 CPred::CPred(const Record
*record
) : Pred(record
) {
55 assert(def
->isSubClassOf("CPred") &&
56 "must be a subclass of Tablegen 'CPred' class");
59 CPred::CPred(const Init
*init
) : Pred(init
) {
60 assert((!def
|| def
->isSubClassOf("CPred")) &&
61 "must be a subclass of Tablegen 'CPred' class");
64 // Get condition of the C Predicate.
65 std::string
CPred::getConditionImpl() const {
66 assert(!isNull() && "null predicate does not have a condition");
67 return std::string(def
->getValueAsString("predExpr"));
70 CombinedPred::CombinedPred(const Record
*record
) : Pred(record
) {
71 assert(def
->isSubClassOf("CombinedPred") &&
72 "must be a subclass of Tablegen 'CombinedPred' class");
75 CombinedPred::CombinedPred(const Init
*init
) : Pred(init
) {
76 assert((!def
|| def
->isSubClassOf("CombinedPred")) &&
77 "must be a subclass of Tablegen 'CombinedPred' class");
80 const Record
*CombinedPred::getCombinerDef() const {
81 assert(def
->getValue("kind") && "CombinedPred must have a value 'kind'");
82 return def
->getValueAsDef("kind");
85 std::vector
<const Record
*> CombinedPred::getChildren() const {
86 assert(def
->getValue("children") &&
87 "CombinedPred must have a value 'children'");
88 return def
->getValueAsListOfDefs("children");
92 // Kinds of nodes in a logical predicate tree.
93 enum class PredCombinerKind
{
100 // Special kinds that are used in simplification.
105 // A node in a logical predicate tree.
107 PredCombinerKind kind
;
108 const Pred
*predicate
;
109 SmallVector
<PredNode
*, 4> children
;
112 // Prefix and suffix are used by ConcatPred.
118 // Get a predicate tree node kind based on the kind used in the predicate
120 static PredCombinerKind
getPredCombinerKind(const Pred
&pred
) {
121 if (!pred
.isCombined())
122 return PredCombinerKind::Leaf
;
124 const auto &combinedPred
= static_cast<const CombinedPred
&>(pred
);
125 return StringSwitch
<PredCombinerKind
>(
126 combinedPred
.getCombinerDef()->getName())
127 .Case("PredCombinerAnd", PredCombinerKind::And
)
128 .Case("PredCombinerOr", PredCombinerKind::Or
)
129 .Case("PredCombinerNot", PredCombinerKind::Not
)
130 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves
)
131 .Case("PredCombinerConcat", PredCombinerKind::Concat
);
135 // Substitution<pattern, replacement>.
136 using Subst
= std::pair
<StringRef
, StringRef
>;
139 /// Perform the given substitutions on 'str' in-place.
140 static void performSubstitutions(std::string
&str
,
141 ArrayRef
<Subst
> substitutions
) {
142 // Apply all parent substitutions from innermost to outermost.
143 for (const auto &subst
: llvm::reverse(substitutions
)) {
144 auto pos
= str
.find(std::string(subst
.first
));
145 while (pos
!= std::string::npos
) {
146 str
.replace(pos
, subst
.first
.size(), std::string(subst
.second
));
147 // Skip the newly inserted substring, which itself may consider the
149 pos
+= subst
.second
.size();
150 // Find the next possible match position.
151 pos
= str
.find(std::string(subst
.first
), pos
);
156 // Build the predicate tree starting from the top-level predicate, which may
157 // have children, and perform leaf substitutions inplace. Note that after
158 // substitution, nodes are still pointing to the original TableGen record.
159 // All nodes are created within "allocator".
161 buildPredicateTree(const Pred
&root
,
162 SpecificBumpPtrAllocator
<PredNode
> &allocator
,
163 ArrayRef
<Subst
> substitutions
) {
164 auto *rootNode
= allocator
.Allocate();
165 new (rootNode
) PredNode
;
166 rootNode
->kind
= getPredCombinerKind(root
);
167 rootNode
->predicate
= &root
;
168 if (!root
.isCombined()) {
169 rootNode
->expr
= root
.getCondition();
170 performSubstitutions(rootNode
->expr
, substitutions
);
174 // If the current combined predicate is a leaf substitution, append it to the
175 // list before continuing.
176 auto allSubstitutions
= llvm::to_vector
<4>(substitutions
);
177 if (rootNode
->kind
== PredCombinerKind::SubstLeaves
) {
178 const auto &substPred
= static_cast<const SubstLeavesPred
&>(root
);
179 allSubstitutions
.push_back(
180 {substPred
.getPattern(), substPred
.getReplacement()});
182 // If the current predicate is a ConcatPred, record the prefix and suffix.
183 } else if (rootNode
->kind
== PredCombinerKind::Concat
) {
184 const auto &concatPred
= static_cast<const ConcatPred
&>(root
);
185 rootNode
->prefix
= std::string(concatPred
.getPrefix());
186 performSubstitutions(rootNode
->prefix
, substitutions
);
187 rootNode
->suffix
= std::string(concatPred
.getSuffix());
188 performSubstitutions(rootNode
->suffix
, substitutions
);
191 // Build child subtrees.
192 auto combined
= static_cast<const CombinedPred
&>(root
);
193 for (const auto *record
: combined
.getChildren()) {
195 buildPredicateTree(Pred(record
), allocator
, allSubstitutions
);
196 rootNode
->children
.push_back(childTree
);
201 // Simplify a predicate tree rooted at "node" using the predicates that are
202 // known to be true(false). For AND(OR) combined predicates, if any of the
203 // children is known to be false(true), the result is also false(true).
204 // Furthermore, for AND(OR) combined predicates, children that are known to be
205 // true(false) don't have to be checked dynamically.
207 propagateGroundTruth(PredNode
*node
,
208 const llvm::SmallPtrSetImpl
<Pred
*> &knownTruePreds
,
209 const llvm::SmallPtrSetImpl
<Pred
*> &knownFalsePreds
) {
210 // If the current predicate is known to be true or false, change the kind of
211 // the node and return immediately.
212 if (knownTruePreds
.count(node
->predicate
) != 0) {
213 node
->kind
= PredCombinerKind::True
;
214 node
->children
.clear();
217 if (knownFalsePreds
.count(node
->predicate
) != 0) {
218 node
->kind
= PredCombinerKind::False
;
219 node
->children
.clear();
223 // If the current node is a substitution, stop recursion now.
224 // The expressions in the leaves below this node were rewritten, but the nodes
225 // still point to the original predicate records. While the original
226 // predicate may be known to be true or false, it is not necessarily the case
228 // TODO: we can support ground truth for rewritten
229 // predicates by either (a) having our own unique'ing of the predicates
230 // instead of relying on TableGen record pointers or (b) taking ground truth
231 // values optionally prefixed with a list of substitutions to apply, e.g.
232 // "predX is true by itself as well as predSubY leaf substitution had been
234 if (node
->kind
== PredCombinerKind::SubstLeaves
) {
238 if (node
->kind
== PredCombinerKind::And
&& node
->children
.empty()) {
239 node
->kind
= PredCombinerKind::True
;
243 if (node
->kind
== PredCombinerKind::Or
&& node
->children
.empty()) {
244 node
->kind
= PredCombinerKind::False
;
248 // Otherwise, look at child nodes.
250 // Move child nodes into some local variable so that they can be optimized
251 // separately and re-added if necessary.
252 llvm::SmallVector
<PredNode
*, 4> children
;
253 std::swap(node
->children
, children
);
255 for (auto &child
: children
) {
256 // First, simplify the child. This maintains the predicate as it was.
257 auto *simplifiedChild
=
258 propagateGroundTruth(child
, knownTruePreds
, knownFalsePreds
);
260 // Just add the child if we don't know how to simplify the current node.
261 if (node
->kind
!= PredCombinerKind::And
&&
262 node
->kind
!= PredCombinerKind::Or
) {
263 node
->children
.push_back(simplifiedChild
);
267 // Second, based on the type define which known values of child predicates
268 // immediately collapse this predicate to a known value, and which others
269 // may be safely ignored.
270 // OR(..., True, ...) = True
271 // OR(..., False, ...) = OR(..., ...)
272 // AND(..., False, ...) = False
273 // AND(..., True, ...) = AND(..., ...)
274 auto collapseKind
= node
->kind
== PredCombinerKind::And
275 ? PredCombinerKind::False
276 : PredCombinerKind::True
;
277 auto eraseKind
= node
->kind
== PredCombinerKind::And
278 ? PredCombinerKind::True
279 : PredCombinerKind::False
;
280 const auto &collapseList
=
281 node
->kind
== PredCombinerKind::And
? knownFalsePreds
: knownTruePreds
;
282 const auto &eraseList
=
283 node
->kind
== PredCombinerKind::And
? knownTruePreds
: knownFalsePreds
;
284 if (simplifiedChild
->kind
== collapseKind
||
285 collapseList
.count(simplifiedChild
->predicate
) != 0) {
286 node
->kind
= collapseKind
;
287 node
->children
.clear();
290 if (simplifiedChild
->kind
== eraseKind
||
291 eraseList
.count(simplifiedChild
->predicate
) != 0) {
294 node
->children
.push_back(simplifiedChild
);
299 // Combine a list of predicate expressions using a binary combiner. If a list
300 // is empty, return "init".
301 static std::string
combineBinary(ArrayRef
<std::string
> children
,
302 const std::string
&combiner
,
304 if (children
.empty())
307 auto size
= children
.size();
309 return children
.front();
312 llvm::raw_string_ostream
os(str
);
313 os
<< '(' << children
.front() << ')';
314 for (unsigned i
= 1; i
< size
; ++i
) {
315 os
<< ' ' << combiner
<< " (" << children
[i
] << ')';
320 // Prepend negation to the only condition in the predicate expression list.
321 static std::string
combineNot(ArrayRef
<std::string
> children
) {
322 assert(children
.size() == 1 && "expected exactly one child predicate of Neg");
323 return (Twine("!(") + children
.front() + Twine(')')).str();
326 // Recursively traverse the predicate tree in depth-first post-order and build
327 // the final expression.
328 static std::string
getCombinedCondition(const PredNode
&root
) {
329 // Immediately return for non-combiner predicates that don't have children.
330 if (root
.kind
== PredCombinerKind::Leaf
)
332 if (root
.kind
== PredCombinerKind::True
)
334 if (root
.kind
== PredCombinerKind::False
)
337 // Recurse into children.
338 llvm::SmallVector
<std::string
, 4> childExpressions
;
339 childExpressions
.reserve(root
.children
.size());
340 for (const auto &child
: root
.children
)
341 childExpressions
.push_back(getCombinedCondition(*child
));
343 // Combine the expressions based on the predicate node kind.
344 if (root
.kind
== PredCombinerKind::And
)
345 return combineBinary(childExpressions
, "&&", "true");
346 if (root
.kind
== PredCombinerKind::Or
)
347 return combineBinary(childExpressions
, "||", "false");
348 if (root
.kind
== PredCombinerKind::Not
)
349 return combineNot(childExpressions
);
350 if (root
.kind
== PredCombinerKind::Concat
) {
351 assert(childExpressions
.size() == 1 &&
352 "ConcatPred should only have one child");
353 return root
.prefix
+ childExpressions
.front() + root
.suffix
;
356 // Substitutions were applied before so just ignore them.
357 if (root
.kind
== PredCombinerKind::SubstLeaves
) {
358 assert(childExpressions
.size() == 1 &&
359 "substitution predicate must have one child");
360 return childExpressions
[0];
363 llvm::PrintFatalError(root
.predicate
->getLoc(), "unsupported predicate kind");
366 std::string
CombinedPred::getConditionImpl() const {
367 SpecificBumpPtrAllocator
<PredNode
> allocator
;
368 auto *predicateTree
= buildPredicateTree(*this, allocator
, {});
370 propagateGroundTruth(predicateTree
,
371 /*knownTruePreds=*/llvm::SmallPtrSet
<Pred
*, 2>(),
372 /*knownFalsePreds=*/llvm::SmallPtrSet
<Pred
*, 2>());
374 return getCombinedCondition(*predicateTree
);
377 StringRef
SubstLeavesPred::getPattern() const {
378 return def
->getValueAsString("pattern");
381 StringRef
SubstLeavesPred::getReplacement() const {
382 return def
->getValueAsString("replacement");
385 StringRef
ConcatPred::getPrefix() const {
386 return def
->getValueAsString("prefix");
389 StringRef
ConcatPred::getSuffix() const {
390 return def
->getValueAsString("suffix");