1 //===- Predicate.cpp - Predicate class ------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Wrapper around predicates defined in TableGen.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/TableGen/Predicate.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
22 using namespace tblgen
;
24 // Construct a Predicate from a record.
25 Pred::Pred(const llvm::Record
*record
) : def(record
) {
26 assert(def
->isSubClassOf("Pred") &&
27 "must be a subclass of TableGen 'Pred' class");
30 // Construct a Predicate from an initializer.
31 Pred::Pred(const llvm::Init
*init
) {
32 if (const auto *defInit
= dyn_cast_or_null
<llvm::DefInit
>(init
))
33 def
= defInit
->getDef();
36 std::string
Pred::getCondition() const {
37 // Static dispatch to subclasses.
38 if (def
->isSubClassOf("CombinedPred"))
39 return static_cast<const CombinedPred
*>(this)->getConditionImpl();
40 if (def
->isSubClassOf("CPred"))
41 return static_cast<const CPred
*>(this)->getConditionImpl();
42 llvm_unreachable("Pred::getCondition must be overridden in subclasses");
45 bool Pred::isCombined() const {
46 return def
&& def
->isSubClassOf("CombinedPred");
49 ArrayRef
<SMLoc
> Pred::getLoc() const { return def
->getLoc(); }
51 CPred::CPred(const llvm::Record
*record
) : Pred(record
) {
52 assert(def
->isSubClassOf("CPred") &&
53 "must be a subclass of Tablegen 'CPred' class");
56 CPred::CPred(const llvm::Init
*init
) : Pred(init
) {
57 assert((!def
|| def
->isSubClassOf("CPred")) &&
58 "must be a subclass of Tablegen 'CPred' class");
61 // Get condition of the C Predicate.
62 std::string
CPred::getConditionImpl() const {
63 assert(!isNull() && "null predicate does not have a condition");
64 return std::string(def
->getValueAsString("predExpr"));
67 CombinedPred::CombinedPred(const llvm::Record
*record
) : Pred(record
) {
68 assert(def
->isSubClassOf("CombinedPred") &&
69 "must be a subclass of Tablegen 'CombinedPred' class");
72 CombinedPred::CombinedPred(const llvm::Init
*init
) : Pred(init
) {
73 assert((!def
|| def
->isSubClassOf("CombinedPred")) &&
74 "must be a subclass of Tablegen 'CombinedPred' class");
77 const llvm::Record
*CombinedPred::getCombinerDef() const {
78 assert(def
->getValue("kind") && "CombinedPred must have a value 'kind'");
79 return def
->getValueAsDef("kind");
82 std::vector
<llvm::Record
*> CombinedPred::getChildren() const {
83 assert(def
->getValue("children") &&
84 "CombinedPred must have a value 'children'");
85 return def
->getValueAsListOfDefs("children");
89 // Kinds of nodes in a logical predicate tree.
90 enum class PredCombinerKind
{
97 // Special kinds that are used in simplification.
102 // A node in a logical predicate tree.
104 PredCombinerKind kind
;
105 const Pred
*predicate
;
106 SmallVector
<PredNode
*, 4> children
;
109 // Prefix and suffix are used by ConcatPred.
115 // Get a predicate tree node kind based on the kind used in the predicate
117 static PredCombinerKind
getPredCombinerKind(const Pred
&pred
) {
118 if (!pred
.isCombined())
119 return PredCombinerKind::Leaf
;
121 const auto &combinedPred
= static_cast<const CombinedPred
&>(pred
);
122 return StringSwitch
<PredCombinerKind
>(
123 combinedPred
.getCombinerDef()->getName())
124 .Case("PredCombinerAnd", PredCombinerKind::And
)
125 .Case("PredCombinerOr", PredCombinerKind::Or
)
126 .Case("PredCombinerNot", PredCombinerKind::Not
)
127 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves
)
128 .Case("PredCombinerConcat", PredCombinerKind::Concat
);
132 // Substitution<pattern, replacement>.
133 using Subst
= std::pair
<StringRef
, StringRef
>;
136 /// Perform the given substitutions on 'str' in-place.
137 static void performSubstitutions(std::string
&str
,
138 ArrayRef
<Subst
> substitutions
) {
139 // Apply all parent substitutions from innermost to outermost.
140 for (const auto &subst
: llvm::reverse(substitutions
)) {
141 auto pos
= str
.find(std::string(subst
.first
));
142 while (pos
!= std::string::npos
) {
143 str
.replace(pos
, subst
.first
.size(), std::string(subst
.second
));
144 // Skip the newly inserted substring, which itself may consider the
146 pos
+= subst
.second
.size();
147 // Find the next possible match position.
148 pos
= str
.find(std::string(subst
.first
), pos
);
153 // Build the predicate tree starting from the top-level predicate, which may
154 // have children, and perform leaf substitutions inplace. Note that after
155 // substitution, nodes are still pointing to the original TableGen record.
156 // All nodes are created within "allocator".
158 buildPredicateTree(const Pred
&root
,
159 llvm::SpecificBumpPtrAllocator
<PredNode
> &allocator
,
160 ArrayRef
<Subst
> substitutions
) {
161 auto *rootNode
= allocator
.Allocate();
162 new (rootNode
) PredNode
;
163 rootNode
->kind
= getPredCombinerKind(root
);
164 rootNode
->predicate
= &root
;
165 if (!root
.isCombined()) {
166 rootNode
->expr
= root
.getCondition();
167 performSubstitutions(rootNode
->expr
, substitutions
);
171 // If the current combined predicate is a leaf substitution, append it to the
172 // list before continuing.
173 auto allSubstitutions
= llvm::to_vector
<4>(substitutions
);
174 if (rootNode
->kind
== PredCombinerKind::SubstLeaves
) {
175 const auto &substPred
= static_cast<const SubstLeavesPred
&>(root
);
176 allSubstitutions
.push_back(
177 {substPred
.getPattern(), substPred
.getReplacement()});
179 // If the current predicate is a ConcatPred, record the prefix and suffix.
180 } else if (rootNode
->kind
== PredCombinerKind::Concat
) {
181 const auto &concatPred
= static_cast<const ConcatPred
&>(root
);
182 rootNode
->prefix
= std::string(concatPred
.getPrefix());
183 performSubstitutions(rootNode
->prefix
, substitutions
);
184 rootNode
->suffix
= std::string(concatPred
.getSuffix());
185 performSubstitutions(rootNode
->suffix
, substitutions
);
188 // Build child subtrees.
189 auto combined
= static_cast<const CombinedPred
&>(root
);
190 for (const auto *record
: combined
.getChildren()) {
192 buildPredicateTree(Pred(record
), allocator
, allSubstitutions
);
193 rootNode
->children
.push_back(childTree
);
198 // Simplify a predicate tree rooted at "node" using the predicates that are
199 // known to be true(false). For AND(OR) combined predicates, if any of the
200 // children is known to be false(true), the result is also false(true).
201 // Furthermore, for AND(OR) combined predicates, children that are known to be
202 // true(false) don't have to be checked dynamically.
204 propagateGroundTruth(PredNode
*node
,
205 const llvm::SmallPtrSetImpl
<Pred
*> &knownTruePreds
,
206 const llvm::SmallPtrSetImpl
<Pred
*> &knownFalsePreds
) {
207 // If the current predicate is known to be true or false, change the kind of
208 // the node and return immediately.
209 if (knownTruePreds
.count(node
->predicate
) != 0) {
210 node
->kind
= PredCombinerKind::True
;
211 node
->children
.clear();
214 if (knownFalsePreds
.count(node
->predicate
) != 0) {
215 node
->kind
= PredCombinerKind::False
;
216 node
->children
.clear();
220 // If the current node is a substitution, stop recursion now.
221 // The expressions in the leaves below this node were rewritten, but the nodes
222 // still point to the original predicate records. While the original
223 // predicate may be known to be true or false, it is not necessarily the case
225 // TODO: we can support ground truth for rewritten
226 // predicates by either (a) having our own unique'ing of the predicates
227 // instead of relying on TableGen record pointers or (b) taking ground truth
228 // values optionally prefixed with a list of substitutions to apply, e.g.
229 // "predX is true by itself as well as predSubY leaf substitution had been
231 if (node
->kind
== PredCombinerKind::SubstLeaves
) {
235 // Otherwise, look at child nodes.
237 // Move child nodes into some local variable so that they can be optimized
238 // separately and re-added if necessary.
239 llvm::SmallVector
<PredNode
*, 4> children
;
240 std::swap(node
->children
, children
);
242 for (auto &child
: children
) {
243 // First, simplify the child. This maintains the predicate as it was.
244 auto *simplifiedChild
=
245 propagateGroundTruth(child
, knownTruePreds
, knownFalsePreds
);
247 // Just add the child if we don't know how to simplify the current node.
248 if (node
->kind
!= PredCombinerKind::And
&&
249 node
->kind
!= PredCombinerKind::Or
) {
250 node
->children
.push_back(simplifiedChild
);
254 // Second, based on the type define which known values of child predicates
255 // immediately collapse this predicate to a known value, and which others
256 // may be safely ignored.
257 // OR(..., True, ...) = True
258 // OR(..., False, ...) = OR(..., ...)
259 // AND(..., False, ...) = False
260 // AND(..., True, ...) = AND(..., ...)
261 auto collapseKind
= node
->kind
== PredCombinerKind::And
262 ? PredCombinerKind::False
263 : PredCombinerKind::True
;
264 auto eraseKind
= node
->kind
== PredCombinerKind::And
265 ? PredCombinerKind::True
266 : PredCombinerKind::False
;
267 const auto &collapseList
=
268 node
->kind
== PredCombinerKind::And
? knownFalsePreds
: knownTruePreds
;
269 const auto &eraseList
=
270 node
->kind
== PredCombinerKind::And
? knownTruePreds
: knownFalsePreds
;
271 if (simplifiedChild
->kind
== collapseKind
||
272 collapseList
.count(simplifiedChild
->predicate
) != 0) {
273 node
->kind
= collapseKind
;
274 node
->children
.clear();
277 if (simplifiedChild
->kind
== eraseKind
||
278 eraseList
.count(simplifiedChild
->predicate
) != 0) {
281 node
->children
.push_back(simplifiedChild
);
286 // Combine a list of predicate expressions using a binary combiner. If a list
287 // is empty, return "init".
288 static std::string
combineBinary(ArrayRef
<std::string
> children
,
289 const std::string
&combiner
,
291 if (children
.empty())
294 auto size
= children
.size();
296 return children
.front();
299 llvm::raw_string_ostream
os(str
);
300 os
<< '(' << children
.front() << ')';
301 for (unsigned i
= 1; i
< size
; ++i
) {
302 os
<< ' ' << combiner
<< " (" << children
[i
] << ')';
307 // Prepend negation to the only condition in the predicate expression list.
308 static std::string
combineNot(ArrayRef
<std::string
> children
) {
309 assert(children
.size() == 1 && "expected exactly one child predicate of Neg");
310 return (Twine("!(") + children
.front() + Twine(')')).str();
313 // Recursively traverse the predicate tree in depth-first post-order and build
314 // the final expression.
315 static std::string
getCombinedCondition(const PredNode
&root
) {
316 // Immediately return for non-combiner predicates that don't have children.
317 if (root
.kind
== PredCombinerKind::Leaf
)
319 if (root
.kind
== PredCombinerKind::True
)
321 if (root
.kind
== PredCombinerKind::False
)
324 // Recurse into children.
325 llvm::SmallVector
<std::string
, 4> childExpressions
;
326 childExpressions
.reserve(root
.children
.size());
327 for (const auto &child
: root
.children
)
328 childExpressions
.push_back(getCombinedCondition(*child
));
330 // Combine the expressions based on the predicate node kind.
331 if (root
.kind
== PredCombinerKind::And
)
332 return combineBinary(childExpressions
, "&&", "true");
333 if (root
.kind
== PredCombinerKind::Or
)
334 return combineBinary(childExpressions
, "||", "false");
335 if (root
.kind
== PredCombinerKind::Not
)
336 return combineNot(childExpressions
);
337 if (root
.kind
== PredCombinerKind::Concat
) {
338 assert(childExpressions
.size() == 1 &&
339 "ConcatPred should only have one child");
340 return root
.prefix
+ childExpressions
.front() + root
.suffix
;
343 // Substitutions were applied before so just ignore them.
344 if (root
.kind
== PredCombinerKind::SubstLeaves
) {
345 assert(childExpressions
.size() == 1 &&
346 "substitution predicate must have one child");
347 return childExpressions
[0];
350 llvm::PrintFatalError(root
.predicate
->getLoc(), "unsupported predicate kind");
353 std::string
CombinedPred::getConditionImpl() const {
354 llvm::SpecificBumpPtrAllocator
<PredNode
> allocator
;
355 auto *predicateTree
= buildPredicateTree(*this, allocator
, {});
357 propagateGroundTruth(predicateTree
,
358 /*knownTruePreds=*/llvm::SmallPtrSet
<Pred
*, 2>(),
359 /*knownFalsePreds=*/llvm::SmallPtrSet
<Pred
*, 2>());
361 return getCombinedCondition(*predicateTree
);
364 StringRef
SubstLeavesPred::getPattern() const {
365 return def
->getValueAsString("pattern");
368 StringRef
SubstLeavesPred::getReplacement() const {
369 return def
->getValueAsString("replacement");
372 StringRef
ConcatPred::getPrefix() const {
373 return def
->getValueAsString("prefix");
376 StringRef
ConcatPred::getSuffix() const {
377 return def
->getValueAsString("suffix");