1 //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements an ODS (and C++) generator from a YAML form
10 // derived from the mathematical expression of linalg named ops. Typically a
11 // math oriented DSL will be used to export the essential representation to
12 // this form, and maintaining the SOT at the math level (versus recreating it
13 // in MLIR) is deemed to have systemic value.
15 //===----------------------------------------------------------------------===//
17 #include "mlir/AsmParser/AsmParser.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/ToolOutputFile.h"
28 #include "llvm/Support/YAMLTraits.h"
33 using llvm::yaml::Input
;
35 #define DEBUG_TYPE "linalg-ods-gen"
37 //===----------------------------------------------------------------------===//
38 // Mapping structs (correspond to data types in the YAML description).
39 // TODO: Since this is a schema/part of the contract, it should be moved to
41 //===----------------------------------------------------------------------===//
45 struct LinalgYAMLContext
{
46 MLIRContext
*mlirContext
;
49 struct LinalgOpMetadata
{
51 std::string cppClassName
;
52 std::optional
<std::string
> doc
;
53 SmallVector
<std::string
> implements
;
54 SmallVector
<std::string
> defines
;
57 struct SerializedAffineMap
{
58 AffineMapAttr affineMapAttr
;
60 AffineMap
affineMap() { return affineMapAttr
.getValue(); }
63 enum class LinalgOperandDefKind
{
74 struct LinalgOperandDef
{
76 LinalgOperandDefKind kind
;
77 std::optional
<std::string
> typeVar
;
78 std::optional
<SerializedAffineMap
> shapeMap
;
79 std::optional
<SerializedAffineMap
> indexAttrMap
;
80 std::optional
<SmallVector
<int64_t>> defaultIndices
;
81 std::optional
<std::string
> defaultFn
;
84 enum class LinalgIteratorTypeDef
{
89 struct LinalgIndexingMapsConfig
{
90 std::optional
<SmallVector
<SerializedAffineMap
>> staticIndexingMaps
;
93 struct ScalarExpression
;
95 enum class ScalarFnKind
{ Unary
, Binary
, Ternary
, Type
};
99 std::optional
<std::string
> fnName
;
100 std::optional
<std::string
> attrName
;
101 std::optional
<std::string
> typeVar
;
102 // NOTE: This must be of arity 1, but to break the self-referential cycle,
103 // we use a heap allocated vector.
104 std::vector
<ScalarExpression
> operands
;
107 struct ScalarExpression
{
108 std::optional
<std::string
> arg
;
109 std::optional
<std::string
> constant
;
110 std::optional
<int64_t> index
;
111 std::optional
<ScalarFn
> scalarFn
;
114 struct ScalarAssign
{
116 ScalarExpression value
;
119 struct LinalgStructuredOpConfig
{
120 SmallVector
<LinalgOperandDef
> args
;
121 LinalgIndexingMapsConfig indexingMaps
;
122 SmallVector
<LinalgIteratorTypeDef
> iteratorTypes
;
123 std::vector
<ScalarAssign
> assignments
;
126 struct LinalgOpConfig
{
127 std::optional
<LinalgOpMetadata
> metadata
;
128 std::optional
<LinalgStructuredOpConfig
> structuredOp
;
133 //===----------------------------------------------------------------------===//
135 //===----------------------------------------------------------------------===//
137 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef
)
138 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap
)
139 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef
)
140 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign
)
141 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression
)
142 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig
)
147 /// Top-level type containing op metadata and one of a concrete op type.
148 /// Currently, the only defined op type is `structured_op` (maps to
149 /// `LinalgStructuredOpConfig`).
151 struct MappingTraits
<LinalgOpConfig
> {
152 static void mapping(IO
&io
, LinalgOpConfig
&info
) {
153 io
.mapOptional("metadata", info
.metadata
);
154 io
.mapOptional("structured_op", info
.structuredOp
);
158 /// A structured op models (at most) a single contraction by modeling
159 /// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
160 /// outputs, or index attributes.
161 /// - List of indexing maps (see `LinalgIndexingMaps`).
162 /// - Iterator types (see `LinalgIteratorTypeDef`).
163 /// - List of scalar level assignment (see `ScalarAssign`).
165 struct MappingTraits
<LinalgStructuredOpConfig
> {
166 static void mapping(IO
&io
, LinalgStructuredOpConfig
&info
) {
167 io
.mapRequired("args", info
.args
);
168 io
.mapRequired("indexing_maps", info
.indexingMaps
);
169 io
.mapRequired("iterator_types", info
.iteratorTypes
);
170 io
.mapRequired("assignments", info
.assignments
);
174 /// Maps a named tensor, scalar or attribute argument to an operation,
176 /// - `name`: Must be unique within the operation.
177 /// - `usage`: How the argument is used (input, output, attribute, etc).
178 /// - `type_var`: The symbolic type variable that binds to the element or self
179 /// type of the tensor or scalar argument, respectively.
180 /// - `shape_map`: An optional AffineMap from all op symbols to the shape of
181 /// the argument. Only tensor arguments have a `shape_map`. Each shape must
182 /// be normalized over the same list of symbols and have no dimension
184 /// - `index_attr_map`: An optional AffineMap from all op symbols to the
185 /// index attribute symbols. During op creation these symbols are replaced
186 /// by the corresponding `name` index attribue values. Only index attribute
187 /// arguments have an `index_attr_map`.
188 /// - `default_indices`: An optional default initialization for index
189 /// attribute arguments.
190 /// - `default_fn`: An optional default initialization for function attribute
193 struct MappingTraits
<LinalgOperandDef
> {
194 static void mapping(IO
&io
, LinalgOperandDef
&info
) {
195 io
.mapRequired("name", info
.name
);
196 io
.mapRequired("kind", info
.kind
);
197 io
.mapOptional("type_var", info
.typeVar
);
198 io
.mapOptional("shape_map", info
.shapeMap
);
199 io
.mapOptional("index_attr_map", info
.indexAttrMap
);
200 io
.mapOptional("default_indices", info
.defaultIndices
);
201 io
.mapOptional("default_fn", info
.defaultFn
);
205 /// Usage enum for a named argument.
207 struct ScalarEnumerationTraits
<LinalgOperandDefKind
> {
208 static void enumeration(IO
&io
, LinalgOperandDefKind
&value
) {
209 io
.enumCase(value
, "input_tensor", LinalgOperandDefKind::InputTensor
);
210 io
.enumCase(value
, "scalar", LinalgOperandDefKind::Scalar
);
211 io
.enumCase(value
, "output_tensor", LinalgOperandDefKind::OutputTensor
);
212 io
.enumCase(value
, "index_attr", LinalgOperandDefKind::IndexAttr
);
213 io
.enumCase(value
, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr
);
214 io
.enumCase(value
, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr
);
215 io
.enumCase(value
, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr
);
216 io
.enumCase(value
, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr
);
220 /// Iterator type enum.
222 struct ScalarEnumerationTraits
<LinalgIteratorTypeDef
> {
223 static void enumeration(IO
&io
, LinalgIteratorTypeDef
&value
) {
224 io
.enumCase(value
, "parallel", LinalgIteratorTypeDef::parallel
);
225 io
.enumCase(value
, "reduction", LinalgIteratorTypeDef::reduction
);
229 /// Metadata about the op (name, C++ name, and documentation).
231 struct MappingTraits
<LinalgOpMetadata
> {
232 static void mapping(IO
&io
, LinalgOpMetadata
&info
) {
233 io
.mapRequired("name", info
.name
);
234 io
.mapRequired("cpp_class_name", info
.cppClassName
);
235 io
.mapOptional("doc", info
.doc
);
236 io
.mapOptional("implements", info
.implements
);
237 io
.mapOptional("defines", info
.defines
);
241 /// How the ops indexing maps are produced. Must be one of:
242 /// - static_indexing_maps: A static list of AffineMaps, possibly with
243 /// some symbols that bind to attributes of the op. Each indexing map must
244 /// be normalized over the same list of dimensions, and its symbols must
245 /// match the symbols for argument shapes.
247 struct MappingTraits
<LinalgIndexingMapsConfig
> {
248 static void mapping(IO
&io
, LinalgIndexingMapsConfig
&info
) {
249 io
.mapOptional("static_indexing_maps", info
.staticIndexingMaps
);
253 /// Models an assignment to a named output.
254 /// - The `arg` name must match a named output.
255 /// - The `value` is a scalar expression for computing the value to
256 /// assign (see `ScalarExpression`).
258 struct MappingTraits
<ScalarAssign
> {
259 static void mapping(IO
&io
, ScalarAssign
&info
) {
260 io
.mapRequired("arg", info
.arg
);
261 io
.mapRequired("value", info
.value
);
265 /// A scalar expression (RHS of an assignment). Must be one of:
266 /// - `scalar_arg`: An operation argument.
267 /// - `scalar_const`: A constant definition.
268 /// - `scalar_index`: An iteration index.
269 /// - `scalar_fn`: A named function (see `ScalarFn`).
271 struct MappingTraits
<ScalarExpression
> {
272 static void mapping(IO
&io
, ScalarExpression
&info
) {
273 io
.mapOptional("scalar_arg", info
.arg
);
274 io
.mapOptional("scalar_const", info
.constant
);
275 io
.mapOptional("scalar_index", info
.index
);
276 io
.mapOptional("scalar_fn", info
.scalarFn
);
280 /// Scalar function kind enum.
282 struct ScalarEnumerationTraits
<ScalarFnKind
> {
283 static void enumeration(IO
&io
, ScalarFnKind
&value
) {
284 io
.enumCase(value
, "unary", ScalarFnKind::Unary
);
285 io
.enumCase(value
, "binary", ScalarFnKind::Binary
);
286 io
.enumCase(value
, "ternary", ScalarFnKind::Ternary
);
287 io
.enumCase(value
, "type", ScalarFnKind::Type
);
291 /// A scalar expression that evaluates a named function.
292 /// Functions are generally "math" level and type polymorphic. Builtin
293 /// functions include:
294 /// - `add(lhs, rhs)`
295 /// - `mul(lhs, rhs)`
297 struct MappingTraits
<ScalarFn
> {
298 static void mapping(IO
&io
, ScalarFn
&info
) {
299 io
.mapRequired("kind", info
.kind
);
300 io
.mapOptional("fn_name", info
.fnName
);
301 io
.mapOptional("attr_name", info
.attrName
);
302 io
.mapOptional("type_var", info
.typeVar
);
303 io
.mapRequired("operands", info
.operands
);
307 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
310 struct ScalarTraits
<SerializedAffineMap
> {
311 static void output(const SerializedAffineMap
&value
, void *rawYamlContext
,
313 assert(value
.affineMapAttr
);
314 value
.affineMapAttr
.print(out
);
316 static StringRef
input(StringRef scalar
, void *rawYamlContext
,
317 SerializedAffineMap
&value
) {
318 assert(rawYamlContext
);
319 auto *yamlContext
= static_cast<LinalgYAMLContext
*>(rawYamlContext
);
320 if (auto attr
= dyn_cast_or_null
<AffineMapAttr
>(
321 mlir::parseAttribute(scalar
, yamlContext
->mlirContext
)))
322 value
.affineMapAttr
= attr
;
323 else if (!value
.affineMapAttr
|| !isa
<AffineMapAttr
>(value
.affineMapAttr
))
324 return "could not parse as an affine map attribute";
327 static QuotingType
mustQuote(StringRef
) { return QuotingType::None
; }
335 //===----------------------------------------------------------------------===//
336 // Generation utilities
337 //===----------------------------------------------------------------------===//
339 class GenerationContext
{
341 GenerationContext(MLIRContext
*context
, raw_ostream
*odsOut
,
342 raw_ostream
*defnOut
)
343 : context(context
), loc(UnknownLoc::get(context
)), odsOut(odsOut
),
346 MLIRContext
*getContext() { return context
; }
348 void setLoc(Location loc
) { this->loc
= loc
; }
349 Location
getLoc() { return loc
; }
351 bool shouldGenerateOds() { return odsOut
; }
352 bool shouldGenerateDefns() { return defnOut
; }
354 raw_ostream
&odss() {
355 assert(odsOut
&& "ODS stream not defined");
359 raw_ostream
&defns() {
360 assert(defnOut
&& "Definition stream not defined");
365 MLIRContext
*context
;
368 raw_ostream
*defnOut
;
373 static std::string
generateCppExpression(SerializedAffineMap self
,
374 StringRef contextName
) {
375 std::string printedStr
;
376 llvm::raw_string_ostream
printedSs(printedStr
);
377 self
.affineMapAttr
.print(printedSs
);
379 static const char exprFormat
[] =
380 R
"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT";
381 return llvm::formatv(exprFormat
, printedStr
, contextName
);
384 template <typename Container
>
385 static std::string
interleaveToString(Container
&container
,
386 StringRef separator
) {
388 llvm::raw_string_ostream
ss(result
);
389 llvm::interleave(container
, ss
, separator
);
393 static std::optional
<int>
394 findTensorDefArgIndex(StringRef name
, SmallVectorImpl
<LinalgOperandDef
> &args
) {
395 for (const auto &it
: llvm::enumerate(args
)) {
396 if (it
.value().name
== name
)
402 // Try to map the TypeVar to a predefined or an argument type.
403 static std::optional
<std::string
>
404 findTypeValue(StringRef typeVar
, SmallVectorImpl
<LinalgOperandDef
> &args
) {
405 // Handle all predefined types.
406 if (typeVar
== "I32")
407 return std::string("helper.getIntegerType(32)");
408 if (typeVar
== "I64")
409 return std::string("helper.getIntegerType(64)");
410 if (typeVar
== "F32")
411 return std::string("helper.getFloat32Type()");
412 if (typeVar
== "F64")
413 return std::string("helper.getFloat64Type()");
415 // Search all argument types.
416 for (const auto &it
: llvm::enumerate(args
)) {
417 if (it
.value().kind
!= LinalgOperandDefKind::InputTensor
&&
418 it
.value().kind
!= LinalgOperandDefKind::Scalar
&&
419 it
.value().kind
!= LinalgOperandDefKind::OutputTensor
)
421 if (*it
.value().typeVar
== typeVar
)
422 return llvm::formatv("block.getArgument({0}).getType()", it
.index())
429 static ScalarAssign
*findAssignment(StringRef name
,
430 std::vector
<ScalarAssign
> &assignments
) {
431 for (auto &assign
: assignments
) {
432 if (assign
.arg
== name
)
438 // Return true if the operand is a function attribute.
439 static bool isFunctionAttribute(LinalgOperandDefKind kind
) {
440 return kind
== LinalgOperandDefKind::UnaryFnAttr
||
441 kind
== LinalgOperandDefKind::BinaryFnAttr
||
442 kind
== LinalgOperandDefKind::TernaryFnAttr
||
443 kind
== LinalgOperandDefKind::TypeFnAttr
;
446 // Return true if the operand is an attribute.
447 static bool isAttribute(LinalgOperandDefKind kind
) {
448 return kind
== LinalgOperandDefKind::IndexAttr
|| isFunctionAttribute(kind
);
451 // Get the enum name for the given operand kind.
452 std::string
convertOperandKindToEnumName(LinalgOperandDefKind kind
) {
454 case LinalgOperandDefKind::UnaryFnAttr
:
455 return std::string("UnaryFn");
456 case LinalgOperandDefKind::BinaryFnAttr
:
457 return std::string("BinaryFn");
458 case LinalgOperandDefKind::TernaryFnAttr
:
459 return std::string("TernaryFn");
460 case LinalgOperandDefKind::TypeFnAttr
:
461 return std::string("TypeFn");
465 llvm_unreachable("unsupported function attribute kind");
468 // Get the enum name for the given function kind.
469 std::string
convertFunctionKindToEnumName(ScalarFnKind kind
) {
471 case ScalarFnKind::Unary
:
472 return std::string("UnaryFn");
473 case ScalarFnKind::Binary
:
474 return std::string("BinaryFn");
475 case ScalarFnKind::Ternary
:
476 return std::string("TernaryFn");
477 case ScalarFnKind::Type
:
478 return std::string("TypeFn");
480 llvm_unreachable("unsupported function kind");
483 //===----------------------------------------------------------------------===//
485 //===----------------------------------------------------------------------===//
487 // A single line banner format. Parameters:
488 // {0}: Single line comment
489 static const char bannerFormat
[] = R
"FMT(
490 //===----------------------------------------------------------------------===//
492 //===----------------------------------------------------------------------===//
495 //===----------------------------------------------------------------------===//
496 // Named generic op generation.
497 // These ops map at most a single contraction that complies with the limitations
498 // of a linalg.generic.
499 //===----------------------------------------------------------------------===//
501 // Template for Linalg named ops' ODS definitions. Parameters:
502 // {0}: ODS/C++ op name
503 // {1}: assembly op mnemonic
504 // {2}: op interface list
505 // {3}: documentation (summary + description)
506 // {4}: op attribute list
507 // {5}: builder methods taking standalone attribute parameters
508 // {6}: additional method defintions
509 // {7}: additional methods for attributes used by indexing maps
510 static const char structuredOpOdsHeaderFormat
[] = R
"FMT(
511 //===----------------------------------------------------------------------===//
512 // Op definition for {0}
513 //===----------------------------------------------------------------------===//
515 def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
516 /*extraInterfaces=*/[{2}])> {
519 Variadic<AnyType>:$inputs,
520 Variadic<AnyShaped>:$outputs{4}
522 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
523 let regions = (region AnyRegion:$region);
525 let skipDefaultBuilders = 1;
528 (ins "ValueRange
":$inputs, "ValueRange
":$outputs,
529 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
531 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
532 attributes, {0}::getRegionBuilder());
535 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$inputs,
536 "ValueRange
":$outputs,
537 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
539 buildStructuredOp($_builder, $_state, resultTensorTypes,
540 inputs, outputs, attributes, {0}::getRegionBuilder());
543 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$operands,
544 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
546 $_state.addOperands(operands);
547 $_state.addAttributes(attributes);
548 $_state.addTypes(resultTensorTypes);
549 (void)$_state.addRegion();
553 let hasCustomAssemblyFormat = 1;
557 let extraClassDeclaration = structuredOpsBaseDecls # [{{
559 SmallVector<utils::IteratorType> getIteratorTypesArray();
560 ArrayAttr getIndexingMaps();
561 static void regionBuilder(ImplicitLocOpBuilder &b,
562 Block &block, ArrayRef<NamedAttribute> attrs);
563 static std::function<void(ImplicitLocOpBuilder &,
564 Block &, ArrayRef<NamedAttribute>)>
565 getRegionBuilder() {{
566 return regionBuilder;
569 ::mlir::MutableOperandRange getDpsInitsMutable() {{
570 return getOutputsMutable();
574 static unsigned getNumRegionArgs();
575 std::string getLibraryCallName();
581 // Builder method taking attribute parameters. Parameters:
583 // {1}: Comma interleaved attribute parameters
584 // {2}: Attribute initialization
585 static const char structuredOpBuilderFormat
[] = R
"FMT(
587 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$inputs,
588 "ValueRange
":$outputs, {1},
589 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
592 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
593 attributes, {0}::getRegionBuilder());
597 // The getIteratorTypesArray() method for structured ops. Parameters:
599 // {1}: Comma interleaved iterator type names.
600 static const char structuredOpIteratorTypesFormat
[] =
602 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
603 return SmallVector<utils::IteratorType>{{ {1} };
607 // The getIteratorTypesArray() method for rank polymorphic structured ops.
610 static const char rankPolyStructuredOpIteratorTypesFormat
[] =
612 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
613 int64_t rank = getRank(getDpsInitOperand(0));
614 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
618 // The indexing_maps() method for structured ops. Parameters:
620 // {1}: Comma-separated list of dimension variable names.
622 static const char structuredOpIndexingMapsFormat
[] = R
"FMT(
623 ArrayAttr {0}::getIndexingMaps() {{
624 static const char memoizeAttr[] = "linalg
.memoized_indexing_maps
";
625 ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
629 MLIRContext *context = getContext();
630 auto symbolBindings = getSymbolBindings(*this);
631 SmallVector<AffineMap> maps;
633 cached = Builder(context).getAffineMapArrayAttr(maps);
634 getOperation()->setAttr(memoizeAttr, cached);
639 // The indexing_maps() method for rank polymorphic structured ops. Parameters:
641 static const char rankPolyStructuredOpIndexingMapsFormat
[] = R
"FMT(
642 ArrayAttr {0}::getIndexingMaps() {{
643 MLIRContext *context = getContext();
644 AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
645 AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
646 getNumParallelLoops(), context);
647 SmallVector<AffineMap> indexingMaps;
648 for (OpOperand &opOperand : getOperation()->getOpOperands())
649 indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
650 return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
654 // Implementations of fold, getEffects and getSpeculatability.
657 const char structuredOpFoldersFormat
[] = R
"FMT(
658 LogicalResult {0}::fold(FoldAdaptor,
659 SmallVectorImpl<OpFoldResult> &) {{
660 return memref::foldMemRefCast(*this);
662 void {0}::getEffects(SmallVectorImpl<
663 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
664 if (hasPureTensorSemantics()) return;
665 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
667 Speculation::Speculatability {0}::getSpeculatability() {{
668 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
672 // Implementation of parse/print.
675 static const char structuredOpParserFormat
[] = R
"FMT(
676 ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
677 return ::parseNamedStructuredOp(parser, result,
678 {0}::getNumRegionArgs(), {0}::getRegionBuilder());
680 void {0}::print(OpAsmPrinter &p) {{
681 SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes
",
682 "linalg
.memoized_indexing_maps
"};
683 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
688 static LogicalResult
generateNamedGenericOpOds(LinalgOpConfig
&opConfig
,
689 GenerationContext
&genContext
) {
690 if (!genContext
.shouldGenerateOds())
693 raw_ostream
&os
= genContext
.odss();
695 std::string interfaceNameList
;
696 std::string attrList
;
697 std::string attrMethods
;
698 std::string attrBuilder
;
701 if (opConfig
.metadata
->doc
) {
702 static const char structuredOpDocFmt
[] = R
"FMT(
703 let summary = [{{{0}}];
704 let description = [{{{1}}];
706 StringRef summary
, description
;
707 std::tie(summary
, description
) =
708 StringRef(*opConfig
.metadata
->doc
).trim().split("\n\n");
710 doc
= llvm::formatv(structuredOpDocFmt
, summary
.trim(), description
.trim());
713 interfaceNameList
= interleaveToString(opConfig
.metadata
->implements
, ", ");
715 std::string definitionList
;
716 for (const std::string
&definition
: opConfig
.metadata
->defines
) {
717 static const char definitionFmt
[] = "let {0} = 1;\n";
718 definitionList
.append(llvm::formatv(definitionFmt
, definition
));
721 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
722 return isAttribute(arg
.kind
);
724 SmallVector
<std::string
> attrDefs
;
725 SmallVector
<std::string
> attrParams
;
726 SmallVector
<std::string
> attrStmts
;
727 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
728 static const char paramFmt
[] = "\"Attribute\":${0}";
729 static const char stmtFmt
[] = "$_state.addAttribute(\"{0}\", {0});";
730 // Add the type conversion attributes to the op definition and builders.
731 if (isFunctionAttribute(arg
.kind
)) {
732 assert(arg
.defaultFn
);
733 std::string enumName
= convertOperandKindToEnumName(arg
.kind
);
734 static const char typeFmt
[] = "{0}::{1}";
735 static const char defFmt
[] =
736 "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}";
737 attrDefs
.push_back(llvm::formatv(
738 defFmt
, llvm::formatv("{0}Attr", enumName
),
739 llvm::formatv(typeFmt
, enumName
, arg
.defaultFn
), arg
.name
));
740 attrParams
.push_back(llvm::formatv(paramFmt
, arg
.name
));
741 attrStmts
.push_back(llvm::formatv(stmtFmt
, arg
.name
));
743 // Add the index attributes to the op definition and builders.
744 if (arg
.kind
== LinalgOperandDefKind::IndexAttr
) {
745 assert(arg
.indexAttrMap
.has_value());
746 assert(arg
.defaultIndices
.has_value());
747 size_t size
= arg
.indexAttrMap
->affineMap().getNumResults();
748 assert(arg
.defaultIndices
->size() == size
);
749 static const char typeFmt
[] = "RankedI64ElementsAttr<[{0}]>";
750 static const char defFmt
[] =
751 "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}";
752 std::string defaultVals
;
753 llvm::raw_string_ostream
ss(defaultVals
);
755 *arg
.defaultIndices
, ss
,
756 [&](int64_t val
) { ss
<< "static_cast<int64_t>(" << val
<< ")"; },
758 attrDefs
.push_back(llvm::formatv(defFmt
, llvm::formatv(typeFmt
, size
),
759 ss
.str(), arg
.name
));
760 attrParams
.push_back(llvm::formatv(paramFmt
, arg
.name
));
761 attrStmts
.push_back(llvm::formatv(stmtFmt
, arg
.name
));
764 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
765 return arg
.kind
== LinalgOperandDefKind::IndexAttr
;
768 bool hasDynamicIndexingMaps();
769 LogicalResult verifyIndexingMapRequiredAttributes();
772 attrList
= ",\n" + llvm::join(attrDefs
, ",\n");
773 attrBuilder
= llvm::formatv(
774 structuredOpBuilderFormat
, opConfig
.metadata
->cppClassName
,
775 llvm::join(attrParams
, ", "), llvm::join(attrStmts
, "\n"));
778 os
<< llvm::formatv(structuredOpOdsHeaderFormat
,
779 opConfig
.metadata
->cppClassName
, opConfig
.metadata
->name
,
780 interfaceNameList
, doc
, attrList
, attrBuilder
,
781 definitionList
, attrMethods
);
787 generateNamedGenericOpDefns(LinalgOpConfig
&opConfig
,
788 GenerationContext
&genContext
) {
789 if (!genContext
.shouldGenerateDefns())
792 raw_ostream
&os
= genContext
.defns();
793 StringRef className
= opConfig
.metadata
->cppClassName
;
795 // Implementation banner.
796 std::string bannerComment
= llvm::formatv("Implementation of {0}", className
);
797 os
<< llvm::formatv(bannerFormat
, bannerComment
);
799 // Compute the number of scalar and tensor arguments.
801 llvm::count_if(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
802 return arg
.kind
== LinalgOperandDefKind::InputTensor
||
803 arg
.kind
== LinalgOperandDefKind::Scalar
||
804 arg
.kind
== LinalgOperandDefKind::OutputTensor
;
807 // An operation that accesses only scalars and scalar/rank zero tensors is
808 // rank polymorhpic. We implement rank polymorphism by generating different
809 // indexing maps and iterators that match the rank of the first output tensor.
810 // An operation is rank polymorphic if the iteration domain has rank zero.
811 bool isRankPolymorphic
= opConfig
.structuredOp
->iteratorTypes
.empty();
813 // Generate the iterator_types() method.
814 if (!isRankPolymorphic
) {
815 std::string iteratorsStr
;
816 llvm::raw_string_ostream
ss(iteratorsStr
);
817 llvm::interleaveComma(opConfig
.structuredOp
->iteratorTypes
, ss
,
818 [&](LinalgIteratorTypeDef it
) {
820 case LinalgIteratorTypeDef::parallel
:
821 ss
<< "utils::IteratorType::parallel";
823 case LinalgIteratorTypeDef::reduction
:
824 ss
<< "utils::IteratorType::reduction";
828 os
<< llvm::formatv(structuredOpIteratorTypesFormat
, className
,
831 os
<< llvm::formatv(rankPolyStructuredOpIteratorTypesFormat
, className
);
834 // Generating the getIndexingMaps() method.
835 if (auto &staticMaps
=
836 opConfig
.structuredOp
->indexingMaps
.staticIndexingMaps
) {
837 if (staticMaps
->empty())
838 return emitError(genContext
.getLoc()) << "op has no indexing maps";
839 if (!isRankPolymorphic
) {
840 AffineMap firstMap
= staticMaps
->front().affineMap();
844 // For each symbol, generate a declaration for it, either with an
845 // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
847 // TODO: Possibly lift into a top-level method.
848 static const char structuredOpSymbolBindingsFormat
[] = R
"FMT(
849 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
850 MLIRContext *context = self.getContext();
851 SmallVector<AffineExpr> exprs;
857 unsigned symbolCount
= firstMap
.getNumSymbols();
858 SmallVector
<std::string
> symbolBindings
;
859 for (unsigned i
= 0; i
< symbolCount
; ++i
) {
860 symbolBindings
.push_back(llvm::formatv(
861 " exprs.push_back(getAffineSymbolExpr({0}, context));", i
));
864 // Access an index attribute. Parameters:
865 // {0}: Attribute name
866 // {1}: Symbol position
867 // {2}: Attribute index
868 static const char structuredOpAccessAttrFormat
[] = R
"FMT(
869 int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}];
870 exprs.push_back(getAffineConstantExpr(cst{1}, context));
872 // Update all symbol bindings mapped to an attribute.
873 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
874 if (arg
.kind
!= LinalgOperandDefKind::IndexAttr
)
876 assert(arg
.indexAttrMap
);
877 for (auto [idx
, result
] :
878 llvm::enumerate(arg
.indexAttrMap
->affineMap().getResults())) {
879 if (auto symbol
= dyn_cast
<AffineSymbolExpr
>(result
)) {
880 std::string argName
= arg
.name
;
881 argName
[0] = toupper(argName
[0]);
882 symbolBindings
[symbol
.getPosition()] =
883 llvm::formatv(structuredOpAccessAttrFormat
, argName
,
884 symbol
.getPosition(), idx
);
889 std::string symbolBindingsStr
;
890 llvm::raw_string_ostream
symbolBindingsSs(symbolBindingsStr
);
891 llvm::interleave(symbolBindings
, symbolBindingsSs
, "\n");
893 os
<< llvm::formatv(structuredOpSymbolBindingsFormat
, className
,
899 unsigned dimCount
= firstMap
.getNumDims();
901 // Generate a comma-separated list of dim identifiers to be passed to
902 // bindDims, ensuring tht AffineExpr identifiers are bound in the right
903 // order to the proper AffineDimExpr.
904 // This results in vars in scope like: d0, d1, d2...
905 SmallVector
<unsigned> dimIndices
;
906 for (unsigned i
= 0; i
< dimCount
; ++i
)
907 dimIndices
.push_back(i
);
908 std::string dimIdentsStr
;
909 llvm::raw_string_ostream
dimIdentsSs(dimIdentsStr
);
910 llvm::interleaveComma(dimIndices
, dimIdentsSs
,
911 [&](unsigned i
) { dimIdentsSs
<< "d" << i
; });
913 // Statements to add and simplify each affine map.
914 SmallVector
<std::string
> stmts
;
915 for (auto &indexingMap
: *staticMaps
) {
916 // TODO: Assert that dim and symbol count match the first.
918 llvm::formatv("maps.push_back({0});",
919 generateCppExpression(indexingMap
, "context")));
920 stmts
.push_back(llvm::formatv(
922 "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
923 "symbolBindings, {0}, 0));",
927 // TODO: This needs to be memoized and/or converted to non-parser based
928 // C++ codegen prior to real use.
929 os
<< llvm::formatv(structuredOpIndexingMapsFormat
, className
,
930 interleaveToString(stmts
, "\n "));
933 os
<< llvm::formatv(rankPolyStructuredOpIndexingMapsFormat
, className
);
936 return emitError(genContext
.getLoc())
937 << "generating code for non static indexing maps not currently "
941 // getNumRegionArgs()
943 // Generates a getNumRegionArgs() method. Parameters:
945 // {1}: Number of region args
946 static const char structuredOpGetNumRegionArgsFormat
[] = R
"FMT(
947 unsigned {0}::getNumRegionArgs() {{ return {1}; }
949 os
<< llvm::formatv(structuredOpGetNumRegionArgsFormat
, className
,
953 // getLibraryCallName()
955 // Generates a getLibraryCallName method. Parameters:
957 static const char structuredOpGetLibraryCallFormat
[] = R
"FMT(
958 std::string {0}::getLibraryCallName() {{
959 return generateLibraryCallName(getOperation());
962 os
<< llvm::formatv(structuredOpGetLibraryCallFormat
, className
);
965 // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
966 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
967 return arg
.kind
== LinalgOperandDefKind::IndexAttr
;
969 std::vector
<std::string
> attrVerifications
;
970 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
971 if (arg
.kind
!= LinalgOperandDefKind::IndexAttr
)
973 assert(arg
.indexAttrMap
);
974 // Verify index attribute. Paramters:
975 // {0}: Attribute name
976 // {1}: Attribute size
977 static const char attrFmt
[] = R
"FMT(
978 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
979 if (!attr.getType().getElementType().isInteger(64))
980 return op->emitError("incorrect element type
for index attribute
'{0}'");
981 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
982 return op->emitError("incorrect shape
for index attribute
'{0}'");
985 attrVerifications
.push_back(llvm::formatv(
986 attrFmt
, arg
.name
, arg
.indexAttrMap
->affineMap().getNumResults()));
989 // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
991 // {1}: Attribute verification
992 static const char structuredOpVerifyIndexingMapRequiredAttributes
[] = R
"FMT(
993 bool {0}::hasDynamicIndexingMaps() {{ return true; }
994 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
995 Operation *op = getOperation();
1000 os
<< llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes
,
1001 className
, llvm::join(attrVerifications
, "\n"));
1006 // Generates a regionBuilder method. Parameters.
1008 // {1}: Number of args
1011 static const char structuredOpRegionBuilderFormat
[] = R
"FMT(
1012 void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1013 Block &block, ArrayRef<NamedAttribute> attrs) {{
1014 assert({1} > 0 && block.getNumArguments() == {1} &&
1015 "{0} regionBuilder expects
{1} (>=0) args
");
1016 RegionBuilderHelper helper(b, block);
1017 SmallVector<Value> yields;
1020 helper.yieldOutputs(yields);
1023 auto &args
= opConfig
.structuredOp
->args
;
1024 auto &assignments
= opConfig
.structuredOp
->assignments
;
1025 size_t generatedAssignmentCount
= 0;
1026 int localCounter
= 0;
1027 SmallVector
<std::string
> attrs
;
1028 SmallVector
<std::string
> stmts
;
1029 for (LinalgOperandDef
&arg
: args
) {
1030 if (!isFunctionAttribute(arg
.kind
))
1032 // Obtain the type function attribute values. Parameters.
1034 // {1}: attribute name
1035 // {2}: default type function name
1036 static const char attrDef
[] = R
"FMT(
1037 {0} {1}Val = {0}::{2};
1038 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1039 return attr.getName() == "{1}"; });
1040 if ({1}Iter != attrs.end()) {{
1041 if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1042 {1}Val = attr.getValue();
1045 std::string enumName
= convertOperandKindToEnumName(arg
.kind
);
1047 llvm::formatv(attrDef
, enumName
, arg
.name
, arg
.defaultFn
));
1049 for (LinalgOperandDef
&arg
: args
) {
1050 if (arg
.kind
!= LinalgOperandDefKind::OutputTensor
)
1053 // Find the assignment that correlates with the argument.
1054 ScalarAssign
*assignment
= findAssignment(arg
.name
, assignments
);
1056 return emitError(genContext
.getLoc())
1057 << "no assignment found for output argument " << arg
.name
;
1058 ++generatedAssignmentCount
;
1060 // Recursively generate the expression.
1061 std::function
<std::optional
<std::string
>(ScalarExpression
&)>
1062 generateExpression
=
1063 [&](ScalarExpression
&expression
) -> std::optional
<std::string
> {
1064 if (expression
.arg
) {
1065 // Argument reference.
1066 std::optional
<int> argIndex
=
1067 findTensorDefArgIndex(*expression
.arg
, args
);
1069 emitError(genContext
.getLoc())
1070 << "scalar argument not defined on the op: " << *expression
.arg
;
1071 return std::nullopt
;
1074 llvm::formatv("block.getArgument({0})", *argIndex
));
1076 if (expression
.constant
) {
1077 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1079 llvm::formatv(R
"FMT(Value {0} = helper.constant("{1}");)FMT",
1080 cppIdent
, expression
.constant
));
1083 if (expression
.index
) {
1084 // Access an iteration index.
1085 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1086 stmts
.push_back(llvm::formatv("Value {0} = helper.index({1});",
1087 cppIdent
, *expression
.index
));
1090 if (expression
.scalarFn
) {
1091 std::string enumName
=
1092 convertFunctionKindToEnumName(expression
.scalarFn
->kind
);
1094 // Get the function or attribute name.
1095 assert(expression
.scalarFn
->fnName
|| expression
.scalarFn
->attrName
);
1096 std::string funcType
;
1097 if (expression
.scalarFn
->fnName
) {
1098 funcType
= llvm::formatv("{0}::{1}", enumName
,
1099 *expression
.scalarFn
->fnName
);
1101 if (expression
.scalarFn
->attrName
) {
1102 if (llvm::none_of(args
, [&](LinalgOperandDef
&arg
) {
1103 return isFunctionAttribute(arg
.kind
) &&
1104 arg
.name
== *expression
.scalarFn
->attrName
;
1106 emitError(genContext
.getLoc()) << "missing function attribute "
1107 << *expression
.scalarFn
->attrName
;
1109 funcType
= llvm::formatv("{0}Val", *expression
.scalarFn
->attrName
);
1111 assert(!funcType
.empty());
1113 // Add the optional type parameter to the operands.
1114 SmallVector
<std::string
> operandCppValues
;
1115 if (expression
.scalarFn
->kind
== ScalarFnKind::Type
) {
1116 assert(expression
.scalarFn
->typeVar
.has_value());
1117 std::optional
<std::string
> typeCppValue
=
1118 findTypeValue(*expression
.scalarFn
->typeVar
, args
);
1119 if (!typeCppValue
) {
1120 emitError(genContext
.getLoc())
1121 << "type variable " << *expression
.scalarFn
->typeVar
1122 << ", used in a type conversion, must map to a predefined or "
1123 << "an argument type but it does not";
1124 return std::nullopt
;
1126 operandCppValues
.push_back(*typeCppValue
);
1129 // Collect the scalar operands.
1130 for (ScalarExpression
&operand
: expression
.scalarFn
->operands
) {
1131 auto operandCppValue
= generateExpression(operand
);
1132 if (!operandCppValue
)
1133 return std::nullopt
;
1134 operandCppValues
.push_back(*operandCppValue
);
1137 // Call the function builder.
1138 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1139 stmts
.push_back(llvm::formatv(
1140 "Value {0} = helper.build{1}({2}, {3});", cppIdent
, enumName
,
1141 funcType
, interleaveToString(operandCppValues
, ", ")));
1144 emitError(genContext
.getLoc()) << "unknown ScalarExpression type";
1145 return std::nullopt
;
1147 std::optional
<std::string
> cppValue
=
1148 generateExpression(assignment
->value
);
1151 stmts
.push_back(llvm::formatv("yields.push_back({0});", *cppValue
));
1154 if (generatedAssignmentCount
!= assignments
.size())
1155 return emitError(genContext
.getLoc())
1156 << "mismatched number of assignments vs output arguments";
1158 os
<< llvm::formatv(structuredOpRegionBuilderFormat
, className
, numOfArgs
,
1159 interleaveToString(attrs
, "\n "),
1160 interleaveToString(stmts
, "\n "));
1163 // Parser and printer.
1164 os
<< llvm::formatv(structuredOpParserFormat
, className
);
1166 // Canonicalizers and folders.
1167 os
<< llvm::formatv(structuredOpFoldersFormat
, className
);
1172 static LogicalResult
generateOp(LinalgOpConfig
&opConfig
,
1173 GenerationContext
&genContext
) {
1174 // Switch on op type being generated.
1175 if (opConfig
.structuredOp
) {
1177 succeeded(generateNamedGenericOpOds(opConfig
, genContext
)) &&
1178 succeeded(generateNamedGenericOpDefns(opConfig
, genContext
)));
1180 return emitError(genContext
.getLoc()) << "unsupported operation type";
1183 //===----------------------------------------------------------------------===//
1184 // Command line options and main
1185 //===----------------------------------------------------------------------===//
1187 static llvm::cl::opt
<std::string
>
1188 inputFilename(llvm::cl::Positional
, llvm::cl::desc("<input file>"),
1189 llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
1191 static llvm::cl::opt
<std::string
>
1192 outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1193 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1195 static llvm::cl::opt
<std::string
>
1196 outputCppImplFilename("o-impl",
1197 llvm::cl::desc("C++ implementation file name"),
1198 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1200 int main(int argc
, char **argv
) {
1201 llvm::cl::ParseCommandLineOptions(argc
, argv
, "Linalg ODS Gen from YAML");
1203 // Set up the input file.
1204 std::string errorMessage
;
1205 std::unique_ptr
<llvm::MemoryBuffer
> file
=
1206 mlir::openInputFile(inputFilename
, &errorMessage
);
1208 llvm::errs() << errorMessage
<< "\n";
1212 MLIRContext mlirContext
;
1213 LinalgYAMLContext yamlContext
{&mlirContext
};
1215 std::vector
<LinalgOpConfig
> opConfigs
;
1218 Input
yin(file
->getBuffer(), &yamlContext
);
1224 // Open output files.
1225 std::unique_ptr
<llvm::ToolOutputFile
> outputOdsDecl
;
1226 if (!outputOdsDeclFilename
.empty()) {
1227 outputOdsDecl
= openOutputFile(outputOdsDeclFilename
, &errorMessage
);
1228 if (!outputOdsDecl
) {
1229 llvm::errs() << errorMessage
<< "\n";
1234 std::unique_ptr
<llvm::ToolOutputFile
> outputCppImpl
;
1235 if (!outputCppImplFilename
.empty()) {
1236 outputCppImpl
= openOutputFile(outputCppImplFilename
, &errorMessage
);
1237 if (!outputCppImpl
) {
1238 llvm::errs() << errorMessage
<< "\n";
1243 if (!outputOdsDecl
&& !outputCppImpl
) {
1244 llvm::errs() << "error: No output files specified\n";
1249 GenerationContext
genContext(&mlirContext
,
1250 outputOdsDecl
? &outputOdsDecl
->os() : nullptr,
1251 outputCppImpl
? &outputCppImpl
->os() : nullptr);
1253 for (auto &opConfig
: opConfigs
) {
1254 if (!opConfig
.metadata
) {
1255 emitError(genContext
.getLoc())
1256 << "missing operation metadata on subsequent op";
1260 genContext
.setLoc(NameLoc::get(
1261 StringAttr::get(&mlirContext
, opConfig
.metadata
->cppClassName
)));
1262 if (failed(generateOp(opConfig
, genContext
))) {
1268 outputOdsDecl
->keep();
1270 outputCppImpl
->keep();