1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/CodeGenHelpers.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/Path.h"
19 #include "llvm/TableGen/Record.h"
23 using namespace mlir::tblgen
;
25 /// Generate a unique label based on the current file name to prevent name
26 /// collisions if multiple generated files are included at once.
27 static std::string
getUniqueOutputLabel(const RecordKeeper
&records
,
29 // Use the input file name when generating a unique name.
30 std::string inputFilename
= records
.getInputFilename();
32 // Drop all but the base filename.
33 StringRef nameRef
= sys::path::filename(inputFilename
);
34 nameRef
.consume_back(".td");
36 // Sanitize any invalid characters.
37 std::string
uniqueName(tag
);
38 for (char c
: nameRef
) {
39 if (isAlnum(c
) || c
== '_')
40 uniqueName
.push_back(c
);
42 uniqueName
.append(utohexstr((unsigned char)c
));
47 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
48 raw_ostream
&os
, const RecordKeeper
&records
, StringRef tag
)
49 : os(os
), uniqueOutputLabel(getUniqueOutputLabel(records
, tag
)) {}
51 void StaticVerifierFunctionEmitter::emitOpConstraints(
52 ArrayRef
<const Record
*> opDefs
) {
53 NamespaceEmitter
namespaceEmitter(os
, Operator(*opDefs
[0]).getCppNamespace());
54 emitTypeConstraints();
55 emitAttrConstraints();
56 emitSuccessorConstraints();
57 emitRegionConstraints();
60 void StaticVerifierFunctionEmitter::emitPatternConstraints(
61 const ArrayRef
<DagLeaf
> constraints
) {
62 collectPatternConstraints(constraints
);
63 emitPatternConstraints();
66 //===----------------------------------------------------------------------===//
69 StringRef
StaticVerifierFunctionEmitter::getTypeConstraintFn(
70 const Constraint
&constraint
) const {
71 const auto *it
= typeConstraints
.find(constraint
);
72 assert(it
!= typeConstraints
.end() && "expected to find a type constraint");
76 // Find a uniqued attribute constraint. Since not all attribute constraints can
77 // be uniqued, return std::nullopt if one was not found.
78 std::optional
<StringRef
> StaticVerifierFunctionEmitter::getAttrConstraintFn(
79 const Constraint
&constraint
) const {
80 const auto *it
= attrConstraints
.find(constraint
);
81 return it
== attrConstraints
.end() ? std::optional
<StringRef
>()
82 : StringRef(it
->second
);
85 StringRef
StaticVerifierFunctionEmitter::getSuccessorConstraintFn(
86 const Constraint
&constraint
) const {
87 const auto *it
= successorConstraints
.find(constraint
);
88 assert(it
!= successorConstraints
.end() &&
89 "expected to find a sucessor constraint");
93 StringRef
StaticVerifierFunctionEmitter::getRegionConstraintFn(
94 const Constraint
&constraint
) const {
95 const auto *it
= regionConstraints
.find(constraint
);
96 assert(it
!= regionConstraints
.end() &&
97 "expected to find a region constraint");
101 //===----------------------------------------------------------------------===//
102 // Constraint Emission
104 /// Code templates for emitting type, attribute, successor, and region
105 /// constraints. Each of these templates require the following arguments:
107 /// {0}: The unique constraint name.
108 /// {1}: The constraint code.
109 /// {2}: The constraint description.
111 /// Code for a type constraint. These may be called on the type of either
112 /// operands or results.
113 static const char *const typeConstraintCode
= R
"(
114 static ::llvm::LogicalResult {0}(
115 ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
116 unsigned valueIndex) {
118 return op->emitOpError(valueKind) << " #" << valueIndex
119 << " must be {2}, but got " << type;
121 return ::mlir::success();
125 /// Code for an attribute constraint. These may be called from ops only.
126 /// Attribute constraints cannot reference anything other than `$_self` and
129 /// TODO: Unique constraints for adaptors. However, most Adaptor::verify
130 /// functions are stripped anyways.
131 static const char *const attrConstraintCode
= R
"(
132 static ::llvm::LogicalResult {0}(
133 ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
135 return emitError() << "attribute
'" << attrName
136 << "' failed to satisfy constraint
: {2}";
137 return ::mlir::success();
139 static ::llvm::LogicalResult {0}(
140 ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{
141 return {0}(attr, attrName, [op]() {{
142 return op->emitOpError();
147 /// Code for a successor constraint.
148 static const char *const successorConstraintCode
= R
"(
149 static ::llvm::LogicalResult {0}(
150 ::mlir::Operation *op, ::mlir::Block *successor,
151 ::llvm::StringRef successorName, unsigned successorIndex) {
153 return op->emitOpError("successor
#") << successorIndex << " ('"
154 << successorName << ")' failed to verify constraint: {2}";
156 return ::mlir::success();
160 /// Code for a region constraint. Callers will need to pass in the region's name
161 /// for emitting an error message.
162 static const char *const regionConstraintCode
= R
"(
163 static ::llvm::LogicalResult {0}(
164 ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName,
165 unsigned regionIndex) {
167 return op->emitOpError("region
#") << regionIndex
168 << (regionName.empty() ? " " : " ('" + regionName + "') ")
169 << "failed to verify constraint: {2}";
171 return ::mlir::success();
175 /// Code for a pattern type or attribute constraint.
177 /// {3}: "Type type" or "Attribute attr".
178 static const char *const patternAttrOrTypeConstraintCode
= R
"(
179 static ::llvm::LogicalResult {0}(
180 ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
181 ::llvm::StringRef failureStr) {
183 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
184 diag << failureStr << ": {2}";
187 return ::mlir::success();
191 void StaticVerifierFunctionEmitter::emitConstraints(
192 const ConstraintMap
&constraints
, StringRef selfName
,
193 const char *const codeTemplate
) {
195 ctx
.addSubst("_op", "*op").withSelf(selfName
);
196 for (auto &it
: constraints
) {
197 os
<< formatv(codeTemplate
, it
.second
,
198 tgfmt(it
.first
.getConditionTemplate(), &ctx
),
199 escapeString(it
.first
.getSummary()));
203 void StaticVerifierFunctionEmitter::emitTypeConstraints() {
204 emitConstraints(typeConstraints
, "type", typeConstraintCode
);
207 void StaticVerifierFunctionEmitter::emitAttrConstraints() {
208 emitConstraints(attrConstraints
, "attr", attrConstraintCode
);
211 void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
212 emitConstraints(successorConstraints
, "successor", successorConstraintCode
);
215 void StaticVerifierFunctionEmitter::emitRegionConstraints() {
216 emitConstraints(regionConstraints
, "region", regionConstraintCode
);
219 void StaticVerifierFunctionEmitter::emitPatternConstraints() {
221 ctx
.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
222 for (auto &it
: typeConstraints
) {
223 os
<< formatv(patternAttrOrTypeConstraintCode
, it
.second
,
224 tgfmt(it
.first
.getConditionTemplate(), &ctx
),
225 escapeString(it
.first
.getSummary()), "Type type");
227 ctx
.withSelf("attr");
228 for (auto &it
: attrConstraints
) {
229 os
<< formatv(patternAttrOrTypeConstraintCode
, it
.second
,
230 tgfmt(it
.first
.getConditionTemplate(), &ctx
),
231 escapeString(it
.first
.getSummary()), "Attribute attr");
235 //===----------------------------------------------------------------------===//
236 // Constraint Uniquing
238 /// An attribute constraint that references anything other than itself and the
239 /// current op cannot be generically extracted into a function. Most
240 /// prohibitive are operands and results, which require calls to
241 /// `getODSOperands` or `getODSResults`. Attribute references are tricky too
242 /// because ops use cached identifiers.
243 static bool canUniqueAttrConstraint(Attribute attr
) {
245 auto test
= tgfmt(attr
.getConditionTemplate(),
246 &ctx
.withSelf("attr").addSubst("_op", "*op"))
248 return !StringRef(test
).contains("<no-subst-found>");
251 std::string
StaticVerifierFunctionEmitter::getUniqueName(StringRef kind
,
253 return ("__mlir_ods_local_" + kind
+ "_constraint_" + uniqueOutputLabel
+
258 void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap
&map
,
260 Constraint constraint
) {
261 auto [it
, inserted
] = map
.try_emplace(constraint
);
263 it
->second
= getUniqueName(kind
, map
.size());
266 void StaticVerifierFunctionEmitter::collectOpConstraints(
267 ArrayRef
<const Record
*> opDefs
) {
268 const auto collectTypeConstraints
= [&](Operator::const_value_range values
) {
269 for (const NamedTypeConstraint
&value
: values
)
270 if (value
.hasPredicate())
271 collectConstraint(typeConstraints
, "type", value
.constraint
);
274 for (const Record
*def
: opDefs
) {
276 /// Collect type constraints.
277 collectTypeConstraints(op
.getOperands());
278 collectTypeConstraints(op
.getResults());
279 /// Collect attribute constraints.
280 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
281 if (!namedAttr
.attr
.getPredicate().isNull() &&
282 !namedAttr
.attr
.isDerivedAttr() &&
283 canUniqueAttrConstraint(namedAttr
.attr
))
284 collectConstraint(attrConstraints
, "attr", namedAttr
.attr
);
286 /// Collect successor constraints.
287 for (const NamedSuccessor
&successor
: op
.getSuccessors()) {
288 if (!successor
.constraint
.getPredicate().isNull()) {
289 collectConstraint(successorConstraints
, "successor",
290 successor
.constraint
);
293 /// Collect region constraints.
294 for (const NamedRegion
®ion
: op
.getRegions())
295 if (!region
.constraint
.getPredicate().isNull())
296 collectConstraint(regionConstraints
, "region", region
.constraint
);
300 void StaticVerifierFunctionEmitter::collectPatternConstraints(
301 const ArrayRef
<DagLeaf
> constraints
) {
302 for (auto &leaf
: constraints
) {
303 assert(leaf
.isOperandMatcher() || leaf
.isAttrMatcher());
305 leaf
.isOperandMatcher() ? typeConstraints
: attrConstraints
,
306 leaf
.isOperandMatcher() ? "type" : "attr", leaf
.getAsConstraint());
310 //===----------------------------------------------------------------------===//
311 // Public Utility Functions
312 //===----------------------------------------------------------------------===//
314 std::string
mlir::tblgen::escapeString(StringRef value
) {
316 raw_string_ostream
os(ret
);
317 os
.write_escaped(value
);