1 //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements an ODS (and C++) generator from a YAML form
10 // derived from the mathematical expression of linalg named ops. Typically a
11 // math oriented DSL will be used to export the essential representation to
12 // this form, and maintaining the SOT at the math level (versus recreating it
13 // in MLIR) is deemed to have systemic value.
15 //===----------------------------------------------------------------------===//
17 #include "mlir/AsmParser/AsmParser.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/ToolOutputFile.h"
28 #include "llvm/Support/YAMLTraits.h"
33 using llvm::yaml::Input
;
34 using llvm::yaml::MappingTraits
;
35 using llvm::yaml::ScalarEnumerationTraits
;
36 using llvm::yaml::ScalarTraits
;
38 #define DEBUG_TYPE "linalg-ods-gen"
40 //===----------------------------------------------------------------------===//
41 // Mapping structs (correspond to data types in the YAML description).
42 // TODO: Since this is a schema/part of the contract, it should be moved to
44 //===----------------------------------------------------------------------===//
48 struct LinalgYAMLContext
{
49 MLIRContext
*mlirContext
;
52 struct LinalgOpMetadata
{
54 std::string cppClassName
;
55 std::optional
<std::string
> doc
;
56 SmallVector
<std::string
> implements
;
57 SmallVector
<std::string
> defines
;
60 struct SerializedAffineMap
{
61 AffineMapAttr affineMapAttr
;
63 AffineMap
affineMap() { return affineMapAttr
.getValue(); }
66 enum class LinalgOperandDefKind
{
77 struct LinalgOperandDef
{
79 LinalgOperandDefKind kind
;
80 std::optional
<std::string
> typeVar
;
81 std::optional
<SerializedAffineMap
> shapeMap
;
82 std::optional
<SerializedAffineMap
> indexAttrMap
;
83 std::optional
<SmallVector
<int64_t>> defaultIndices
;
84 std::optional
<std::string
> defaultFn
;
87 enum class LinalgIteratorTypeDef
{
92 struct LinalgIndexingMapsConfig
{
93 std::optional
<SmallVector
<SerializedAffineMap
>> staticIndexingMaps
;
96 struct ScalarExpression
;
98 enum class ScalarFnKind
{ Unary
, Binary
, Ternary
, Type
};
102 std::optional
<std::string
> fnName
;
103 std::optional
<std::string
> attrName
;
104 std::optional
<std::string
> typeVar
;
105 // NOTE: This must be of arity 1, but to break the self-referential cycle,
106 // we use a heap allocated vector.
107 std::vector
<ScalarExpression
> operands
;
110 struct ScalarExpression
{
111 std::optional
<std::string
> arg
;
112 std::optional
<std::string
> constant
;
113 std::optional
<int64_t> index
;
114 std::optional
<ScalarFn
> scalarFn
;
117 struct ScalarAssign
{
119 ScalarExpression value
;
122 struct LinalgStructuredOpConfig
{
123 SmallVector
<LinalgOperandDef
> args
;
124 LinalgIndexingMapsConfig indexingMaps
;
125 SmallVector
<LinalgIteratorTypeDef
> iteratorTypes
;
126 std::vector
<ScalarAssign
> assignments
;
129 struct LinalgOpConfig
{
130 std::optional
<LinalgOpMetadata
> metadata
;
131 std::optional
<LinalgStructuredOpConfig
> structuredOp
;
136 //===----------------------------------------------------------------------===//
138 //===----------------------------------------------------------------------===//
140 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef
)
141 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap
)
142 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef
)
143 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign
)
144 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression
)
145 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig
)
150 /// Top-level type containing op metadata and one of a concrete op type.
151 /// Currently, the only defined op type is `structured_op` (maps to
152 /// `LinalgStructuredOpConfig`).
154 struct MappingTraits
<LinalgOpConfig
> {
155 static void mapping(IO
&io
, LinalgOpConfig
&info
) {
156 io
.mapOptional("metadata", info
.metadata
);
157 io
.mapOptional("structured_op", info
.structuredOp
);
161 /// A structured op models (at most) a single contraction by modeling
162 /// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
163 /// outputs, or index attributes.
164 /// - List of indexing maps (see `LinalgIndexingMaps`).
165 /// - Iterator types (see `LinalgIteratorTypeDef`).
166 /// - List of scalar level assignment (see `ScalarAssign`).
168 struct MappingTraits
<LinalgStructuredOpConfig
> {
169 static void mapping(IO
&io
, LinalgStructuredOpConfig
&info
) {
170 io
.mapRequired("args", info
.args
);
171 io
.mapRequired("indexing_maps", info
.indexingMaps
);
172 io
.mapRequired("iterator_types", info
.iteratorTypes
);
173 io
.mapRequired("assignments", info
.assignments
);
177 /// Maps a named tensor, scalar or attribute argument to an operation,
179 /// - `name`: Must be unique within the operation.
180 /// - `usage`: How the argument is used (input, output, attribute, etc).
181 /// - `type_var`: The symbolic type variable that binds to the element or self
182 /// type of the tensor or scalar argument, respectively.
183 /// - `shape_map`: An optional AffineMap from all op symbols to the shape of
184 /// the argument. Only tensor arguments have a `shape_map`. Each shape must
185 /// be normalized over the same list of symbols and have no dimension
187 /// - `index_attr_map`: An optional AffineMap from all op symbols to the
188 /// index attribute symbols. During op creation these symbols are replaced
189 /// by the corresponding `name` index attribue values. Only index attribute
190 /// arguments have an `index_attr_map`.
191 /// - `default_indices`: An optional default initialization for index
192 /// attribute arguments.
193 /// - `default_fn`: An optional default initialization for function attribute
196 struct MappingTraits
<LinalgOperandDef
> {
197 static void mapping(IO
&io
, LinalgOperandDef
&info
) {
198 io
.mapRequired("name", info
.name
);
199 io
.mapRequired("kind", info
.kind
);
200 io
.mapOptional("type_var", info
.typeVar
);
201 io
.mapOptional("shape_map", info
.shapeMap
);
202 io
.mapOptional("index_attr_map", info
.indexAttrMap
);
203 io
.mapOptional("default_indices", info
.defaultIndices
);
204 io
.mapOptional("default_fn", info
.defaultFn
);
208 /// Usage enum for a named argument.
210 struct ScalarEnumerationTraits
<LinalgOperandDefKind
> {
211 static void enumeration(IO
&io
, LinalgOperandDefKind
&value
) {
212 io
.enumCase(value
, "input_tensor", LinalgOperandDefKind::InputTensor
);
213 io
.enumCase(value
, "scalar", LinalgOperandDefKind::Scalar
);
214 io
.enumCase(value
, "output_tensor", LinalgOperandDefKind::OutputTensor
);
215 io
.enumCase(value
, "index_attr", LinalgOperandDefKind::IndexAttr
);
216 io
.enumCase(value
, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr
);
217 io
.enumCase(value
, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr
);
218 io
.enumCase(value
, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr
);
219 io
.enumCase(value
, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr
);
223 /// Iterator type enum.
225 struct ScalarEnumerationTraits
<LinalgIteratorTypeDef
> {
226 static void enumeration(IO
&io
, LinalgIteratorTypeDef
&value
) {
227 io
.enumCase(value
, "parallel", LinalgIteratorTypeDef::parallel
);
228 io
.enumCase(value
, "reduction", LinalgIteratorTypeDef::reduction
);
232 /// Metadata about the op (name, C++ name, and documentation).
234 struct MappingTraits
<LinalgOpMetadata
> {
235 static void mapping(IO
&io
, LinalgOpMetadata
&info
) {
236 io
.mapRequired("name", info
.name
);
237 io
.mapRequired("cpp_class_name", info
.cppClassName
);
238 io
.mapOptional("doc", info
.doc
);
239 io
.mapOptional("implements", info
.implements
);
240 io
.mapOptional("defines", info
.defines
);
244 /// How the ops indexing maps are produced. Must be one of:
245 /// - static_indexing_maps: A static list of AffineMaps, possibly with
246 /// some symbols that bind to attributes of the op. Each indexing map must
247 /// be normalized over the same list of dimensions, and its symbols must
248 /// match the symbols for argument shapes.
250 struct MappingTraits
<LinalgIndexingMapsConfig
> {
251 static void mapping(IO
&io
, LinalgIndexingMapsConfig
&info
) {
252 io
.mapOptional("static_indexing_maps", info
.staticIndexingMaps
);
256 /// Models an assignment to a named output.
257 /// - The `arg` name must match a named output.
258 /// - The `value` is a scalar expression for computing the value to
259 /// assign (see `ScalarExpression`).
261 struct MappingTraits
<ScalarAssign
> {
262 static void mapping(IO
&io
, ScalarAssign
&info
) {
263 io
.mapRequired("arg", info
.arg
);
264 io
.mapRequired("value", info
.value
);
268 /// A scalar expression (RHS of an assignment). Must be one of:
269 /// - `scalar_arg`: An operation argument.
270 /// - `scalar_const`: A constant definition.
271 /// - `scalar_index`: An iteration index.
272 /// - `scalar_fn`: A named function (see `ScalarFn`).
274 struct MappingTraits
<ScalarExpression
> {
275 static void mapping(IO
&io
, ScalarExpression
&info
) {
276 io
.mapOptional("scalar_arg", info
.arg
);
277 io
.mapOptional("scalar_const", info
.constant
);
278 io
.mapOptional("scalar_index", info
.index
);
279 io
.mapOptional("scalar_fn", info
.scalarFn
);
283 /// Scalar function kind enum.
285 struct ScalarEnumerationTraits
<ScalarFnKind
> {
286 static void enumeration(IO
&io
, ScalarFnKind
&value
) {
287 io
.enumCase(value
, "unary", ScalarFnKind::Unary
);
288 io
.enumCase(value
, "binary", ScalarFnKind::Binary
);
289 io
.enumCase(value
, "ternary", ScalarFnKind::Ternary
);
290 io
.enumCase(value
, "type", ScalarFnKind::Type
);
294 /// A scalar expression that evaluates a named function.
295 /// Functions are generally "math" level and type polymorphic. Builtin
296 /// functions include:
297 /// - `add(lhs, rhs)`
298 /// - `mul(lhs, rhs)`
300 struct MappingTraits
<ScalarFn
> {
301 static void mapping(IO
&io
, ScalarFn
&info
) {
302 io
.mapRequired("kind", info
.kind
);
303 io
.mapOptional("fn_name", info
.fnName
);
304 io
.mapOptional("attr_name", info
.attrName
);
305 io
.mapOptional("type_var", info
.typeVar
);
306 io
.mapRequired("operands", info
.operands
);
310 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
313 struct ScalarTraits
<SerializedAffineMap
> {
314 static void output(const SerializedAffineMap
&value
, void *rawYamlContext
,
316 assert(value
.affineMapAttr
);
317 value
.affineMapAttr
.print(out
);
319 static StringRef
input(StringRef scalar
, void *rawYamlContext
,
320 SerializedAffineMap
&value
) {
321 assert(rawYamlContext
);
322 auto *yamlContext
= static_cast<LinalgYAMLContext
*>(rawYamlContext
);
323 if (auto attr
= dyn_cast_or_null
<AffineMapAttr
>(
324 mlir::parseAttribute(scalar
, yamlContext
->mlirContext
)))
325 value
.affineMapAttr
= attr
;
326 else if (!value
.affineMapAttr
|| !isa
<AffineMapAttr
>(value
.affineMapAttr
))
327 return "could not parse as an affine map attribute";
330 static QuotingType
mustQuote(StringRef
) { return QuotingType::None
; }
338 //===----------------------------------------------------------------------===//
339 // Generation utilities
340 //===----------------------------------------------------------------------===//
342 class GenerationContext
{
344 GenerationContext(MLIRContext
*context
, raw_ostream
*odsOut
,
345 raw_ostream
*defnOut
)
346 : context(context
), loc(UnknownLoc::get(context
)), odsOut(odsOut
),
349 MLIRContext
*getContext() { return context
; }
351 void setLoc(Location loc
) { this->loc
= loc
; }
352 Location
getLoc() { return loc
; }
354 bool shouldGenerateOds() { return odsOut
; }
355 bool shouldGenerateDefns() { return defnOut
; }
357 raw_ostream
&odss() {
358 assert(odsOut
&& "ODS stream not defined");
362 raw_ostream
&defns() {
363 assert(defnOut
&& "Definition stream not defined");
368 MLIRContext
*context
;
371 raw_ostream
*defnOut
;
376 static std::string
generateCppExpression(SerializedAffineMap self
,
377 StringRef contextName
) {
378 std::string printedStr
;
379 llvm::raw_string_ostream
printedSs(printedStr
);
380 self
.affineMapAttr
.print(printedSs
);
383 static const char exprFormat
[] =
384 R
"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT";
385 return llvm::formatv(exprFormat
, printedStr
, contextName
);
388 template <typename Container
>
389 static std::string
interleaveToString(Container
&container
,
390 StringRef separator
) {
392 llvm::raw_string_ostream
ss(result
);
393 llvm::interleave(container
, ss
, separator
);
398 static std::optional
<int>
399 findTensorDefArgIndex(StringRef name
, SmallVectorImpl
<LinalgOperandDef
> &args
) {
400 for (const auto &it
: llvm::enumerate(args
)) {
401 if (it
.value().name
== name
)
407 // Try to map the TypeVar to a predefined or an argument type.
408 static std::optional
<std::string
>
409 findTypeValue(StringRef typeVar
, SmallVectorImpl
<LinalgOperandDef
> &args
) {
410 // Handle all predefined types.
411 if (typeVar
== "I32")
412 return std::string("helper.getIntegerType(32)");
413 if (typeVar
== "I64")
414 return std::string("helper.getIntegerType(64)");
415 if (typeVar
== "F32")
416 return std::string("helper.getFloat32Type()");
417 if (typeVar
== "F64")
418 return std::string("helper.getFloat64Type()");
420 // Search all argument types.
421 for (const auto &it
: llvm::enumerate(args
)) {
422 if (it
.value().kind
!= LinalgOperandDefKind::InputTensor
&&
423 it
.value().kind
!= LinalgOperandDefKind::Scalar
&&
424 it
.value().kind
!= LinalgOperandDefKind::OutputTensor
)
426 if (*it
.value().typeVar
== typeVar
)
427 return llvm::formatv("block.getArgument({0}).getType()", it
.index())
434 static ScalarAssign
*findAssignment(StringRef name
,
435 std::vector
<ScalarAssign
> &assignments
) {
436 for (auto &assign
: assignments
) {
437 if (assign
.arg
== name
)
443 // Return true if the operand is a function attribute.
444 static bool isFunctionAttribute(LinalgOperandDefKind kind
) {
445 return kind
== LinalgOperandDefKind::UnaryFnAttr
||
446 kind
== LinalgOperandDefKind::BinaryFnAttr
||
447 kind
== LinalgOperandDefKind::TernaryFnAttr
||
448 kind
== LinalgOperandDefKind::TypeFnAttr
;
451 // Return true if the operand is an attribute.
452 static bool isAttribute(LinalgOperandDefKind kind
) {
453 return kind
== LinalgOperandDefKind::IndexAttr
|| isFunctionAttribute(kind
);
456 // Get the enum name for the given operand kind.
457 std::string
convertOperandKindToEnumName(LinalgOperandDefKind kind
) {
459 case LinalgOperandDefKind::UnaryFnAttr
:
460 return std::string("UnaryFn");
461 case LinalgOperandDefKind::BinaryFnAttr
:
462 return std::string("BinaryFn");
463 case LinalgOperandDefKind::TernaryFnAttr
:
464 return std::string("TernaryFn");
465 case LinalgOperandDefKind::TypeFnAttr
:
466 return std::string("TypeFn");
470 llvm_unreachable("unsupported function attribute kind");
473 // Get the enum name for the given function kind.
474 std::string
convertFunctionKindToEnumName(ScalarFnKind kind
) {
476 case ScalarFnKind::Unary
:
477 return std::string("UnaryFn");
478 case ScalarFnKind::Binary
:
479 return std::string("BinaryFn");
480 case ScalarFnKind::Ternary
:
481 return std::string("TernaryFn");
482 case ScalarFnKind::Type
:
483 return std::string("TypeFn");
485 llvm_unreachable("unsupported function kind");
488 //===----------------------------------------------------------------------===//
490 //===----------------------------------------------------------------------===//
492 // A single line banner format. Parameters:
493 // {0}: Single line comment
494 static const char bannerFormat
[] = R
"FMT(
495 //===----------------------------------------------------------------------===//
497 //===----------------------------------------------------------------------===//
500 //===----------------------------------------------------------------------===//
501 // Named generic op generation.
502 // These ops map at most a single contraction that complies with the limitations
503 // of a linalg.generic.
504 //===----------------------------------------------------------------------===//
506 // Template for Linalg named ops' ODS definitions. Parameters:
507 // {0}: ODS/C++ op name
508 // {1}: assembly op mnemonic
509 // {2}: op interface list
510 // {3}: documentation (summary + description)
511 // {4}: op attribute list
512 // {5}: builder methods taking standalone attribute parameters
513 // {6}: additional method defintions
514 // {7}: additional methods for attributes used by indexing maps
515 static const char structuredOpOdsHeaderFormat
[] = R
"FMT(
516 //===----------------------------------------------------------------------===//
517 // Op definition for {0}
518 //===----------------------------------------------------------------------===//
520 def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
521 /*extraInterfaces=*/[{2}])> {
524 Variadic<AnyType>:$inputs,
525 Variadic<AnyShaped>:$outputs{4}
527 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
528 let regions = (region AnyRegion:$region);
530 let skipDefaultBuilders = 1;
533 (ins "ValueRange
":$inputs, "ValueRange
":$outputs,
534 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
536 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
537 attributes, {0}::getRegionBuilder());
540 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$inputs,
541 "ValueRange
":$outputs,
542 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
544 buildStructuredOp($_builder, $_state, resultTensorTypes,
545 inputs, outputs, attributes, {0}::getRegionBuilder());
548 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$operands,
549 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
551 $_state.addOperands(operands);
552 $_state.addAttributes(attributes);
553 $_state.addTypes(resultTensorTypes);
554 (void)$_state.addRegion();
558 let hasCustomAssemblyFormat = 1;
562 let extraClassDeclaration = structuredOpsBaseDecls # [{{
564 SmallVector<utils::IteratorType> getIteratorTypesArray();
565 ArrayAttr getIndexingMaps();
566 static void regionBuilder(ImplicitLocOpBuilder &b,
567 Block &block, ArrayRef<NamedAttribute> attrs);
568 static std::function<void(ImplicitLocOpBuilder &,
569 Block &, ArrayRef<NamedAttribute>)>
570 getRegionBuilder() {{
571 return regionBuilder;
574 ::mlir::MutableOperandRange getDpsInitsMutable() {{
575 return getOutputsMutable();
579 static unsigned getNumRegionArgs();
580 std::string getLibraryCallName();
586 // Builder method taking attribute parameters. Parameters:
588 // {1}: Comma interleaved attribute parameters
589 // {2}: Attribute initialization
590 static const char structuredOpBuilderFormat
[] = R
"FMT(
592 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$inputs,
593 "ValueRange
":$outputs, {1},
594 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
597 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
598 attributes, {0}::getRegionBuilder());
602 // The getIteratorTypesArray() method for structured ops. Parameters:
604 // {1}: Comma interleaved iterator type names.
605 static const char structuredOpIteratorTypesFormat
[] =
607 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
608 return SmallVector<utils::IteratorType>{{ {1} };
612 // The getIteratorTypesArray() method for rank polymorphic structured ops.
615 static const char rankPolyStructuredOpIteratorTypesFormat
[] =
617 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
618 int64_t rank = getRank(getDpsInitOperand(0));
619 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
623 // The indexing_maps() method for structured ops. Parameters:
625 // {1}: Comma-separated list of dimension variable names.
627 static const char structuredOpIndexingMapsFormat
[] = R
"FMT(
628 ArrayAttr {0}::getIndexingMaps() {{
629 static const char memoizeAttr[] = "linalg
.memoized_indexing_maps
";
630 ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
634 MLIRContext *context = getContext();
635 auto symbolBindings = getSymbolBindings(*this);
636 SmallVector<AffineMap> maps;
638 cached = Builder(context).getAffineMapArrayAttr(maps);
639 getOperation()->setAttr(memoizeAttr, cached);
644 // The indexing_maps() method for rank polymorphic structured ops. Parameters:
646 static const char rankPolyStructuredOpIndexingMapsFormat
[] = R
"FMT(
647 ArrayAttr {0}::getIndexingMaps() {{
648 MLIRContext *context = getContext();
649 AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
650 AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
651 getNumParallelLoops(), context);
652 SmallVector<AffineMap> indexingMaps;
653 for (OpOperand &opOperand : getOperation()->getOpOperands())
654 indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
655 return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
659 // Implementations of fold and getEffects.
662 const char structuredOpFoldersFormat
[] = R
"FMT(
663 LogicalResult {0}::fold(FoldAdaptor,
664 SmallVectorImpl<OpFoldResult> &) {{
665 return memref::foldMemRefCast(*this);
667 void {0}::getEffects(SmallVectorImpl<
668 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
669 if (hasPureTensorSemantics()) return;
670 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
674 // Implementation of parse/print.
677 static const char structuredOpParserFormat
[] = R
"FMT(
678 ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
679 return ::parseNamedStructuredOp(parser, result,
680 {0}::getNumRegionArgs(), {0}::getRegionBuilder());
682 void {0}::print(OpAsmPrinter &p) {{
683 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
687 static LogicalResult
generateNamedGenericOpOds(LinalgOpConfig
&opConfig
,
688 GenerationContext
&genContext
) {
689 if (!genContext
.shouldGenerateOds())
692 raw_ostream
&os
= genContext
.odss();
694 std::string interfaceNameList
;
695 std::string attrList
;
696 std::string attrMethods
;
697 std::string attrBuilder
;
700 if (opConfig
.metadata
->doc
) {
701 static const char structuredOpDocFmt
[] = R
"FMT(
702 let summary = [{{{0}}];
703 let description = [{{{1}}];
705 StringRef summary
, description
;
706 std::tie(summary
, description
) =
707 StringRef(*opConfig
.metadata
->doc
).trim().split("\n\n");
709 doc
= llvm::formatv(structuredOpDocFmt
, summary
.trim(), description
.trim());
712 interfaceNameList
= interleaveToString(opConfig
.metadata
->implements
, ", ");
714 std::string definitionList
;
715 for (const std::string
&definition
: opConfig
.metadata
->defines
) {
716 static const char definitionFmt
[] = "let {0} = 1;\n";
717 definitionList
.append(llvm::formatv(definitionFmt
, definition
));
720 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
721 return isAttribute(arg
.kind
);
723 SmallVector
<std::string
> attrDefs
;
724 SmallVector
<std::string
> attrParams
;
725 SmallVector
<std::string
> attrStmts
;
726 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
727 static const char paramFmt
[] = "\"Attribute\":${0}";
728 static const char stmtFmt
[] = "$_state.addAttribute(\"{0}\", {0});";
729 // Add the type conversion attributes to the op definition and builders.
730 if (isFunctionAttribute(arg
.kind
)) {
731 assert(arg
.defaultFn
);
732 std::string enumName
= convertOperandKindToEnumName(arg
.kind
);
733 static const char typeFmt
[] = "{0}::{1}";
734 static const char defFmt
[] =
735 "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}";
736 attrDefs
.push_back(llvm::formatv(
737 defFmt
, llvm::formatv("{0}Attr", enumName
),
738 llvm::formatv(typeFmt
, enumName
, arg
.defaultFn
), arg
.name
));
739 attrParams
.push_back(llvm::formatv(paramFmt
, arg
.name
));
740 attrStmts
.push_back(llvm::formatv(stmtFmt
, arg
.name
));
742 // Add the index attributes to the op definition and builders.
743 if (arg
.kind
== LinalgOperandDefKind::IndexAttr
) {
744 assert(arg
.indexAttrMap
.has_value());
745 assert(arg
.defaultIndices
.has_value());
746 size_t size
= arg
.indexAttrMap
->affineMap().getNumResults();
747 assert(arg
.defaultIndices
->size() == size
);
748 static const char typeFmt
[] = "RankedI64ElementsAttr<[{0}]>";
749 static const char defFmt
[] =
750 "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}";
751 std::string defaultVals
;
752 llvm::raw_string_ostream
ss(defaultVals
);
754 *arg
.defaultIndices
, ss
,
755 [&](int64_t val
) { ss
<< "static_cast<int64_t>(" << val
<< ")"; },
757 attrDefs
.push_back(llvm::formatv(defFmt
, llvm::formatv(typeFmt
, size
),
758 ss
.str(), arg
.name
));
759 attrParams
.push_back(llvm::formatv(paramFmt
, arg
.name
));
760 attrStmts
.push_back(llvm::formatv(stmtFmt
, arg
.name
));
763 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
764 return arg
.kind
== LinalgOperandDefKind::IndexAttr
;
767 bool hasDynamicIndexingMaps();
768 LogicalResult verifyIndexingMapRequiredAttributes();
771 attrList
= ",\n" + llvm::join(attrDefs
, ",\n");
772 attrBuilder
= llvm::formatv(
773 structuredOpBuilderFormat
, opConfig
.metadata
->cppClassName
,
774 llvm::join(attrParams
, ", "), llvm::join(attrStmts
, "\n"));
777 os
<< llvm::formatv(structuredOpOdsHeaderFormat
,
778 opConfig
.metadata
->cppClassName
, opConfig
.metadata
->name
,
779 interfaceNameList
, doc
, attrList
, attrBuilder
,
780 definitionList
, attrMethods
);
786 generateNamedGenericOpDefns(LinalgOpConfig
&opConfig
,
787 GenerationContext
&genContext
) {
788 if (!genContext
.shouldGenerateDefns())
791 raw_ostream
&os
= genContext
.defns();
792 StringRef className
= opConfig
.metadata
->cppClassName
;
794 // Implementation banner.
795 std::string bannerComment
= llvm::formatv("Implementation of {0}", className
);
796 os
<< llvm::formatv(bannerFormat
, bannerComment
);
798 // Compute the number of scalar and tensor arguments.
800 llvm::count_if(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
801 return arg
.kind
== LinalgOperandDefKind::InputTensor
||
802 arg
.kind
== LinalgOperandDefKind::Scalar
||
803 arg
.kind
== LinalgOperandDefKind::OutputTensor
;
806 // An operation that accesses only scalars and scalar/rank zero tensors is
807 // rank polymorhpic. We implement rank polymorphism by generating different
808 // indexing maps and iterators that match the rank of the first output tensor.
809 // An operation is rank polymorphic if the iteration domain has rank zero.
810 bool isRankPolymorphic
= opConfig
.structuredOp
->iteratorTypes
.empty();
812 // Generate the iterator_types() method.
813 if (!isRankPolymorphic
) {
814 std::string iteratorsStr
;
815 llvm::raw_string_ostream
ss(iteratorsStr
);
816 llvm::interleaveComma(opConfig
.structuredOp
->iteratorTypes
, ss
,
817 [&](LinalgIteratorTypeDef it
) {
819 case LinalgIteratorTypeDef::parallel
:
820 ss
<< "utils::IteratorType::parallel";
822 case LinalgIteratorTypeDef::reduction
:
823 ss
<< "utils::IteratorType::reduction";
828 os
<< llvm::formatv(structuredOpIteratorTypesFormat
, className
,
831 os
<< llvm::formatv(rankPolyStructuredOpIteratorTypesFormat
, className
);
834 // Generating the getIndexingMaps() method.
835 if (auto &staticMaps
=
836 opConfig
.structuredOp
->indexingMaps
.staticIndexingMaps
) {
837 if (staticMaps
->empty())
838 return emitError(genContext
.getLoc()) << "op has no indexing maps";
839 if (!isRankPolymorphic
) {
840 AffineMap firstMap
= staticMaps
->front().affineMap();
844 // For each symbol, generate a declaration for it, either with an
845 // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
847 // TODO: Possibly lift into a top-level method.
848 static const char structuredOpSymbolBindingsFormat
[] = R
"FMT(
849 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
850 MLIRContext *context = self.getContext();
851 SmallVector<AffineExpr> exprs;
857 unsigned symbolCount
= firstMap
.getNumSymbols();
858 SmallVector
<std::string
> symbolBindings
;
859 for (unsigned i
= 0; i
< symbolCount
; ++i
) {
860 symbolBindings
.push_back(llvm::formatv(
861 " exprs.push_back(getAffineSymbolExpr({0}, context));", i
));
864 // Access an index attribute. Parameters:
865 // {0}: Attribute name
866 // {1}: Symbol position
867 // {2}: Attribute index
868 static const char structuredOpAccessAttrFormat
[] = R
"FMT(
869 int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}];
870 exprs.push_back(getAffineConstantExpr(cst{1}, context));
872 // Update all symbol bindings mapped to an attribute.
873 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
874 if (arg
.kind
!= LinalgOperandDefKind::IndexAttr
)
876 assert(arg
.indexAttrMap
);
877 for (auto [idx
, result
] :
878 llvm::enumerate(arg
.indexAttrMap
->affineMap().getResults())) {
879 if (auto symbol
= dyn_cast
<AffineSymbolExpr
>(result
)) {
880 std::string argName
= arg
.name
;
881 argName
[0] = toupper(argName
[0]);
882 symbolBindings
[symbol
.getPosition()] =
883 llvm::formatv(structuredOpAccessAttrFormat
, argName
,
884 symbol
.getPosition(), idx
);
889 std::string symbolBindingsStr
;
890 llvm::raw_string_ostream
symbolBindingsSs(symbolBindingsStr
);
891 llvm::interleave(symbolBindings
, symbolBindingsSs
, "\n");
892 symbolBindingsSs
.flush();
894 os
<< llvm::formatv(structuredOpSymbolBindingsFormat
, className
,
900 unsigned dimCount
= firstMap
.getNumDims();
902 // Generate a comma-separated list of dim identifiers to be passed to
903 // bindDims, ensuring tht AffineExpr identifiers are bound in the right
904 // order to the proper AffineDimExpr.
905 // This results in vars in scope like: d0, d1, d2...
906 SmallVector
<unsigned> dimIndices
;
907 for (unsigned i
= 0; i
< dimCount
; ++i
)
908 dimIndices
.push_back(i
);
909 std::string dimIdentsStr
;
910 llvm::raw_string_ostream
dimIdentsSs(dimIdentsStr
);
911 llvm::interleaveComma(dimIndices
, dimIdentsSs
,
912 [&](unsigned i
) { dimIdentsSs
<< "d" << i
; });
915 // Statements to add and simplify each affine map.
916 SmallVector
<std::string
> stmts
;
917 for (auto &indexingMap
: *staticMaps
) {
918 // TODO: Assert that dim and symbol count match the first.
920 llvm::formatv("maps.push_back({0});",
921 generateCppExpression(indexingMap
, "context")));
922 stmts
.push_back(llvm::formatv(
924 "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
925 "symbolBindings, {0}, 0));",
929 // TODO: This needs to be memoized and/or converted to non-parser based
930 // C++ codegen prior to real use.
931 os
<< llvm::formatv(structuredOpIndexingMapsFormat
, className
,
932 dimIdentsStr
, interleaveToString(stmts
, "\n "));
935 os
<< llvm::formatv(rankPolyStructuredOpIndexingMapsFormat
, className
);
938 return emitError(genContext
.getLoc())
939 << "generating code for non static indexing maps not currently "
943 // getNumRegionArgs()
945 // Generates a getNumRegionArgs() method. Parameters:
947 // {1}: Number of region args
948 static const char structuredOpGetNumRegionArgsFormat
[] = R
"FMT(
949 unsigned {0}::getNumRegionArgs() {{ return {1}; }
951 os
<< llvm::formatv(structuredOpGetNumRegionArgsFormat
, className
,
955 // getLibraryCallName()
957 // Generates a getLibraryCallName method. Parameters:
959 static const char structuredOpGetLibraryCallFormat
[] = R
"FMT(
960 std::string {0}::getLibraryCallName() {{
961 return generateLibraryCallName(getOperation());
964 os
<< llvm::formatv(structuredOpGetLibraryCallFormat
, className
);
967 // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
968 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
969 return arg
.kind
== LinalgOperandDefKind::IndexAttr
;
971 std::vector
<std::string
> attrVerifications
;
972 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
973 if (arg
.kind
!= LinalgOperandDefKind::IndexAttr
)
975 assert(arg
.indexAttrMap
);
976 // Verify index attribute. Paramters:
977 // {0}: Attribute name
978 // {1}: Attribute size
979 static const char attrFmt
[] = R
"FMT(
980 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
981 if (!attr.getType().getElementType().isInteger(64))
982 return op->emitError("incorrect element type
for index attribute
'{0}'");
983 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
984 return op->emitError("incorrect shape
for index attribute
'{0}'");
987 attrVerifications
.push_back(llvm::formatv(
988 attrFmt
, arg
.name
, arg
.indexAttrMap
->affineMap().getNumResults()));
991 // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
993 // {1}: Attribute verification
994 static const char structuredOpVerifyIndexingMapRequiredAttributes
[] = R
"FMT(
995 bool {0}::hasDynamicIndexingMaps() {{ return true; }
996 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
997 Operation *op = getOperation();
1002 os
<< llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes
,
1003 className
, llvm::join(attrVerifications
, "\n"));
1008 // Generates a regionBuilder method. Parameters.
1010 // {1}: Number of args
1013 static const char structuredOpRegionBuilderFormat
[] = R
"FMT(
1014 void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1015 Block &block, ArrayRef<NamedAttribute> attrs) {{
1016 assert({1} > 0 && block.getNumArguments() == {1} &&
1017 "{0} regionBuilder expects
{1} (>=0) args
");
1018 RegionBuilderHelper helper(b, block);
1019 SmallVector<Value> yields;
1022 helper.yieldOutputs(yields);
1025 auto &args
= opConfig
.structuredOp
->args
;
1026 auto &assignments
= opConfig
.structuredOp
->assignments
;
1027 size_t generatedAssignmentCount
= 0;
1028 int localCounter
= 0;
1029 SmallVector
<std::string
> attrs
;
1030 SmallVector
<std::string
> stmts
;
1031 for (LinalgOperandDef
&arg
: args
) {
1032 if (!isFunctionAttribute(arg
.kind
))
1034 // Obtain the type function attribute values. Parameters.
1036 // {1}: attribute name
1037 // {2}: default type function name
1038 static const char attrDef
[] = R
"FMT(
1039 {0} {1}Val = {0}::{2};
1040 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1041 return attr.getName() == "{1}"; });
1042 if ({1}Iter != attrs.end()) {{
1043 if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1044 {1}Val = attr.getValue();
1047 std::string enumName
= convertOperandKindToEnumName(arg
.kind
);
1049 llvm::formatv(attrDef
, enumName
, arg
.name
, arg
.defaultFn
));
1051 for (LinalgOperandDef
&arg
: args
) {
1052 if (arg
.kind
!= LinalgOperandDefKind::OutputTensor
)
1055 // Find the assignment that correlates with the argument.
1056 ScalarAssign
*assignment
= findAssignment(arg
.name
, assignments
);
1058 return emitError(genContext
.getLoc())
1059 << "no assignment found for output argument " << arg
.name
;
1060 ++generatedAssignmentCount
;
1062 // Recursively generate the expression.
1063 std::function
<std::optional
<std::string
>(ScalarExpression
&)>
1064 generateExpression
=
1065 [&](ScalarExpression
&expression
) -> std::optional
<std::string
> {
1066 if (expression
.arg
) {
1067 // Argument reference.
1068 std::optional
<int> argIndex
=
1069 findTensorDefArgIndex(*expression
.arg
, args
);
1071 emitError(genContext
.getLoc())
1072 << "scalar argument not defined on the op: " << *expression
.arg
;
1073 return std::nullopt
;
1076 llvm::formatv("block.getArgument({0})", *argIndex
));
1078 if (expression
.constant
) {
1079 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1081 llvm::formatv(R
"FMT(Value {0} = helper.constant("{1}");)FMT",
1082 cppIdent
, expression
.constant
));
1085 if (expression
.index
) {
1086 // Access an iteration index.
1087 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1088 stmts
.push_back(llvm::formatv("Value {0} = helper.index({1});",
1089 cppIdent
, *expression
.index
));
1092 if (expression
.scalarFn
) {
1093 std::string enumName
=
1094 convertFunctionKindToEnumName(expression
.scalarFn
->kind
);
1096 // Get the function or attribute name.
1097 assert(expression
.scalarFn
->fnName
|| expression
.scalarFn
->attrName
);
1098 std::string funcType
;
1099 if (expression
.scalarFn
->fnName
) {
1100 funcType
= llvm::formatv("{0}::{1}", enumName
,
1101 *expression
.scalarFn
->fnName
);
1103 if (expression
.scalarFn
->attrName
) {
1104 if (llvm::none_of(args
, [&](LinalgOperandDef
&arg
) {
1105 return isFunctionAttribute(arg
.kind
) &&
1106 arg
.name
== *expression
.scalarFn
->attrName
;
1108 emitError(genContext
.getLoc()) << "missing function attribute "
1109 << *expression
.scalarFn
->attrName
;
1111 funcType
= llvm::formatv("{0}Val", *expression
.scalarFn
->attrName
);
1113 assert(!funcType
.empty());
1115 // Add the optional type parameter to the operands.
1116 SmallVector
<std::string
> operandCppValues
;
1117 if (expression
.scalarFn
->kind
== ScalarFnKind::Type
) {
1118 assert(expression
.scalarFn
->typeVar
.has_value());
1119 std::optional
<std::string
> typeCppValue
=
1120 findTypeValue(*expression
.scalarFn
->typeVar
, args
);
1121 if (!typeCppValue
) {
1122 emitError(genContext
.getLoc())
1123 << "type variable " << *expression
.scalarFn
->typeVar
1124 << ", used in a type conversion, must map to a predefined or "
1125 << "an argument type but it does not";
1126 return std::nullopt
;
1128 operandCppValues
.push_back(*typeCppValue
);
1131 // Collect the scalar operands.
1132 for (ScalarExpression
&operand
: expression
.scalarFn
->operands
) {
1133 auto operandCppValue
= generateExpression(operand
);
1134 if (!operandCppValue
)
1135 return std::nullopt
;
1136 operandCppValues
.push_back(*operandCppValue
);
1139 // Call the function builder.
1140 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1141 stmts
.push_back(llvm::formatv(
1142 "Value {0} = helper.build{1}({2}, {3});", cppIdent
, enumName
,
1143 funcType
, interleaveToString(operandCppValues
, ", ")));
1146 emitError(genContext
.getLoc()) << "unknown ScalarExpression type";
1147 return std::nullopt
;
1149 std::optional
<std::string
> cppValue
=
1150 generateExpression(assignment
->value
);
1153 stmts
.push_back(llvm::formatv("yields.push_back({0});", *cppValue
));
1156 if (generatedAssignmentCount
!= assignments
.size())
1157 return emitError(genContext
.getLoc())
1158 << "mismatched number of assignments vs output arguments";
1160 os
<< llvm::formatv(structuredOpRegionBuilderFormat
, className
, numOfArgs
,
1161 interleaveToString(attrs
, "\n "),
1162 interleaveToString(stmts
, "\n "));
1165 // Parser and printer.
1166 os
<< llvm::formatv(structuredOpParserFormat
, className
);
1168 // Canonicalizers and folders.
1169 os
<< llvm::formatv(structuredOpFoldersFormat
, className
);
1174 static LogicalResult
generateOp(LinalgOpConfig
&opConfig
,
1175 GenerationContext
&genContext
) {
1176 // Switch on op type being generated.
1177 if (opConfig
.structuredOp
) {
1179 succeeded(generateNamedGenericOpOds(opConfig
, genContext
)) &&
1180 succeeded(generateNamedGenericOpDefns(opConfig
, genContext
)));
1182 return emitError(genContext
.getLoc()) << "unsupported operation type";
1185 //===----------------------------------------------------------------------===//
1186 // Command line options and main
1187 //===----------------------------------------------------------------------===//
1189 static llvm::cl::opt
<std::string
>
1190 inputFilename(llvm::cl::Positional
, llvm::cl::desc("<input file>"),
1191 llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
1193 static llvm::cl::opt
<std::string
>
1194 outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1195 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1197 static llvm::cl::opt
<std::string
>
1198 outputCppImplFilename("o-impl",
1199 llvm::cl::desc("C++ implementation file name"),
1200 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1202 int main(int argc
, char **argv
) {
1203 llvm::cl::ParseCommandLineOptions(argc
, argv
, "Linalg ODS Gen from YAML");
1205 // Set up the input file.
1206 std::string errorMessage
;
1207 std::unique_ptr
<llvm::MemoryBuffer
> file
=
1208 mlir::openInputFile(inputFilename
, &errorMessage
);
1210 llvm::errs() << errorMessage
<< "\n";
1214 MLIRContext mlirContext
;
1215 LinalgYAMLContext yamlContext
{&mlirContext
};
1217 std::vector
<LinalgOpConfig
> opConfigs
;
1220 Input
yin(file
->getBuffer(), &yamlContext
);
1226 // Open output files.
1227 std::unique_ptr
<llvm::ToolOutputFile
> outputOdsDecl
;
1228 if (!outputOdsDeclFilename
.empty()) {
1229 outputOdsDecl
= openOutputFile(outputOdsDeclFilename
, &errorMessage
);
1230 if (!outputOdsDecl
) {
1231 llvm::errs() << errorMessage
<< "\n";
1236 std::unique_ptr
<llvm::ToolOutputFile
> outputCppImpl
;
1237 if (!outputCppImplFilename
.empty()) {
1238 outputCppImpl
= openOutputFile(outputCppImplFilename
, &errorMessage
);
1239 if (!outputCppImpl
) {
1240 llvm::errs() << errorMessage
<< "\n";
1245 if (!outputOdsDecl
&& !outputCppImpl
) {
1246 llvm::errs() << "error: No output files specified\n";
1251 GenerationContext
genContext(&mlirContext
,
1252 outputOdsDecl
? &outputOdsDecl
->os() : nullptr,
1253 outputCppImpl
? &outputCppImpl
->os() : nullptr);
1255 for (auto &opConfig
: opConfigs
) {
1256 if (!opConfig
.metadata
) {
1257 emitError(genContext
.getLoc())
1258 << "missing operation metadata on subsequent op";
1262 genContext
.setLoc(NameLoc::get(
1263 StringAttr::get(&mlirContext
, opConfig
.metadata
->cppClassName
)));
1264 if (failed(generateOp(opConfig
, genContext
))) {
1270 outputOdsDecl
->keep();
1272 outputCppImpl
->keep();