[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-linalg-ods-gen / mlir-linalg-ods-yaml-gen.cpp
blob7311cdd39d07558bce51862fdb80ea156389cb9e
1 //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an ODS (and C++) generator from a YAML form
10 // derived from the mathematical expression of linalg named ops. Typically a
11 // math oriented DSL will be used to export the essential representation to
12 // this form, and maintaining the SOT at the math level (versus recreating it
13 // in MLIR) is deemed to have systemic value.
15 //===----------------------------------------------------------------------===//
17 #include "mlir/AsmParser/AsmParser.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/ToolOutputFile.h"
28 #include "llvm/Support/YAMLTraits.h"
29 #include <optional>
31 using namespace mlir;
33 using llvm::yaml::Input;
34 using llvm::yaml::MappingTraits;
35 using llvm::yaml::ScalarEnumerationTraits;
36 using llvm::yaml::ScalarTraits;
38 #define DEBUG_TYPE "linalg-ods-gen"
40 //===----------------------------------------------------------------------===//
41 // Mapping structs (correspond to data types in the YAML description).
42 // TODO: Since this is a schema/part of the contract, it should be moved to
43 // a real header.
44 //===----------------------------------------------------------------------===//
46 namespace {
48 struct LinalgYAMLContext {
49 MLIRContext *mlirContext;
52 struct LinalgOpMetadata {
53 std::string name;
54 std::string cppClassName;
55 std::optional<std::string> doc;
56 SmallVector<std::string> implements;
57 SmallVector<std::string> defines;
60 struct SerializedAffineMap {
61 AffineMapAttr affineMapAttr;
63 AffineMap affineMap() { return affineMapAttr.getValue(); }
66 enum class LinalgOperandDefKind {
67 InputTensor,
68 Scalar,
69 OutputTensor,
70 IndexAttr,
71 UnaryFnAttr,
72 BinaryFnAttr,
73 TernaryFnAttr,
74 TypeFnAttr
77 struct LinalgOperandDef {
78 std::string name;
79 LinalgOperandDefKind kind;
80 std::optional<std::string> typeVar;
81 std::optional<SerializedAffineMap> shapeMap;
82 std::optional<SerializedAffineMap> indexAttrMap;
83 std::optional<SmallVector<int64_t>> defaultIndices;
84 std::optional<std::string> defaultFn;
87 enum class LinalgIteratorTypeDef {
88 parallel,
89 reduction,
92 struct LinalgIndexingMapsConfig {
93 std::optional<SmallVector<SerializedAffineMap>> staticIndexingMaps;
96 struct ScalarExpression;
98 enum class ScalarFnKind { Unary, Binary, Ternary, Type };
100 struct ScalarFn {
101 ScalarFnKind kind;
102 std::optional<std::string> fnName;
103 std::optional<std::string> attrName;
104 std::optional<std::string> typeVar;
105 // NOTE: This must be of arity 1, but to break the self-referential cycle,
106 // we use a heap allocated vector.
107 std::vector<ScalarExpression> operands;
110 struct ScalarExpression {
111 std::optional<std::string> arg;
112 std::optional<std::string> constant;
113 std::optional<int64_t> index;
114 std::optional<ScalarFn> scalarFn;
117 struct ScalarAssign {
118 std::string arg;
119 ScalarExpression value;
122 struct LinalgStructuredOpConfig {
123 SmallVector<LinalgOperandDef> args;
124 LinalgIndexingMapsConfig indexingMaps;
125 SmallVector<LinalgIteratorTypeDef> iteratorTypes;
126 std::vector<ScalarAssign> assignments;
129 struct LinalgOpConfig {
130 std::optional<LinalgOpMetadata> metadata;
131 std::optional<LinalgStructuredOpConfig> structuredOp;
134 } // namespace
136 //===----------------------------------------------------------------------===//
137 // Mapping traits.
138 //===----------------------------------------------------------------------===//
140 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef)
141 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap)
142 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef)
143 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign)
144 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression)
145 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig)
147 namespace llvm {
148 namespace yaml {
150 /// Top-level type containing op metadata and one of a concrete op type.
151 /// Currently, the only defined op type is `structured_op` (maps to
152 /// `LinalgStructuredOpConfig`).
153 template <>
154 struct MappingTraits<LinalgOpConfig> {
155 static void mapping(IO &io, LinalgOpConfig &info) {
156 io.mapOptional("metadata", info.metadata);
157 io.mapOptional("structured_op", info.structuredOp);
161 /// A structured op models (at most) a single contraction by modeling
162 /// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
163 /// outputs, or index attributes.
164 /// - List of indexing maps (see `LinalgIndexingMaps`).
165 /// - Iterator types (see `LinalgIteratorTypeDef`).
166 /// - List of scalar level assignment (see `ScalarAssign`).
167 template <>
168 struct MappingTraits<LinalgStructuredOpConfig> {
169 static void mapping(IO &io, LinalgStructuredOpConfig &info) {
170 io.mapRequired("args", info.args);
171 io.mapRequired("indexing_maps", info.indexingMaps);
172 io.mapRequired("iterator_types", info.iteratorTypes);
173 io.mapRequired("assignments", info.assignments);
177 /// Maps a named tensor, scalar or attribute argument to an operation,
178 /// consisting of:
179 /// - `name`: Must be unique within the operation.
180 /// - `usage`: How the argument is used (input, output, attribute, etc).
181 /// - `type_var`: The symbolic type variable that binds to the element or self
182 /// type of the tensor or scalar argument, respectively.
183 /// - `shape_map`: An optional AffineMap from all op symbols to the shape of
184 /// the argument. Only tensor arguments have a `shape_map`. Each shape must
185 /// be normalized over the same list of symbols and have no dimension
186 /// inputs.
187 /// - `index_attr_map`: An optional AffineMap from all op symbols to the
188 /// index attribute symbols. During op creation these symbols are replaced
189 /// by the corresponding `name` index attribue values. Only index attribute
190 /// arguments have an `index_attr_map`.
191 /// - `default_indices`: An optional default initialization for index
192 /// attribute arguments.
193 /// - `default_fn`: An optional default initialization for function attribute
194 /// arguments.
195 template <>
196 struct MappingTraits<LinalgOperandDef> {
197 static void mapping(IO &io, LinalgOperandDef &info) {
198 io.mapRequired("name", info.name);
199 io.mapRequired("kind", info.kind);
200 io.mapOptional("type_var", info.typeVar);
201 io.mapOptional("shape_map", info.shapeMap);
202 io.mapOptional("index_attr_map", info.indexAttrMap);
203 io.mapOptional("default_indices", info.defaultIndices);
204 io.mapOptional("default_fn", info.defaultFn);
208 /// Usage enum for a named argument.
209 template <>
210 struct ScalarEnumerationTraits<LinalgOperandDefKind> {
211 static void enumeration(IO &io, LinalgOperandDefKind &value) {
212 io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor);
213 io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
214 io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
215 io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
216 io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
217 io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
218 io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr);
219 io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
223 /// Iterator type enum.
224 template <>
225 struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
226 static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
227 io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
228 io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
232 /// Metadata about the op (name, C++ name, and documentation).
233 template <>
234 struct MappingTraits<LinalgOpMetadata> {
235 static void mapping(IO &io, LinalgOpMetadata &info) {
236 io.mapRequired("name", info.name);
237 io.mapRequired("cpp_class_name", info.cppClassName);
238 io.mapOptional("doc", info.doc);
239 io.mapOptional("implements", info.implements);
240 io.mapOptional("defines", info.defines);
244 /// How the ops indexing maps are produced. Must be one of:
245 /// - static_indexing_maps: A static list of AffineMaps, possibly with
246 /// some symbols that bind to attributes of the op. Each indexing map must
247 /// be normalized over the same list of dimensions, and its symbols must
248 /// match the symbols for argument shapes.
249 template <>
250 struct MappingTraits<LinalgIndexingMapsConfig> {
251 static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
252 io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
256 /// Models an assignment to a named output.
257 /// - The `arg` name must match a named output.
258 /// - The `value` is a scalar expression for computing the value to
259 /// assign (see `ScalarExpression`).
260 template <>
261 struct MappingTraits<ScalarAssign> {
262 static void mapping(IO &io, ScalarAssign &info) {
263 io.mapRequired("arg", info.arg);
264 io.mapRequired("value", info.value);
268 /// A scalar expression (RHS of an assignment). Must be one of:
269 /// - `scalar_arg`: An operation argument.
270 /// - `scalar_const`: A constant definition.
271 /// - `scalar_index`: An iteration index.
272 /// - `scalar_fn`: A named function (see `ScalarFn`).
273 template <>
274 struct MappingTraits<ScalarExpression> {
275 static void mapping(IO &io, ScalarExpression &info) {
276 io.mapOptional("scalar_arg", info.arg);
277 io.mapOptional("scalar_const", info.constant);
278 io.mapOptional("scalar_index", info.index);
279 io.mapOptional("scalar_fn", info.scalarFn);
283 /// Scalar function kind enum.
284 template <>
285 struct ScalarEnumerationTraits<ScalarFnKind> {
286 static void enumeration(IO &io, ScalarFnKind &value) {
287 io.enumCase(value, "unary", ScalarFnKind::Unary);
288 io.enumCase(value, "binary", ScalarFnKind::Binary);
289 io.enumCase(value, "ternary", ScalarFnKind::Ternary);
290 io.enumCase(value, "type", ScalarFnKind::Type);
294 /// A scalar expression that evaluates a named function.
295 /// Functions are generally "math" level and type polymorphic. Builtin
296 /// functions include:
297 /// - `add(lhs, rhs)`
298 /// - `mul(lhs, rhs)`
299 template <>
300 struct MappingTraits<ScalarFn> {
301 static void mapping(IO &io, ScalarFn &info) {
302 io.mapRequired("kind", info.kind);
303 io.mapOptional("fn_name", info.fnName);
304 io.mapOptional("attr_name", info.attrName);
305 io.mapOptional("type_var", info.typeVar);
306 io.mapRequired("operands", info.operands);
310 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
311 /// the same.
312 template <>
313 struct ScalarTraits<SerializedAffineMap> {
314 static void output(const SerializedAffineMap &value, void *rawYamlContext,
315 raw_ostream &out) {
316 assert(value.affineMapAttr);
317 value.affineMapAttr.print(out);
319 static StringRef input(StringRef scalar, void *rawYamlContext,
320 SerializedAffineMap &value) {
321 assert(rawYamlContext);
322 auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
323 if (auto attr = dyn_cast_or_null<AffineMapAttr>(
324 mlir::parseAttribute(scalar, yamlContext->mlirContext)))
325 value.affineMapAttr = attr;
326 else if (!value.affineMapAttr || !isa<AffineMapAttr>(value.affineMapAttr))
327 return "could not parse as an affine map attribute";
328 return StringRef();
330 static QuotingType mustQuote(StringRef) { return QuotingType::None; }
333 } // namespace yaml
334 } // namespace llvm
336 namespace {
338 //===----------------------------------------------------------------------===//
339 // Generation utilities
340 //===----------------------------------------------------------------------===//
342 class GenerationContext {
343 public:
344 GenerationContext(MLIRContext *context, raw_ostream *odsOut,
345 raw_ostream *defnOut)
346 : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut),
347 defnOut(defnOut) {}
349 MLIRContext *getContext() { return context; }
351 void setLoc(Location loc) { this->loc = loc; }
352 Location getLoc() { return loc; }
354 bool shouldGenerateOds() { return odsOut; }
355 bool shouldGenerateDefns() { return defnOut; }
357 raw_ostream &odss() {
358 assert(odsOut && "ODS stream not defined");
359 return *odsOut;
362 raw_ostream &defns() {
363 assert(defnOut && "Definition stream not defined");
364 return *defnOut;
367 private:
368 MLIRContext *context;
369 Location loc;
370 raw_ostream *odsOut;
371 raw_ostream *defnOut;
374 } // namespace
376 static std::string generateCppExpression(SerializedAffineMap self,
377 StringRef contextName) {
378 std::string printedStr;
379 llvm::raw_string_ostream printedSs(printedStr);
380 self.affineMapAttr.print(printedSs);
381 printedSs.flush();
383 static const char exprFormat[] =
384 R"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT";
385 return llvm::formatv(exprFormat, printedStr, contextName);
388 template <typename Container>
389 static std::string interleaveToString(Container &container,
390 StringRef separator) {
391 std::string result;
392 llvm::raw_string_ostream ss(result);
393 llvm::interleave(container, ss, separator);
394 ss.flush();
395 return result;
398 static std::optional<int>
399 findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) {
400 for (const auto &it : llvm::enumerate(args)) {
401 if (it.value().name == name)
402 return it.index();
404 return std::nullopt;
407 // Try to map the TypeVar to a predefined or an argument type.
408 static std::optional<std::string>
409 findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
410 // Handle all predefined types.
411 if (typeVar == "I32")
412 return std::string("helper.getIntegerType(32)");
413 if (typeVar == "I64")
414 return std::string("helper.getIntegerType(64)");
415 if (typeVar == "F32")
416 return std::string("helper.getFloat32Type()");
417 if (typeVar == "F64")
418 return std::string("helper.getFloat64Type()");
420 // Search all argument types.
421 for (const auto &it : llvm::enumerate(args)) {
422 if (it.value().kind != LinalgOperandDefKind::InputTensor &&
423 it.value().kind != LinalgOperandDefKind::Scalar &&
424 it.value().kind != LinalgOperandDefKind::OutputTensor)
425 continue;
426 if (*it.value().typeVar == typeVar)
427 return llvm::formatv("block.getArgument({0}).getType()", it.index())
428 .str();
431 return std::nullopt;
434 static ScalarAssign *findAssignment(StringRef name,
435 std::vector<ScalarAssign> &assignments) {
436 for (auto &assign : assignments) {
437 if (assign.arg == name)
438 return &assign;
440 return nullptr;
443 // Return true if the operand is a function attribute.
444 static bool isFunctionAttribute(LinalgOperandDefKind kind) {
445 return kind == LinalgOperandDefKind::UnaryFnAttr ||
446 kind == LinalgOperandDefKind::BinaryFnAttr ||
447 kind == LinalgOperandDefKind::TernaryFnAttr ||
448 kind == LinalgOperandDefKind::TypeFnAttr;
451 // Return true if the operand is an attribute.
452 static bool isAttribute(LinalgOperandDefKind kind) {
453 return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
456 // Get the enum name for the given operand kind.
457 std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
458 switch (kind) {
459 case LinalgOperandDefKind::UnaryFnAttr:
460 return std::string("UnaryFn");
461 case LinalgOperandDefKind::BinaryFnAttr:
462 return std::string("BinaryFn");
463 case LinalgOperandDefKind::TernaryFnAttr:
464 return std::string("TernaryFn");
465 case LinalgOperandDefKind::TypeFnAttr:
466 return std::string("TypeFn");
467 default:
468 break;
470 llvm_unreachable("unsupported function attribute kind");
473 // Get the enum name for the given function kind.
474 std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
475 switch (kind) {
476 case ScalarFnKind::Unary:
477 return std::string("UnaryFn");
478 case ScalarFnKind::Binary:
479 return std::string("BinaryFn");
480 case ScalarFnKind::Ternary:
481 return std::string("TernaryFn");
482 case ScalarFnKind::Type:
483 return std::string("TypeFn");
485 llvm_unreachable("unsupported function kind");
488 //===----------------------------------------------------------------------===//
489 // Templates
490 //===----------------------------------------------------------------------===//
492 // A single line banner format. Parameters:
493 // {0}: Single line comment
494 static const char bannerFormat[] = R"FMT(
495 //===----------------------------------------------------------------------===//
496 // {0}
497 //===----------------------------------------------------------------------===//
498 )FMT";
500 //===----------------------------------------------------------------------===//
501 // Named generic op generation.
502 // These ops map at most a single contraction that complies with the limitations
503 // of a linalg.generic.
504 //===----------------------------------------------------------------------===//
506 // Template for Linalg named ops' ODS definitions. Parameters:
507 // {0}: ODS/C++ op name
508 // {1}: assembly op mnemonic
509 // {2}: op interface list
510 // {3}: documentation (summary + description)
511 // {4}: op attribute list
512 // {5}: builder methods taking standalone attribute parameters
513 // {6}: additional method defintions
514 // {7}: additional methods for attributes used by indexing maps
515 static const char structuredOpOdsHeaderFormat[] = R"FMT(
516 //===----------------------------------------------------------------------===//
517 // Op definition for {0}
518 //===----------------------------------------------------------------------===//
520 def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
521 /*extraInterfaces=*/[{2}])> {
523 let arguments = (ins
524 Variadic<AnyType>:$inputs,
525 Variadic<AnyShaped>:$outputs{4}
527 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
528 let regions = (region AnyRegion:$region);
530 let skipDefaultBuilders = 1;
531 let builders = [
532 OpBuilder<
533 (ins "ValueRange":$inputs, "ValueRange":$outputs,
534 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
536 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
537 attributes, {0}::getRegionBuilder());
538 }]>,
539 OpBuilder<
540 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
541 "ValueRange":$outputs,
542 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
544 buildStructuredOp($_builder, $_state, resultTensorTypes,
545 inputs, outputs, attributes, {0}::getRegionBuilder());
546 }]>,
547 OpBuilder<
548 (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
549 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
551 $_state.addOperands(operands);
552 $_state.addAttributes(attributes);
553 $_state.addTypes(resultTensorTypes);
554 (void)$_state.addRegion();
558 let hasCustomAssemblyFormat = 1;
559 let hasFolder = 1;
562 let extraClassDeclaration = structuredOpsBaseDecls # [{{
563 // Auto-generated.
564 SmallVector<utils::IteratorType> getIteratorTypesArray();
565 ArrayAttr getIndexingMaps();
566 static void regionBuilder(ImplicitLocOpBuilder &b,
567 Block &block, ArrayRef<NamedAttribute> attrs);
568 static std::function<void(ImplicitLocOpBuilder &,
569 Block &, ArrayRef<NamedAttribute>)>
570 getRegionBuilder() {{
571 return regionBuilder;
574 ::mlir::MutableOperandRange getDpsInitsMutable() {{
575 return getOutputsMutable();
578 // Generic methods.
579 static unsigned getNumRegionArgs();
580 std::string getLibraryCallName();
584 )FMT";
586 // Builder method taking attribute parameters. Parameters:
587 // {0}: Class name
588 // {1}: Comma interleaved attribute parameters
589 // {2}: Attribute initialization
590 static const char structuredOpBuilderFormat[] = R"FMT(
591 , OpBuilder<
592 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
593 "ValueRange":$outputs, {1},
594 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
597 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
598 attributes, {0}::getRegionBuilder());
600 )FMT";
602 // The getIteratorTypesArray() method for structured ops. Parameters:
603 // {0}: Class name
604 // {1}: Comma interleaved iterator type names.
605 static const char structuredOpIteratorTypesFormat[] =
606 R"FMT(
607 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
608 return SmallVector<utils::IteratorType>{{ {1} };
610 )FMT";
612 // The getIteratorTypesArray() method for rank polymorphic structured ops.
613 // Parameters:
614 // {0}: Class name
615 static const char rankPolyStructuredOpIteratorTypesFormat[] =
616 R"FMT(
617 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
618 int64_t rank = getRank(getDpsInitOperand(0));
619 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
621 )FMT";
623 // The indexing_maps() method for structured ops. Parameters:
624 // {0}: Class name
625 // {1}: Comma-separated list of dimension variable names.
626 // {2}: Statements
627 static const char structuredOpIndexingMapsFormat[] = R"FMT(
628 ArrayAttr {0}::getIndexingMaps() {{
629 static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
630 ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
631 if (cached)
632 return cached;
634 MLIRContext *context = getContext();
635 auto symbolBindings = getSymbolBindings(*this);
636 SmallVector<AffineMap> maps;
638 cached = Builder(context).getAffineMapArrayAttr(maps);
639 getOperation()->setAttr(memoizeAttr, cached);
640 return cached;
642 )FMT";
644 // The indexing_maps() method for rank polymorphic structured ops. Parameters:
645 // {0}: Class name
646 static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT(
647 ArrayAttr {0}::getIndexingMaps() {{
648 MLIRContext *context = getContext();
649 AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
650 AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
651 getNumParallelLoops(), context);
652 SmallVector<AffineMap> indexingMaps;
653 for (OpOperand &opOperand : getOperation()->getOpOperands())
654 indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
655 return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
657 )FMT";
659 // Implementations of fold and getEffects.
660 // Parameters:
661 // {0}: Class name
662 const char structuredOpFoldersFormat[] = R"FMT(
663 LogicalResult {0}::fold(FoldAdaptor,
664 SmallVectorImpl<OpFoldResult> &) {{
665 return memref::foldMemRefCast(*this);
667 void {0}::getEffects(SmallVectorImpl<
668 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
669 if (hasPureTensorSemantics()) return;
670 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
672 )FMT";
674 // Implementation of parse/print.
675 // Parameters:
676 // {0}: Class name
677 static const char structuredOpParserFormat[] = R"FMT(
678 ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
679 return ::parseNamedStructuredOp(parser, result,
680 {0}::getNumRegionArgs(), {0}::getRegionBuilder());
682 void {0}::print(OpAsmPrinter &p) {{
683 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
685 )FMT";
687 static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
688 GenerationContext &genContext) {
689 if (!genContext.shouldGenerateOds())
690 return success();
692 raw_ostream &os = genContext.odss();
694 std::string interfaceNameList;
695 std::string attrList;
696 std::string attrMethods;
697 std::string attrBuilder;
699 std::string doc;
700 if (opConfig.metadata->doc) {
701 static const char structuredOpDocFmt[] = R"FMT(
702 let summary = [{{{0}}];
703 let description = [{{{1}}];
704 )FMT";
705 StringRef summary, description;
706 std::tie(summary, description) =
707 StringRef(*opConfig.metadata->doc).trim().split("\n\n");
709 doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim());
712 interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
714 std::string definitionList;
715 for (const std::string &definition : opConfig.metadata->defines) {
716 static const char definitionFmt[] = "let {0} = 1;\n";
717 definitionList.append(llvm::formatv(definitionFmt, definition));
720 if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
721 return isAttribute(arg.kind);
722 })) {
723 SmallVector<std::string> attrDefs;
724 SmallVector<std::string> attrParams;
725 SmallVector<std::string> attrStmts;
726 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
727 static const char paramFmt[] = "\"Attribute\":${0}";
728 static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
729 // Add the type conversion attributes to the op definition and builders.
730 if (isFunctionAttribute(arg.kind)) {
731 assert(arg.defaultFn);
732 std::string enumName = convertOperandKindToEnumName(arg.kind);
733 static const char typeFmt[] = "{0}::{1}";
734 static const char defFmt[] =
735 "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}";
736 attrDefs.push_back(llvm::formatv(
737 defFmt, llvm::formatv("{0}Attr", enumName),
738 llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name));
739 attrParams.push_back(llvm::formatv(paramFmt, arg.name));
740 attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
742 // Add the index attributes to the op definition and builders.
743 if (arg.kind == LinalgOperandDefKind::IndexAttr) {
744 assert(arg.indexAttrMap.has_value());
745 assert(arg.defaultIndices.has_value());
746 size_t size = arg.indexAttrMap->affineMap().getNumResults();
747 assert(arg.defaultIndices->size() == size);
748 static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
749 static const char defFmt[] =
750 "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}";
751 std::string defaultVals;
752 llvm::raw_string_ostream ss(defaultVals);
753 llvm::interleave(
754 *arg.defaultIndices, ss,
755 [&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
756 ", ");
757 attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
758 ss.str(), arg.name));
759 attrParams.push_back(llvm::formatv(paramFmt, arg.name));
760 attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
763 if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
764 return arg.kind == LinalgOperandDefKind::IndexAttr;
765 })) {
766 attrMethods = R"(
767 bool hasDynamicIndexingMaps();
768 LogicalResult verifyIndexingMapRequiredAttributes();
771 attrList = ",\n" + llvm::join(attrDefs, ",\n");
772 attrBuilder = llvm::formatv(
773 structuredOpBuilderFormat, opConfig.metadata->cppClassName,
774 llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
777 os << llvm::formatv(structuredOpOdsHeaderFormat,
778 opConfig.metadata->cppClassName, opConfig.metadata->name,
779 interfaceNameList, doc, attrList, attrBuilder,
780 definitionList, attrMethods);
782 return success();
785 static LogicalResult
786 generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
787 GenerationContext &genContext) {
788 if (!genContext.shouldGenerateDefns())
789 return success();
791 raw_ostream &os = genContext.defns();
792 StringRef className = opConfig.metadata->cppClassName;
794 // Implementation banner.
795 std::string bannerComment = llvm::formatv("Implementation of {0}", className);
796 os << llvm::formatv(bannerFormat, bannerComment);
798 // Compute the number of scalar and tensor arguments.
799 int64_t numOfArgs =
800 llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
801 return arg.kind == LinalgOperandDefKind::InputTensor ||
802 arg.kind == LinalgOperandDefKind::Scalar ||
803 arg.kind == LinalgOperandDefKind::OutputTensor;
806 // An operation that accesses only scalars and scalar/rank zero tensors is
807 // rank polymorhpic. We implement rank polymorphism by generating different
808 // indexing maps and iterators that match the rank of the first output tensor.
809 // An operation is rank polymorphic if the iteration domain has rank zero.
810 bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty();
812 // Generate the iterator_types() method.
813 if (!isRankPolymorphic) {
814 std::string iteratorsStr;
815 llvm::raw_string_ostream ss(iteratorsStr);
816 llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss,
817 [&](LinalgIteratorTypeDef it) {
818 switch (it) {
819 case LinalgIteratorTypeDef::parallel:
820 ss << "utils::IteratorType::parallel";
821 break;
822 case LinalgIteratorTypeDef::reduction:
823 ss << "utils::IteratorType::reduction";
824 break;
827 ss.flush();
828 os << llvm::formatv(structuredOpIteratorTypesFormat, className,
829 iteratorsStr);
830 } else {
831 os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className);
834 // Generating the getIndexingMaps() method.
835 if (auto &staticMaps =
836 opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
837 if (staticMaps->empty())
838 return emitError(genContext.getLoc()) << "op has no indexing maps";
839 if (!isRankPolymorphic) {
840 AffineMap firstMap = staticMaps->front().affineMap();
842 // Symbol bindings.
844 // For each symbol, generate a declaration for it, either with an
845 // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
846 // an attribute).
847 // TODO: Possibly lift into a top-level method.
848 static const char structuredOpSymbolBindingsFormat[] = R"FMT(
849 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
850 MLIRContext *context = self.getContext();
851 SmallVector<AffineExpr> exprs;
853 return exprs;
855 )FMT";
857 unsigned symbolCount = firstMap.getNumSymbols();
858 SmallVector<std::string> symbolBindings;
859 for (unsigned i = 0; i < symbolCount; ++i) {
860 symbolBindings.push_back(llvm::formatv(
861 " exprs.push_back(getAffineSymbolExpr({0}, context));", i));
864 // Access an index attribute. Parameters:
865 // {0}: Attribute name
866 // {1}: Symbol position
867 // {2}: Attribute index
868 static const char structuredOpAccessAttrFormat[] = R"FMT(
869 int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}];
870 exprs.push_back(getAffineConstantExpr(cst{1}, context));
871 )FMT";
872 // Update all symbol bindings mapped to an attribute.
873 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
874 if (arg.kind != LinalgOperandDefKind::IndexAttr)
875 continue;
876 assert(arg.indexAttrMap);
877 for (auto [idx, result] :
878 llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
879 if (auto symbol = dyn_cast<AffineSymbolExpr>(result)) {
880 std::string argName = arg.name;
881 argName[0] = toupper(argName[0]);
882 symbolBindings[symbol.getPosition()] =
883 llvm::formatv(structuredOpAccessAttrFormat, argName,
884 symbol.getPosition(), idx);
889 std::string symbolBindingsStr;
890 llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
891 llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
892 symbolBindingsSs.flush();
894 os << llvm::formatv(structuredOpSymbolBindingsFormat, className,
895 symbolBindingsStr);
898 // Indexing maps.
900 unsigned dimCount = firstMap.getNumDims();
902 // Generate a comma-separated list of dim identifiers to be passed to
903 // bindDims, ensuring tht AffineExpr identifiers are bound in the right
904 // order to the proper AffineDimExpr.
905 // This results in vars in scope like: d0, d1, d2...
906 SmallVector<unsigned> dimIndices;
907 for (unsigned i = 0; i < dimCount; ++i)
908 dimIndices.push_back(i);
909 std::string dimIdentsStr;
910 llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
911 llvm::interleaveComma(dimIndices, dimIdentsSs,
912 [&](unsigned i) { dimIdentsSs << "d" << i; });
913 dimIdentsSs.flush();
915 // Statements to add and simplify each affine map.
916 SmallVector<std::string> stmts;
917 for (auto &indexingMap : *staticMaps) {
918 // TODO: Assert that dim and symbol count match the first.
919 stmts.push_back(
920 llvm::formatv("maps.push_back({0});",
921 generateCppExpression(indexingMap, "context")));
922 stmts.push_back(llvm::formatv(
923 "maps.back() = "
924 "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
925 "symbolBindings, {0}, 0));",
926 dimCount));
929 // TODO: This needs to be memoized and/or converted to non-parser based
930 // C++ codegen prior to real use.
931 os << llvm::formatv(structuredOpIndexingMapsFormat, className,
932 dimIdentsStr, interleaveToString(stmts, "\n "));
934 } else {
935 os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className);
937 } else {
938 return emitError(genContext.getLoc())
939 << "generating code for non static indexing maps not currently "
940 "supported";
943 // getNumRegionArgs()
945 // Generates a getNumRegionArgs() method. Parameters:
946 // {0}: Class name
947 // {1}: Number of region args
948 static const char structuredOpGetNumRegionArgsFormat[] = R"FMT(
949 unsigned {0}::getNumRegionArgs() {{ return {1}; }
950 )FMT";
951 os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className,
952 numOfArgs);
955 // getLibraryCallName()
957 // Generates a getLibraryCallName method. Parameters:
958 // {0}: Class name
959 static const char structuredOpGetLibraryCallFormat[] = R"FMT(
960 std::string {0}::getLibraryCallName() {{
961 return generateLibraryCallName(getOperation());
963 )FMT";
964 os << llvm::formatv(structuredOpGetLibraryCallFormat, className);
967 // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
968 if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
969 return arg.kind == LinalgOperandDefKind::IndexAttr;
970 })) {
971 std::vector<std::string> attrVerifications;
972 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
973 if (arg.kind != LinalgOperandDefKind::IndexAttr)
974 continue;
975 assert(arg.indexAttrMap);
976 // Verify index attribute. Paramters:
977 // {0}: Attribute name
978 // {1}: Attribute size
979 static const char attrFmt[] = R"FMT(
980 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
981 if (!attr.getType().getElementType().isInteger(64))
982 return op->emitError("incorrect element type for index attribute '{0}'");
983 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
984 return op->emitError("incorrect shape for index attribute '{0}'");
986 )FMT";
987 attrVerifications.push_back(llvm::formatv(
988 attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults()));
991 // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
992 // {0}: Class name
993 // {1}: Attribute verification
994 static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT(
995 bool {0}::hasDynamicIndexingMaps() {{ return true; }
996 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
997 Operation *op = getOperation();
999 return success();
1001 )FMT";
1002 os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes,
1003 className, llvm::join(attrVerifications, "\n"));
1006 // regionBuilder()
1008 // Generates a regionBuilder method. Parameters.
1009 // {0}: Class name
1010 // {1}: Number of args
1011 // {2}: Attributes
1012 // {3}: Statements
1013 static const char structuredOpRegionBuilderFormat[] = R"FMT(
1014 void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1015 Block &block, ArrayRef<NamedAttribute> attrs) {{
1016 assert({1} > 0 && block.getNumArguments() == {1} &&
1017 "{0} regionBuilder expects {1} (>=0) args");
1018 RegionBuilderHelper helper(b, block);
1019 SmallVector<Value> yields;
1022 helper.yieldOutputs(yields);
1024 )FMT";
1025 auto &args = opConfig.structuredOp->args;
1026 auto &assignments = opConfig.structuredOp->assignments;
1027 size_t generatedAssignmentCount = 0;
1028 int localCounter = 0;
1029 SmallVector<std::string> attrs;
1030 SmallVector<std::string> stmts;
1031 for (LinalgOperandDef &arg : args) {
1032 if (!isFunctionAttribute(arg.kind))
1033 continue;
1034 // Obtain the type function attribute values. Parameters.
1035 // {0}: enum name
1036 // {1}: attribute name
1037 // {2}: default type function name
1038 static const char attrDef[] = R"FMT(
1039 {0} {1}Val = {0}::{2};
1040 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1041 return attr.getName() == "{1}"; });
1042 if ({1}Iter != attrs.end()) {{
1043 if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1044 {1}Val = attr.getValue();
1046 )FMT";
1047 std::string enumName = convertOperandKindToEnumName(arg.kind);
1048 attrs.push_back(
1049 llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn));
1051 for (LinalgOperandDef &arg : args) {
1052 if (arg.kind != LinalgOperandDefKind::OutputTensor)
1053 continue;
1055 // Find the assignment that correlates with the argument.
1056 ScalarAssign *assignment = findAssignment(arg.name, assignments);
1057 if (!assignment)
1058 return emitError(genContext.getLoc())
1059 << "no assignment found for output argument " << arg.name;
1060 ++generatedAssignmentCount;
1062 // Recursively generate the expression.
1063 std::function<std::optional<std::string>(ScalarExpression &)>
1064 generateExpression =
1065 [&](ScalarExpression &expression) -> std::optional<std::string> {
1066 if (expression.arg) {
1067 // Argument reference.
1068 std::optional<int> argIndex =
1069 findTensorDefArgIndex(*expression.arg, args);
1070 if (!argIndex) {
1071 emitError(genContext.getLoc())
1072 << "scalar argument not defined on the op: " << *expression.arg;
1073 return std::nullopt;
1075 return std::string(
1076 llvm::formatv("block.getArgument({0})", *argIndex));
1078 if (expression.constant) {
1079 std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1080 stmts.push_back(
1081 llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT",
1082 cppIdent, expression.constant));
1083 return cppIdent;
1085 if (expression.index) {
1086 // Access an iteration index.
1087 std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1088 stmts.push_back(llvm::formatv("Value {0} = helper.index({1});",
1089 cppIdent, *expression.index));
1090 return cppIdent;
1092 if (expression.scalarFn) {
1093 std::string enumName =
1094 convertFunctionKindToEnumName(expression.scalarFn->kind);
1096 // Get the function or attribute name.
1097 assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
1098 std::string funcType;
1099 if (expression.scalarFn->fnName) {
1100 funcType = llvm::formatv("{0}::{1}", enumName,
1101 *expression.scalarFn->fnName);
1103 if (expression.scalarFn->attrName) {
1104 if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
1105 return isFunctionAttribute(arg.kind) &&
1106 arg.name == *expression.scalarFn->attrName;
1107 })) {
1108 emitError(genContext.getLoc()) << "missing function attribute "
1109 << *expression.scalarFn->attrName;
1111 funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
1113 assert(!funcType.empty());
1115 // Add the optional type parameter to the operands.
1116 SmallVector<std::string> operandCppValues;
1117 if (expression.scalarFn->kind == ScalarFnKind::Type) {
1118 assert(expression.scalarFn->typeVar.has_value());
1119 std::optional<std::string> typeCppValue =
1120 findTypeValue(*expression.scalarFn->typeVar, args);
1121 if (!typeCppValue) {
1122 emitError(genContext.getLoc())
1123 << "type variable " << *expression.scalarFn->typeVar
1124 << ", used in a type conversion, must map to a predefined or "
1125 << "an argument type but it does not";
1126 return std::nullopt;
1128 operandCppValues.push_back(*typeCppValue);
1131 // Collect the scalar operands.
1132 for (ScalarExpression &operand : expression.scalarFn->operands) {
1133 auto operandCppValue = generateExpression(operand);
1134 if (!operandCppValue)
1135 return std::nullopt;
1136 operandCppValues.push_back(*operandCppValue);
1139 // Call the function builder.
1140 std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
1141 stmts.push_back(llvm::formatv(
1142 "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
1143 funcType, interleaveToString(operandCppValues, ", ")));
1144 return cppIdent;
1146 emitError(genContext.getLoc()) << "unknown ScalarExpression type";
1147 return std::nullopt;
1149 std::optional<std::string> cppValue =
1150 generateExpression(assignment->value);
1151 if (!cppValue)
1152 return failure();
1153 stmts.push_back(llvm::formatv("yields.push_back({0});", *cppValue));
1156 if (generatedAssignmentCount != assignments.size())
1157 return emitError(genContext.getLoc())
1158 << "mismatched number of assignments vs output arguments";
1160 os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
1161 interleaveToString(attrs, "\n "),
1162 interleaveToString(stmts, "\n "));
1165 // Parser and printer.
1166 os << llvm::formatv(structuredOpParserFormat, className);
1168 // Canonicalizers and folders.
1169 os << llvm::formatv(structuredOpFoldersFormat, className);
1171 return success();
1174 static LogicalResult generateOp(LinalgOpConfig &opConfig,
1175 GenerationContext &genContext) {
1176 // Switch on op type being generated.
1177 if (opConfig.structuredOp) {
1178 return success(
1179 succeeded(generateNamedGenericOpOds(opConfig, genContext)) &&
1180 succeeded(generateNamedGenericOpDefns(opConfig, genContext)));
1182 return emitError(genContext.getLoc()) << "unsupported operation type";
1185 //===----------------------------------------------------------------------===//
1186 // Command line options and main
1187 //===----------------------------------------------------------------------===//
1189 static llvm::cl::opt<std::string>
1190 inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
1191 llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
1193 static llvm::cl::opt<std::string>
1194 outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1195 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1197 static llvm::cl::opt<std::string>
1198 outputCppImplFilename("o-impl",
1199 llvm::cl::desc("C++ implementation file name"),
1200 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1202 int main(int argc, char **argv) {
1203 llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML");
1205 // Set up the input file.
1206 std::string errorMessage;
1207 std::unique_ptr<llvm::MemoryBuffer> file =
1208 mlir::openInputFile(inputFilename, &errorMessage);
1209 if (!file) {
1210 llvm::errs() << errorMessage << "\n";
1211 return 1;
1214 MLIRContext mlirContext;
1215 LinalgYAMLContext yamlContext{&mlirContext};
1217 std::vector<LinalgOpConfig> opConfigs;
1219 // Parse input.
1220 Input yin(file->getBuffer(), &yamlContext);
1221 yin >> opConfigs;
1223 if (yin.error())
1224 return 1;
1226 // Open output files.
1227 std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl;
1228 if (!outputOdsDeclFilename.empty()) {
1229 outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage);
1230 if (!outputOdsDecl) {
1231 llvm::errs() << errorMessage << "\n";
1232 return 1;
1236 std::unique_ptr<llvm::ToolOutputFile> outputCppImpl;
1237 if (!outputCppImplFilename.empty()) {
1238 outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage);
1239 if (!outputCppImpl) {
1240 llvm::errs() << errorMessage << "\n";
1241 return 1;
1245 if (!outputOdsDecl && !outputCppImpl) {
1246 llvm::errs() << "error: No output files specified\n";
1247 return 1;
1250 // Generate.
1251 GenerationContext genContext(&mlirContext,
1252 outputOdsDecl ? &outputOdsDecl->os() : nullptr,
1253 outputCppImpl ? &outputCppImpl->os() : nullptr);
1255 for (auto &opConfig : opConfigs) {
1256 if (!opConfig.metadata) {
1257 emitError(genContext.getLoc())
1258 << "missing operation metadata on subsequent op";
1259 return 1;
1262 genContext.setLoc(NameLoc::get(
1263 StringAttr::get(&mlirContext, opConfig.metadata->cppClassName)));
1264 if (failed(generateOp(opConfig, genContext))) {
1265 return 1;
1269 if (outputOdsDecl)
1270 outputOdsDecl->keep();
1271 if (outputCppImpl)
1272 outputCppImpl->keep();
1274 return 0;